Manager: actually pass context to Gorm queries
This commit is contained in:
parent
3d854078ba
commit
32af1ffaef
@ -53,6 +53,7 @@ type PersistenceService interface {
|
||||
FetchTask(ctx context.Context, taskID string) (*persistence.Task, error)
|
||||
SaveTask(ctx context.Context, task *persistence.Task) error
|
||||
SaveTaskActivity(ctx context.Context, t *persistence.Task) error
|
||||
FetchTasksOfWorkerInStatus(context.Context, *persistence.Worker, api.TaskStatus) ([]*persistence.Task, error)
|
||||
|
||||
CreateWorker(ctx context.Context, w *persistence.Worker) error
|
||||
FetchWorker(ctx context.Context, uuid string) (*persistence.Worker, error)
|
||||
@ -61,7 +62,7 @@ type PersistenceService interface {
|
||||
|
||||
// ScheduleTask finds a task to execute by the given worker, and assigns it to that worker.
|
||||
// If no task is available, (nil, nil) is returned, as this is not an error situation.
|
||||
ScheduleTask(w *persistence.Worker) (*persistence.Task, error)
|
||||
ScheduleTask(ctx context.Context, w *persistence.Worker) (*persistence.Task, error)
|
||||
}
|
||||
|
||||
var _ PersistenceService = (*persistence.DB)(nil)
|
||||
|
@ -155,18 +155,18 @@ func (mr *MockPersistenceServiceMockRecorder) SaveWorkerStatus(arg0, arg1 interf
|
||||
}
|
||||
|
||||
// ScheduleTask mocks base method.
|
||||
func (m *MockPersistenceService) ScheduleTask(arg0 *persistence.Worker) (*persistence.Task, error) {
|
||||
func (m *MockPersistenceService) ScheduleTask(arg0 context.Context, arg1 *persistence.Worker) (*persistence.Task, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ScheduleTask", arg0)
|
||||
ret := m.ctrl.Call(m, "ScheduleTask", arg0, arg1)
|
||||
ret0, _ := ret[0].(*persistence.Task)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ScheduleTask indicates an expected call of ScheduleTask.
|
||||
func (mr *MockPersistenceServiceMockRecorder) ScheduleTask(arg0 interface{}) *gomock.Call {
|
||||
func (mr *MockPersistenceServiceMockRecorder) ScheduleTask(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ScheduleTask", reflect.TypeOf((*MockPersistenceService)(nil).ScheduleTask), arg0)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ScheduleTask", reflect.TypeOf((*MockPersistenceService)(nil).ScheduleTask), arg0, arg1)
|
||||
}
|
||||
|
||||
// StoreAuthoredJob mocks base method.
|
||||
|
@ -243,7 +243,7 @@ func (f *Flamenco) ScheduleTask(e echo.Context) error {
|
||||
}
|
||||
|
||||
// Get a task to execute:
|
||||
dbTask, err := f.persist.ScheduleTask(worker)
|
||||
dbTask, err := f.persist.ScheduleTask(e.Request().Context(), worker)
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("error scheduling task for worker")
|
||||
return sendAPIError(e, http.StatusInternalServerError, "internal error finding a task for you: %v", err)
|
||||
|
@ -21,6 +21,7 @@ package api_impl
|
||||
* ***** END GPL LICENSE BLOCK ***** */
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -39,6 +40,8 @@ func TestTaskScheduleHappy(t *testing.T) {
|
||||
mf := newMockedFlamenco(mockCtrl)
|
||||
worker := testWorker()
|
||||
|
||||
echo := mf.prepareMockedRequest(&worker, nil)
|
||||
|
||||
// Expect a call into the persistence layer, which should return a scheduled task.
|
||||
job := persistence.Job{
|
||||
UUID: "583a7d59-887a-4c6c-b3e4-a753018f71b0",
|
||||
@ -47,13 +50,12 @@ func TestTaskScheduleHappy(t *testing.T) {
|
||||
UUID: "4107c7aa-e86d-4244-858b-6c4fce2af503",
|
||||
Job: &job,
|
||||
}
|
||||
mf.persistence.EXPECT().ScheduleTask(&worker).Return(&task, nil)
|
||||
mf.persistence.EXPECT().ScheduleTask(echo.Request().Context(), &worker).Return(&task, nil)
|
||||
|
||||
echoCtx := mf.prepareMockedRequest(&worker, nil)
|
||||
err := mf.flamenco.ScheduleTask(echoCtx)
|
||||
err := mf.flamenco.ScheduleTask(echo)
|
||||
assert.NoError(t, err)
|
||||
|
||||
resp := echoCtx.Response().Writer.(*httptest.ResponseRecorder)
|
||||
resp := echo.Response().Writer.(*httptest.ResponseRecorder)
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
// TODO: check that the returned JSON actually matches what we expect.
|
||||
}
|
||||
|
@ -125,7 +125,7 @@ func (db *DB) StoreAuthoredJob(ctx context.Context, authoredJob job_compilers.Au
|
||||
Metadata: StringStringMap(authoredJob.Metadata),
|
||||
}
|
||||
|
||||
if err := db.gormDB.Create(&dbJob).Error; err != nil {
|
||||
if err := db.gormDB.WithContext(ctx).Create(&dbJob).Error; err != nil {
|
||||
return fmt.Errorf("error storing job: %v", err)
|
||||
}
|
||||
|
||||
@ -149,7 +149,7 @@ func (db *DB) StoreAuthoredJob(ctx context.Context, authoredJob job_compilers.Au
|
||||
Commands: commands,
|
||||
// dependencies are stored below.
|
||||
}
|
||||
if err := db.gormDB.Create(&dbTask).Error; err != nil {
|
||||
if err := db.gormDB.WithContext(ctx).Create(&dbTask).Error; err != nil {
|
||||
return fmt.Errorf("error storing task: %v", err)
|
||||
}
|
||||
|
||||
@ -177,7 +177,7 @@ func (db *DB) StoreAuthoredJob(ctx context.Context, authoredJob job_compilers.Au
|
||||
}
|
||||
|
||||
dbTask.Dependencies = deps
|
||||
if err := db.gormDB.Save(dbTask).Error; err != nil {
|
||||
if err := db.gormDB.WithContext(ctx).Save(dbTask).Error; err != nil {
|
||||
return fmt.Errorf("unable to store dependencies of task %q: %w", authoredTask.UUID, err)
|
||||
}
|
||||
}
|
||||
@ -188,7 +188,7 @@ func (db *DB) StoreAuthoredJob(ctx context.Context, authoredJob job_compilers.Au
|
||||
|
||||
func (db *DB) FetchJob(ctx context.Context, jobUUID string) (*Job, error) {
|
||||
dbJob := Job{}
|
||||
findResult := db.gormDB.First(&dbJob, "uuid = ?", jobUUID)
|
||||
findResult := db.gormDB.WithContext(ctx).First(&dbJob, "uuid = ?", jobUUID)
|
||||
if findResult.Error != nil {
|
||||
return nil, findResult.Error
|
||||
}
|
||||
@ -197,24 +197,28 @@ func (db *DB) FetchJob(ctx context.Context, jobUUID string) (*Job, error) {
|
||||
}
|
||||
|
||||
func (db *DB) SaveJobStatus(ctx context.Context, j *Job) error {
|
||||
if err := db.gormDB.Model(j).Updates(Job{Status: j.Status}).Error; err != nil {
|
||||
return fmt.Errorf("error saving job status: %w", err)
|
||||
tx := db.gormDB.WithContext(ctx).
|
||||
Model(j).
|
||||
Updates(Job{Status: j.Status})
|
||||
if tx.Error != nil {
|
||||
return fmt.Errorf("error saving job status: %w", tx.Error)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) FetchTask(ctx context.Context, taskUUID string) (*Task, error) {
|
||||
dbTask := Task{}
|
||||
findResult := db.gormDB.Joins("Job").First(&dbTask, "tasks.uuid = ?", taskUUID)
|
||||
if findResult.Error != nil {
|
||||
return nil, findResult.Error
|
||||
tx := db.gormDB.WithContext(ctx).
|
||||
Joins("Job").
|
||||
First(&dbTask, "tasks.uuid = ?", taskUUID)
|
||||
if tx.Error != nil {
|
||||
return nil, tx.Error
|
||||
}
|
||||
|
||||
return &dbTask, nil
|
||||
}
|
||||
|
||||
func (db *DB) SaveTask(ctx context.Context, t *Task) error {
|
||||
if err := db.gormDB.Save(t).Error; err != nil {
|
||||
if err := db.gormDB.WithContext(ctx).Save(t).Error; err != nil {
|
||||
return fmt.Errorf("error saving task: %w", err)
|
||||
}
|
||||
return nil
|
||||
@ -229,7 +233,8 @@ func (db *DB) SaveTaskActivity(ctx context.Context, t *Task) error {
|
||||
|
||||
func (db *DB) JobHasTasksInStatus(ctx context.Context, job *Job, taskStatus api.TaskStatus) (bool, error) {
|
||||
var numTasksInStatus int64
|
||||
tx := db.gormDB.Model(&Task{}).
|
||||
tx := db.gormDB.WithContext(ctx).
|
||||
Model(&Task{}).
|
||||
Where("job_id", job.ID).
|
||||
Where("status", taskStatus).
|
||||
Count(&numTasksInStatus)
|
||||
@ -246,7 +251,8 @@ func (db *DB) CountTasksOfJobInStatus(ctx context.Context, job *Job, taskStatus
|
||||
}
|
||||
var results []Result
|
||||
|
||||
tx := db.gormDB.Debug().Model(&Task{}).
|
||||
tx := db.gormDB.WithContext(ctx).
|
||||
Model(&Task{}).
|
||||
Select("status, count(*) as num_tasks").
|
||||
Where("job_id", job.ID).
|
||||
Group("status").
|
||||
@ -274,7 +280,8 @@ func (db *DB) UpdateJobsTaskStatuses(ctx context.Context, job *Job,
|
||||
return errors.New("empty status not allowed")
|
||||
}
|
||||
|
||||
tx := db.gormDB.Model(Task{}).
|
||||
tx := db.gormDB.WithContext(ctx).
|
||||
Model(Task{}).
|
||||
Where("job_Id = ?", job.ID).
|
||||
Updates(Task{Status: taskStatus, Activity: activity})
|
||||
|
||||
@ -293,13 +300,10 @@ func (db *DB) UpdateJobsTaskStatusesConditional(ctx context.Context, job *Job,
|
||||
return errors.New("empty status not allowed")
|
||||
}
|
||||
|
||||
tx := db.gormDB.Debug().Model(Task{}).
|
||||
tx := db.gormDB.WithContext(ctx).
|
||||
Model(Task{}).
|
||||
Where("job_Id = ?", job.ID).
|
||||
Where("status in ?", statusesToUpdate).
|
||||
Updates(Task{Status: taskStatus, Activity: activity})
|
||||
|
||||
if tx.Error != nil {
|
||||
return tx.Error
|
||||
}
|
||||
return nil
|
||||
return tx.Error
|
||||
}
|
||||
|
@ -73,7 +73,8 @@ func TestStoreAuthoredJob(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestJobHasTasksInStatus(t *testing.T) {
|
||||
ctx, db, job, _ := jobTasksTestFixtures(t)
|
||||
ctx, ctxCancel, db, job, _ := jobTasksTestFixtures(t)
|
||||
defer ctxCancel()
|
||||
|
||||
hasTasks, err := db.JobHasTasksInStatus(ctx, job, api.TaskStatusQueued)
|
||||
assert.NoError(t, err)
|
||||
@ -85,7 +86,8 @@ func TestJobHasTasksInStatus(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCountTasksOfJobInStatus(t *testing.T) {
|
||||
ctx, db, job, authoredJob := jobTasksTestFixtures(t)
|
||||
ctx, ctxCancel, db, job, authoredJob := jobTasksTestFixtures(t)
|
||||
defer ctxCancel()
|
||||
|
||||
numQueued, numTotal, err := db.CountTasksOfJobInStatus(ctx, job, api.TaskStatusQueued)
|
||||
assert.NoError(t, err)
|
||||
@ -115,7 +117,8 @@ func TestCountTasksOfJobInStatus(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUpdateJobsTaskStatuses(t *testing.T) {
|
||||
ctx, db, job, authoredJob := jobTasksTestFixtures(t)
|
||||
ctx, ctxCancel, db, job, authoredJob := jobTasksTestFixtures(t)
|
||||
defer ctxCancel()
|
||||
|
||||
err := db.UpdateJobsTaskStatuses(ctx, job, api.TaskStatusSoftFailed, "testing æctivity")
|
||||
assert.NoError(t, err)
|
||||
@ -143,7 +146,8 @@ func TestUpdateJobsTaskStatuses(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUpdateJobsTaskStatusesConditional(t *testing.T) {
|
||||
ctx, db, job, authoredJob := jobTasksTestFixtures(t)
|
||||
ctx, ctxCancel, db, job, authoredJob := jobTasksTestFixtures(t)
|
||||
defer ctxCancel()
|
||||
|
||||
getTask := func(taskIndex int) *Task {
|
||||
task, err := db.FetchTask(ctx, authoredJob.Tasks[taskIndex].UUID)
|
||||
@ -247,11 +251,10 @@ func createTestAuthoredJobWithTasks() job_compilers.AuthoredJob {
|
||||
return job
|
||||
}
|
||||
|
||||
func jobTasksTestFixtures(t *testing.T) (context.Context, *DB, *Job, job_compilers.AuthoredJob) {
|
||||
func jobTasksTestFixtures(t *testing.T) (context.Context, context.CancelFunc, *DB, *Job, job_compilers.AuthoredJob) {
|
||||
db := CreateTestDB(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
authoredJob := createTestAuthoredJobWithTasks()
|
||||
err := db.StoreAuthoredJob(ctx, authoredJob)
|
||||
@ -267,5 +270,5 @@ func jobTasksTestFixtures(t *testing.T) (context.Context, *DB, *Job, job_compile
|
||||
t.Fatalf("nil job obtained from DB but with no error!")
|
||||
}
|
||||
|
||||
return ctx, db, dbJob, authoredJob
|
||||
return ctx, cancel, db, dbJob, authoredJob
|
||||
}
|
||||
|
@ -21,6 +21,7 @@ package persistence
|
||||
* ***** END GPL LICENSE BLOCK ***** */
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
@ -37,7 +38,7 @@ var (
|
||||
// ScheduleTask finds a task to execute by the given worker.
|
||||
// If no task is available, (nil, nil) is returned, as this is not an error situation.
|
||||
// NOTE: this does not also fetch returnedTask.Worker, but returnedTask.WorkerID is set.
|
||||
func (db *DB) ScheduleTask(w *Worker) (*Task, error) {
|
||||
func (db *DB) ScheduleTask(ctx context.Context, w *Worker) (*Task, error) {
|
||||
logger := log.With().Str("worker", w.UUID).Logger()
|
||||
logger.Debug().Msg("finding task for worker")
|
||||
|
||||
@ -45,7 +46,7 @@ func (db *DB) ScheduleTask(w *Worker) (*Task, error) {
|
||||
// 1. find task, and
|
||||
// 2. assign the task to the worker.
|
||||
var task *Task
|
||||
txErr := db.gormDB.Transaction(func(tx *gorm.DB) error {
|
||||
txErr := db.gormDB.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
var err error
|
||||
task, err = findTaskForWorker(tx, w)
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
|
@ -33,22 +33,28 @@ import (
|
||||
|
||||
func TestNoTasks(t *testing.T) {
|
||||
db := CreateTestDB(t)
|
||||
ctx, ctxCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer ctxCancel()
|
||||
|
||||
w := linuxWorker(t, db)
|
||||
|
||||
task, err := db.ScheduleTask(&w)
|
||||
task, err := db.ScheduleTask(ctx, &w)
|
||||
assert.Nil(t, task)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestOneJobOneTask(t *testing.T) {
|
||||
db := CreateTestDB(t)
|
||||
ctx, ctxCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer ctxCancel()
|
||||
|
||||
w := linuxWorker(t, db)
|
||||
|
||||
authTask := authorTestTask("the task", "blender")
|
||||
atj := authorTestJob("b6a1d859-122f-4791-8b78-b943329a9989", "simple-blender-render", authTask)
|
||||
job := constructTestJob(t, db, atj)
|
||||
|
||||
task, err := db.ScheduleTask(&w)
|
||||
task, err := db.ScheduleTask(ctx, &w)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Check the returned task.
|
||||
@ -75,6 +81,9 @@ func TestOneJobOneTask(t *testing.T) {
|
||||
|
||||
func TestOneJobThreeTasksByPrio(t *testing.T) {
|
||||
db := CreateTestDB(t)
|
||||
ctx, ctxCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer ctxCancel()
|
||||
|
||||
w := linuxWorker(t, db)
|
||||
|
||||
att1 := authorTestTask("1 low-prio task", "blender")
|
||||
@ -88,7 +97,7 @@ func TestOneJobThreeTasksByPrio(t *testing.T) {
|
||||
|
||||
job := constructTestJob(t, db, atj)
|
||||
|
||||
task, err := db.ScheduleTask(&w)
|
||||
task, err := db.ScheduleTask(ctx, &w)
|
||||
assert.NoError(t, err)
|
||||
if task == nil {
|
||||
t.Fatal("task is nil")
|
||||
@ -104,6 +113,9 @@ func TestOneJobThreeTasksByPrio(t *testing.T) {
|
||||
|
||||
func TestOneJobThreeTasksByDependencies(t *testing.T) {
|
||||
db := CreateTestDB(t)
|
||||
ctx, ctxCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer ctxCancel()
|
||||
|
||||
w := linuxWorker(t, db)
|
||||
|
||||
att1 := authorTestTask("1 low-prio task", "blender")
|
||||
@ -117,7 +129,7 @@ func TestOneJobThreeTasksByDependencies(t *testing.T) {
|
||||
att1, att2, att3)
|
||||
job := constructTestJob(t, db, atj)
|
||||
|
||||
task, err := db.ScheduleTask(&w)
|
||||
task, err := db.ScheduleTask(ctx, &w)
|
||||
assert.NoError(t, err)
|
||||
if task == nil {
|
||||
t.Fatal("task is nil")
|
||||
@ -128,6 +140,9 @@ func TestOneJobThreeTasksByDependencies(t *testing.T) {
|
||||
|
||||
func TestTwoJobsThreeTasks(t *testing.T) {
|
||||
db := CreateTestDB(t)
|
||||
ctx, ctxCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer ctxCancel()
|
||||
|
||||
w := linuxWorker(t, db)
|
||||
|
||||
att1_1 := authorTestTask("1.1 low-prio task", "blender")
|
||||
@ -155,7 +170,7 @@ func TestTwoJobsThreeTasks(t *testing.T) {
|
||||
constructTestJob(t, db, atj1)
|
||||
job2 := constructTestJob(t, db, atj2)
|
||||
|
||||
task, err := db.ScheduleTask(&w)
|
||||
task, err := db.ScheduleTask(ctx, &w)
|
||||
assert.NoError(t, err)
|
||||
if task == nil {
|
||||
t.Fatal("task is nil")
|
||||
|
@ -51,7 +51,7 @@ func (w *Worker) TaskTypes() []string {
|
||||
}
|
||||
|
||||
func (db *DB) CreateWorker(ctx context.Context, w *Worker) error {
|
||||
if err := db.gormDB.Create(w).Error; err != nil {
|
||||
if err := db.gormDB.WithContext(ctx).Create(w).Error; err != nil {
|
||||
return fmt.Errorf("error creating new worker: %w", err)
|
||||
}
|
||||
return nil
|
||||
@ -59,15 +59,16 @@ func (db *DB) CreateWorker(ctx context.Context, w *Worker) error {
|
||||
|
||||
func (db *DB) FetchWorker(ctx context.Context, uuid string) (*Worker, error) {
|
||||
w := Worker{}
|
||||
findResult := db.gormDB.First(&w, "uuid = ?", uuid)
|
||||
if findResult.Error != nil {
|
||||
return nil, findResult.Error
|
||||
tx := db.gormDB.WithContext(ctx).
|
||||
First(&w, "uuid = ?", uuid)
|
||||
if tx.Error != nil {
|
||||
return nil, tx.Error
|
||||
}
|
||||
return &w, nil
|
||||
}
|
||||
|
||||
func (db *DB) SaveWorkerStatus(ctx context.Context, w *Worker) error {
|
||||
err := db.gormDB.
|
||||
err := db.gormDB.WithContext(ctx).
|
||||
Model(w).
|
||||
Select("status", "status_requested").
|
||||
Updates(Worker{
|
||||
@ -81,7 +82,7 @@ func (db *DB) SaveWorkerStatus(ctx context.Context, w *Worker) error {
|
||||
}
|
||||
|
||||
func (db *DB) SaveWorker(ctx context.Context, w *Worker) error {
|
||||
if err := db.gormDB.Save(w).Error; err != nil {
|
||||
if err := db.gormDB.WithContext(ctx).Save(w).Error; err != nil {
|
||||
return fmt.Errorf("error saving worker: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
Loading…
x
Reference in New Issue
Block a user