diff --git a/internal/manager/task_state_machine/mocks/interfaces_mock.gen.go b/internal/manager/task_state_machine/mocks/interfaces_mock.gen.go index 94eca89d..5a0b7d8b 100644 --- a/internal/manager/task_state_machine/mocks/interfaces_mock.gen.go +++ b/internal/manager/task_state_machine/mocks/interfaces_mock.gen.go @@ -94,3 +94,31 @@ func (mr *MockPersistenceServiceMockRecorder) SaveTask(arg0, arg1 interface{}) * mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveTask", reflect.TypeOf((*MockPersistenceService)(nil).SaveTask), arg0, arg1) } + +// UpdateJobsTaskStatuses mocks base method. +func (m *MockPersistenceService) UpdateJobsTaskStatuses(arg0 context.Context, arg1 *persistence.Job, arg2 api.TaskStatus, arg3 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateJobsTaskStatuses", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateJobsTaskStatuses indicates an expected call of UpdateJobsTaskStatuses. +func (mr *MockPersistenceServiceMockRecorder) UpdateJobsTaskStatuses(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateJobsTaskStatuses", reflect.TypeOf((*MockPersistenceService)(nil).UpdateJobsTaskStatuses), arg0, arg1, arg2, arg3) +} + +// UpdateJobsTaskStatusesConditional mocks base method. +func (m *MockPersistenceService) UpdateJobsTaskStatusesConditional(arg0 context.Context, arg1 *persistence.Job, arg2 []api.TaskStatus, arg3 api.TaskStatus, arg4 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateJobsTaskStatusesConditional", arg0, arg1, arg2, arg3, arg4) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateJobsTaskStatusesConditional indicates an expected call of UpdateJobsTaskStatusesConditional. +func (mr *MockPersistenceServiceMockRecorder) UpdateJobsTaskStatusesConditional(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateJobsTaskStatusesConditional", reflect.TypeOf((*MockPersistenceService)(nil).UpdateJobsTaskStatusesConditional), arg0, arg1, arg2, arg3, arg4) +} diff --git a/internal/manager/task_state_machine/task_state_machine.go b/internal/manager/task_state_machine/task_state_machine.go index 1e028cce..9b20ba3c 100644 --- a/internal/manager/task_state_machine/task_state_machine.go +++ b/internal/manager/task_state_machine/task_state_machine.go @@ -24,6 +24,7 @@ import ( "context" "fmt" + "github.com/rs/zerolog" "github.com/rs/zerolog/log" "gitlab.com/blender/flamenco-ng-poc/internal/manager/persistence" "gitlab.com/blender/flamenco-ng-poc/pkg/api" @@ -47,6 +48,15 @@ type PersistenceService interface { // Subset of persistence.DB JobHasTasksInStatus(ctx context.Context, job *persistence.Job, taskStatus api.TaskStatus) (bool, error) CountTasksOfJobInStatus(ctx context.Context, job *persistence.Job, taskStatus api.TaskStatus) (numInStatus, numTotal int, err error) + + // UpdateJobsTaskStatuses updates the status & activity of the tasks of `job`. + UpdateJobsTaskStatuses(ctx context.Context, job *persistence.Job, + taskStatus api.TaskStatus, activity string) error + + // UpdateJobsTaskStatusesConditional updates the status & activity of the tasks of `job`, + // limited to those tasks with status in `statusesToUpdate`. + UpdateJobsTaskStatusesConditional(ctx context.Context, job *persistence.Job, + statusesToUpdate []api.TaskStatus, taskStatus api.TaskStatus, activity string) error } func NewStateMachine(persist PersistenceService) *StateMachine { @@ -57,26 +67,32 @@ func NewStateMachine(persist PersistenceService) *StateMachine { // TaskStatusChange updates the task's status to the new one. // `task` is expected to still have its original status, and have a filled `Job` pointer. -func (sm *StateMachine) TaskStatusChange(ctx context.Context, task *persistence.Task, newTaskStatus api.TaskStatus) error { +func (sm *StateMachine) TaskStatusChange( + ctx context.Context, + task *persistence.Task, + newTaskStatus api.TaskStatus, +) error { job := task.Job if job == nil { log.Panic().Str("task", task.UUID).Msg("task without job, cannot handle this") return nil // Will not run because of the panic. } + oldTaskStatus := task.Status + task.Status = newTaskStatus + logger := log.With(). Str("task", task.UUID). Str("job", job.UUID). - Str("taskStatusOld", string(task.Status)). + Str("taskStatusOld", string(oldTaskStatus)). Str("taskStatusNew", string(newTaskStatus)). Logger() logger.Debug().Msg("task state changed") - task.Status = newTaskStatus if err := sm.persist.SaveTask(ctx, task); err != nil { return fmt.Errorf("error saving task to database: %w", err) } - if err := sm.updateJobAfterTaskStatusChange(ctx, task, newTaskStatus); err != nil { + if err := sm.updateJobAfterTaskStatusChange(ctx, task, oldTaskStatus); err != nil { return fmt.Errorf("error updating job after task status change: %w", err) } return nil @@ -85,16 +101,15 @@ func (sm *StateMachine) TaskStatusChange(ctx context.Context, task *persistence. // updateJobAfterTaskStatusChange updates the job status based on the status of // this task and other tasks in the job. func (sm *StateMachine) updateJobAfterTaskStatusChange( - ctx context.Context, task *persistence.Task, newTaskStatus api.TaskStatus, + ctx context.Context, task *persistence.Task, oldTaskStatus api.TaskStatus, ) error { - job := task.Job logger := log.With(). Str("job", job.UUID). Str("task", task.UUID). - Str("taskStatusOld", string(task.Status)). - Str("taskStatusNew", string(newTaskStatus)). + Str("taskStatusOld", string(oldTaskStatus)). + Str("taskStatusNew", string(task.Status)). Logger() // If the job has status 'ifStatus', move it to status 'thenStatus'. @@ -105,15 +120,15 @@ func (sm *StateMachine) updateJobAfterTaskStatusChange( logger.Info(). Str("jobStatusOld", string(ifStatus)). Str("jobStatusNew", string(thenStatus)). - Msg("Job changed status because one of its task changed status") + Msg("Job will change status because one of its task changed status") return sm.JobStatusChange(ctx, job, thenStatus) } // Every 'case' in this switch MUST return. Just for sanity's sake. - switch newTaskStatus { + switch task.Status { case api.TaskStatusQueued: // Re-queueing a task on a completed job should re-queue the job too. - return jobStatusIfAThenB(api.JobStatusCompleted, api.JobStatusQueued) + return jobStatusIfAThenB(api.JobStatusCompleted, api.JobStatusRequeued) case api.TaskStatusCancelRequested: // Requesting cancellation of a single task has no influence on the job itself. @@ -195,16 +210,182 @@ func (sm *StateMachine) updateJobAfterTaskStatusChange( } func (sm *StateMachine) JobStatusChange(ctx context.Context, job *persistence.Job, newJobStatus api.JobStatus) error { - logger := log.With(). - Str("job", job.UUID). - Str("jobStatusOld", string(job.Status)). - Str("jobStatusNew", string(newJobStatus)). - Logger() + // Job status changes can trigger task status changes, which can trigger the + // next job status change. Keep looping over these job status changes until + // there is no more change left to do. + var err error + for newJobStatus != "" && newJobStatus != job.Status { + oldJobStatus := job.Status + job.Status = newJobStatus - logger.Info().Msg("job status changed") + logger := log.With(). + Str("job", job.UUID). + Str("jobStatusOld", string(oldJobStatus)). + Str("jobStatusNew", string(newJobStatus)). + Logger() + logger.Info().Msg("job status changed") - // TODO: actually respond to status change, instead of just saving the new job state. + // Persist the new job status. + err = sm.persist.SaveJobStatus(ctx, job) + if err != nil { + return fmt.Errorf("error saving job status change %q to %q to database: %w", + oldJobStatus, newJobStatus, err) + } - job.Status = newJobStatus - return sm.persist.SaveJobStatus(ctx, job) + // Handle the status change. + newJobStatus, err = sm.updateTasksAfterJobStatusChange(ctx, logger, job, oldJobStatus) + if err != nil { + return fmt.Errorf("error updating job's tasks after job status change: %w", err) + } + } + + return nil +} + +// updateTasksAfterJobStatusChange updates the status of its tasks based on the +// new status of this job. +// +// NOTE: this function assumes that the job already has its new status. +// +// Returns the new state the job should go into after this change, or an empty +// string if there is no subsequent change necessary. +func (sm *StateMachine) updateTasksAfterJobStatusChange( + ctx context.Context, + logger zerolog.Logger, + job *persistence.Job, + oldJobStatus api.JobStatus, +) (api.JobStatus, error) { + + // Every case in this switch MUST return, for sanity sake. + switch job.Status { + case api.JobStatusCompleted, api.JobStatusCanceled: + // Nothing to do; this will happen as a response to all tasks receiving this status. + return "", nil + + case api.JobStatusActive: + // Nothing to do; this happens when a task gets started, which has nothing to + // do with other tasks in the job. + return "", nil + + case api.JobStatusCancelRequested, api.JobStatusFailed: + return sm.cancelTasks(ctx, logger, job) + + case api.JobStatusRequeued: + return sm.requeueTasks(ctx, logger, job, oldJobStatus) + + case api.JobStatusQueued: + return sm.checkTaskCompletion(ctx, logger, job) + + default: + logger.Warn().Msg("unknown job status change, ignoring") + return "", nil + } +} + +// Directly cancel any task that might run in the future. +// +// Returns the next job status, if a status change is required. +func (sm *StateMachine) cancelTasks( + ctx context.Context, logger zerolog.Logger, job *persistence.Job, +) (api.JobStatus, error) { + logger.Info().Msg("cancelling tasks of job") + + // Any task that is running or might run in the future should get cancelled. + taskStatusesToCancel := []api.TaskStatus{ + api.TaskStatusActive, + api.TaskStatusQueued, + api.TaskStatusSoftFailed, + } + err := sm.persist.UpdateJobsTaskStatusesConditional( + ctx, job, taskStatusesToCancel, api.TaskStatusCanceled, + fmt.Sprintf("Manager cancelled this task because the job got status %q.", job.Status), + ) + if err != nil { + return "", fmt.Errorf("error cancelling tasks of job %s: %w", job.UUID, err) + } + + // If cancellation was requested, it has now happened, so the job can transition. + if job.Status == api.JobStatusCancelRequested { + logger.Info().Msg("all tasks of job cancelled, job can go to 'cancelled' status") + return api.JobStatusCanceled, nil + } + + // This could mean cancellation was triggered by failure of the job, in which + // case the job is already in the correct status. + return "", nil +} + +// requeueTasks re-queues all tasks of the job. +// +// This function assumes that the current job status is "requeued". +// +// Returns the new job status, if this status transition should be followed by +// another one. +func (sm *StateMachine) requeueTasks( + ctx context.Context, logger zerolog.Logger, job *persistence.Job, oldJobStatus api.JobStatus, +) (api.JobStatus, error) { + var err error + + if job.Status != api.JobStatusRequeued { + logger.Warn().Msg("unexpected job status in StateMachine::requeueTasks()") + } + + switch oldJobStatus { + case api.JobStatusUnderConstruction: + // Nothing to do, the job compiler has just finished its work; the tasks have + // already been set to 'queued' status. + logger.Debug().Msg("ignoring job status change") + return "", nil + case api.JobStatusCompleted: + // Re-queue all tasks. + err = sm.persist.UpdateJobsTaskStatuses(ctx, job, api.TaskStatusQueued, + fmt.Sprintf("Queued because job transitioned status from %q to %q", oldJobStatus, job.Status)) + default: + // Re-queue only the non-completed tasks. + statusesToUpdate := []api.TaskStatus{ + api.TaskStatusCancelRequested, + api.TaskStatusCanceled, + api.TaskStatusFailed, + api.TaskStatusPaused, + api.TaskStatusSoftFailed, + } + err = sm.persist.UpdateJobsTaskStatusesConditional(ctx, job, + statusesToUpdate, api.TaskStatusQueued, + fmt.Sprintf("Queued because job transitioned status from %q to %q", oldJobStatus, job.Status)) + } + + if err != nil { + return "", fmt.Errorf("error queueing tasks of job %s: %w", job.UUID, err) + } + + // TODO: also reset the 'failed by workers' blacklist. + + // The appropriate tasks have been requeued, so now the job can go from "requeued" to "queued". + return api.JobStatusQueued, nil +} + +// checkTaskCompletion returns "completed" as next job status when all tasks of +// the job are completed. +// +// Returns the new job status, if this status transition should be followed by +// another one. +func (sm *StateMachine) checkTaskCompletion( + ctx context.Context, logger zerolog.Logger, job *persistence.Job, +) (api.JobStatus, error) { + + numCompleted, numTotal, err := sm.persist.CountTasksOfJobInStatus(ctx, job, api.TaskStatusCompleted) + if err != nil { + return "", fmt.Errorf("checking task completion of job %s: %w", job.UUID, err) + } + + if numCompleted < numTotal { + logger.Debug(). + Int("numTasksCompleted", numCompleted). + Int("numTasksTotal", numTotal). + Msg("not all tasks of job are completed") + return "", nil + } + + logger.Info().Msg("job has all tasks completed, transition job to 'completed'") + return api.JobStatusCompleted, nil } diff --git a/internal/manager/task_state_machine/task_state_machine_test.go b/internal/manager/task_state_machine/task_state_machine_test.go index 20aafa4e..1471daa2 100644 --- a/internal/manager/task_state_machine/task_state_machine_test.go +++ b/internal/manager/task_state_machine/task_state_machine_test.go @@ -118,11 +118,22 @@ func TestTaskStatusChangeActiveToFailedFailJob(t *testing.T) { mockCtrl, ctx, sm, mocks := taskStateMachineTestFixtures(t) defer mockCtrl.Finish() - // T: active > failed (10% task failure) --> J: active > failed + // T: active > failed (10% task failure) --> J: active > failed + cancellation of any runnable tasks. task := taskWithStatus(api.JobStatusActive, api.TaskStatusActive) mocks.expectSaveTaskWithStatus(t, task, api.TaskStatusFailed) mocks.expectSaveJobWithStatus(t, task.Job, api.JobStatusFailed) mocks.persist.EXPECT().CountTasksOfJobInStatus(ctx, task.Job, api.TaskStatusFailed).Return(10, 100, nil) // 10 out of 100 failed. + + // Expect failure of the job to trigger cancellation of remaining tasks. + taskStatusesToCancel := []api.TaskStatus{ + api.TaskStatusActive, + api.TaskStatusQueued, + api.TaskStatusSoftFailed, + } + mocks.persist.EXPECT().UpdateJobsTaskStatusesConditional(ctx, task.Job, taskStatusesToCancel, api.TaskStatusCanceled, + "Manager cancelled this task because the job got status \"failed\".", + ) + assert.NoError(t, sm.TaskStatusChange(ctx, task, api.TaskStatusFailed)) } @@ -130,10 +141,19 @@ func TestTaskStatusChangeRequeueOnCompletedJob(t *testing.T) { mockCtrl, ctx, sm, mocks := taskStateMachineTestFixtures(t) defer mockCtrl.Finish() - // T: completed > queued --> J: completed > queued + // T: completed > queued --> J: completed > requeued > queued task := taskWithStatus(api.JobStatusCompleted, api.TaskStatusCompleted) mocks.expectSaveTaskWithStatus(t, task, api.TaskStatusQueued) + mocks.expectSaveJobWithStatus(t, task.Job, api.JobStatusRequeued) + + // Expect queueing of the job to trigger queueing of all its tasks, if those tasks were all completed before. + // 2 out of 3 completed, because one was just queued. + mocks.persist.EXPECT().CountTasksOfJobInStatus(ctx, task.Job, api.TaskStatusCompleted).Return(2, 3, nil) + mocks.persist.EXPECT().UpdateJobsTaskStatuses(ctx, task.Job, api.TaskStatusQueued, + "Queued because job transitioned status from \"completed\" to \"requeued\"", + ) mocks.expectSaveJobWithStatus(t, task.Job, api.JobStatusQueued) + assert.NoError(t, sm.TaskStatusChange(ctx, task, api.TaskStatusQueued)) } @@ -167,6 +187,64 @@ func TestTaskStatusChangeUnknownStatus(t *testing.T) { assert.NoError(t, sm.TaskStatusChange(ctx, task, api.TaskStatus("borked"))) } +func TestJobRequeueWithSomeCompletedTasks(t *testing.T) { + mockCtrl, ctx, sm, mocks := taskStateMachineTestFixtures(t) + defer mockCtrl.Finish() + + task1 := taskWithStatus(api.JobStatusActive, api.TaskStatusCompleted) + // These are not necessary to create for this test, but just imagine these tasks are there too. + // This is mimicked by returning (1, 3, nil) when counting the tasks (1 of 3 completed). + // task2 := taskOfSameJob(task1, api.TaskStatusFailed) + // task3 := taskOfSameJob(task1, api.TaskStatusSoftFailed) + job := task1.Job + + mocks.expectSaveJobWithStatus(t, job, api.JobStatusRequeued) + + // Expect queueing of the job to trigger queueing of all its not-yet-completed tasks. + mocks.persist.EXPECT().CountTasksOfJobInStatus(ctx, job, api.TaskStatusCompleted).Return(1, 3, nil) + mocks.persist.EXPECT().UpdateJobsTaskStatusesConditional(ctx, job, + []api.TaskStatus{ + api.TaskStatusCancelRequested, + api.TaskStatusCanceled, + api.TaskStatusFailed, + api.TaskStatusPaused, + api.TaskStatusSoftFailed, + }, + api.TaskStatusQueued, + "Queued because job transitioned status from \"active\" to \"requeued\"", + ) + mocks.expectSaveJobWithStatus(t, job, api.JobStatusQueued) + assert.NoError(t, sm.JobStatusChange(ctx, job, api.JobStatusRequeued)) +} + +func TestJobRequeueWithAllCompletedTasks(t *testing.T) { + mockCtrl, ctx, sm, mocks := taskStateMachineTestFixtures(t) + defer mockCtrl.Finish() + + task1 := taskWithStatus(api.JobStatusCompleted, api.TaskStatusCompleted) + // These are not necessary to create for this test, but just imagine these tasks are there too. + // This is mimicked by returning (3, 3, nil) when counting the tasks (3 of 3 completed). + // task2 := taskOfSameJob(task1, api.TaskStatusCompleted) + // task3 := taskOfSameJob(task1, api.TaskStatusCompleted) + job := task1.Job + + call1 := mocks.expectSaveJobWithStatus(t, job, api.JobStatusRequeued) + + // Expect queueing of the job to trigger queueing of all its not-yet-completed tasks. + call2 := mocks.persist.EXPECT(). + UpdateJobsTaskStatuses(ctx, job, api.TaskStatusQueued, "Queued because job transitioned status from \"completed\" to \"requeued\""). + After(call1) + + call3 := mocks.expectSaveJobWithStatus(t, job, api.JobStatusQueued).After(call2) + + mocks.persist.EXPECT(). + CountTasksOfJobInStatus(ctx, job, api.TaskStatusCompleted). + Return(0, 3, nil). // By now all tasks are queued. + After(call3) + + assert.NoError(t, sm.JobStatusChange(ctx, job, api.JobStatusRequeued)) +} + func mockedTaskStateMachine(mockCtrl *gomock.Controller) (*StateMachine, *StateMachineMocks) { mocks := StateMachineMocks{ persist: mocks.NewMockPersistenceService(mockCtrl), @@ -192,8 +270,8 @@ func (m *StateMachineMocks) expectSaveJobWithStatus( t *testing.T, job *persistence.Job, expectJobStatus api.JobStatus, -) { - m.persist.EXPECT(). +) *gomock.Call { + return m.persist.EXPECT(). SaveJobStatus(gomock.Any(), job). DoAndReturn(func(ctx context.Context, savedJob *persistence.Job) error { assert.Equal(t, expectJobStatus, savedJob.Status)