diff --git a/internal/manager/persistence/task_scheduler.go b/internal/manager/persistence/task_scheduler.go index cf4a0b25..46378c99 100644 --- a/internal/manager/persistence/task_scheduler.go +++ b/internal/manager/persistence/task_scheduler.go @@ -71,16 +71,25 @@ func (db *DB) ScheduleTask(ctx context.Context, w *Worker) (*Task, error) { func findTaskForWorker(tx *gorm.DB, w *Worker) (*Task, error) { task := Task{} + + // Produce the 'current task ID' by selecting all its incomplete dependencies. + // This can then be used in a subquery to filter out such tasks. + // `tasks.id` is the task ID from the outer query. + incompleteDepsQuery := tx.Table("tasks as tasks2"). + Select("tasks2.id"). + Joins("left join task_dependencies td on tasks2.id = td.task_id"). + Joins("left join tasks dep on dep.id = td.dependency_id"). + Where("tasks2.id = tasks.id"). + Where("dep.status is not NULL and dep.status != ?", api.TaskStatusCompleted) + findTaskResult := tx. Model(&task). Joins("left join jobs on tasks.job_id = jobs.id"). - Joins("left join task_dependencies on tasks.id = task_dependencies.task_id"). - Joins("left join tasks as tdeps on tdeps.id = task_dependencies.dependency_id"). - Where("tasks.status in ?", schedulableTaskStatuses). // Schedulable task statuses - Where("tdeps.status in ? or tdeps.status is NULL", completedTaskStatuses). // Dependencies completed - Where("jobs.status in ?", schedulableJobStatuses). // Schedulable job statuses - Where("tasks.type in ?", w.TaskTypes()). // Supported task types - Where("tasks.worker_id = ? or tasks.worker_id is NULL", w.ID). // assigned to this worker or not assigned at all + Where("tasks.status in ?", schedulableTaskStatuses). // Schedulable task statuses + Where("jobs.status in ?", schedulableJobStatuses). // Schedulable job statuses + Where("tasks.type in ?", w.TaskTypes()). // Supported task types + Where("tasks.worker_id = ? or tasks.worker_id is NULL", w.ID). // assigned to this worker or not assigned at all + Where("tasks.id not in (?)", incompleteDepsQuery). // Dependencies completed // TODO: Non-blacklisted Order("jobs.priority desc"). // Highest job priority Order("tasks.priority desc"). // Highest task priority diff --git a/internal/manager/persistence/task_scheduler_test.go b/internal/manager/persistence/task_scheduler_test.go index 7c1d14ac..b6982966 100644 --- a/internal/manager/persistence/task_scheduler_test.go +++ b/internal/manager/persistence/task_scheduler_test.go @@ -159,6 +159,32 @@ func TestTwoJobsThreeTasks(t *testing.T) { assert.Equal(t, att2_3.Name, task.Name, "the 3rd task of the 2nd job should have been chosen") } +func TestSomeButNotAllDependenciesCompleted(t *testing.T) { + // There was a bug in the task scheduler query, where it would schedule a task + // if any of its dependencies was completed (instead of all dependencies). + // This test reproduces that problematic scenario. + ctx, cancel, db := persistenceTestFixtures(t, schedulerTestTimeout) + defer cancel() + + att1 := authorTestTask("1.1 completed task", "blender") + att2 := authorTestTask("1.2 queued task of unsupported type", "unsupported") + att3 := authorTestTask("1.3 queued task with queued dependency", "ffmpeg") + att3.Dependencies = []*job_compilers.AuthoredTask{&att1, &att2} + + atj := authorTestJob("1295757b-e668-4c49-8b89-f73db8270e42", "simple-blender-render", att1, att2, att3) + constructTestJob(ctx, t, db, atj) + + // Complete the first task. The other two are still `queued`. + setTaskStatus(t, db, att1.UUID, api.TaskStatusCompleted) + + w := linuxWorker(t, db) + task, err := db.ScheduleTask(ctx, &w) + assert.NoError(t, err) + if task != nil { + t.Fatalf("there should not be any task assigned, but received %q", task.Name) + } +} + // To test: blacklists // To test: variable replacement @@ -211,6 +237,21 @@ func authorTestTask(name, taskType string, dependencies ...*job_compilers.Author return task } +func setTaskStatus(t *testing.T, db *DB, taskUUID string, status api.TaskStatus) { + ctx := context.Background() + task, err := db.FetchTask(ctx, taskUUID) + if err != nil { + t.Fatal(err) + } + + task.Status = status + + err = db.SaveTask(ctx, task) + if err != nil { + t.Fatal(err) + } +} + func linuxWorker(t *testing.T, db *DB) Worker { w := Worker{ UUID: "b13b8322-3e96-41c3-940a-3d581008a5f8",