From 9a5bbb41313b836c31072f093695f5d2c373e18f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sybren=20A=2E=20St=C3=BCvel?= Date: Fri, 25 Feb 2022 14:34:29 +0100 Subject: [PATCH] Manager: implement persistence layer interface for task status machine Implement the functions used by the task status machine in the DB persistence layer. --- internal/manager/persistence/jobs.go | 77 +++++++ internal/manager/persistence/jobs_test.go | 197 +++++++++++++++--- .../task_state_machine/task_state_machine.go | 5 +- 3 files changed, 251 insertions(+), 28 deletions(-) diff --git a/internal/manager/persistence/jobs.go b/internal/manager/persistence/jobs.go index b898a91f..7b079c27 100644 --- a/internal/manager/persistence/jobs.go +++ b/internal/manager/persistence/jobs.go @@ -219,3 +219,80 @@ func (db *DB) SaveTask(ctx context.Context, t *Task) error { } return nil } + +func (db *DB) JobHasTasksInStatus(ctx context.Context, job *Job, taskStatus api.TaskStatus) (bool, error) { + var numTasksInStatus int64 + tx := db.gormDB.Model(&Task{}). + Where("job_id", job.ID). + Where("status", taskStatus). + Count(&numTasksInStatus) + if tx.Error != nil { + return false, tx.Error + } + return numTasksInStatus > 0, nil +} + +func (db *DB) CountTasksOfJobInStatus(ctx context.Context, job *Job, taskStatus api.TaskStatus) (numInStatus, numTotal int, err error) { + type Result struct { + Status api.TaskStatus + NumTasks int + } + var results []Result + + tx := db.gormDB.Debug().Model(&Task{}). + Select("status, count(*) as num_tasks"). + Where("job_id", job.ID). + Group("status"). + Scan(&results) + + if tx.Error != nil { + return 0, 0, fmt.Errorf("count tasks of job %s in status %q: %w", job.UUID, taskStatus, tx.Error) + } + + for _, result := range results { + if result.Status == taskStatus { + numInStatus += result.NumTasks + } + numTotal += result.NumTasks + } + + return +} + +// UpdateJobsTaskStatuses updates the status & activity of the tasks of `job`. +func (db *DB) UpdateJobsTaskStatuses(ctx context.Context, job *Job, + taskStatus api.TaskStatus, activity string) error { + + if taskStatus == "" { + return errors.New("empty status not allowed") + } + + tx := db.gormDB.Model(Task{}). + Where("job_Id = ?", job.ID). + Updates(Task{Status: taskStatus, Activity: activity}) + + if tx.Error != nil { + return tx.Error + } + return nil +} + +// UpdateJobsTaskStatusesConditional updates the status & activity of the tasks of `job`, +// limited to those tasks with status in `statusesToUpdate`. +func (db *DB) UpdateJobsTaskStatusesConditional(ctx context.Context, job *Job, + statusesToUpdate []api.TaskStatus, taskStatus api.TaskStatus, activity string) error { + + if taskStatus == "" { + return errors.New("empty status not allowed") + } + + tx := db.gormDB.Debug().Model(Task{}). + Where("job_Id = ?", job.ID). + Where("status in ?", statusesToUpdate). + Updates(Task{Status: taskStatus, Activity: activity}) + + if tx.Error != nil { + return tx.Error + } + return nil +} diff --git a/internal/manager/persistence/jobs_test.go b/internal/manager/persistence/jobs_test.go index 4e0924a7..657a16f0 100644 --- a/internal/manager/persistence/jobs_test.go +++ b/internal/manager/persistence/jobs_test.go @@ -37,6 +37,157 @@ func TestStoreAuthoredJob(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() + job := createTestAuthoredJobWithTasks() + err := db.StoreAuthoredJob(ctx, job) + assert.NoError(t, err) + + fetchedJob, err := db.FetchJob(ctx, job.JobID) + assert.NoError(t, err) + assert.NotNil(t, fetchedJob) + + // Test contents of fetched job + assert.Equal(t, job.JobID, fetchedJob.UUID) + assert.Equal(t, job.Name, fetchedJob.Name) + assert.Equal(t, job.JobType, fetchedJob.JobType) + assert.Equal(t, job.Priority, fetchedJob.Priority) + assert.Equal(t, api.JobStatusUnderConstruction, fetchedJob.Status) + assert.EqualValues(t, map[string]interface{}(job.Settings), fetchedJob.Settings) + assert.EqualValues(t, map[string]string(job.Metadata), fetchedJob.Metadata) + + // Fetch tasks of job. + var dbJob Job + tx := db.gormDB.Where(&Job{UUID: job.JobID}).Find(&dbJob) + assert.NoError(t, tx.Error) + var tasks []Task + tx = db.gormDB.Where("job_id = ?", dbJob.ID).Find(&tasks) + assert.NoError(t, tx.Error) + + if len(tasks) != 3 { + t.Fatalf("expected 3 tasks, got %d", len(tasks)) + } + + // TODO: test task contents. + assert.Equal(t, api.TaskStatusQueued, tasks[0].Status) + assert.Equal(t, api.TaskStatusQueued, tasks[1].Status) + assert.Equal(t, api.TaskStatusQueued, tasks[2].Status) +} + +func TestJobHasTasksInStatus(t *testing.T) { + ctx, db, job, _ := jobTasksTestFixtures(t) + + hasTasks, err := db.JobHasTasksInStatus(ctx, job, api.TaskStatusQueued) + assert.NoError(t, err) + assert.True(t, hasTasks, "expected freshly-created job to have queued tasks") + + hasTasks, err = db.JobHasTasksInStatus(ctx, job, api.TaskStatusActive) + assert.NoError(t, err) + assert.False(t, hasTasks, "expected freshly-created job to have no active tasks") +} + +func TestCountTasksOfJobInStatus(t *testing.T) { + ctx, db, job, authoredJob := jobTasksTestFixtures(t) + + numQueued, numTotal, err := db.CountTasksOfJobInStatus(ctx, job, api.TaskStatusQueued) + assert.NoError(t, err) + assert.Equal(t, 3, numQueued) + assert.Equal(t, 3, numTotal) + + // Make one task failed. + task, err := db.FetchTask(ctx, authoredJob.Tasks[0].UUID) + assert.NoError(t, err) + task.Status = api.TaskStatusFailed + assert.NoError(t, db.SaveTask(ctx, task)) + + numQueued, numTotal, err = db.CountTasksOfJobInStatus(ctx, job, api.TaskStatusQueued) + assert.NoError(t, err) + assert.Equal(t, 2, numQueued) + assert.Equal(t, 3, numTotal) + + numFailed, numTotal, err := db.CountTasksOfJobInStatus(ctx, job, api.TaskStatusFailed) + assert.NoError(t, err) + assert.Equal(t, 1, numFailed) + assert.Equal(t, 3, numTotal) + + numActive, numTotal, err := db.CountTasksOfJobInStatus(ctx, job, api.TaskStatusActive) + assert.NoError(t, err) + assert.Equal(t, 0, numActive) + assert.Equal(t, 3, numTotal) +} + +func TestUpdateJobsTaskStatuses(t *testing.T) { + ctx, db, job, authoredJob := jobTasksTestFixtures(t) + + err := db.UpdateJobsTaskStatuses(ctx, job, api.TaskStatusSoftFailed, "testing æctivity") + assert.NoError(t, err) + + numSoftFailed, numTotal, err := db.CountTasksOfJobInStatus(ctx, job, api.TaskStatusSoftFailed) + assert.NoError(t, err) + assert.Equal(t, 3, numSoftFailed, "all tasks should have had their status changed") + assert.Equal(t, 3, numTotal) + + task, err := db.FetchTask(ctx, authoredJob.Tasks[0].UUID) + assert.NoError(t, err) + assert.Equal(t, "testing æctivity", task.Activity) + + // Empty status should be rejected. + err = db.UpdateJobsTaskStatuses(ctx, job, "", "testing empty status") + assert.Error(t, err) + + numEmpty, _, err := db.CountTasksOfJobInStatus(ctx, job, "") + assert.NoError(t, err) + assert.Equal(t, 0, numEmpty, "tasks should not have their status changed") + + numSoftFailed, _, err = db.CountTasksOfJobInStatus(ctx, job, api.TaskStatusSoftFailed) + assert.NoError(t, err) + assert.Equal(t, 3, numSoftFailed, "all tasks should still be soft-failed") +} + +func TestUpdateJobsTaskStatusesConditional(t *testing.T) { + ctx, db, job, authoredJob := jobTasksTestFixtures(t) + + getTask := func(taskIndex int) *Task { + task, err := db.FetchTask(ctx, authoredJob.Tasks[taskIndex].UUID) + if err != nil { + t.Fatalf("Fetching task %d: %v", taskIndex, err) + } + return task + } + + setTaskStatus := func(taskIndex int, taskStatus api.TaskStatus) { + task := getTask(taskIndex) + task.Status = taskStatus + if err := db.SaveTask(ctx, task); err != nil { + t.Fatalf("Setting task %d to status %s: %v", taskIndex, taskStatus, err) + } + } + + setTaskStatus(0, api.TaskStatusFailed) + setTaskStatus(1, api.TaskStatusCompleted) + setTaskStatus(2, api.TaskStatusActive) + + err := db.UpdateJobsTaskStatusesConditional(ctx, job, + []api.TaskStatus{api.TaskStatusFailed, api.TaskStatusActive}, + api.TaskStatusCancelRequested, "some activity") + assert.NoError(t, err) + + // Task statuses should have updated for tasks 0 and 2. + assert.Equal(t, api.TaskStatusCancelRequested, getTask(0).Status) + assert.Equal(t, api.TaskStatusCompleted, getTask(1).Status) + assert.Equal(t, api.TaskStatusCancelRequested, getTask(2).Status) + + err = db.UpdateJobsTaskStatusesConditional(ctx, job, + []api.TaskStatus{api.TaskStatusFailed, api.TaskStatusActive}, + "", "empty task status should be disallowed") + assert.Error(t, err) + + // Task statuses should remain unchanged. + assert.Equal(t, api.TaskStatusCancelRequested, getTask(0).Status) + assert.Equal(t, api.TaskStatusCompleted, getTask(1).Status) + assert.Equal(t, api.TaskStatusCancelRequested, getTask(2).Status) + +} + +func createTestAuthoredJobWithTasks() job_compilers.AuthoredJob { task1 := job_compilers.AuthoredTask{ Name: "render-1-3", Type: "blender", @@ -93,36 +244,28 @@ func TestStoreAuthoredJob(t *testing.T) { Tasks: []job_compilers.AuthoredTask{task1, task2, task3}, } - err := db.StoreAuthoredJob(ctx, job) - assert.NoError(t, err) + return job +} - fetchedJob, err := db.FetchJob(ctx, job.JobID) - assert.NoError(t, err) - assert.NotNil(t, fetchedJob) +func jobTasksTestFixtures(t *testing.T) (context.Context, *DB, *Job, job_compilers.AuthoredJob) { + db := CreateTestDB(t) - // Test contents of fetched job - assert.Equal(t, job.JobID, fetchedJob.UUID) - assert.Equal(t, job.Name, fetchedJob.Name) - assert.Equal(t, job.JobType, fetchedJob.JobType) - assert.Equal(t, job.Priority, fetchedJob.Priority) - assert.Equal(t, api.JobStatusUnderConstruction, fetchedJob.Status) - assert.EqualValues(t, map[string]interface{}(job.Settings), fetchedJob.Settings) - assert.EqualValues(t, map[string]string(job.Metadata), fetchedJob.Metadata) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() - // Fetch tasks of job. - var dbJob Job - tx := db.gormDB.Where(&Job{UUID: job.JobID}).Find(&dbJob) - assert.NoError(t, tx.Error) - var tasks []Task - tx = db.gormDB.Where("job_id = ?", dbJob.ID).Find(&tasks) - assert.NoError(t, tx.Error) - - if len(tasks) != 3 { - t.Fatalf("expected 3 tasks, got %d", len(tasks)) + authoredJob := createTestAuthoredJobWithTasks() + err := db.StoreAuthoredJob(ctx, authoredJob) + if err != nil { + t.Fatalf("error storing authored job in DB: %v", err) } - // TODO: test task contents. - assert.Equal(t, api.TaskStatusQueued, tasks[0].Status) - assert.Equal(t, api.TaskStatusQueued, tasks[1].Status) - assert.Equal(t, api.TaskStatusQueued, tasks[2].Status) + dbJob, err := db.FetchJob(ctx, authoredJob.JobID) + if err != nil { + t.Fatalf("error fetching job from DB: %v", err) + } + if dbJob == nil { + t.Fatalf("nil job obtained from DB but with no error!") + } + + return ctx, db, dbJob, authoredJob } diff --git a/internal/manager/task_state_machine/task_state_machine.go b/internal/manager/task_state_machine/task_state_machine.go index 9b20ba3c..3e6c3288 100644 --- a/internal/manager/task_state_machine/task_state_machine.go +++ b/internal/manager/task_state_machine/task_state_machine.go @@ -42,7 +42,7 @@ type StateMachine struct { // Generate mock implementations of these interfaces. //go:generate go run github.com/golang/mock/mockgen -destination mocks/interfaces_mock.gen.go -package mocks gitlab.com/blender/flamenco-ng-poc/internal/manager/task_state_machine PersistenceService -type PersistenceService interface { // Subset of persistence.DB +type PersistenceService interface { SaveTask(ctx context.Context, task *persistence.Task) error SaveJobStatus(ctx context.Context, j *persistence.Job) error @@ -59,6 +59,9 @@ type PersistenceService interface { // Subset of persistence.DB statusesToUpdate []api.TaskStatus, taskStatus api.TaskStatus, activity string) error } +// PersistenceService should be a subset of persistence.DB +var _ PersistenceService = (*persistence.DB)(nil) + func NewStateMachine(persist PersistenceService) *StateMachine { return &StateMachine{ persist: persist,