diff --git a/internal/manager/persistence/task_scheduler.go b/internal/manager/persistence/task_scheduler.go index b8227100..ac564c5e 100644 --- a/internal/manager/persistence/task_scheduler.go +++ b/internal/manager/persistence/task_scheduler.go @@ -13,8 +13,9 @@ import ( ) var ( - schedulableTaskStatuses = []api.TaskStatus{api.TaskStatusQueued, api.TaskStatusSoftFailed, api.TaskStatusActive} - schedulableJobStatuses = []api.JobStatus{api.JobStatusActive, api.JobStatusQueued, api.JobStatusRequeued} + // Note that active tasks are not schedulable, because they're already dunning on some worker. + schedulableTaskStatuses = []api.TaskStatus{api.TaskStatusQueued, api.TaskStatusSoftFailed} + schedulableJobStatuses = []api.JobStatus{api.JobStatusActive, api.JobStatusQueued} // completedTaskStatuses = []api.TaskStatus{api.TaskStatusCompleted} ) @@ -81,6 +82,25 @@ func (db *DB) ScheduleTask(ctx context.Context, w *Worker) (*Task, error) { func findTaskForWorker(tx *gorm.DB, w *Worker) (*Task, error) { task := Task{} + // If a task is alreay active & assigned to this worker, return just that. + // Note that this task type could be blacklisted or no longer supported by the + // Worker, but since it's active that is unlikely. + assignedTaskResult := tx. + Model(&task). + Joins("left join jobs on tasks.job_id = jobs.id"). + Where("tasks.status = ?", api.TaskStatusActive). + Where("jobs.status in ?", schedulableJobStatuses). + Where("tasks.worker_id = ?", w.ID). // assigned to this worker + Limit(1). + Preload("Job"). + Find(&task) + if assignedTaskResult.Error != nil { + return nil, assignedTaskResult.Error + } + if assignedTaskResult.RowsAffected > 0 { + return &task, nil + } + // Produce the 'current task ID' by selecting all its incomplete dependencies. // This can then be used in a subquery to filter out such tasks. // `tasks.id` is the task ID from the outer query. @@ -91,14 +111,17 @@ func findTaskForWorker(tx *gorm.DB, w *Worker) (*Task, error) { Where("tasks2.id = tasks.id"). Where("dep.status is not NULL and dep.status != ?", api.TaskStatusCompleted) + // Note that this query doesn't check for the assigned worker. Tasks that have + // a 'schedulable' status might have been assigned to a worker, representing + // the last worker to touch it -- it's not meant to indicate "ownership" of + // the task. findTaskResult := tx. Model(&task). Joins("left join jobs on tasks.job_id = jobs.id"). - Where("tasks.status in ?", schedulableTaskStatuses). // Schedulable task statuses - Where("jobs.status in ?", schedulableJobStatuses). // Schedulable job statuses - Where("tasks.type in ?", w.TaskTypes()). // Supported task types - Where("tasks.worker_id = ? or tasks.worker_id is NULL", w.ID). // assigned to this worker or not assigned at all - Where("tasks.id not in (?)", incompleteDepsQuery). // Dependencies completed + Where("tasks.status in ?", schedulableTaskStatuses). // Schedulable task statuses + Where("jobs.status in ?", schedulableJobStatuses). // Schedulable job statuses + Where("tasks.type in ?", w.TaskTypes()). // Supported task types + Where("tasks.id not in (?)", incompleteDepsQuery). // Dependencies completed // TODO: Non-blacklisted Order("jobs.priority desc"). // Highest job priority Order("tasks.priority desc"). // Highest task priority diff --git a/internal/manager/persistence/task_scheduler_test.go b/internal/manager/persistence/task_scheduler_test.go index b6982966..1b78b9ae 100644 --- a/internal/manager/persistence/task_scheduler_test.go +++ b/internal/manager/persistence/task_scheduler_test.go @@ -185,6 +185,78 @@ func TestSomeButNotAllDependenciesCompleted(t *testing.T) { } } +func TestAlreadyAssigned(t *testing.T) { + ctx, cancel, db := persistenceTestFixtures(t, schedulerTestTimeout) + defer cancel() + + w := linuxWorker(t, db) + + att1 := authorTestTask("1 low-prio task", "blender") + att2 := authorTestTask("2 high-prio task", "ffmpeg") + att2.Priority = 100 + att3 := authorTestTask("3 low-prio task", "blender") + atj := authorTestJob( + "1295757b-e668-4c49-8b89-f73db8270e42", + "simple-blender-render", + att1, att2, att3) + + constructTestJob(ctx, t, db, atj) + + // Assign the task to the worker, and mark it as Active. + // This should make it get returned by the scheduler, even when there is + // another, higher-prio task to be done. + dbTask3, err := db.FetchTask(ctx, att3.UUID) + assert.NoError(t, err) + dbTask3.WorkerID = &w.ID + dbTask3.Status = api.TaskStatusActive + err = db.SaveTask(ctx, dbTask3) + assert.NoError(t, err) + + task, err := db.ScheduleTask(ctx, &w) + assert.NoError(t, err) + if task == nil { + t.Fatal("task is nil") + } + + assert.Equal(t, att3.Name, task.Name, "the already-assigned task should have been chosen") +} + +func TestAssignedToOtherWorker(t *testing.T) { + ctx, cancel, db := persistenceTestFixtures(t, schedulerTestTimeout) + defer cancel() + + w := linuxWorker(t, db) + w2 := windowsWorker(t, db) + + att1 := authorTestTask("1 low-prio task", "blender") + att2 := authorTestTask("2 high-prio task", "ffmpeg") + att2.Priority = 100 + atj := authorTestJob( + "1295757b-e668-4c49-8b89-f73db8270e42", + "simple-blender-render", + att1, att2) + + constructTestJob(ctx, t, db, atj) + + // Assign the high-prio task to the other worker. Because the task is queued, + // it shouldn't matter which worker it's assigned to. + dbTask2, err := db.FetchTask(ctx, att2.UUID) + assert.NoError(t, err) + dbTask2.WorkerID = &w2.ID + dbTask2.Status = api.TaskStatusQueued + err = db.SaveTask(ctx, dbTask2) + assert.NoError(t, err) + + task, err := db.ScheduleTask(ctx, &w) + assert.NoError(t, err) + if task == nil { + t.Fatal("task is nil") + } + + assert.Equal(t, att2.Name, task.Name, "the high-prio task should have been chosen") + assert.Equal(t, *task.WorkerID, w.ID, "the task should now be assigned to the worker it was scheduled for") +} + // To test: blacklists // To test: variable replacement @@ -269,3 +341,21 @@ func linuxWorker(t *testing.T, db *DB) Worker { return w } + +func windowsWorker(t *testing.T, db *DB) Worker { + w := Worker{ + UUID: "4f6ee45e-c8fc-4c31-bf5c-922f2415deb1", + Name: "Windows", + Platform: "windows", + Status: api.WorkerStatusAwake, + SupportedTaskTypes: "blender,ffmpeg,file-management,misc", + } + + err := db.gormDB.Save(&w).Error + if err != nil { + t.Logf("cannot save Windows worker: %v", err) + t.FailNow() + } + + return w +}