From 7689a988b12df6e54cb4d0b540057a376ad55d70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sybren=20A=2E=20St=C3=BCvel?= Date: Mon, 28 Feb 2022 12:06:50 +0100 Subject: [PATCH] Manager: re-queue tasks of worker when signing off --- .../api_impl/mocks/api_impl_mock.gen.go | 15 +++++ internal/manager/api_impl/workers.go | 45 +++++++++++-- internal/manager/api_impl/workers_test.go | 51 ++++++++++++++ internal/manager/persistence/jobs.go | 28 ++++++++ internal/manager/persistence/jobs_test.go | 67 +++++++++++++++++++ 5 files changed, 202 insertions(+), 4 deletions(-) diff --git a/internal/manager/api_impl/mocks/api_impl_mock.gen.go b/internal/manager/api_impl/mocks/api_impl_mock.gen.go index ce4f277c..2c4cd764 100644 --- a/internal/manager/api_impl/mocks/api_impl_mock.gen.go +++ b/internal/manager/api_impl/mocks/api_impl_mock.gen.go @@ -83,6 +83,21 @@ func (mr *MockPersistenceServiceMockRecorder) FetchTask(arg0, arg1 interface{}) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchTask", reflect.TypeOf((*MockPersistenceService)(nil).FetchTask), arg0, arg1) } +// FetchTasksOfWorkerInStatus mocks base method. +func (m *MockPersistenceService) FetchTasksOfWorkerInStatus(arg0 context.Context, arg1 *persistence.Worker, arg2 api.TaskStatus) ([]*persistence.Task, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FetchTasksOfWorkerInStatus", arg0, arg1, arg2) + ret0, _ := ret[0].([]*persistence.Task) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FetchTasksOfWorkerInStatus indicates an expected call of FetchTasksOfWorkerInStatus. +func (mr *MockPersistenceServiceMockRecorder) FetchTasksOfWorkerInStatus(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchTasksOfWorkerInStatus", reflect.TypeOf((*MockPersistenceService)(nil).FetchTasksOfWorkerInStatus), arg0, arg1, arg2) +} + // FetchWorker mocks base method. func (m *MockPersistenceService) FetchWorker(arg0 context.Context, arg1 string) (*persistence.Worker, error) { m.ctrl.T.Helper() diff --git a/internal/manager/api_impl/workers.go b/internal/manager/api_impl/workers.go index 3a45eff2..ab57e3f1 100644 --- a/internal/manager/api_impl/workers.go +++ b/internal/manager/api_impl/workers.go @@ -21,12 +21,15 @@ package api_impl * ***** END GPL LICENSE BLOCK ***** */ import ( + "context" "fmt" "net/http" "strings" + "time" "github.com/google/uuid" "github.com/labstack/echo/v4" + "github.com/rs/zerolog" "golang.org/x/crypto/bcrypt" "gitlab.com/blender/flamenco-ng-poc/internal/manager/persistence" @@ -148,10 +151,12 @@ func (f *Flamenco) SignOff(e echo.Context) error { w.StatusRequested = "" } - // TODO: check whether we should pass the request context here, or a generic - // background context, as this should be stored even when the HTTP connection - // is aborted. - err = f.persist.SaveWorkerStatus(e.Request().Context(), w) + // Pass a generic background context, as these changes should be stored even + // when the HTTP connection is aborted. + ctx, ctxCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer ctxCancel() + + err = f.persist.SaveWorkerStatus(ctx, w) if err != nil { logger.Warn(). Err(err). @@ -160,9 +165,41 @@ func (f *Flamenco) SignOff(e echo.Context) error { return sendAPIError(e, http.StatusInternalServerError, "error storing new status in database") } + // Re-queue all tasks (should be only one) this worker is now working on. + err = f.workerRequeueActiveTasks(ctx, logger, w) + if err != nil { + return sendAPIError(e, http.StatusInternalServerError, "error re-queueing your tasks") + } + return e.String(http.StatusNoContent, "") } +// workerRequeueActiveTasks re-queues all active tasks (should be max one) of this worker. +func (f *Flamenco) workerRequeueActiveTasks(ctx context.Context, logger zerolog.Logger, worker *persistence.Worker) error { + // Fetch the tasks to update. + tasks, err := f.persist.FetchTasksOfWorkerInStatus(ctx, worker, api.TaskStatusActive) + if err != nil { + return fmt.Errorf("fetching tasks of worker %s in status %q: %w", worker.UUID, api.TaskStatusActive, err) + } + + // Run each task change through the task state machine. + var lastErr error + for _, task := range tasks { + logger.Info(). + Str("task", task.UUID). + Msg("re-queueing task") + err := f.stateMachine.TaskStatusChange(ctx, task, api.TaskStatusQueued) + if err != nil { + logger.Warn().Err(err). + Str("task", task.UUID). + Msg("error queueing task on worker sign-off") + lastErr = err + } + } + + return lastErr +} + // (GET /api/worker/state) func (f *Flamenco) WorkerState(e echo.Context) error { worker := requestWorkerOrPanic(e) diff --git a/internal/manager/api_impl/workers_test.go b/internal/manager/api_impl/workers_test.go index b0da8839..7044e756 100644 --- a/internal/manager/api_impl/workers_test.go +++ b/internal/manager/api_impl/workers_test.go @@ -102,3 +102,54 @@ func TestTaskScheduleOtherStatusRequested(t *testing.T) { assert.NoError(t, err) assert.Equal(t, worker.StatusRequested, responseBody.StatusRequested) } + +func TestWorkerSignoffTaskRequeue(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + mf := newMockedFlamenco(mockCtrl) + worker := testWorker() + + job := persistence.Job{ + UUID: "583a7d59-887a-4c6c-b3e4-a753018f71b0", + } + // Mock that the worker has two active tasks. It shouldn't happen, but even + // when it does, both should be requeued when the worker signs off. + task1 := persistence.Task{ + UUID: "4107c7aa-e86d-4244-858b-6c4fce2af503", + Job: &job, + Status: api.TaskStatusActive, + } + task2 := persistence.Task{ + UUID: "beb3f39b-57a5-44bf-a0ad-533e3513a0b6", + Job: &job, + Status: api.TaskStatusActive, + } + workerTasks := []*persistence.Task{&task1, &task2} + + // Signing off should be handled completely, even when the HTTP connection + // breaks. This means using a different context than the one passed by Echo. + echo := mf.prepareMockedRequest(&worker, nil) + expectCtx := gomock.Not(gomock.Eq(echo.Request().Context())) + + // Expect worker's tasks to be re-queued. + mf.persistence.EXPECT(). + FetchTasksOfWorkerInStatus(expectCtx, &worker, api.TaskStatusActive). + Return(workerTasks, nil) + mf.stateMachine.EXPECT().TaskStatusChange(expectCtx, &task1, api.TaskStatusQueued) + mf.stateMachine.EXPECT().TaskStatusChange(expectCtx, &task2, api.TaskStatusQueued) + + // Expect worker to be saved as 'offline'. + mf.persistence.EXPECT(). + SaveWorkerStatus(expectCtx, &worker). + Do(func(ctx context.Context, w *persistence.Worker) error { + assert.Equal(t, api.WorkerStatusOffline, w.Status) + return nil + }) + + err := mf.flamenco.SignOff(echo) + assert.NoError(t, err) + + resp := echo.Response().Writer.(*httptest.ResponseRecorder) + assert.Equal(t, http.StatusNoContent, resp.Code) +} diff --git a/internal/manager/persistence/jobs.go b/internal/manager/persistence/jobs.go index f9cdef54..8706f912 100644 --- a/internal/manager/persistence/jobs.go +++ b/internal/manager/persistence/jobs.go @@ -231,6 +231,34 @@ func (db *DB) SaveTaskActivity(ctx context.Context, t *Task) error { return nil } +func (db *DB) TaskAssignToWorker(ctx context.Context, t *Task, w *Worker) error { + tx := db.gormDB.WithContext(ctx). + Model(t).Updates(Task{WorkerID: &w.ID}) + if tx.Error != nil { + return fmt.Errorf("assigning task %s to worker %s: %w", t.UUID, w.UUID, tx.Error) + } + + // Gorm updates t.WorkerID itself, but not t.Worker (even when it's added to + // the Updates() call above). + t.Worker = w + + return nil +} + +func (db *DB) FetchTasksOfWorkerInStatus(ctx context.Context, worker *Worker, taskStatus api.TaskStatus) ([]*Task, error) { + result := []*Task{} + tx := db.gormDB.WithContext(ctx). + Model(&Task{}). + Joins("Job"). + Where("tasks.worker_id = ?", worker.ID). + Where("tasks.status = ?", taskStatus). + Scan(&result) + if tx.Error != nil { + return nil, fmt.Errorf("finding tasks of worker %s in status %q: %w", worker.UUID, taskStatus, tx.Error) + } + return result, nil +} + func (db *DB) JobHasTasksInStatus(ctx context.Context, job *Job, taskStatus api.TaskStatus) (bool, error) { var numTasksInStatus int64 tx := db.gormDB.WithContext(ctx). diff --git a/internal/manager/persistence/jobs_test.go b/internal/manager/persistence/jobs_test.go index c6dc21db..83e77970 100644 --- a/internal/manager/persistence/jobs_test.go +++ b/internal/manager/persistence/jobs_test.go @@ -188,7 +188,42 @@ func TestUpdateJobsTaskStatusesConditional(t *testing.T) { assert.Equal(t, api.TaskStatusCancelRequested, getTask(0).Status) assert.Equal(t, api.TaskStatusCompleted, getTask(1).Status) assert.Equal(t, api.TaskStatusCancelRequested, getTask(2).Status) +} +func TestTaskAssignToWorker(t *testing.T) { + ctx, ctxCancel, db, _, authoredJob := jobTasksTestFixtures(t) + defer ctxCancel() + + task, err := db.FetchTask(ctx, authoredJob.Tasks[1].UUID) + assert.NoError(t, err) + + w := createWorker(t, db) + assert.NoError(t, db.TaskAssignToWorker(ctx, task, w)) + + assert.Equal(t, w, task.Worker) + assert.Equal(t, w.ID, *task.WorkerID) +} + +func TestFetchTasksOfWorkerInStatus(t *testing.T) { + ctx, ctxCancel, db, _, authoredJob := jobTasksTestFixtures(t) + defer ctxCancel() + + task, err := db.FetchTask(ctx, authoredJob.Tasks[1].UUID) + assert.NoError(t, err) + + w := createWorker(t, db) + assert.NoError(t, db.TaskAssignToWorker(ctx, task, w)) + + tasks, err := db.FetchTasksOfWorkerInStatus(ctx, w, task.Status) + assert.NoError(t, err) + assert.Len(t, tasks, 1, "worker should have one task in status %q", task.Status) + assert.Equal(t, task.ID, tasks[0].ID) + assert.Equal(t, task.UUID, tasks[0].UUID) + + assert.NotEqual(t, api.TaskStatusCancelRequested, task.Status) + tasks, err = db.FetchTasksOfWorkerInStatus(ctx, w, api.TaskStatusCancelRequested) + assert.NoError(t, err) + assert.Empty(t, tasks, "worker should have no task in status %q", w) } func createTestAuthoredJobWithTasks() job_compilers.AuthoredJob { @@ -272,3 +307,35 @@ func jobTasksTestFixtures(t *testing.T) (context.Context, context.CancelFunc, *D return ctx, cancel, db, dbJob, authoredJob } + +func createWorker(t *testing.T, db *DB) *Worker { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + w := Worker{ + UUID: "f0a123a9-ab05-4ce2-8577-94802cfe74a4", + Name: "дрон", + Address: "fe80::5054:ff:fede:2ad7", + LastActivity: "", + Platform: "linux", + Software: "3.0", + Status: api.WorkerStatusAwake, + SupportedTaskTypes: "blender,ffmpeg,file-management", + } + + err := db.CreateWorker(ctx, &w) + if err != nil { + t.Fatalf("error creating worker: %v", err) + } + assert.NoError(t, err) + + fetchedWorker, err := db.FetchWorker(ctx, w.UUID) + if err != nil { + t.Fatalf("error fetching worker: %v", err) + } + if fetchedWorker == nil { + t.Fatal("fetched worker is nil, but no error returned") + } + + return fetchedWorker +}