diff --git a/internal/manager/persistence/jobs_blocklist.go b/internal/manager/persistence/jobs_blocklist.go index 40403689..761efa6a 100644 --- a/internal/manager/persistence/jobs_blocklist.go +++ b/internal/manager/persistence/jobs_blocklist.go @@ -93,42 +93,33 @@ func (db *DB) RemoveFromJobBlocklist(ctx context.Context, jobUUID, workerUUID, t // NOTE: this does NOT consider the task failure list, which blocks individual // workers from individual tasks. This is ONLY concerning the job blocklist. func (db *DB) WorkersLeftToRun(ctx context.Context, job *Job, taskType string) (map[string]bool, error) { - // Find the IDs of the workers blocked on this job + tasktype combo. - blockedWorkers := db.gormDB. - Table("workers as blocked_workers"). - Select("blocked_workers.id"). - Joins("inner join job_blocks JB on blocked_workers.id = JB.worker_id"). - Where("JB.job_id = ?", job.ID). - Where("JB.task_type = ?", taskType) - - query := db.gormDB.WithContext(ctx). - Model(&Worker{}). - Select("uuid"). - Where("id not in (?)", blockedWorkers) + queries := db.queries() + var ( + workerUUIDs []string + err error + ) if job.WorkerTagID == nil { - // Count all workers, so no extra restrictions are necessary. + workerUUIDs, err = queries.WorkersLeftToRun(ctx, sqlc.WorkersLeftToRunParams{ + JobID: int64(job.ID), + TaskType: taskType, + }) } else { - // Only count workers in the job's tag. - jobTag := db.gormDB. - Table("worker_tag_membership"). - Select("worker_id"). - Where("worker_tag_id = ?", *job.WorkerTagID) - query = query. - Where("id in (?)", jobTag) + workerUUIDs, err = queries.WorkersLeftToRunWithWorkerTag(ctx, + sqlc.WorkersLeftToRunWithWorkerTagParams{ + JobID: int64(job.ID), + TaskType: taskType, + WorkerTagID: int64(*job.WorkerTagID), + }) + } + if err != nil { + return nil, err } - // Find the workers NOT blocked. - workers := []*Worker{} - tx := query.Scan(&workers) - if tx.Error != nil { - return nil, tx.Error - } - - // From the list of workers, construct the map of UUIDs. + // Construct a map of UUIDs. uuidMap := map[string]bool{} - for _, worker := range workers { - uuidMap[worker.UUID] = true + for _, uuid := range workerUUIDs { + uuidMap[uuid] = true } return uuidMap, nil @@ -138,17 +129,16 @@ func (db *DB) WorkersLeftToRun(ctx context.Context, job *Job, taskType string) ( func (db *DB) CountTaskFailuresOfWorker(ctx context.Context, job *Job, worker *Worker, taskType string) (int, error) { var numFailures int64 - tx := db.gormDB.WithContext(ctx). - Model(&TaskFailure{}). - Joins("inner join tasks T on task_failures.task_id = T.id"). - Where("task_failures.worker_id = ?", worker.ID). - Where("T.job_id = ?", job.ID). - Where("T.type = ?", taskType). - Count(&numFailures) + queries := db.queries() + numFailures, err := queries.CountTaskFailuresOfWorker(ctx, sqlc.CountTaskFailuresOfWorkerParams{ + JobID: int64(job.ID), + WorkerID: int64(worker.ID), + TaskType: taskType, + }) if numFailures > math.MaxInt { panic("overflow error in number of failures") } - return int(numFailures), tx.Error + return int(numFailures), err } diff --git a/internal/manager/persistence/sqlc/query_jobs.sql b/internal/manager/persistence/sqlc/query_jobs.sql index 2e5f35e1..3baa684c 100644 --- a/internal/manager/persistence/sqlc/query_jobs.sql +++ b/internal/manager/persistence/sqlc/query_jobs.sql @@ -269,3 +269,36 @@ WHERE job_blocks.job_id in (SELECT jobs.id FROM jobs WHERE jobs.uuid=@jobuuid) AND job_blocks.worker_id in (SELECT workers.id FROM workers WHERE workers.uuid=@workeruuid) AND job_blocks.task_type = @task_type; + +-- name: WorkersLeftToRun :many +SELECT workers.uuid FROM workers +WHERE id NOT IN ( + SELECT blocked_workers.id + FROM workers AS blocked_workers + INNER JOIN job_blocks JB on blocked_workers.id = JB.worker_id + WHERE + JB.job_id = @job_id + AND JB.task_type = @task_type +); + +-- name: WorkersLeftToRunWithWorkerTag :many +SELECT workers.uuid +FROM workers +INNER JOIN worker_tag_membership WTM ON workers.id = WTM.worker_id +WHERE id NOT IN ( + SELECT blocked_workers.id + FROM workers AS blocked_workers + INNER JOIN job_blocks JB ON blocked_workers.id = JB.worker_id + WHERE + JB.job_id = @job_id + AND JB.task_type = @task_type +) +AND WTM.worker_tag_id = @worker_tag_id; + +-- name: CountTaskFailuresOfWorker :one +SELECT count(TF.task_id) FROM task_failures TF +INNER JOIN tasks T ON TF.task_id = T.id +WHERE + TF.worker_id = @worker_id +AND T.job_id = @job_id +AND T.type = @task_type; diff --git a/internal/manager/persistence/sqlc/query_jobs.sql.go b/internal/manager/persistence/sqlc/query_jobs.sql.go index 5350c085..cb859501 100644 --- a/internal/manager/persistence/sqlc/query_jobs.sql.go +++ b/internal/manager/persistence/sqlc/query_jobs.sql.go @@ -84,6 +84,28 @@ func (q *Queries) ClearJobBlocklist(ctx context.Context, jobuuid string) error { return err } +const countTaskFailuresOfWorker = `-- name: CountTaskFailuresOfWorker :one +SELECT count(TF.task_id) FROM task_failures TF +INNER JOIN tasks T ON TF.task_id = T.id +WHERE + TF.worker_id = ?1 +AND T.job_id = ?2 +AND T.type = ?3 +` + +type CountTaskFailuresOfWorkerParams struct { + WorkerID int64 + JobID int64 + TaskType string +} + +func (q *Queries) CountTaskFailuresOfWorker(ctx context.Context, arg CountTaskFailuresOfWorkerParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countTaskFailuresOfWorker, arg.WorkerID, arg.JobID, arg.TaskType) + var count int64 + err := row.Scan(&count) + return count, err +} + const countWorkersFailingTask = `-- name: CountWorkersFailingTask :one SELECT count(*) as num_failed FROM task_failures WHERE task_id=?1 @@ -1178,3 +1200,87 @@ func (q *Queries) UpdateTaskStatus(ctx context.Context, arg UpdateTaskStatusPara _, err := q.db.ExecContext(ctx, updateTaskStatus, arg.UpdatedAt, arg.Status, arg.ID) return err } + +const workersLeftToRun = `-- name: WorkersLeftToRun :many +SELECT workers.uuid FROM workers +WHERE id NOT IN ( + SELECT blocked_workers.id + FROM workers AS blocked_workers + INNER JOIN job_blocks JB on blocked_workers.id = JB.worker_id + WHERE + JB.job_id = ?1 + AND JB.task_type = ?2 +) +` + +type WorkersLeftToRunParams struct { + JobID int64 + TaskType string +} + +func (q *Queries) WorkersLeftToRun(ctx context.Context, arg WorkersLeftToRunParams) ([]string, error) { + rows, err := q.db.QueryContext(ctx, workersLeftToRun, arg.JobID, arg.TaskType) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var uuid string + if err := rows.Scan(&uuid); err != nil { + return nil, err + } + items = append(items, uuid) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const workersLeftToRunWithWorkerTag = `-- name: WorkersLeftToRunWithWorkerTag :many +SELECT workers.uuid +FROM workers +INNER JOIN worker_tag_membership WTM ON workers.id = WTM.worker_id +WHERE id NOT IN ( + SELECT blocked_workers.id + FROM workers AS blocked_workers + INNER JOIN job_blocks JB ON blocked_workers.id = JB.worker_id + WHERE + JB.job_id = ?1 + AND JB.task_type = ?2 +) +AND WTM.worker_tag_id = ?3 +` + +type WorkersLeftToRunWithWorkerTagParams struct { + JobID int64 + TaskType string + WorkerTagID int64 +} + +func (q *Queries) WorkersLeftToRunWithWorkerTag(ctx context.Context, arg WorkersLeftToRunWithWorkerTagParams) ([]string, error) { + rows, err := q.db.QueryContext(ctx, workersLeftToRunWithWorkerTag, arg.JobID, arg.TaskType, arg.WorkerTagID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var uuid string + if err := rows.Scan(&uuid); err != nil { + return nil, err + } + items = append(items, uuid) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +}