diff --git a/internal/manager/persistence/jobs.go b/internal/manager/persistence/jobs.go index 250d8e7a..ab9cf465 100644 --- a/internal/manager/persistence/jobs.go +++ b/internal/manager/persistence/jobs.go @@ -466,7 +466,7 @@ func (db *DB) FetchTask(ctx context.Context, taskUUID string) (*Task, error) { return nil, taskError(err, "fetching task %s", taskUUID) } - convertedTask, err := convertSqlcTask(taskRow) + convertedTask, err := convertSqlcTask(taskRow.Task, taskRow.JobUUID.String, taskRow.WorkerUUID.String) if err != nil { return nil, err } @@ -605,32 +605,60 @@ func (db *DB) SaveTaskActivity(ctx context.Context, t *Task) error { return nil } +// TaskAssignToWorker assigns the given task to the given worker. +// This function is only used by unit tests. During normal operation, Flamenco +// uses the code in task_scheduler.go to assign tasks to workers. func (db *DB) TaskAssignToWorker(ctx context.Context, t *Task, w *Worker) error { - tx := db.gormDB.WithContext(ctx). - Model(t). - Select("WorkerID"). - Updates(Task{WorkerID: &w.ID}) - if tx.Error != nil { - return taskError(tx.Error, "assigning task %s to worker %s", t.UUID, w.UUID) + queries, err := db.queries() + if err != nil { + return err } - // Gorm updates t.WorkerID itself, but not t.Worker (even when it's added to - // the Updates() call above). + err = queries.TaskAssignToWorker(ctx, sqlc.TaskAssignToWorkerParams{ + UpdatedAt: db.now(), + WorkerID: sql.NullInt64{ + Int64: int64(w.ID), + Valid: true, + }, + ID: int64(t.ID), + }) + if err != nil { + return taskError(err, "assigning task %s to worker %s", t.UUID, w.UUID) + } + + // Update the task itself. t.Worker = w + t.WorkerID = &w.ID return nil } func (db *DB) FetchTasksOfWorkerInStatus(ctx context.Context, worker *Worker, taskStatus api.TaskStatus) ([]*Task, error) { - result := []*Task{} - tx := db.gormDB.WithContext(ctx). - Model(&Task{}). - Joins("Job"). - Where("tasks.worker_id = ?", worker.ID). - Where("tasks.status = ?", taskStatus). - Scan(&result) - if tx.Error != nil { - return nil, taskError(tx.Error, "finding tasks of worker %s in status %q", worker.UUID, taskStatus) + queries, err := db.queries() + if err != nil { + return nil, err + } + + rows, err := queries.FetchTasksOfWorkerInStatus(ctx, sqlc.FetchTasksOfWorkerInStatusParams{ + WorkerID: sql.NullInt64{ + Int64: int64(worker.ID), + Valid: true, + }, + TaskStatus: string(taskStatus), + }) + if err != nil { + return nil, taskError(err, "finding tasks of worker %s in status %q", worker.UUID, taskStatus) + } + + result := make([]*Task, len(rows)) + for i := range rows { + gormTask, err := convertSqlcTask(rows[i].Task, rows[i].JobUUID.String, worker.UUID) + if err != nil { + return nil, err + } + gormTask.Worker = worker + gormTask.WorkerID = &worker.ID + result[i] = gormTask } return result, nil } @@ -902,37 +930,37 @@ func convertSqlcJob(job sqlc.Job) (*Job, error) { // model expected by the rest of the code. This is mostly in place to aid in the // GORM to SQLC migration. It is intended that eventually the rest of the code // will use the same SQLC-generated model. -func convertSqlcTask(taskRow sqlc.FetchTaskRow) (*Task, error) { +func convertSqlcTask(task sqlc.Task, jobUUID string, workerUUID string) (*Task, error) { dbTask := Task{ Model: Model{ - ID: uint(taskRow.Task.ID), - CreatedAt: taskRow.Task.CreatedAt, - UpdatedAt: taskRow.Task.UpdatedAt.Time, + ID: uint(task.ID), + CreatedAt: task.CreatedAt, + UpdatedAt: task.UpdatedAt.Time, }, - UUID: taskRow.Task.UUID, - Name: taskRow.Task.Name, - Type: taskRow.Task.Type, - Priority: int(taskRow.Task.Priority), - Status: api.TaskStatus(taskRow.Task.Status), - LastTouchedAt: taskRow.Task.LastTouchedAt.Time, - Activity: taskRow.Task.Activity, + UUID: task.UUID, + Name: task.Name, + Type: task.Type, + Priority: int(task.Priority), + Status: api.TaskStatus(task.Status), + LastTouchedAt: task.LastTouchedAt.Time, + Activity: task.Activity, - JobID: uint(taskRow.Task.JobID), - JobUUID: taskRow.JobUUID.String, - WorkerUUID: taskRow.WorkerUUID.String, + JobID: uint(task.JobID), + JobUUID: jobUUID, + WorkerUUID: workerUUID, } // TODO: convert dependencies? - if taskRow.Task.WorkerID.Valid { - workerID := uint(taskRow.Task.WorkerID.Int64) + if task.WorkerID.Valid { + workerID := uint(task.WorkerID.Int64) dbTask.WorkerID = &workerID } - if err := json.Unmarshal(taskRow.Task.Commands, &dbTask.Commands); err != nil { + if err := json.Unmarshal(task.Commands, &dbTask.Commands); err != nil { return nil, taskError(err, fmt.Sprintf("task %s of job %s has invalid commands: %v", - taskRow.Task.UUID, taskRow.JobUUID.String, err)) + task.UUID, jobUUID, err)) } return &dbTask, nil diff --git a/internal/manager/persistence/sqlc/query_jobs.sql b/internal/manager/persistence/sqlc/query_jobs.sql index bb37b454..8fe44f6b 100644 --- a/internal/manager/persistence/sqlc/query_jobs.sql +++ b/internal/manager/persistence/sqlc/query_jobs.sql @@ -69,6 +69,13 @@ LEFT JOIN jobs ON (tasks.job_id = jobs.id) LEFT JOIN workers ON (tasks.worker_id = workers.id) WHERE tasks.uuid = @uuid; +-- name: FetchTasksOfWorkerInStatus :many +SELECT sqlc.embed(tasks), jobs.UUID as jobUUID +FROM tasks +LEFT JOIN jobs ON (tasks.job_id = jobs.id) +WHERE tasks.worker_id = @worker_id + AND tasks.status = @task_status; + -- name: FetchTaskJobUUID :one SELECT jobs.UUID as jobUUID FROM tasks @@ -100,3 +107,9 @@ UPDATE tasks SET updated_at = @updated_at, activity = @activity WHERE id=@id; + +-- name: TaskAssignToWorker :exec +UPDATE tasks SET + updated_at = @updated_at, + worker_id = @worker_id +WHERE id=@id; diff --git a/internal/manager/persistence/sqlc/query_jobs.sql.go b/internal/manager/persistence/sqlc/query_jobs.sql.go index 1ebe76a1..6f9e8d16 100644 --- a/internal/manager/persistence/sqlc/query_jobs.sql.go +++ b/internal/manager/persistence/sqlc/query_jobs.sql.go @@ -285,6 +285,62 @@ func (q *Queries) FetchTaskJobUUID(ctx context.Context, uuid string) (sql.NullSt return jobuuid, err } +const fetchTasksOfWorkerInStatus = `-- name: FetchTasksOfWorkerInStatus :many +SELECT tasks.id, tasks.created_at, tasks.updated_at, tasks.uuid, tasks.name, tasks.type, tasks.job_id, tasks.priority, tasks.status, tasks.worker_id, tasks.last_touched_at, tasks.commands, tasks.activity, jobs.UUID as jobUUID +FROM tasks +LEFT JOIN jobs ON (tasks.job_id = jobs.id) +WHERE tasks.worker_id = ?1 + AND tasks.status = ?2 +` + +type FetchTasksOfWorkerInStatusParams struct { + WorkerID sql.NullInt64 + TaskStatus string +} + +type FetchTasksOfWorkerInStatusRow struct { + Task Task + JobUUID sql.NullString +} + +func (q *Queries) FetchTasksOfWorkerInStatus(ctx context.Context, arg FetchTasksOfWorkerInStatusParams) ([]FetchTasksOfWorkerInStatusRow, error) { + rows, err := q.db.QueryContext(ctx, fetchTasksOfWorkerInStatus, arg.WorkerID, arg.TaskStatus) + if err != nil { + return nil, err + } + defer rows.Close() + var items []FetchTasksOfWorkerInStatusRow + for rows.Next() { + var i FetchTasksOfWorkerInStatusRow + if err := rows.Scan( + &i.Task.ID, + &i.Task.CreatedAt, + &i.Task.UpdatedAt, + &i.Task.UUID, + &i.Task.Name, + &i.Task.Type, + &i.Task.JobID, + &i.Task.Priority, + &i.Task.Status, + &i.Task.WorkerID, + &i.Task.LastTouchedAt, + &i.Task.Commands, + &i.Task.Activity, + &i.JobUUID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const requestJobDeletion = `-- name: RequestJobDeletion :exec UPDATE jobs SET updated_at = ?1, @@ -380,6 +436,24 @@ func (q *Queries) SaveJobStorageInfo(ctx context.Context, arg SaveJobStorageInfo return err } +const taskAssignToWorker = `-- name: TaskAssignToWorker :exec +UPDATE tasks SET + updated_at = ?1, + worker_id = ?2 +WHERE id=?3 +` + +type TaskAssignToWorkerParams struct { + UpdatedAt sql.NullTime + WorkerID sql.NullInt64 + ID int64 +} + +func (q *Queries) TaskAssignToWorker(ctx context.Context, arg TaskAssignToWorkerParams) error { + _, err := q.db.ExecContext(ctx, taskAssignToWorker, arg.UpdatedAt, arg.WorkerID, arg.ID) + return err +} + const updateTask = `-- name: UpdateTask :exec UPDATE tasks SET updated_at = ?1,