From 67bf77de13d99b1bc5d7344951068822c4fadd88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sybren=20A=2E=20St=C3=BCvel?= Date: Tue, 3 May 2022 16:06:15 +0200 Subject: [PATCH] Manager: rework mass updates to task statuses When the job status changes, it impacts the task statuses as well. These status changes are now no longer done with a single database query, but instead each affected task is fetched, changed, and saved. This unifies the regular & mass updates to the tasks, and causes the resulting task changes to be broadcast to SocketIO clients. --- internal/manager/persistence/jobs.go | 56 +++--- internal/manager/persistence/jobs_test.go | 79 ++------ .../mocks/interfaces_mock.gen.go | 63 +++--- .../task_state_machine/task_state_machine.go | 87 ++++++--- .../task_state_machine_test.go | 179 +++++++++++++----- 5 files changed, 274 insertions(+), 190 deletions(-) diff --git a/internal/manager/persistence/jobs.go b/internal/manager/persistence/jobs.go index 4b1972a7..7b551ad6 100644 --- a/internal/manager/persistence/jobs.go +++ b/internal/manager/persistence/jobs.go @@ -285,41 +285,39 @@ func (db *DB) CountTasksOfJobInStatus(ctx context.Context, job *Job, taskStatus return } -// UpdateJobsTaskStatuses updates the status & activity of all tasks of `job`. -func (db *DB) UpdateJobsTaskStatuses(ctx context.Context, job *Job, - taskStatus api.TaskStatus, activity string) error { - - if taskStatus == "" { - return taskError(nil, "empty status not allowed") - } - +// FetchTaskIDsOfJob returns all tasks of the given job. +func (db *DB) FetchTasksOfJob(ctx context.Context, job *Job) ([]*Task, error) { + var tasks []*Task tx := db.gormDB.WithContext(ctx). - Model(Task{}). - Where("job_Id = ?", job.ID). - Updates(Task{Status: taskStatus, Activity: activity}) - + Model(&Task{}). + Where("job_id", job.ID). + Scan(&tasks) if tx.Error != nil { - return taskError(tx.Error, "updating status of all tasks of job %s", job.UUID) + return nil, taskError(tx.Error, "fetching tasks of job %s", job.UUID) } - return nil + + for i := range tasks { + tasks[i].Job = job + } + + return tasks, nil } -// UpdateJobsTaskStatusesConditional updates the status & activity of the tasks of `job`, -// limited to those tasks with status in `statusesToUpdate`. -func (db *DB) UpdateJobsTaskStatusesConditional(ctx context.Context, job *Job, - statusesToUpdate []api.TaskStatus, taskStatus api.TaskStatus, activity string) error { - - if taskStatus == "" { - return taskError(nil, "empty status not allowed") - } - +// FetchTasksOfJobInStatus returns those tasks of the given job that have any of the given statuses. +func (db *DB) FetchTasksOfJobInStatus(ctx context.Context, job *Job, taskStatuses ...api.TaskStatus) ([]*Task, error) { + var tasks []*Task tx := db.gormDB.WithContext(ctx). - Model(Task{}). - Where("job_Id = ?", job.ID). - Where("status in ?", statusesToUpdate). - Updates(Task{Status: taskStatus, Activity: activity}) + Model(&Task{}). + Where("job_id", job.ID). + Where("status in ?", taskStatuses). + Scan(&tasks) if tx.Error != nil { - return taskError(tx.Error, "updating status of all tasks in status %v of job %s", statusesToUpdate, job.UUID) + return nil, taskError(tx.Error, "fetching tasks of job %s in status %q", job.UUID, taskStatuses) } - return nil + + for i := range tasks { + tasks[i].Job = job + } + + return tasks, nil } diff --git a/internal/manager/persistence/jobs_test.go b/internal/manager/persistence/jobs_test.go index 3e5c921d..b3f634c7 100644 --- a/internal/manager/persistence/jobs_test.go +++ b/internal/manager/persistence/jobs_test.go @@ -97,78 +97,41 @@ func TestCountTasksOfJobInStatus(t *testing.T) { assert.Equal(t, 3, numTotal) } -func TestUpdateJobsTaskStatuses(t *testing.T) { +func TestFetchTasksOfJobInStatus(t *testing.T) { ctx, close, db, job, authoredJob := jobTasksTestFixtures(t) defer close() - err := db.UpdateJobsTaskStatuses(ctx, job, api.TaskStatusSoftFailed, "testing æctivity") - assert.NoError(t, err) + allTasks, err := db.FetchTasksOfJob(ctx, job) + if !assert.NoError(t, err) { + return + } + assert.Equal(t, job, allTasks[0].Job, "FetchTasksOfJob should set job pointer") - numSoftFailed, numTotal, err := db.CountTasksOfJobInStatus(ctx, job, api.TaskStatusSoftFailed) + tasks, err := db.FetchTasksOfJobInStatus(ctx, job, api.TaskStatusQueued) assert.NoError(t, err) - assert.Equal(t, 3, numSoftFailed, "all tasks should have had their status changed") - assert.Equal(t, 3, numTotal) + assert.Equal(t, allTasks, tasks) + assert.Equal(t, job, tasks[0].Job, "FetchTasksOfJobInStatus should set job pointer") + // Make one task failed. task, err := db.FetchTask(ctx, authoredJob.Tasks[0].UUID) assert.NoError(t, err) - assert.Equal(t, "testing æctivity", task.Activity) + task.Status = api.TaskStatusFailed + assert.NoError(t, db.SaveTask(ctx, task)) - // Empty status should be rejected. - err = db.UpdateJobsTaskStatuses(ctx, job, "", "testing empty status") - assert.Error(t, err) - - numEmpty, _, err := db.CountTasksOfJobInStatus(ctx, job, "") + tasks, err = db.FetchTasksOfJobInStatus(ctx, job, api.TaskStatusQueued) assert.NoError(t, err) - assert.Equal(t, 0, numEmpty, "tasks should not have their status changed") + assert.Equal(t, []*Task{allTasks[1], allTasks[2]}, tasks) - numSoftFailed, _, err = db.CountTasksOfJobInStatus(ctx, job, api.TaskStatusSoftFailed) + // Check the failed task. This cannot directly compare to `allTasks[0]` + // because saving the task above changed some of its fields. + tasks, err = db.FetchTasksOfJobInStatus(ctx, job, api.TaskStatusFailed) assert.NoError(t, err) - assert.Equal(t, 3, numSoftFailed, "all tasks should still be soft-failed") -} + assert.Len(t, tasks, 1) + assert.Equal(t, allTasks[0].ID, tasks[0].ID) -func TestUpdateJobsTaskStatusesConditional(t *testing.T) { - ctx, close, db, job, authoredJob := jobTasksTestFixtures(t) - defer close() - - getTask := func(taskIndex int) *Task { - task, err := db.FetchTask(ctx, authoredJob.Tasks[taskIndex].UUID) - if err != nil { - t.Fatalf("Fetching task %d: %v", taskIndex, err) - } - return task - } - - setTaskStatus := func(taskIndex int, taskStatus api.TaskStatus) { - task := getTask(taskIndex) - task.Status = taskStatus - if err := db.SaveTask(ctx, task); err != nil { - t.Fatalf("Setting task %d to status %s: %v", taskIndex, taskStatus, err) - } - } - - setTaskStatus(0, api.TaskStatusFailed) - setTaskStatus(1, api.TaskStatusCompleted) - setTaskStatus(2, api.TaskStatusActive) - - err := db.UpdateJobsTaskStatusesConditional(ctx, job, - []api.TaskStatus{api.TaskStatusFailed, api.TaskStatusActive}, - api.TaskStatusCancelRequested, "some activity") + tasks, err = db.FetchTasksOfJobInStatus(ctx, job, api.TaskStatusActive) assert.NoError(t, err) - - // Task statuses should have updated for tasks 0 and 2. - assert.Equal(t, api.TaskStatusCancelRequested, getTask(0).Status) - assert.Equal(t, api.TaskStatusCompleted, getTask(1).Status) - assert.Equal(t, api.TaskStatusCancelRequested, getTask(2).Status) - - err = db.UpdateJobsTaskStatusesConditional(ctx, job, - []api.TaskStatus{api.TaskStatusFailed, api.TaskStatusActive}, - "", "empty task status should be disallowed") - assert.Error(t, err) - - // Task statuses should remain unchanged. - assert.Equal(t, api.TaskStatusCancelRequested, getTask(0).Status) - assert.Equal(t, api.TaskStatusCompleted, getTask(1).Status) - assert.Equal(t, api.TaskStatusCancelRequested, getTask(2).Status) + assert.Empty(t, tasks) } func TestTaskAssignToWorker(t *testing.T) { diff --git a/internal/manager/task_state_machine/mocks/interfaces_mock.gen.go b/internal/manager/task_state_machine/mocks/interfaces_mock.gen.go index 5b1ef49e..0d0abef3 100644 --- a/internal/manager/task_state_machine/mocks/interfaces_mock.gen.go +++ b/internal/manager/task_state_machine/mocks/interfaces_mock.gen.go @@ -52,6 +52,41 @@ func (mr *MockPersistenceServiceMockRecorder) CountTasksOfJobInStatus(arg0, arg1 return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountTasksOfJobInStatus", reflect.TypeOf((*MockPersistenceService)(nil).CountTasksOfJobInStatus), arg0, arg1, arg2) } +// FetchTasksOfJob mocks base method. +func (m *MockPersistenceService) FetchTasksOfJob(arg0 context.Context, arg1 *persistence.Job) ([]*persistence.Task, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FetchTasksOfJob", arg0, arg1) + ret0, _ := ret[0].([]*persistence.Task) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FetchTasksOfJob indicates an expected call of FetchTasksOfJob. +func (mr *MockPersistenceServiceMockRecorder) FetchTasksOfJob(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchTasksOfJob", reflect.TypeOf((*MockPersistenceService)(nil).FetchTasksOfJob), arg0, arg1) +} + +// FetchTasksOfJobInStatus mocks base method. +func (m *MockPersistenceService) FetchTasksOfJobInStatus(arg0 context.Context, arg1 *persistence.Job, arg2 ...api.TaskStatus) ([]*persistence.Task, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "FetchTasksOfJobInStatus", varargs...) + ret0, _ := ret[0].([]*persistence.Task) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FetchTasksOfJobInStatus indicates an expected call of FetchTasksOfJobInStatus. +func (mr *MockPersistenceServiceMockRecorder) FetchTasksOfJobInStatus(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchTasksOfJobInStatus", reflect.TypeOf((*MockPersistenceService)(nil).FetchTasksOfJobInStatus), varargs...) +} + // JobHasTasksInStatus mocks base method. func (m *MockPersistenceService) JobHasTasksInStatus(arg0 context.Context, arg1 *persistence.Job, arg2 api.TaskStatus) (bool, error) { m.ctrl.T.Helper() @@ -95,34 +130,6 @@ func (mr *MockPersistenceServiceMockRecorder) SaveTask(arg0, arg1 interface{}) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveTask", reflect.TypeOf((*MockPersistenceService)(nil).SaveTask), arg0, arg1) } -// UpdateJobsTaskStatuses mocks base method. -func (m *MockPersistenceService) UpdateJobsTaskStatuses(arg0 context.Context, arg1 *persistence.Job, arg2 api.TaskStatus, arg3 string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateJobsTaskStatuses", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(error) - return ret0 -} - -// UpdateJobsTaskStatuses indicates an expected call of UpdateJobsTaskStatuses. -func (mr *MockPersistenceServiceMockRecorder) UpdateJobsTaskStatuses(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateJobsTaskStatuses", reflect.TypeOf((*MockPersistenceService)(nil).UpdateJobsTaskStatuses), arg0, arg1, arg2, arg3) -} - -// UpdateJobsTaskStatusesConditional mocks base method. -func (m *MockPersistenceService) UpdateJobsTaskStatusesConditional(arg0 context.Context, arg1 *persistence.Job, arg2 []api.TaskStatus, arg3 api.TaskStatus, arg4 string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateJobsTaskStatusesConditional", arg0, arg1, arg2, arg3, arg4) - ret0, _ := ret[0].(error) - return ret0 -} - -// UpdateJobsTaskStatusesConditional indicates an expected call of UpdateJobsTaskStatusesConditional. -func (mr *MockPersistenceServiceMockRecorder) UpdateJobsTaskStatusesConditional(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateJobsTaskStatusesConditional", reflect.TypeOf((*MockPersistenceService)(nil).UpdateJobsTaskStatusesConditional), arg0, arg1, arg2, arg3, arg4) -} - // MockChangeBroadcaster is a mock of ChangeBroadcaster interface. type MockChangeBroadcaster struct { ctrl *gomock.Controller diff --git a/internal/manager/task_state_machine/task_state_machine.go b/internal/manager/task_state_machine/task_state_machine.go index bdf81285..e53d5bef 100644 --- a/internal/manager/task_state_machine/task_state_machine.go +++ b/internal/manager/task_state_machine/task_state_machine.go @@ -34,14 +34,8 @@ type PersistenceService interface { JobHasTasksInStatus(ctx context.Context, job *persistence.Job, taskStatus api.TaskStatus) (bool, error) CountTasksOfJobInStatus(ctx context.Context, job *persistence.Job, taskStatus api.TaskStatus) (numInStatus, numTotal int, err error) - // UpdateJobsTaskStatuses updates the status & activity of the tasks of `job`. - UpdateJobsTaskStatuses(ctx context.Context, job *persistence.Job, - taskStatus api.TaskStatus, activity string) error - - // UpdateJobsTaskStatusesConditional updates the status & activity of the tasks of `job`, - // limited to those tasks with status in `statusesToUpdate`. - UpdateJobsTaskStatusesConditional(ctx context.Context, job *persistence.Job, - statusesToUpdate []api.TaskStatus, taskStatus api.TaskStatus, activity string) error + FetchTasksOfJob(ctx context.Context, job *persistence.Job) ([]*persistence.Task, error) + FetchTasksOfJobInStatus(ctx context.Context, job *persistence.Job, taskStatuses ...api.TaskStatus) ([]*persistence.Task, error) } // PersistenceService should be a subset of persistence.DB @@ -71,6 +65,25 @@ func (sm *StateMachine) TaskStatusChange( ctx context.Context, task *persistence.Task, newTaskStatus api.TaskStatus, +) error { + oldTaskStatus := task.Status + + if err := sm.taskStatusChangeOnly(ctx, task, newTaskStatus); err != nil { + return err + } + + if err := sm.updateJobAfterTaskStatusChange(ctx, task, oldTaskStatus); err != nil { + return fmt.Errorf("updating job after task status change: %w", err) + } + return nil +} + +// taskStatusChangeOnly updates the task's status to the new one, but does not "ripple" the change to the job. +// `task` is expected to still have its original status, and have a filled `Job` pointer. +func (sm *StateMachine) taskStatusChangeOnly( + ctx context.Context, + task *persistence.Task, + newTaskStatus api.TaskStatus, ) error { job := task.Job if job == nil { @@ -98,9 +111,6 @@ func (sm *StateMachine) TaskStatusChange( taskUpdate.PreviousStatus = &oldTaskStatus sm.broadcaster.BroadcastTaskUpdate(taskUpdate) - if err := sm.updateJobAfterTaskStatusChange(ctx, task, oldTaskStatus); err != nil { - return fmt.Errorf("updating job after task status change: %w", err) - } return nil } @@ -333,15 +343,16 @@ func (sm *StateMachine) cancelTasks( logger.Info().Msg("cancelling tasks of job") // Any task that is running or might run in the future should get cancelled. - taskStatusesToCancel := []api.TaskStatus{ + tasks, err := sm.persist.FetchTasksOfJobInStatus(ctx, job, api.TaskStatusActive, api.TaskStatusQueued, api.TaskStatusSoftFailed, - } - err := sm.persist.UpdateJobsTaskStatusesConditional( - ctx, job, taskStatusesToCancel, api.TaskStatusCanceled, - fmt.Sprintf("Manager cancelled this task because the job got status %q.", job.Status), ) + if err != nil { + return "", err + } + activity := fmt.Sprintf("Manager cancelled this task because the job got status %q.", job.Status) + err = sm.massUpdateTaskStatus(ctx, tasks, api.TaskStatusCanceled, activity) if err != nil { return "", fmt.Errorf("cancelling tasks of job %s: %w", job.UUID, err) } @@ -366,12 +377,13 @@ func (sm *StateMachine) cancelTasks( func (sm *StateMachine) requeueTasks( ctx context.Context, logger zerolog.Logger, job *persistence.Job, oldJobStatus api.JobStatus, ) (api.JobStatus, error) { - var err error - if job.Status != api.JobStatusRequeued { logger.Warn().Msg("unexpected job status in StateMachine::requeueTasks()") } + var err error + var tasks []*persistence.Task + switch oldJobStatus { case api.JobStatusUnderConstruction: // Nothing to do, the job compiler has just finished its work; the tasks have @@ -380,24 +392,25 @@ func (sm *StateMachine) requeueTasks( return "", nil case api.JobStatusCompleted: // Re-queue all tasks. - err = sm.persist.UpdateJobsTaskStatuses(ctx, job, api.TaskStatusQueued, - fmt.Sprintf("Queued because job transitioned status from %q to %q", oldJobStatus, job.Status)) + tasks, err = sm.persist.FetchTasksOfJob(ctx, job) default: // Re-queue only the non-completed tasks. - statusesToUpdate := []api.TaskStatus{ + tasks, err = sm.persist.FetchTasksOfJobInStatus(ctx, job, api.TaskStatusCancelRequested, api.TaskStatusCanceled, api.TaskStatusFailed, api.TaskStatusPaused, api.TaskStatusSoftFailed, - } - err = sm.persist.UpdateJobsTaskStatusesConditional(ctx, job, - statusesToUpdate, api.TaskStatusQueued, - fmt.Sprintf("Queued because job transitioned status from %q to %q", oldJobStatus, job.Status)) + ) + } + if err != nil { + return "", err } + activity := fmt.Sprintf("Queued because job transitioned status from %q to %q", oldJobStatus, job.Status) + err = sm.massUpdateTaskStatus(ctx, tasks, api.TaskStatusQueued, activity) if err != nil { - return "", fmt.Errorf("queueing tasks of job %s: %w", job.UUID, err) + return "", err } // TODO: also reset the 'failed by workers' blacklist. @@ -406,6 +419,28 @@ func (sm *StateMachine) requeueTasks( return api.JobStatusQueued, nil } +// massUpdateTaskStatus updates the status of all the given tasks. +// NOTE: these task statuses do NOT affect the job status. +// Tasks that are passed in the `tasks` parameter but already have the given status will be skipped. +func (sm *StateMachine) massUpdateTaskStatus( + ctx context.Context, + tasks []*persistence.Task, + status api.TaskStatus, + activity string, +) error { + for _, task := range tasks { + if task.Status == status { + continue + } + task.Activity = activity + err := sm.taskStatusChangeOnly(ctx, task, status) + if err != nil { + return err + } + } + return nil +} + // checkTaskCompletion returns "completed" as next job status when all tasks of // the job are completed. // diff --git a/internal/manager/task_state_machine/task_state_machine_test.go b/internal/manager/task_state_machine/task_state_machine_test.go index 27ea6880..c27e149f 100644 --- a/internal/manager/task_state_machine/task_state_machine_test.go +++ b/internal/manager/task_state_machine/task_state_machine_test.go @@ -34,6 +34,7 @@ func TestTaskStatusChangeQueuedToActive(t *testing.T) { mocks.expectSaveTaskWithStatus(t, task, api.TaskStatusActive) mocks.expectSaveJobWithStatus(t, task.Job, api.JobStatusActive) mocks.expectBroadcastJobChange(task.Job, api.JobStatusQueued, api.JobStatusActive) + mocks.expectBroadcastTaskChange(task, api.TaskStatusQueued, api.TaskStatusActive) assert.NoError(t, sm.TaskStatusChange(ctx, task, api.TaskStatusActive)) } @@ -52,6 +53,7 @@ func TestTaskStatusChangeSaveTaskAfterJobChangeFailure(t *testing.T) { // Expect a call to save the task in the persistence layer, regardless of the above error. mocks.expectSaveTaskWithStatus(t, task, api.TaskStatusActive) + mocks.expectBroadcastTaskChange(task, api.TaskStatusQueued, api.TaskStatusActive) returnedErr := sm.TaskStatusChange(ctx, task, api.TaskStatusActive) assert.ErrorIs(t, returnedErr, jobSaveErr, "the returned error should wrap the persistence layer error") @@ -68,20 +70,24 @@ func TestTaskStatusChangeActiveToCompleted(t *testing.T) { // First task completing: T: active > completed --> J: active > active mocks.expectSaveTaskWithStatus(t, task, api.TaskStatusCompleted) + mocks.expectBroadcastTaskChange(task, api.TaskStatusActive, api.TaskStatusCompleted) mocks.persist.EXPECT().CountTasksOfJobInStatus(ctx, task.Job, api.TaskStatusCompleted).Return(1, 3, nil) // 1 of 3 complete. assert.NoError(t, sm.TaskStatusChange(ctx, task, api.TaskStatusCompleted)) // Second task hickup: T: active > soft-failed --> J: active > active mocks.expectSaveTaskWithStatus(t, task2, api.TaskStatusSoftFailed) + mocks.expectBroadcastTaskChange(task2, api.TaskStatusActive, api.TaskStatusSoftFailed) assert.NoError(t, sm.TaskStatusChange(ctx, task2, api.TaskStatusSoftFailed)) // Second task completing: T: soft-failed > completed --> J: active > active mocks.expectSaveTaskWithStatus(t, task2, api.TaskStatusCompleted) + mocks.expectBroadcastTaskChange(task2, api.TaskStatusSoftFailed, api.TaskStatusCompleted) mocks.persist.EXPECT().CountTasksOfJobInStatus(ctx, task.Job, api.TaskStatusCompleted).Return(2, 3, nil) // 2 of 3 complete. assert.NoError(t, sm.TaskStatusChange(ctx, task2, api.TaskStatusCompleted)) // Third task completing: T: active > completed --> J: active > completed mocks.expectSaveTaskWithStatus(t, task3, api.TaskStatusCompleted) + mocks.expectBroadcastTaskChange(task3, api.TaskStatusActive, api.TaskStatusCompleted) mocks.persist.EXPECT().CountTasksOfJobInStatus(ctx, task.Job, api.TaskStatusCompleted).Return(3, 3, nil) // 3 of 3 complete. mocks.expectSaveJobWithStatus(t, task.Job, api.JobStatusCompleted) mocks.expectBroadcastJobChange(task.Job, api.JobStatusActive, api.JobStatusCompleted) @@ -96,6 +102,7 @@ func TestTaskStatusChangeQueuedToFailed(t *testing.T) { // T: queued > failed (1% task failure) --> J: queued > active task := taskWithStatus(api.JobStatusQueued, api.TaskStatusQueued) mocks.expectSaveTaskWithStatus(t, task, api.TaskStatusFailed) + mocks.expectBroadcastTaskChange(task, api.TaskStatusQueued, api.TaskStatusFailed) mocks.expectSaveJobWithStatus(t, task.Job, api.JobStatusActive) mocks.persist.EXPECT().CountTasksOfJobInStatus(ctx, task.Job, api.TaskStatusFailed).Return(1, 100, nil) // 1 out of 100 failed. mocks.expectBroadcastJobChange(task.Job, api.JobStatusQueued, api.JobStatusActive) @@ -107,25 +114,32 @@ func TestTaskStatusChangeActiveToFailedFailJob(t *testing.T) { mockCtrl, ctx, sm, mocks := taskStateMachineTestFixtures(t) defer mockCtrl.Finish() - // T: active > failed (10% task failure) --> J: active > failed + cancellation of any runnable tasks. - task := taskWithStatus(api.JobStatusActive, api.TaskStatusActive) - mocks.expectSaveTaskWithStatus(t, task, api.TaskStatusFailed) - mocks.expectSaveJobWithStatus(t, task.Job, api.JobStatusFailed) - mocks.expectBroadcastJobChange(task.Job, api.JobStatusActive, api.JobStatusFailed) + // T: active > failed (10% task1 failure) --> J: active > failed + cancellation of any runnable tasks. + task1 := taskWithStatus(api.JobStatusActive, api.TaskStatusActive) + task2 := taskOfSameJob(task1, api.TaskStatusFailed) + task3 := taskOfSameJob(task2, api.TaskStatusSoftFailed) + remainingTasks := []*persistence.Task{task2, task3} - mocks.persist.EXPECT().CountTasksOfJobInStatus(ctx, task.Job, api.TaskStatusFailed).Return(10, 100, nil) // 10 out of 100 failed. + mocks.expectSaveTaskWithStatus(t, task1, api.TaskStatusFailed) + mocks.expectBroadcastTaskChange(task1, api.TaskStatusActive, api.TaskStatusFailed) + mocks.expectSaveJobWithStatus(t, task1.Job, api.JobStatusFailed) + mocks.expectBroadcastJobChange(task1.Job, api.JobStatusActive, api.JobStatusFailed) + + mocks.persist.EXPECT().CountTasksOfJobInStatus(ctx, task1.Job, api.TaskStatusFailed).Return(10, 100, nil) // 10 out of 100 failed. // Expect failure of the job to trigger cancellation of remaining tasks. - taskStatusesToCancel := []api.TaskStatus{ + mocks.persist.EXPECT().FetchTasksOfJobInStatus(ctx, task1.Job, api.TaskStatusActive, api.TaskStatusQueued, api.TaskStatusSoftFailed, - } - mocks.persist.EXPECT().UpdateJobsTaskStatusesConditional(ctx, task.Job, taskStatusesToCancel, api.TaskStatusCanceled, - "Manager cancelled this task because the job got status \"failed\".", - ) + ).Return(remainingTasks, nil) + mocks.expectSaveTaskWithStatus(t, task2, api.TaskStatusCanceled) + mocks.expectSaveTaskWithStatus(t, task3, api.TaskStatusCanceled) - assert.NoError(t, sm.TaskStatusChange(ctx, task, api.TaskStatusFailed)) + mocks.expectBroadcastTaskChange(task2, api.TaskStatusFailed, api.TaskStatusCanceled) + mocks.expectBroadcastTaskChange(task3, api.TaskStatusSoftFailed, api.TaskStatusCanceled) + + assert.NoError(t, sm.TaskStatusChange(ctx, task1, api.TaskStatusFailed)) } func TestTaskStatusChangeRequeueOnCompletedJob(t *testing.T) { @@ -133,21 +147,30 @@ func TestTaskStatusChangeRequeueOnCompletedJob(t *testing.T) { defer mockCtrl.Finish() // T: completed > queued --> J: completed > requeued > queued - task := taskWithStatus(api.JobStatusCompleted, api.TaskStatusCompleted) - mocks.expectSaveTaskWithStatus(t, task, api.TaskStatusQueued) - mocks.expectSaveJobWithStatus(t, task.Job, api.JobStatusRequeued) - mocks.expectBroadcastJobChange(task.Job, api.JobStatusCompleted, api.JobStatusRequeued) - mocks.expectBroadcastJobChange(task.Job, api.JobStatusRequeued, api.JobStatusQueued) + task1 := taskWithStatus(api.JobStatusCompleted, api.TaskStatusCompleted) + task2 := taskOfSameJob(task1, api.TaskStatusCompleted) + task3 := taskOfSameJob(task2, api.TaskStatusCompleted) + allTaskIDs := []*persistence.Task{task1, task2, task3} + + mocks.expectSaveTaskWithStatus(t, task1, api.TaskStatusQueued) + mocks.expectBroadcastTaskChange(task1, api.TaskStatusCompleted, api.TaskStatusQueued) + mocks.expectSaveJobWithStatus(t, task1.Job, api.JobStatusRequeued) + mocks.expectBroadcastJobChange(task1.Job, api.JobStatusCompleted, api.JobStatusRequeued) + mocks.expectBroadcastJobChange(task1.Job, api.JobStatusRequeued, api.JobStatusQueued) // Expect queueing of the job to trigger queueing of all its tasks, if those tasks were all completed before. // 2 out of 3 completed, because one was just queued. - mocks.persist.EXPECT().CountTasksOfJobInStatus(ctx, task.Job, api.TaskStatusCompleted).Return(2, 3, nil) - mocks.persist.EXPECT().UpdateJobsTaskStatuses(ctx, task.Job, api.TaskStatusQueued, - "Queued because job transitioned status from \"completed\" to \"requeued\"", - ) - mocks.expectSaveJobWithStatus(t, task.Job, api.JobStatusQueued) + mocks.persist.EXPECT().CountTasksOfJobInStatus(ctx, task1.Job, api.TaskStatusCompleted).Return(2, 3, nil) + fetchCall := mocks.persist.EXPECT().FetchTasksOfJob(ctx, task1.Job).Return(allTaskIDs, nil) + mocks.expectSaveTaskWithStatus(t, task2, api.TaskStatusQueued).After(fetchCall) + mocks.expectSaveTaskWithStatus(t, task3, api.TaskStatusQueued).After(fetchCall) - assert.NoError(t, sm.TaskStatusChange(ctx, task, api.TaskStatusQueued)) + mocks.expectBroadcastTaskChange(task2, api.TaskStatusCompleted, api.TaskStatusQueued) + mocks.expectBroadcastTaskChange(task3, api.TaskStatusCompleted, api.TaskStatusQueued) + + mocks.expectSaveJobWithStatus(t, task1.Job, api.JobStatusQueued) + + assert.NoError(t, sm.TaskStatusChange(ctx, task1, api.TaskStatusQueued)) } func TestTaskStatusChangeCancelSingleTask(t *testing.T) { @@ -160,11 +183,13 @@ func TestTaskStatusChangeCancelSingleTask(t *testing.T) { // T1: cancel-requested > cancelled --> J: cancel-requested > cancel-requested mocks.expectSaveTaskWithStatus(t, task, api.TaskStatusCanceled) + mocks.expectBroadcastTaskChange(task, api.TaskStatusCancelRequested, api.TaskStatusCanceled) mocks.persist.EXPECT().JobHasTasksInStatus(ctx, job, api.TaskStatusCancelRequested).Return(true, nil) assert.NoError(t, sm.TaskStatusChange(ctx, task, api.TaskStatusCanceled)) // T2: cancel-requested > cancelled --> J: cancel-requested > canceled mocks.expectSaveTaskWithStatus(t, task2, api.TaskStatusCanceled) + mocks.expectBroadcastTaskChange(task2, api.TaskStatusCancelRequested, api.TaskStatusCanceled) mocks.persist.EXPECT().JobHasTasksInStatus(ctx, job, api.TaskStatusCancelRequested).Return(false, nil) mocks.expectSaveJobWithStatus(t, job, api.JobStatusCanceled) mocks.expectBroadcastJobChange(task.Job, api.JobStatusCancelRequested, api.JobStatusCanceled) @@ -176,9 +201,11 @@ func TestTaskStatusChangeUnknownStatus(t *testing.T) { mockCtrl, ctx, sm, mocks := taskStateMachineTestFixtures(t) defer mockCtrl.Finish() - // T: queued > borked --> saved to DB but otherwise ignored + // T: queued > borked --> saved to DB but otherwise ignored w.r.t. job status changes. task := taskWithStatus(api.JobStatusQueued, api.TaskStatusQueued) mocks.expectSaveTaskWithStatus(t, task, api.TaskStatus("borked")) + mocks.expectBroadcastTaskChange(task, api.TaskStatusQueued, api.TaskStatus("borked")) + assert.NoError(t, sm.TaskStatusChange(ctx, task, api.TaskStatus("borked"))) } @@ -187,32 +214,33 @@ func TestJobRequeueWithSomeCompletedTasks(t *testing.T) { defer mockCtrl.Finish() task1 := taskWithStatus(api.JobStatusActive, api.TaskStatusCompleted) - // These are not necessary to create for this test, but just imagine these tasks are there too. - // This is mimicked by returning (1, 3, nil) when counting the tasks (1 of 3 completed). - // task2 := taskOfSameJob(task1, api.TaskStatusFailed) - // task3 := taskOfSameJob(task1, api.TaskStatusSoftFailed) + task2 := taskOfSameJob(task1, api.TaskStatusFailed) + task3 := taskOfSameJob(task2, api.TaskStatusSoftFailed) + notCompleteTasks := []*persistence.Task{task2, task3} job := task1.Job mocks.expectSaveJobWithStatus(t, job, api.JobStatusRequeued) // Expect queueing of the job to trigger queueing of all its not-yet-completed tasks. mocks.persist.EXPECT().CountTasksOfJobInStatus(ctx, job, api.TaskStatusCompleted).Return(1, 3, nil) - mocks.persist.EXPECT().UpdateJobsTaskStatusesConditional(ctx, job, - []api.TaskStatus{ - api.TaskStatusCancelRequested, - api.TaskStatusCanceled, - api.TaskStatusFailed, - api.TaskStatusPaused, - api.TaskStatusSoftFailed, - }, - api.TaskStatusQueued, - "Queued because job transitioned status from \"active\" to \"requeued\"", - ) + mocks.persist.EXPECT().FetchTasksOfJobInStatus(ctx, job, + api.TaskStatusCancelRequested, + api.TaskStatusCanceled, + api.TaskStatusFailed, + api.TaskStatusPaused, + api.TaskStatusSoftFailed, + ).Return(notCompleteTasks, nil) + + mocks.expectSaveTaskWithStatus(t, task2, api.TaskStatusQueued) + mocks.expectSaveTaskWithStatus(t, task3, api.TaskStatusQueued) mocks.expectSaveJobWithStatus(t, job, api.JobStatusQueued) mocks.expectBroadcastJobChange(job, api.JobStatusActive, api.JobStatusRequeued) mocks.expectBroadcastJobChange(job, api.JobStatusRequeued, api.JobStatusQueued) + mocks.expectBroadcastTaskChange(task2, api.TaskStatusFailed, api.TaskStatusQueued) + mocks.expectBroadcastTaskChange(task3, api.TaskStatusSoftFailed, api.TaskStatusQueued) + assert.NoError(t, sm.JobStatusChange(ctx, job, api.JobStatusRequeued, "someone wrote a unittest")) } @@ -221,32 +249,70 @@ func TestJobRequeueWithAllCompletedTasks(t *testing.T) { defer mockCtrl.Finish() task1 := taskWithStatus(api.JobStatusCompleted, api.TaskStatusCompleted) - // These are not necessary to create for this test, but just imagine these tasks are there too. - // This is mimicked by returning (3, 3, nil) when counting the tasks (3 of 3 completed). - // task2 := taskOfSameJob(task1, api.TaskStatusCompleted) - // task3 := taskOfSameJob(task1, api.TaskStatusCompleted) + task2 := taskOfSameJob(task1, api.TaskStatusCompleted) + task3 := taskOfSameJob(task2, api.TaskStatusCompleted) + allTasks := []*persistence.Task{task1, task2, task3} job := task1.Job call1 := mocks.expectSaveJobWithStatus(t, job, api.JobStatusRequeued) // Expect queueing of the job to trigger queueing of all its not-yet-completed tasks. - call2 := mocks.persist.EXPECT(). - UpdateJobsTaskStatuses(ctx, job, api.TaskStatusQueued, "Queued because job transitioned status from \"completed\" to \"requeued\""). + fetchCall := mocks.persist.EXPECT().FetchTasksOfJob(ctx, job). + Return(allTasks, nil). After(call1) - call3 := mocks.expectSaveJobWithStatus(t, job, api.JobStatusQueued).After(call2) + mocks.expectSaveTaskWithStatus(t, task1, api.TaskStatusQueued).After(fetchCall) + mocks.expectSaveTaskWithStatus(t, task2, api.TaskStatusQueued).After(fetchCall) + mocks.expectSaveTaskWithStatus(t, task3, api.TaskStatusQueued).After(fetchCall) + + saveJobCall := mocks.expectSaveJobWithStatus(t, job, api.JobStatusQueued).After(fetchCall) mocks.persist.EXPECT(). CountTasksOfJobInStatus(ctx, job, api.TaskStatusCompleted). Return(0, 3, nil). // By now all tasks are queued. - After(call3) + After(saveJobCall) mocks.expectBroadcastJobChange(job, api.JobStatusCompleted, api.JobStatusRequeued) mocks.expectBroadcastJobChange(job, api.JobStatusRequeued, api.JobStatusQueued) + mocks.expectBroadcastTaskChange(task1, api.TaskStatusCompleted, api.TaskStatusQueued) + mocks.expectBroadcastTaskChange(task2, api.TaskStatusCompleted, api.TaskStatusQueued) + mocks.expectBroadcastTaskChange(task3, api.TaskStatusCompleted, api.TaskStatusQueued) + assert.NoError(t, sm.JobStatusChange(ctx, job, api.JobStatusRequeued, "someone wrote a unit test")) } +func TestJobCancelWithSomeCompletedTasks(t *testing.T) { + mockCtrl, ctx, sm, mocks := taskStateMachineTestFixtures(t) + defer mockCtrl.Finish() + + task1 := taskWithStatus(api.JobStatusActive, api.TaskStatusCompleted) + task2 := taskOfSameJob(task1, api.TaskStatusFailed) + task3 := taskOfSameJob(task2, api.TaskStatusSoftFailed) + job := task1.Job + potentialRunTasks := []*persistence.Task{task2, task3} + + mocks.expectSaveJobWithStatus(t, job, api.JobStatusCancelRequested) + + // Expect cancelling of the job to trigger cancelling of all its could-potentially-still-run tasks. + fetchCall := mocks.persist.EXPECT().FetchTasksOfJobInStatus(ctx, job, + api.TaskStatusActive, + api.TaskStatusQueued, + api.TaskStatusSoftFailed, + ).Return(potentialRunTasks, nil) + mocks.expectSaveTaskWithStatus(t, task2, api.TaskStatusCanceled).After(fetchCall) + mocks.expectSaveTaskWithStatus(t, task3, api.TaskStatusCanceled).After(fetchCall) + mocks.expectSaveJobWithStatus(t, job, api.JobStatusCanceled).After(fetchCall) + + mocks.expectBroadcastJobChange(job, api.JobStatusActive, api.JobStatusCancelRequested) + mocks.expectBroadcastJobChange(job, api.JobStatusCancelRequested, api.JobStatusCanceled) + + mocks.expectBroadcastTaskChange(task2, api.TaskStatusFailed, api.TaskStatusCanceled) + mocks.expectBroadcastTaskChange(task3, api.TaskStatusSoftFailed, api.TaskStatusCanceled) + + assert.NoError(t, sm.JobStatusChange(ctx, job, api.JobStatusCancelRequested, "someone wrote a unittest")) +} + func mockedTaskStateMachine(mockCtrl *gomock.Controller) (*StateMachine, *StateMachineMocks) { mocks := StateMachineMocks{ persist: mocks.NewMockPersistenceService(mockCtrl), @@ -260,8 +326,8 @@ func (m *StateMachineMocks) expectSaveTaskWithStatus( t *testing.T, task *persistence.Task, expectTaskStatus api.TaskStatus, -) { - m.persist.EXPECT(). +) *gomock.Call { + return m.persist.EXPECT(). SaveTask(gomock.Any(), task). DoAndReturn(func(ctx context.Context, savedTask *persistence.Task) error { assert.Equal(t, expectTaskStatus, savedTask.Status) @@ -296,6 +362,21 @@ func (m *StateMachineMocks) expectBroadcastJobChange( return m.broadcaster.EXPECT().BroadcastJobUpdate(expectUpdate) } +func (m *StateMachineMocks) expectBroadcastTaskChange( + task *persistence.Task, + fromStatus, toStatus api.TaskStatus, +) *gomock.Call { + expectUpdate := api.SocketIOTaskUpdate{ + Id: task.UUID, + JobId: task.Job.UUID, + Name: task.Name, + Updated: task.UpdatedAt, + PreviousStatus: &fromStatus, + Status: toStatus, + } + return m.broadcaster.EXPECT().BroadcastTaskUpdate(expectUpdate) +} + /* taskWithStatus() creates a task of a certain status, with a job of a certain status. */ func taskWithStatus(jobStatus api.JobStatus, taskStatus api.TaskStatus) *persistence.Task { job := persistence.Job{