From 32af1ffaef93b816626c995bd4bb04eff36adb05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sybren=20A=2E=20St=C3=BCvel?= Date: Mon, 28 Feb 2022 11:47:55 +0100 Subject: [PATCH] Manager: actually pass context to Gorm queries --- internal/manager/api_impl/api_impl.go | 3 +- .../api_impl/mocks/api_impl_mock.gen.go | 8 ++-- internal/manager/api_impl/workers.go | 2 +- internal/manager/api_impl/workers_test.go | 10 +++-- internal/manager/persistence/jobs.go | 44 ++++++++++--------- internal/manager/persistence/jobs_test.go | 17 ++++--- .../manager/persistence/task_scheduler.go | 5 ++- .../persistence/task_scheduler_test.go | 25 ++++++++--- internal/manager/persistence/workers.go | 13 +++--- 9 files changed, 77 insertions(+), 50 deletions(-) diff --git a/internal/manager/api_impl/api_impl.go b/internal/manager/api_impl/api_impl.go index 79f3467a..a6335a47 100644 --- a/internal/manager/api_impl/api_impl.go +++ b/internal/manager/api_impl/api_impl.go @@ -53,6 +53,7 @@ type PersistenceService interface { FetchTask(ctx context.Context, taskID string) (*persistence.Task, error) SaveTask(ctx context.Context, task *persistence.Task) error SaveTaskActivity(ctx context.Context, t *persistence.Task) error + FetchTasksOfWorkerInStatus(context.Context, *persistence.Worker, api.TaskStatus) ([]*persistence.Task, error) CreateWorker(ctx context.Context, w *persistence.Worker) error FetchWorker(ctx context.Context, uuid string) (*persistence.Worker, error) @@ -61,7 +62,7 @@ type PersistenceService interface { // ScheduleTask finds a task to execute by the given worker, and assigns it to that worker. // If no task is available, (nil, nil) is returned, as this is not an error situation. - ScheduleTask(w *persistence.Worker) (*persistence.Task, error) + ScheduleTask(ctx context.Context, w *persistence.Worker) (*persistence.Task, error) } var _ PersistenceService = (*persistence.DB)(nil) diff --git a/internal/manager/api_impl/mocks/api_impl_mock.gen.go b/internal/manager/api_impl/mocks/api_impl_mock.gen.go index 22da51f3..ce4f277c 100644 --- a/internal/manager/api_impl/mocks/api_impl_mock.gen.go +++ b/internal/manager/api_impl/mocks/api_impl_mock.gen.go @@ -155,18 +155,18 @@ func (mr *MockPersistenceServiceMockRecorder) SaveWorkerStatus(arg0, arg1 interf } // ScheduleTask mocks base method. -func (m *MockPersistenceService) ScheduleTask(arg0 *persistence.Worker) (*persistence.Task, error) { +func (m *MockPersistenceService) ScheduleTask(arg0 context.Context, arg1 *persistence.Worker) (*persistence.Task, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ScheduleTask", arg0) + ret := m.ctrl.Call(m, "ScheduleTask", arg0, arg1) ret0, _ := ret[0].(*persistence.Task) ret1, _ := ret[1].(error) return ret0, ret1 } // ScheduleTask indicates an expected call of ScheduleTask. -func (mr *MockPersistenceServiceMockRecorder) ScheduleTask(arg0 interface{}) *gomock.Call { +func (mr *MockPersistenceServiceMockRecorder) ScheduleTask(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ScheduleTask", reflect.TypeOf((*MockPersistenceService)(nil).ScheduleTask), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ScheduleTask", reflect.TypeOf((*MockPersistenceService)(nil).ScheduleTask), arg0, arg1) } // StoreAuthoredJob mocks base method. diff --git a/internal/manager/api_impl/workers.go b/internal/manager/api_impl/workers.go index 6f472b90..3a45eff2 100644 --- a/internal/manager/api_impl/workers.go +++ b/internal/manager/api_impl/workers.go @@ -243,7 +243,7 @@ func (f *Flamenco) ScheduleTask(e echo.Context) error { } // Get a task to execute: - dbTask, err := f.persist.ScheduleTask(worker) + dbTask, err := f.persist.ScheduleTask(e.Request().Context(), worker) if err != nil { logger.Warn().Err(err).Msg("error scheduling task for worker") return sendAPIError(e, http.StatusInternalServerError, "internal error finding a task for you: %v", err) diff --git a/internal/manager/api_impl/workers_test.go b/internal/manager/api_impl/workers_test.go index 2616f2de..b0da8839 100644 --- a/internal/manager/api_impl/workers_test.go +++ b/internal/manager/api_impl/workers_test.go @@ -21,6 +21,7 @@ package api_impl * ***** END GPL LICENSE BLOCK ***** */ import ( + "context" "encoding/json" "net/http" "net/http/httptest" @@ -39,6 +40,8 @@ func TestTaskScheduleHappy(t *testing.T) { mf := newMockedFlamenco(mockCtrl) worker := testWorker() + echo := mf.prepareMockedRequest(&worker, nil) + // Expect a call into the persistence layer, which should return a scheduled task. job := persistence.Job{ UUID: "583a7d59-887a-4c6c-b3e4-a753018f71b0", @@ -47,13 +50,12 @@ func TestTaskScheduleHappy(t *testing.T) { UUID: "4107c7aa-e86d-4244-858b-6c4fce2af503", Job: &job, } - mf.persistence.EXPECT().ScheduleTask(&worker).Return(&task, nil) + mf.persistence.EXPECT().ScheduleTask(echo.Request().Context(), &worker).Return(&task, nil) - echoCtx := mf.prepareMockedRequest(&worker, nil) - err := mf.flamenco.ScheduleTask(echoCtx) + err := mf.flamenco.ScheduleTask(echo) assert.NoError(t, err) - resp := echoCtx.Response().Writer.(*httptest.ResponseRecorder) + resp := echo.Response().Writer.(*httptest.ResponseRecorder) assert.Equal(t, http.StatusOK, resp.Code) // TODO: check that the returned JSON actually matches what we expect. } diff --git a/internal/manager/persistence/jobs.go b/internal/manager/persistence/jobs.go index e68489ac..f9cdef54 100644 --- a/internal/manager/persistence/jobs.go +++ b/internal/manager/persistence/jobs.go @@ -125,7 +125,7 @@ func (db *DB) StoreAuthoredJob(ctx context.Context, authoredJob job_compilers.Au Metadata: StringStringMap(authoredJob.Metadata), } - if err := db.gormDB.Create(&dbJob).Error; err != nil { + if err := db.gormDB.WithContext(ctx).Create(&dbJob).Error; err != nil { return fmt.Errorf("error storing job: %v", err) } @@ -149,7 +149,7 @@ func (db *DB) StoreAuthoredJob(ctx context.Context, authoredJob job_compilers.Au Commands: commands, // dependencies are stored below. } - if err := db.gormDB.Create(&dbTask).Error; err != nil { + if err := db.gormDB.WithContext(ctx).Create(&dbTask).Error; err != nil { return fmt.Errorf("error storing task: %v", err) } @@ -177,7 +177,7 @@ func (db *DB) StoreAuthoredJob(ctx context.Context, authoredJob job_compilers.Au } dbTask.Dependencies = deps - if err := db.gormDB.Save(dbTask).Error; err != nil { + if err := db.gormDB.WithContext(ctx).Save(dbTask).Error; err != nil { return fmt.Errorf("unable to store dependencies of task %q: %w", authoredTask.UUID, err) } } @@ -188,7 +188,7 @@ func (db *DB) StoreAuthoredJob(ctx context.Context, authoredJob job_compilers.Au func (db *DB) FetchJob(ctx context.Context, jobUUID string) (*Job, error) { dbJob := Job{} - findResult := db.gormDB.First(&dbJob, "uuid = ?", jobUUID) + findResult := db.gormDB.WithContext(ctx).First(&dbJob, "uuid = ?", jobUUID) if findResult.Error != nil { return nil, findResult.Error } @@ -197,24 +197,28 @@ func (db *DB) FetchJob(ctx context.Context, jobUUID string) (*Job, error) { } func (db *DB) SaveJobStatus(ctx context.Context, j *Job) error { - if err := db.gormDB.Model(j).Updates(Job{Status: j.Status}).Error; err != nil { - return fmt.Errorf("error saving job status: %w", err) + tx := db.gormDB.WithContext(ctx). + Model(j). + Updates(Job{Status: j.Status}) + if tx.Error != nil { + return fmt.Errorf("error saving job status: %w", tx.Error) } return nil } func (db *DB) FetchTask(ctx context.Context, taskUUID string) (*Task, error) { dbTask := Task{} - findResult := db.gormDB.Joins("Job").First(&dbTask, "tasks.uuid = ?", taskUUID) - if findResult.Error != nil { - return nil, findResult.Error + tx := db.gormDB.WithContext(ctx). + Joins("Job"). + First(&dbTask, "tasks.uuid = ?", taskUUID) + if tx.Error != nil { + return nil, tx.Error } - return &dbTask, nil } func (db *DB) SaveTask(ctx context.Context, t *Task) error { - if err := db.gormDB.Save(t).Error; err != nil { + if err := db.gormDB.WithContext(ctx).Save(t).Error; err != nil { return fmt.Errorf("error saving task: %w", err) } return nil @@ -229,7 +233,8 @@ func (db *DB) SaveTaskActivity(ctx context.Context, t *Task) error { func (db *DB) JobHasTasksInStatus(ctx context.Context, job *Job, taskStatus api.TaskStatus) (bool, error) { var numTasksInStatus int64 - tx := db.gormDB.Model(&Task{}). + tx := db.gormDB.WithContext(ctx). + Model(&Task{}). Where("job_id", job.ID). Where("status", taskStatus). Count(&numTasksInStatus) @@ -246,7 +251,8 @@ func (db *DB) CountTasksOfJobInStatus(ctx context.Context, job *Job, taskStatus } var results []Result - tx := db.gormDB.Debug().Model(&Task{}). + tx := db.gormDB.WithContext(ctx). + Model(&Task{}). Select("status, count(*) as num_tasks"). Where("job_id", job.ID). Group("status"). @@ -274,7 +280,8 @@ func (db *DB) UpdateJobsTaskStatuses(ctx context.Context, job *Job, return errors.New("empty status not allowed") } - tx := db.gormDB.Model(Task{}). + tx := db.gormDB.WithContext(ctx). + Model(Task{}). Where("job_Id = ?", job.ID). Updates(Task{Status: taskStatus, Activity: activity}) @@ -293,13 +300,10 @@ func (db *DB) UpdateJobsTaskStatusesConditional(ctx context.Context, job *Job, return errors.New("empty status not allowed") } - tx := db.gormDB.Debug().Model(Task{}). + tx := db.gormDB.WithContext(ctx). + Model(Task{}). Where("job_Id = ?", job.ID). Where("status in ?", statusesToUpdate). Updates(Task{Status: taskStatus, Activity: activity}) - - if tx.Error != nil { - return tx.Error - } - return nil + return tx.Error } diff --git a/internal/manager/persistence/jobs_test.go b/internal/manager/persistence/jobs_test.go index 657a16f0..c6dc21db 100644 --- a/internal/manager/persistence/jobs_test.go +++ b/internal/manager/persistence/jobs_test.go @@ -73,7 +73,8 @@ func TestStoreAuthoredJob(t *testing.T) { } func TestJobHasTasksInStatus(t *testing.T) { - ctx, db, job, _ := jobTasksTestFixtures(t) + ctx, ctxCancel, db, job, _ := jobTasksTestFixtures(t) + defer ctxCancel() hasTasks, err := db.JobHasTasksInStatus(ctx, job, api.TaskStatusQueued) assert.NoError(t, err) @@ -85,7 +86,8 @@ func TestJobHasTasksInStatus(t *testing.T) { } func TestCountTasksOfJobInStatus(t *testing.T) { - ctx, db, job, authoredJob := jobTasksTestFixtures(t) + ctx, ctxCancel, db, job, authoredJob := jobTasksTestFixtures(t) + defer ctxCancel() numQueued, numTotal, err := db.CountTasksOfJobInStatus(ctx, job, api.TaskStatusQueued) assert.NoError(t, err) @@ -115,7 +117,8 @@ func TestCountTasksOfJobInStatus(t *testing.T) { } func TestUpdateJobsTaskStatuses(t *testing.T) { - ctx, db, job, authoredJob := jobTasksTestFixtures(t) + ctx, ctxCancel, db, job, authoredJob := jobTasksTestFixtures(t) + defer ctxCancel() err := db.UpdateJobsTaskStatuses(ctx, job, api.TaskStatusSoftFailed, "testing æctivity") assert.NoError(t, err) @@ -143,7 +146,8 @@ func TestUpdateJobsTaskStatuses(t *testing.T) { } func TestUpdateJobsTaskStatusesConditional(t *testing.T) { - ctx, db, job, authoredJob := jobTasksTestFixtures(t) + ctx, ctxCancel, db, job, authoredJob := jobTasksTestFixtures(t) + defer ctxCancel() getTask := func(taskIndex int) *Task { task, err := db.FetchTask(ctx, authoredJob.Tasks[taskIndex].UUID) @@ -247,11 +251,10 @@ func createTestAuthoredJobWithTasks() job_compilers.AuthoredJob { return job } -func jobTasksTestFixtures(t *testing.T) (context.Context, *DB, *Job, job_compilers.AuthoredJob) { +func jobTasksTestFixtures(t *testing.T) (context.Context, context.CancelFunc, *DB, *Job, job_compilers.AuthoredJob) { db := CreateTestDB(t) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() authoredJob := createTestAuthoredJobWithTasks() err := db.StoreAuthoredJob(ctx, authoredJob) @@ -267,5 +270,5 @@ func jobTasksTestFixtures(t *testing.T) (context.Context, *DB, *Job, job_compile t.Fatalf("nil job obtained from DB but with no error!") } - return ctx, db, dbJob, authoredJob + return ctx, cancel, db, dbJob, authoredJob } diff --git a/internal/manager/persistence/task_scheduler.go b/internal/manager/persistence/task_scheduler.go index 3971362d..7c5bcba0 100644 --- a/internal/manager/persistence/task_scheduler.go +++ b/internal/manager/persistence/task_scheduler.go @@ -21,6 +21,7 @@ package persistence * ***** END GPL LICENSE BLOCK ***** */ import ( + "context" "fmt" "github.com/rs/zerolog/log" @@ -37,7 +38,7 @@ var ( // ScheduleTask finds a task to execute by the given worker. // If no task is available, (nil, nil) is returned, as this is not an error situation. // NOTE: this does not also fetch returnedTask.Worker, but returnedTask.WorkerID is set. -func (db *DB) ScheduleTask(w *Worker) (*Task, error) { +func (db *DB) ScheduleTask(ctx context.Context, w *Worker) (*Task, error) { logger := log.With().Str("worker", w.UUID).Logger() logger.Debug().Msg("finding task for worker") @@ -45,7 +46,7 @@ func (db *DB) ScheduleTask(w *Worker) (*Task, error) { // 1. find task, and // 2. assign the task to the worker. var task *Task - txErr := db.gormDB.Transaction(func(tx *gorm.DB) error { + txErr := db.gormDB.WithContext(ctx).Transaction(func(tx *gorm.DB) error { var err error task, err = findTaskForWorker(tx, w) if err == gorm.ErrRecordNotFound { diff --git a/internal/manager/persistence/task_scheduler_test.go b/internal/manager/persistence/task_scheduler_test.go index 69aa641f..488a21f4 100644 --- a/internal/manager/persistence/task_scheduler_test.go +++ b/internal/manager/persistence/task_scheduler_test.go @@ -33,22 +33,28 @@ import ( func TestNoTasks(t *testing.T) { db := CreateTestDB(t) + ctx, ctxCancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer ctxCancel() + w := linuxWorker(t, db) - task, err := db.ScheduleTask(&w) + task, err := db.ScheduleTask(ctx, &w) assert.Nil(t, task) assert.NoError(t, err) } func TestOneJobOneTask(t *testing.T) { db := CreateTestDB(t) + ctx, ctxCancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer ctxCancel() + w := linuxWorker(t, db) authTask := authorTestTask("the task", "blender") atj := authorTestJob("b6a1d859-122f-4791-8b78-b943329a9989", "simple-blender-render", authTask) job := constructTestJob(t, db, atj) - task, err := db.ScheduleTask(&w) + task, err := db.ScheduleTask(ctx, &w) assert.NoError(t, err) // Check the returned task. @@ -75,6 +81,9 @@ func TestOneJobOneTask(t *testing.T) { func TestOneJobThreeTasksByPrio(t *testing.T) { db := CreateTestDB(t) + ctx, ctxCancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer ctxCancel() + w := linuxWorker(t, db) att1 := authorTestTask("1 low-prio task", "blender") @@ -88,7 +97,7 @@ func TestOneJobThreeTasksByPrio(t *testing.T) { job := constructTestJob(t, db, atj) - task, err := db.ScheduleTask(&w) + task, err := db.ScheduleTask(ctx, &w) assert.NoError(t, err) if task == nil { t.Fatal("task is nil") @@ -104,6 +113,9 @@ func TestOneJobThreeTasksByPrio(t *testing.T) { func TestOneJobThreeTasksByDependencies(t *testing.T) { db := CreateTestDB(t) + ctx, ctxCancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer ctxCancel() + w := linuxWorker(t, db) att1 := authorTestTask("1 low-prio task", "blender") @@ -117,7 +129,7 @@ func TestOneJobThreeTasksByDependencies(t *testing.T) { att1, att2, att3) job := constructTestJob(t, db, atj) - task, err := db.ScheduleTask(&w) + task, err := db.ScheduleTask(ctx, &w) assert.NoError(t, err) if task == nil { t.Fatal("task is nil") @@ -128,6 +140,9 @@ func TestOneJobThreeTasksByDependencies(t *testing.T) { func TestTwoJobsThreeTasks(t *testing.T) { db := CreateTestDB(t) + ctx, ctxCancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer ctxCancel() + w := linuxWorker(t, db) att1_1 := authorTestTask("1.1 low-prio task", "blender") @@ -155,7 +170,7 @@ func TestTwoJobsThreeTasks(t *testing.T) { constructTestJob(t, db, atj1) job2 := constructTestJob(t, db, atj2) - task, err := db.ScheduleTask(&w) + task, err := db.ScheduleTask(ctx, &w) assert.NoError(t, err) if task == nil { t.Fatal("task is nil") diff --git a/internal/manager/persistence/workers.go b/internal/manager/persistence/workers.go index 37ecd8b4..04a19374 100644 --- a/internal/manager/persistence/workers.go +++ b/internal/manager/persistence/workers.go @@ -51,7 +51,7 @@ func (w *Worker) TaskTypes() []string { } func (db *DB) CreateWorker(ctx context.Context, w *Worker) error { - if err := db.gormDB.Create(w).Error; err != nil { + if err := db.gormDB.WithContext(ctx).Create(w).Error; err != nil { return fmt.Errorf("error creating new worker: %w", err) } return nil @@ -59,15 +59,16 @@ func (db *DB) CreateWorker(ctx context.Context, w *Worker) error { func (db *DB) FetchWorker(ctx context.Context, uuid string) (*Worker, error) { w := Worker{} - findResult := db.gormDB.First(&w, "uuid = ?", uuid) - if findResult.Error != nil { - return nil, findResult.Error + tx := db.gormDB.WithContext(ctx). + First(&w, "uuid = ?", uuid) + if tx.Error != nil { + return nil, tx.Error } return &w, nil } func (db *DB) SaveWorkerStatus(ctx context.Context, w *Worker) error { - err := db.gormDB. + err := db.gormDB.WithContext(ctx). Model(w). Select("status", "status_requested"). Updates(Worker{ @@ -81,7 +82,7 @@ func (db *DB) SaveWorkerStatus(ctx context.Context, w *Worker) error { } func (db *DB) SaveWorker(ctx context.Context, w *Worker) error { - if err := db.gormDB.Save(w).Error; err != nil { + if err := db.gormDB.WithContext(ctx).Save(w).Error; err != nil { return fmt.Errorf("error saving worker: %w", err) } return nil