diff --git a/internal/manager/api_impl/interfaces.go b/internal/manager/api_impl/interfaces.go index 34d2aa58..d6943d8d 100644 --- a/internal/manager/api_impl/interfaces.go +++ b/internal/manager/api_impl/interfaces.go @@ -32,23 +32,24 @@ import ( type PersistenceService interface { StoreAuthoredJob(ctx context.Context, authoredJob job_compilers.AuthoredJob) error // FetchJob fetches a single job, without fetching its tasks. - FetchJob(ctx context.Context, jobID string) (*persistence.Job, error) + FetchJob(ctx context.Context, jobUUID string) (*persistence.Job, error) FetchJobs(ctx context.Context) ([]*persistence.Job, error) + FetchJobByID(ctx context.Context, jobID int64) (*persistence.Job, error) SaveJobPriority(ctx context.Context, job *persistence.Job) error // FetchTask fetches the given task and the accompanying job. - FetchTask(ctx context.Context, taskID string) (*persistence.Task, error) + FetchTask(ctx context.Context, taskID string) (persistence.TaskJobWorker, error) // FetchTaskJobUUID fetches the UUID of the job this task belongs to. FetchTaskJobUUID(ctx context.Context, taskID string) (string, error) FetchTaskFailureList(context.Context, *persistence.Task) ([]*persistence.Worker, error) SaveTaskActivity(ctx context.Context, t *persistence.Task) error // TaskTouchedByWorker marks the task as 'touched' by a worker. This is used for timeout detection. - TaskTouchedByWorker(context.Context, *persistence.Task) error + TaskTouchedByWorker(ctx context.Context, taskUUID string) error CreateWorker(ctx context.Context, w *persistence.Worker) error FetchWorker(ctx context.Context, uuid string) (*persistence.Worker, error) FetchWorkers(ctx context.Context) ([]*persistence.Worker, error) - FetchWorkerTask(context.Context, *persistence.Worker) (*persistence.Task, error) + FetchWorkerTask(context.Context, *persistence.Worker) (*persistence.TaskJob, error) SaveWorker(ctx context.Context, w *persistence.Worker) error SaveWorkerStatus(ctx context.Context, w *persistence.Worker) error WorkerSeen(ctx context.Context, w *persistence.Worker) error @@ -56,7 +57,7 @@ type PersistenceService interface { // ScheduleTask finds a task to execute by the given worker, and assigns it to that worker. // If no task is available, (nil, nil) is returned, as this is not an error situation. - ScheduleTask(ctx context.Context, w *persistence.Worker) (*persistence.Task, error) + ScheduleTask(ctx context.Context, w *persistence.Worker) (*persistence.ScheduledTask, error) AddWorkerToTaskFailedList(context.Context, *persistence.Task, *persistence.Worker) (numFailed int, err error) // ClearFailureListOfTask clears the list of workers that failed this task. ClearFailureListOfTask(context.Context, *persistence.Task) error @@ -64,7 +65,7 @@ type PersistenceService interface { ClearFailureListOfJob(context.Context, *persistence.Job) error // AddWorkerToJobBlocklist prevents this Worker of getting any task, of this type, on this job, from the task scheduler. - AddWorkerToJobBlocklist(ctx context.Context, job *persistence.Job, worker *persistence.Worker, taskType string) error + AddWorkerToJobBlocklist(ctx context.Context, jobID int64, workerID int64, taskType string) error FetchJobBlocklist(ctx context.Context, jobUUID string) ([]persistence.JobBlockListEntry, error) RemoveFromJobBlocklist(ctx context.Context, jobUUID, workerUUID, taskType string) error ClearJobBlocklist(ctx context.Context, job *persistence.Job) error @@ -73,6 +74,7 @@ type PersistenceService interface { WorkerSetTags(ctx context.Context, worker *persistence.Worker, tagUUIDs []string) error CreateWorkerTag(ctx context.Context, tag *persistence.WorkerTag) error FetchWorkerTag(ctx context.Context, uuid string) (persistence.WorkerTag, error) + FetchWorkerTagByID(ctx context.Context, id int64) (persistence.WorkerTag, error) FetchWorkerTags(ctx context.Context) ([]persistence.WorkerTag, error) DeleteWorkerTag(ctx context.Context, uuid string) error SaveWorkerTag(ctx context.Context, tag *persistence.WorkerTag) error @@ -81,13 +83,13 @@ type PersistenceService interface { // WorkersLeftToRun returns a set of worker UUIDs that can run tasks of the given type on the given job. WorkersLeftToRun(ctx context.Context, job *persistence.Job, taskType string) (map[string]bool, error) // CountTaskFailuresOfWorker returns the number of task failures of this worker, on this particular job and task type. - CountTaskFailuresOfWorker(ctx context.Context, job *persistence.Job, worker *persistence.Worker, taskType string) (int, error) + CountTaskFailuresOfWorker(ctx context.Context, jobUUID string, workerID int64, taskType string) (int, error) // Database queries. - QueryJobTaskSummaries(ctx context.Context, jobUUID string) ([]*persistence.Task, error) + QueryJobTaskSummaries(ctx context.Context, jobUUID string) ([]persistence.TaskSummary, error) // SetLastRendered sets this job as the one with the most recent rendered image. - SetLastRendered(ctx context.Context, j *persistence.Job) error + SetLastRendered(ctx context.Context, jobUUID string) error // GetLastRendered returns the UUID of the job with the most recent rendered image. GetLastRenderedJobUUID(ctx context.Context) (string, error) } @@ -99,10 +101,10 @@ type TaskStateMachine interface { TaskStatusChange(ctx context.Context, task *persistence.Task, newStatus api.TaskStatus) error // JobStatusChange gives a Job a new status, and handles the resulting status changes on its tasks. - JobStatusChange(ctx context.Context, job *persistence.Job, newJobStatus api.JobStatus, reason string) error + JobStatusChange(ctx context.Context, jobUUID string, newJobStatus api.JobStatus, reason string) error RequeueActiveTasksOfWorker(ctx context.Context, worker *persistence.Worker, reason string) error - RequeueFailedTasksOfWorkerOfJob(ctx context.Context, worker *persistence.Worker, job *persistence.Job, reason string) error + RequeueFailedTasksOfWorkerOfJob(ctx context.Context, worker *persistence.Worker, jobUUID string, reason string) error } // TaskStateMachine should be a subset of task_state_machine.StateMachine. @@ -140,12 +142,12 @@ type JobCompiler interface { // LogStorage handles incoming task logs. type LogStorage interface { - Write(logger zerolog.Logger, jobID, taskID string, logText string) error - WriteTimestamped(logger zerolog.Logger, jobID, taskID string, logText string) error - RotateFile(logger zerolog.Logger, jobID, taskID string) - Tail(jobID, taskID string) (string, error) - TaskLogSize(jobID, taskID string) (int64, error) - Filepath(jobID, taskID string) string + Write(logger zerolog.Logger, jobUUID, taskUUID string, logText string) error + WriteTimestamped(logger zerolog.Logger, jobUUID, taskUUID string, logText string) error + RotateFile(logger zerolog.Logger, jobUUID, taskUUID string) + Tail(jobID, taskUUID string) (string, error) + TaskLogSize(jobUUID, taskUUID string) (int64, error) + Filepath(jobUUID, taskUUID string) string } // LastRendered processes the "last rendered" images. diff --git a/internal/manager/api_impl/jobs.go b/internal/manager/api_impl/jobs.go index 78b4b283..a303d888 100644 --- a/internal/manager/api_impl/jobs.go +++ b/internal/manager/api_impl/jobs.go @@ -4,6 +4,7 @@ package api_impl import ( "context" + "encoding/json" "errors" "fmt" "math" @@ -15,6 +16,7 @@ import ( "github.com/labstack/echo/v4" "github.com/rs/zerolog" + "github.com/rs/zerolog/log" "projects.blender.org/studio/flamenco/internal/manager/eventbus" "projects.blender.org/studio/flamenco/internal/manager/job_compilers" @@ -112,7 +114,7 @@ func (f *Flamenco) SubmitJob(e echo.Context) error { jobUpdate := eventbus.NewJobUpdate(dbJob) f.broadcaster.BroadcastNewJob(jobUpdate) - apiJob := jobDBtoAPI(dbJob) + apiJob := jobDBtoAPI(ctx, f.persist, dbJob) return e.JSON(http.StatusOK, apiJob) } @@ -179,7 +181,8 @@ func (f *Flamenco) DeleteJob(e echo.Context, jobID string) error { } logger = logger.With(). - Uint("dbID", dbJob.ID). + Int64("dbID", dbJob.ID). + Str("job", dbJob.UUID). Str("currentstatus", string(dbJob.Status)). Logger() logger.Info().Msg("job deletion requested") @@ -283,9 +286,9 @@ func (f *Flamenco) DeleteJobMass(e echo.Context) error { } // SetJobStatus is used by the web interface to change a job's status. -func (f *Flamenco) SetJobStatus(e echo.Context, jobID string) error { +func (f *Flamenco) SetJobStatus(e echo.Context, jobUUID string) error { logger := requestLogger(e).With(). - Str("job", jobID). + Str("job", jobUUID). Logger() var statusChange api.SetJobStatusJSONRequestBody @@ -294,7 +297,7 @@ func (f *Flamenco) SetJobStatus(e echo.Context, jobID string) error { return sendAPIError(e, http.StatusBadRequest, "invalid format") } - dbJob, err := f.fetchJob(e, logger, jobID) + dbJob, err := f.fetchJob(e, logger, jobUUID) if dbJob == nil { // f.fetchJob already sent a response. return err @@ -308,7 +311,7 @@ func (f *Flamenco) SetJobStatus(e echo.Context, jobID string) error { logger.Info().Msg("job status change requested") ctx := e.Request().Context() - err = f.stateMachine.JobStatusChange(ctx, dbJob, statusChange.Status, statusChange.Reason) + err = f.stateMachine.JobStatusChange(ctx, jobUUID, statusChange.Status, statusChange.Reason) if err != nil { logger.Error().Err(err).Msg("error changing job status") return sendAPIError(e, http.StatusInternalServerError, "unexpected error changing job status") @@ -352,7 +355,7 @@ func (f *Flamenco) SetJobPriority(e echo.Context, jobID string) error { logger = logger.With(). Str("jobName", dbJob.Name). - Int("prioCurrent", dbJob.Priority). + Int64("prioCurrent", dbJob.Priority). Int("prioRequested", prioChange.Priority). Logger() logger.Info().Msg("job priority change requested") @@ -361,7 +364,7 @@ func (f *Flamenco) SetJobPriority(e echo.Context, jobID string) error { bgCtx, bgCtxCancel := bgContext() defer bgCtxCancel() - dbJob.Priority = prioChange.Priority + dbJob.Priority = int64(prioChange.Priority) err = f.persist.SaveJobPriority(bgCtx, dbJob) if err != nil { logger.Error().Err(err).Msg("error changing job priority") @@ -388,7 +391,7 @@ func (f *Flamenco) SetTaskStatus(e echo.Context, taskID string) error { return sendAPIError(e, http.StatusBadRequest, "invalid format") } - dbTask, err := f.persist.FetchTask(ctx, taskID) + taskJobWorker, err := f.persist.FetchTask(ctx, taskID) if err != nil { if errors.Is(err, persistence.ErrTaskNotFound) { return sendAPIError(e, http.StatusNotFound, "no such task") @@ -396,6 +399,9 @@ func (f *Flamenco) SetTaskStatus(e echo.Context, taskID string) error { logger.Error().Err(err).Msg("error fetching task") return sendAPIError(e, http.StatusInternalServerError, "error fetching task") } + dbTask := &taskJobWorker.Task + + // TODO: do the rest of the processing in a background context. logger = logger.With(). Str("currentstatus", string(dbTask.Status)). @@ -660,80 +666,100 @@ func (f *Flamenco) lastRenderedInfoForJob(logger zerolog.Logger, jobUUID string) return &info, nil } -func jobDBtoAPI(dbJob *persistence.Job) api.Job { +// jobDBtoAPI converts the job from the database struct to the API struct. +// +// Note that this function does not connect to the database, and thus cannot +// find the job's worker tag. +func jobDBtoAPI(ctx context.Context, persist PersistenceService, dbJob *persistence.Job) api.Job { apiJob := api.Job{ SubmittedJob: api.SubmittedJob{ Name: dbJob.Name, - Priority: dbJob.Priority, + Priority: int(dbJob.Priority), Type: dbJob.JobType, }, Id: dbJob.UUID, Created: dbJob.CreatedAt, - Updated: dbJob.UpdatedAt, + Updated: dbJob.UpdatedAt.Time, Status: api.JobStatus(dbJob.Status), Activity: dbJob.Activity, } - apiJob.Settings = &api.JobSettings{AdditionalProperties: dbJob.Settings} - apiJob.Metadata = &api.JobMetadata{AdditionalProperties: dbJob.Metadata} + { // Parse job settings JSON. + settings := api.JobSettings{} + if err := json.Unmarshal(dbJob.Settings, &settings.AdditionalProperties); err != nil { + log.Error().Str("job", dbJob.UUID).AnErr("cause", err).Msg("could not parse job settings in database as JSON") + } else { + apiJob.Settings = &settings + } + } - if dbJob.Storage.ShamanCheckoutID != "" { + { // Parse job metadata JSON. + metadata := api.JobMetadata{} + if err := json.Unmarshal(dbJob.Metadata, &metadata.AdditionalProperties); err != nil { + log.Error().Str("job", dbJob.UUID).AnErr("cause", err).Msg("could not parse job metadata in database as JSON") + } else { + apiJob.Metadata = &metadata + } + } + + if dbJob.StorageShamanCheckoutID != "" { apiJob.Storage = &api.JobStorageInfo{ - ShamanCheckoutId: &dbJob.Storage.ShamanCheckoutID, + ShamanCheckoutId: &dbJob.StorageShamanCheckoutID, } } if dbJob.DeleteRequestedAt.Valid { apiJob.DeleteRequestedAt = &dbJob.DeleteRequestedAt.Time } - if dbJob.WorkerTag != nil { - apiJob.WorkerTag = &dbJob.WorkerTag.UUID + + if dbJob.WorkerTagID.Valid { + // TODO: see if this can be handled by the callers. + tag, err := persist.FetchWorkerTagByID(ctx, dbJob.WorkerTagID.Int64) + if err != nil { + log.Error().Str("job", dbJob.UUID).AnErr("cause", err).Msg("could not find job's worker tag") + } + apiJob.WorkerTag = &tag.UUID } return apiJob } -func taskDBtoAPI(dbTask *persistence.Task) api.Task { +func taskJobWorkertoAPI(taskJobWorker persistence.TaskJobWorker) api.Task { + return taskToAPI(&taskJobWorker.Task, taskJobWorker.JobUUID, taskJobWorker.WorkerUUID) +} + +func taskToAPI(task *persistence.Task, jobUUID, workerUUID string) api.Task { apiTask := api.Task{ - Id: dbTask.UUID, - Name: dbTask.Name, - Priority: dbTask.Priority, - TaskType: dbTask.Type, - Created: dbTask.CreatedAt, - Updated: dbTask.UpdatedAt, - Status: dbTask.Status, - Activity: dbTask.Activity, - Commands: make([]api.Command, len(dbTask.Commands)), + Id: task.UUID, + IndexInJob: int(task.IndexInJob), + JobId: jobUUID, + Name: task.Name, + Priority: int(task.Priority), + TaskType: task.Type, + Created: task.CreatedAt, + Updated: task.UpdatedAt.Time, + Status: task.Status, + Activity: task.Activity, - // TODO: convert this to just store dbTask.WorkerUUID. - Worker: workerToTaskWorker(dbTask.Worker), - - JobId: dbTask.JobUUID, - IndexInJob: dbTask.IndexInJob, + // TODO: update the web frontend just use the UUID, as we have not enough + // info here to fill the name & address fields. + Worker: &api.TaskWorker{Id: workerUUID}, } - if dbTask.Job != nil { - apiTask.JobId = dbTask.Job.UUID + if err := json.Unmarshal(task.Commands, &apiTask.Commands); err != nil { + log.Error(). + Str("task", task.UUID). + AnErr("cause", err). + Msg("could not parse task commands JSON") } - if !dbTask.LastTouchedAt.IsZero() { - apiTask.LastTouched = &dbTask.LastTouchedAt - } - - for i := range dbTask.Commands { - apiTask.Commands[i] = commandDBtoAPI(dbTask.Commands[i]) + if task.LastTouchedAt.Valid { + apiTask.LastTouched = &task.LastTouchedAt.Time } return apiTask } -func commandDBtoAPI(dbCommand persistence.Command) api.Command { - return api.Command{ - Name: dbCommand.Name, - Parameters: dbCommand.Parameters, - } -} - // workerToTaskWorker is nil-safe. func workerToTaskWorker(worker *persistence.Worker) *api.TaskWorker { if worker == nil { diff --git a/internal/manager/api_impl/jobs_query.go b/internal/manager/api_impl/jobs_query.go index 5d2cb1a6..ab9b8186 100644 --- a/internal/manager/api_impl/jobs_query.go +++ b/internal/manager/api_impl/jobs_query.go @@ -55,7 +55,8 @@ func (f *Flamenco) FetchJob(e echo.Context, jobID string) error { return err } - apiJob := jobDBtoAPI(dbJob) + ctx := e.Request().Context() + apiJob := jobDBtoAPI(ctx, f.persist, dbJob) return e.JSON(http.StatusOK, apiJob) } @@ -75,7 +76,7 @@ func (f *Flamenco) FetchJobs(e echo.Context) error { apiJobs := make([]api.Job, len(dbJobs)) for i, dbJob := range dbJobs { - apiJobs[i] = jobDBtoAPI(dbJob) + apiJobs[i] = jobDBtoAPI(ctx, f.persist, dbJob) } result := api.JobsQueryResult{ Jobs: apiJobs, @@ -94,7 +95,7 @@ func (f *Flamenco) FetchJobTasks(e echo.Context, jobID string) error { return sendAPIError(e, http.StatusBadRequest, "job ID not valid") } - tasks, err := f.persist.QueryJobTaskSummaries(ctx, jobID) + dbSummaries, err := f.persist.QueryJobTaskSummaries(ctx, jobID) switch { case errors.Is(err, context.Canceled): logger.Debug().AnErr("cause", err).Msg("could not fetch job tasks, remote end probably closed connection") @@ -104,12 +105,12 @@ func (f *Flamenco) FetchJobTasks(e echo.Context, jobID string) error { return sendAPIError(e, http.StatusInternalServerError, "error fetching job tasks: %v", err) } - summaries := make([]api.TaskSummary, len(tasks)) - for i, task := range tasks { - summaries[i] = taskDBtoSummary(task) + apiSummaries := make([]api.TaskSummary, len(dbSummaries)) + for i, dbSummary := range dbSummaries { + apiSummaries[i] = taskSummaryDBtoAPI(dbSummary) } result := api.JobTasksSummary{ - Tasks: &summaries, + Tasks: &apiSummaries, } return e.JSON(http.StatusOK, result) } @@ -125,8 +126,8 @@ func (f *Flamenco) FetchTask(e echo.Context, taskID string) error { return sendAPIError(e, http.StatusBadRequest, "job ID not valid") } - // Fetch & convert the task. - task, err := f.persist.FetchTask(ctx, taskID) + // Fetch & convert the taskJobWorker. + taskJobWorker, err := f.persist.FetchTask(ctx, taskID) if errors.Is(err, persistence.ErrTaskNotFound) { logger.Debug().Msg("non-existent task requested") return sendAPIError(e, http.StatusNotFound, "no such task") @@ -135,10 +136,20 @@ func (f *Flamenco) FetchTask(e echo.Context, taskID string) error { logger.Warn().Err(err).Msg("error fetching task") return sendAPIError(e, http.StatusInternalServerError, "error fetching task") } - apiTask := taskDBtoAPI(task) + apiTask := taskJobWorkertoAPI(taskJobWorker) + + // Fetch the worker. TODO: get rid of this conversion, just include the + // worker's UUID and let the caller fetch the worker info themselves if + // necessary. + taskWorker, err := f.persist.FetchWorker(ctx, taskJobWorker.WorkerUUID) + if err != nil { + logger.Warn().Err(err).Msg("error fetching task worker") + return sendAPIError(e, http.StatusInternalServerError, "error fetching task worker") + } + apiTask.Worker = workerToTaskWorker(taskWorker) // Fetch & convert the failure list. - failedWorkers, err := f.persist.FetchTaskFailureList(ctx, task) + failedWorkers, err := f.persist.FetchTaskFailureList(ctx, &taskJobWorker.Task) if err != nil { logger.Warn().Err(err).Msg("error fetching task failure list") return sendAPIError(e, http.StatusInternalServerError, "error fetching task failure list") @@ -152,14 +163,26 @@ func (f *Flamenco) FetchTask(e echo.Context, taskID string) error { return e.JSON(http.StatusOK, apiTask) } -func taskDBtoSummary(task *persistence.Task) api.TaskSummary { +func taskSummaryDBtoAPI(task persistence.TaskSummary) api.TaskSummary { return api.TaskSummary{ Id: task.UUID, Name: task.Name, - IndexInJob: task.IndexInJob, - Priority: task.Priority, + IndexInJob: int(task.IndexInJob), + Priority: int(task.Priority), Status: task.Status, TaskType: task.Type, - Updated: task.UpdatedAt, + Updated: task.UpdatedAt.Time, + } +} + +func taskDBtoSummaryAPI(task persistence.Task) api.TaskSummary { + return api.TaskSummary{ + Id: task.UUID, + Name: task.Name, + IndexInJob: int(task.IndexInJob), + Priority: int(task.Priority), + Status: task.Status, + TaskType: task.Type, + Updated: task.UpdatedAt.Time, } } diff --git a/internal/manager/api_impl/jobs_query_test.go b/internal/manager/api_impl/jobs_query_test.go index c8aa0b13..97d941f5 100644 --- a/internal/manager/api_impl/jobs_query_test.go +++ b/internal/manager/api_impl/jobs_query_test.go @@ -25,12 +25,8 @@ func TestFetchJobs(t *testing.T) { JobType: "test", Priority: 50, Status: api.JobStatusActive, - Settings: persistence.StringInterfaceMap{ - "result": "/render/frames/exploding.kittens", - }, - Metadata: persistence.StringStringMap{ - "project": "/projects/exploding-kittens", - }, + Settings: []byte(`{"result": "/render/frames/exploding.kittens"}`), + Metadata: []byte(`{"project": "/projects/exploding-kittens"}`), } deletionRequestedAt := time.Now() @@ -76,8 +72,6 @@ func TestFetchJobs(t *testing.T) { Name: "уходить", Type: "test", Priority: 75, - Settings: &api.JobSettings{}, - Metadata: &api.JobMetadata{}, }, Id: "d912ac69-de48-48ba-8028-35d82cb41451", Status: api.JobStatusCompleted, @@ -96,26 +90,25 @@ func TestFetchJob(t *testing.T) { mf := newMockedFlamenco(mockCtrl) dbJob := persistence.Job{ - UUID: "afc47568-bd9d-4368-8016-e91d945db36d", - Name: "работа", - JobType: "test", - Priority: 50, - Status: api.JobStatusActive, - Settings: persistence.StringInterfaceMap{ - "result": "/render/frames/exploding.kittens", - }, - Metadata: persistence.StringStringMap{ - "project": "/projects/exploding-kittens", - }, - WorkerTag: &persistence.WorkerTag{ - UUID: "d86e1b84-5ee2-4784-a178-65963eeb484b", - Name: "Tikkie terug Kees!", - Description: "", - }, + UUID: "afc47568-bd9d-4368-8016-e91d945db36d", + Name: "работа", + JobType: "test", + Priority: 50, + Status: api.JobStatusActive, + Settings: []byte(`{"result": "/render/frames/exploding.kittens"}`), + Metadata: []byte(`{"project": "/projects/exploding-kittens"}`), + WorkerTagID: sql.NullInt64{Int64: 4477, Valid: true}, + } + + tag := persistence.WorkerTag{ + UUID: "d86e1b84-5ee2-4784-a178-65963eeb484b", + Name: "Tikkie terug Kees!", + Description: "", } echoCtx := mf.prepareMockedRequest(nil) mf.persistence.EXPECT().FetchJob(gomock.Any(), dbJob.UUID).Return(&dbJob, nil) + mf.persistence.EXPECT().FetchWorkerTagByID(gomock.Any(), dbJob.WorkerTagID.Int64).Return(tag, nil) require.NoError(t, mf.flamenco.FetchJob(echoCtx, dbJob.UUID)) @@ -152,30 +145,27 @@ func TestFetchTask(t *testing.T) { taskWorker := persistence.Worker{UUID: workerUUID, Name: "Radnik", Address: "Slapić"} dbTask := persistence.Task{ - Model: persistence.Model{ - ID: 327, - CreatedAt: mf.clock.Now().Add(-30 * time.Second), - UpdatedAt: mf.clock.Now(), - }, - UUID: taskUUID, - Name: "симпатичная задача", - Type: "misc", - JobID: 0, - Job: &persistence.Job{UUID: jobUUID}, - Priority: 47, - Status: api.TaskStatusQueued, - WorkerID: new(uint), - Worker: &taskWorker, - Dependencies: []*persistence.Task{}, - Activity: "used in unit test", + ID: 327, + CreatedAt: mf.clock.Now().Add(-30 * time.Second), + UpdatedAt: sql.NullTime{Time: mf.clock.Now(), Valid: true}, + UUID: taskUUID, + Name: "симпатичная задача", + Type: "misc", + JobID: 332277, + Priority: 47, + Status: api.TaskStatusQueued, + WorkerID: sql.NullInt64{Int64: taskWorker.ID, Valid: true}, + Activity: "used in unit test", - Commands: []persistence.Command{ - {Name: "move-directory", - Parameters: map[string]interface{}{ + Commands: []byte(`[ + { + "name": "move-directory", + "parameters": { "dest": "/render/_flamenco/tests/renders/2022-04-29 Weekly/2022-04-29_140531", - "src": "/render/_flamenco/tests/renders/2022-04-29 Weekly/2022-04-29_140531__intermediate-2022-04-29_140531", - }}, - }, + "src": "/render/_flamenco/tests/renders/2022-04-29 Weekly/2022-04-29_140531__intermediate-2022-04-29_140531" + } + } + ]`), } expectAPITask := api.Task{ @@ -187,7 +177,7 @@ func TestFetchTask(t *testing.T) { Priority: 47, Status: api.TaskStatusQueued, TaskType: "misc", - Updated: dbTask.UpdatedAt, + Updated: dbTask.UpdatedAt.Time, Worker: &api.TaskWorker{Id: workerUUID, Name: "Radnik", Address: "Slapić"}, Commands: []api.Command{ @@ -203,11 +193,18 @@ func TestFetchTask(t *testing.T) { }), } + taskJobWorker := persistence.TaskJobWorker{ + Task: dbTask, + JobUUID: jobUUID, + WorkerUUID: workerUUID, + } + echoCtx := mf.prepareMockedRequest(nil) ctx := echoCtx.Request().Context() - mf.persistence.EXPECT().FetchTask(ctx, taskUUID).Return(&dbTask, nil) + mf.persistence.EXPECT().FetchTask(ctx, taskUUID).Return(taskJobWorker, nil) mf.persistence.EXPECT().FetchTaskFailureList(ctx, &dbTask). Return([]*persistence.Worker{&taskWorker}, nil) + mf.persistence.EXPECT().FetchWorker(ctx, workerUUID).Return(&taskWorker, nil) err := mf.flamenco.FetchTask(echoCtx, taskUUID) require.NoError(t, err) diff --git a/internal/manager/api_impl/jobs_test.go b/internal/manager/api_impl/jobs_test.go index c1344a28..a99895be 100644 --- a/internal/manager/api_impl/jobs_test.go +++ b/internal/manager/api_impl/jobs_test.go @@ -3,6 +3,7 @@ package api_impl // SPDX-License-Identifier: GPL-3.0-or-later import ( + "database/sql" "errors" "fmt" "net/http" @@ -66,10 +67,8 @@ func TestSubmitJobWithoutSettings(t *testing.T) { UUID: queuedJob.JobID, Name: queuedJob.Name, JobType: queuedJob.JobType, - Priority: queuedJob.Priority, + Priority: int64(queuedJob.Priority), Status: queuedJob.Status, - Settings: persistence.StringInterfaceMap{}, - Metadata: persistence.StringStringMap{}, } mf.persistence.EXPECT().FetchJob(gomock.Any(), queuedJob.JobID).Return(&dbJob, nil) @@ -77,10 +76,10 @@ func TestSubmitJobWithoutSettings(t *testing.T) { jobUpdate := api.EventJobUpdate{ Id: dbJob.UUID, Name: &dbJob.Name, - Priority: dbJob.Priority, + Priority: int(dbJob.Priority), Status: dbJob.Status, Type: dbJob.JobType, - Updated: dbJob.UpdatedAt, + Updated: dbJob.UpdatedAt.Time, } mf.broadcaster.EXPECT().BroadcastNewJob(jobUpdate) @@ -155,10 +154,10 @@ func TestSubmitJobWithSettings(t *testing.T) { UUID: queuedJob.JobID, Name: queuedJob.Name, JobType: queuedJob.JobType, - Priority: queuedJob.Priority, + Priority: int64(queuedJob.Priority), Status: queuedJob.Status, - Settings: variableReplacedSettings, - Metadata: variableReplacedMetadata, + Settings: []byte(`{"result": "{frames}/exploding.kittens"}`), + Metadata: []byte(`{"project": "{projects}/exploding-kittens"}`), } mf.persistence.EXPECT().FetchJob(gomock.Any(), queuedJob.JobID).Return(&dbJob, nil) @@ -166,10 +165,10 @@ func TestSubmitJobWithSettings(t *testing.T) { jobUpdate := api.EventJobUpdate{ Id: dbJob.UUID, Name: &dbJob.Name, - Priority: dbJob.Priority, + Priority: int(dbJob.Priority), Status: dbJob.Status, Type: dbJob.JobType, - Updated: dbJob.UpdatedAt, + Updated: dbJob.UpdatedAt.Time, } mf.broadcaster.EXPECT().BroadcastNewJob(jobUpdate) @@ -226,10 +225,8 @@ func TestSubmitJobWithEtag(t *testing.T) { UUID: authoredJob.JobID, Name: authoredJob.Name, JobType: authoredJob.JobType, - Priority: authoredJob.Priority, + Priority: int64(authoredJob.Priority), Status: api.JobStatusQueued, - Settings: persistence.StringInterfaceMap{}, - Metadata: persistence.StringStringMap{}, } mf.persistence.EXPECT().FetchJob(gomock.Any(), authoredJob.JobID).Return(&dbJob, nil) @@ -290,16 +287,12 @@ func TestSubmitJobWithShamanCheckoutID(t *testing.T) { // Expect the job to be fetched from the database again: dbJob := persistence.Job{ - UUID: queuedJob.JobID, - Name: queuedJob.Name, - JobType: queuedJob.JobType, - Priority: queuedJob.Priority, - Status: queuedJob.Status, - Settings: persistence.StringInterfaceMap{}, - Metadata: persistence.StringStringMap{}, - Storage: persistence.JobStorageInfo{ - ShamanCheckoutID: "Весы/Синтел", - }, + UUID: queuedJob.JobID, + Name: queuedJob.Name, + JobType: queuedJob.JobType, + Priority: int64(queuedJob.Priority), + Status: queuedJob.Status, + StorageShamanCheckoutID: "Весы/Синтел", } mf.persistence.EXPECT().FetchJob(gomock.Any(), queuedJob.JobID).Return(&dbJob, nil) @@ -307,10 +300,10 @@ func TestSubmitJobWithShamanCheckoutID(t *testing.T) { jobUpdate := api.EventJobUpdate{ Id: dbJob.UUID, Name: &dbJob.Name, - Priority: dbJob.Priority, + Priority: int(dbJob.Priority), Status: dbJob.Status, Type: dbJob.JobType, - Updated: dbJob.UpdatedAt, + Updated: dbJob.UpdatedAt.Time, } mf.broadcaster.EXPECT().BroadcastNewJob(jobUpdate) @@ -368,34 +361,30 @@ func TestSubmitJobWithWorkerTag(t *testing.T) { queuedJob.Status = api.JobStatusQueued mf.persistence.EXPECT().StoreAuthoredJob(gomock.Any(), queuedJob).Return(nil) - // Expect the job to be fetched from the database again: + // Expect the job to be fetched from the database again, including its tag. dbJob := persistence.Job{ - Model: persistence.Model{ - ID: 47, - CreatedAt: mf.clock.Now(), - UpdatedAt: mf.clock.Now(), - }, - UUID: queuedJob.JobID, - Name: queuedJob.Name, - JobType: queuedJob.JobType, - Priority: queuedJob.Priority, - Status: queuedJob.Status, - Settings: persistence.StringInterfaceMap{}, - Metadata: persistence.StringStringMap{}, + ID: 47, + CreatedAt: mf.clock.Now(), + UpdatedAt: sql.NullTime{Time: mf.clock.Now(), Valid: true}, + UUID: queuedJob.JobID, + Name: queuedJob.Name, + JobType: queuedJob.JobType, + Priority: int64(queuedJob.Priority), + Status: queuedJob.Status, - WorkerTagID: ptr(uint(tag.ID)), - WorkerTag: &tag, + WorkerTagID: sql.NullInt64{Int64: tag.ID, Valid: true}, } mf.persistence.EXPECT().FetchJob(gomock.Any(), queuedJob.JobID).Return(&dbJob, nil) + mf.persistence.EXPECT().FetchWorkerTagByID(gomock.Any(), tag.ID).Return(tag, nil) // Expect the new job to be broadcast. jobUpdate := api.EventJobUpdate{ Id: dbJob.UUID, Name: &dbJob.Name, - Priority: dbJob.Priority, + Priority: int(dbJob.Priority), Status: dbJob.Status, Type: dbJob.JobType, - Updated: dbJob.UpdatedAt, + Updated: dbJob.UpdatedAt.Time, } mf.broadcaster.EXPECT().BroadcastNewJob(jobUpdate) @@ -404,14 +393,12 @@ func TestSubmitJobWithWorkerTag(t *testing.T) { requestWorkerStore(echoCtx, &worker) require.NoError(t, mf.flamenco.SubmitJob(echoCtx)) - submittedJob.Metadata = new(api.JobMetadata) - submittedJob.Settings = new(api.JobSettings) submittedJob.SubmitterPlatform = "" // Not persisted in the database. assertResponseJSON(t, echoCtx, http.StatusOK, api.Job{ SubmittedJob: submittedJob, Id: dbJob.UUID, Created: dbJob.CreatedAt, - Updated: dbJob.UpdatedAt, + Updated: dbJob.UpdatedAt.Time, DeleteRequestedAt: nil, Activity: "", Status: api.JobStatusQueued, @@ -580,17 +567,15 @@ func TestSetJobStatus_happy(t *testing.T) { Reason: "someone pushed a button", } dbJob := persistence.Job{ - UUID: jobID, - Name: "test job", - Status: api.JobStatusActive, - Settings: persistence.StringInterfaceMap{}, - Metadata: persistence.StringStringMap{}, + UUID: jobID, + Name: "test job", + Status: api.JobStatusActive, } // Set up expectations. ctx := gomock.Any() mf.persistence.EXPECT().FetchJob(ctx, jobID).Return(&dbJob, nil) - mf.stateMachine.EXPECT().JobStatusChange(ctx, &dbJob, statusUpdate.Status, "someone pushed a button") + mf.stateMachine.EXPECT().JobStatusChange(ctx, jobID, statusUpdate.Status, "someone pushed a button") // Going to Cancel Requested should NOT clear the failure list. @@ -635,8 +620,6 @@ func TestSetJobPrio(t *testing.T) { UUID: jobID, Name: "test job", Priority: 50, - Settings: persistence.StringInterfaceMap{}, - Metadata: persistence.StringStringMap{}, } echoCtx := mf.prepareMockedJSONRequest(prioUpdate) @@ -655,7 +638,7 @@ func TestSetJobPrio(t *testing.T) { RefreshTasks: false, Priority: prioUpdate.Priority, Status: dbJob.Status, - Updated: dbJob.UpdatedAt, + Updated: dbJob.UpdatedAt.Time, } mf.broadcaster.EXPECT().BroadcastJobUpdate(expectUpdate) @@ -677,18 +660,16 @@ func TestSetJobStatusFailedToRequeueing(t *testing.T) { Reason: "someone pushed a button", } dbJob := persistence.Job{ - UUID: jobID, - Name: "test job", - Status: api.JobStatusFailed, - Settings: persistence.StringInterfaceMap{}, - Metadata: persistence.StringStringMap{}, + UUID: jobID, + Name: "test job", + Status: api.JobStatusFailed, } // Set up expectations. echoCtx := mf.prepareMockedJSONRequest(statusUpdate) ctx := echoCtx.Request().Context() mf.persistence.EXPECT().FetchJob(moremock.ContextWithDeadline(), jobID).Return(&dbJob, nil) - mf.stateMachine.EXPECT().JobStatusChange(ctx, &dbJob, statusUpdate.Status, "someone pushed a button") + mf.stateMachine.EXPECT().JobStatusChange(ctx, jobID, statusUpdate.Status, "someone pushed a button") mf.persistence.EXPECT().ClearFailureListOfJob(ctx, &dbJob) mf.persistence.EXPECT().ClearJobBlocklist(ctx, &dbJob) @@ -712,31 +693,35 @@ func TestSetTaskStatusQueued(t *testing.T) { Reason: "someone pushed a button", } dbJob := persistence.Job{ - Model: persistence.Model{ID: 47}, - UUID: jobID, - Name: "test job", - Status: api.JobStatusFailed, - Settings: persistence.StringInterfaceMap{}, - Metadata: persistence.StringStringMap{}, + ID: 47, + UUID: jobID, + Name: "test job", + Status: api.JobStatusFailed, } dbTask := persistence.Task{ UUID: taskID, Name: "test task", Status: api.TaskStatusFailed, - Job: &dbJob, JobID: dbJob.ID, } // Set up expectations. echoCtx := mf.prepareMockedJSONRequest(statusUpdate) ctx := echoCtx.Request().Context() - mf.persistence.EXPECT().FetchTask(ctx, taskID).Return(&dbTask, nil) - mf.stateMachine.EXPECT().TaskStatusChange(ctx, &dbTask, statusUpdate.Status) - mf.persistence.EXPECT().ClearFailureListOfTask(ctx, &dbTask) - updatedTask := dbTask - updatedTask.Activity = "someone pushed a button" - mf.persistence.EXPECT().SaveTaskActivity(ctx, &updatedTask) + taskJobWorker := persistence.TaskJobWorker{ + Task: dbTask, + JobUUID: dbJob.UUID, + WorkerUUID: "", + } + + taskWithActivity := dbTask + taskWithActivity.Activity = "someone pushed a button" + + mf.persistence.EXPECT().FetchTask(ctx, taskID).Return(taskJobWorker, nil) + mf.persistence.EXPECT().SaveTaskActivity(ctx, &taskWithActivity) + mf.stateMachine.EXPECT().TaskStatusChange(ctx, &taskWithActivity, statusUpdate.Status) + mf.persistence.EXPECT().ClearFailureListOfTask(ctx, &taskWithActivity) // Do the call. err := mf.flamenco.SetTaskStatus(echoCtx, taskID) @@ -914,12 +899,10 @@ func TestDeleteJob(t *testing.T) { jobID := "18a9b096-d77e-438c-9be2-74397038298b" dbJob := persistence.Job{ - Model: persistence.Model{ID: 47}, - UUID: jobID, - Name: "test job", - Status: api.JobStatusFailed, - Settings: persistence.StringInterfaceMap{}, - Metadata: persistence.StringStringMap{}, + ID: 47, + UUID: jobID, + Name: "test job", + Status: api.JobStatusFailed, } // Set up expectations. diff --git a/internal/manager/api_impl/mocks/api_impl_mock.gen.go b/internal/manager/api_impl/mocks/api_impl_mock.gen.go index dfccebaa..8a19e7e0 100644 --- a/internal/manager/api_impl/mocks/api_impl_mock.gen.go +++ b/internal/manager/api_impl/mocks/api_impl_mock.gen.go @@ -44,7 +44,7 @@ func (m *MockPersistenceService) EXPECT() *MockPersistenceServiceMockRecorder { } // AddWorkerToJobBlocklist mocks base method. -func (m *MockPersistenceService) AddWorkerToJobBlocklist(arg0 context.Context, arg1 *persistence.Job, arg2 *sqlc.Worker, arg3 string) error { +func (m *MockPersistenceService) AddWorkerToJobBlocklist(arg0 context.Context, arg1, arg2 int64, arg3 string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AddWorkerToJobBlocklist", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(error) @@ -58,7 +58,7 @@ func (mr *MockPersistenceServiceMockRecorder) AddWorkerToJobBlocklist(arg0, arg1 } // AddWorkerToTaskFailedList mocks base method. -func (m *MockPersistenceService) AddWorkerToTaskFailedList(arg0 context.Context, arg1 *persistence.Task, arg2 *sqlc.Worker) (int, error) { +func (m *MockPersistenceService) AddWorkerToTaskFailedList(arg0 context.Context, arg1 *sqlc.Task, arg2 *sqlc.Worker) (int, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AddWorkerToTaskFailedList", arg0, arg1, arg2) ret0, _ := ret[0].(int) @@ -73,7 +73,7 @@ func (mr *MockPersistenceServiceMockRecorder) AddWorkerToTaskFailedList(arg0, ar } // ClearFailureListOfJob mocks base method. -func (m *MockPersistenceService) ClearFailureListOfJob(arg0 context.Context, arg1 *persistence.Job) error { +func (m *MockPersistenceService) ClearFailureListOfJob(arg0 context.Context, arg1 *sqlc.Job) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ClearFailureListOfJob", arg0, arg1) ret0, _ := ret[0].(error) @@ -87,7 +87,7 @@ func (mr *MockPersistenceServiceMockRecorder) ClearFailureListOfJob(arg0, arg1 i } // ClearFailureListOfTask mocks base method. -func (m *MockPersistenceService) ClearFailureListOfTask(arg0 context.Context, arg1 *persistence.Task) error { +func (m *MockPersistenceService) ClearFailureListOfTask(arg0 context.Context, arg1 *sqlc.Task) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ClearFailureListOfTask", arg0, arg1) ret0, _ := ret[0].(error) @@ -101,7 +101,7 @@ func (mr *MockPersistenceServiceMockRecorder) ClearFailureListOfTask(arg0, arg1 } // ClearJobBlocklist mocks base method. -func (m *MockPersistenceService) ClearJobBlocklist(arg0 context.Context, arg1 *persistence.Job) error { +func (m *MockPersistenceService) ClearJobBlocklist(arg0 context.Context, arg1 *sqlc.Job) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ClearJobBlocklist", arg0, arg1) ret0, _ := ret[0].(error) @@ -115,7 +115,7 @@ func (mr *MockPersistenceServiceMockRecorder) ClearJobBlocklist(arg0, arg1 inter } // CountTaskFailuresOfWorker mocks base method. -func (m *MockPersistenceService) CountTaskFailuresOfWorker(arg0 context.Context, arg1 *persistence.Job, arg2 *sqlc.Worker, arg3 string) (int, error) { +func (m *MockPersistenceService) CountTaskFailuresOfWorker(arg0 context.Context, arg1 string, arg2 int64, arg3 string) (int, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CountTaskFailuresOfWorker", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(int) @@ -186,10 +186,10 @@ func (mr *MockPersistenceServiceMockRecorder) DeleteWorkerTag(arg0, arg1 interfa } // FetchJob mocks base method. -func (m *MockPersistenceService) FetchJob(arg0 context.Context, arg1 string) (*persistence.Job, error) { +func (m *MockPersistenceService) FetchJob(arg0 context.Context, arg1 string) (*sqlc.Job, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "FetchJob", arg0, arg1) - ret0, _ := ret[0].(*persistence.Job) + ret0, _ := ret[0].(*sqlc.Job) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -215,11 +215,26 @@ func (mr *MockPersistenceServiceMockRecorder) FetchJobBlocklist(arg0, arg1 inter return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchJobBlocklist", reflect.TypeOf((*MockPersistenceService)(nil).FetchJobBlocklist), arg0, arg1) } +// FetchJobByID mocks base method. +func (m *MockPersistenceService) FetchJobByID(arg0 context.Context, arg1 int64) (*sqlc.Job, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FetchJobByID", arg0, arg1) + ret0, _ := ret[0].(*sqlc.Job) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FetchJobByID indicates an expected call of FetchJobByID. +func (mr *MockPersistenceServiceMockRecorder) FetchJobByID(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchJobByID", reflect.TypeOf((*MockPersistenceService)(nil).FetchJobByID), arg0, arg1) +} + // FetchJobs mocks base method. -func (m *MockPersistenceService) FetchJobs(arg0 context.Context) ([]*persistence.Job, error) { +func (m *MockPersistenceService) FetchJobs(arg0 context.Context) ([]*sqlc.Job, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "FetchJobs", arg0) - ret0, _ := ret[0].([]*persistence.Job) + ret0, _ := ret[0].([]*sqlc.Job) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -246,10 +261,10 @@ func (mr *MockPersistenceServiceMockRecorder) FetchTagsOfWorker(arg0, arg1 inter } // FetchTask mocks base method. -func (m *MockPersistenceService) FetchTask(arg0 context.Context, arg1 string) (*persistence.Task, error) { +func (m *MockPersistenceService) FetchTask(arg0 context.Context, arg1 string) (persistence.TaskJobWorker, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "FetchTask", arg0, arg1) - ret0, _ := ret[0].(*persistence.Task) + ret0, _ := ret[0].(persistence.TaskJobWorker) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -261,7 +276,7 @@ func (mr *MockPersistenceServiceMockRecorder) FetchTask(arg0, arg1 interface{}) } // FetchTaskFailureList mocks base method. -func (m *MockPersistenceService) FetchTaskFailureList(arg0 context.Context, arg1 *persistence.Task) ([]*sqlc.Worker, error) { +func (m *MockPersistenceService) FetchTaskFailureList(arg0 context.Context, arg1 *sqlc.Task) ([]*sqlc.Worker, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "FetchTaskFailureList", arg0, arg1) ret0, _ := ret[0].([]*sqlc.Worker) @@ -320,6 +335,21 @@ func (mr *MockPersistenceServiceMockRecorder) FetchWorkerTag(arg0, arg1 interfac return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchWorkerTag", reflect.TypeOf((*MockPersistenceService)(nil).FetchWorkerTag), arg0, arg1) } +// FetchWorkerTagByID mocks base method. +func (m *MockPersistenceService) FetchWorkerTagByID(arg0 context.Context, arg1 int64) (sqlc.WorkerTag, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FetchWorkerTagByID", arg0, arg1) + ret0, _ := ret[0].(sqlc.WorkerTag) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FetchWorkerTagByID indicates an expected call of FetchWorkerTagByID. +func (mr *MockPersistenceServiceMockRecorder) FetchWorkerTagByID(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchWorkerTagByID", reflect.TypeOf((*MockPersistenceService)(nil).FetchWorkerTagByID), arg0, arg1) +} + // FetchWorkerTags mocks base method. func (m *MockPersistenceService) FetchWorkerTags(arg0 context.Context) ([]sqlc.WorkerTag, error) { m.ctrl.T.Helper() @@ -336,10 +366,10 @@ func (mr *MockPersistenceServiceMockRecorder) FetchWorkerTags(arg0 interface{}) } // FetchWorkerTask mocks base method. -func (m *MockPersistenceService) FetchWorkerTask(arg0 context.Context, arg1 *sqlc.Worker) (*persistence.Task, error) { +func (m *MockPersistenceService) FetchWorkerTask(arg0 context.Context, arg1 *sqlc.Worker) (*persistence.TaskJob, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "FetchWorkerTask", arg0, arg1) - ret0, _ := ret[0].(*persistence.Task) + ret0, _ := ret[0].(*persistence.TaskJob) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -381,10 +411,10 @@ func (mr *MockPersistenceServiceMockRecorder) GetLastRenderedJobUUID(arg0 interf } // QueryJobTaskSummaries mocks base method. -func (m *MockPersistenceService) QueryJobTaskSummaries(arg0 context.Context, arg1 string) ([]*persistence.Task, error) { +func (m *MockPersistenceService) QueryJobTaskSummaries(arg0 context.Context, arg1 string) ([]sqlc.QueryJobTaskSummariesRow, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "QueryJobTaskSummaries", arg0, arg1) - ret0, _ := ret[0].([]*persistence.Task) + ret0, _ := ret[0].([]sqlc.QueryJobTaskSummariesRow) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -410,7 +440,7 @@ func (mr *MockPersistenceServiceMockRecorder) RemoveFromJobBlocklist(arg0, arg1, } // SaveJobPriority mocks base method. -func (m *MockPersistenceService) SaveJobPriority(arg0 context.Context, arg1 *persistence.Job) error { +func (m *MockPersistenceService) SaveJobPriority(arg0 context.Context, arg1 *sqlc.Job) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SaveJobPriority", arg0, arg1) ret0, _ := ret[0].(error) @@ -424,7 +454,7 @@ func (mr *MockPersistenceServiceMockRecorder) SaveJobPriority(arg0, arg1 interfa } // SaveTaskActivity mocks base method. -func (m *MockPersistenceService) SaveTaskActivity(arg0 context.Context, arg1 *persistence.Task) error { +func (m *MockPersistenceService) SaveTaskActivity(arg0 context.Context, arg1 *sqlc.Task) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SaveTaskActivity", arg0, arg1) ret0, _ := ret[0].(error) @@ -480,10 +510,10 @@ func (mr *MockPersistenceServiceMockRecorder) SaveWorkerTag(arg0, arg1 interface } // ScheduleTask mocks base method. -func (m *MockPersistenceService) ScheduleTask(arg0 context.Context, arg1 *sqlc.Worker) (*persistence.Task, error) { +func (m *MockPersistenceService) ScheduleTask(arg0 context.Context, arg1 *sqlc.Worker) (*persistence.ScheduledTask, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ScheduleTask", arg0, arg1) - ret0, _ := ret[0].(*persistence.Task) + ret0, _ := ret[0].(*persistence.ScheduledTask) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -495,7 +525,7 @@ func (mr *MockPersistenceServiceMockRecorder) ScheduleTask(arg0, arg1 interface{ } // SetLastRendered mocks base method. -func (m *MockPersistenceService) SetLastRendered(arg0 context.Context, arg1 *persistence.Job) error { +func (m *MockPersistenceService) SetLastRendered(arg0 context.Context, arg1 string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SetLastRendered", arg0, arg1) ret0, _ := ret[0].(error) @@ -523,7 +553,7 @@ func (mr *MockPersistenceServiceMockRecorder) StoreAuthoredJob(arg0, arg1 interf } // TaskTouchedByWorker mocks base method. -func (m *MockPersistenceService) TaskTouchedByWorker(arg0 context.Context, arg1 *persistence.Task) error { +func (m *MockPersistenceService) TaskTouchedByWorker(arg0 context.Context, arg1 string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "TaskTouchedByWorker", arg0, arg1) ret0, _ := ret[0].(error) @@ -565,7 +595,7 @@ func (mr *MockPersistenceServiceMockRecorder) WorkerSetTags(arg0, arg1, arg2 int } // WorkersLeftToRun mocks base method. -func (m *MockPersistenceService) WorkersLeftToRun(arg0 context.Context, arg1 *persistence.Job, arg2 string) (map[string]bool, error) { +func (m *MockPersistenceService) WorkersLeftToRun(arg0 context.Context, arg1 *sqlc.Job, arg2 string) (map[string]bool, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "WorkersLeftToRun", arg0, arg1, arg2) ret0, _ := ret[0].(map[string]bool) @@ -1018,7 +1048,7 @@ func (m *MockTaskStateMachine) EXPECT() *MockTaskStateMachineMockRecorder { } // JobStatusChange mocks base method. -func (m *MockTaskStateMachine) JobStatusChange(arg0 context.Context, arg1 *persistence.Job, arg2 api.JobStatus, arg3 string) error { +func (m *MockTaskStateMachine) JobStatusChange(arg0 context.Context, arg1 string, arg2 api.JobStatus, arg3 string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "JobStatusChange", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(error) @@ -1046,7 +1076,7 @@ func (mr *MockTaskStateMachineMockRecorder) RequeueActiveTasksOfWorker(arg0, arg } // RequeueFailedTasksOfWorkerOfJob mocks base method. -func (m *MockTaskStateMachine) RequeueFailedTasksOfWorkerOfJob(arg0 context.Context, arg1 *sqlc.Worker, arg2 *persistence.Job, arg3 string) error { +func (m *MockTaskStateMachine) RequeueFailedTasksOfWorkerOfJob(arg0 context.Context, arg1 *sqlc.Worker, arg2, arg3 string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RequeueFailedTasksOfWorkerOfJob", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(error) @@ -1060,7 +1090,7 @@ func (mr *MockTaskStateMachineMockRecorder) RequeueFailedTasksOfWorkerOfJob(arg0 } // TaskStatusChange mocks base method. -func (m *MockTaskStateMachine) TaskStatusChange(arg0 context.Context, arg1 *persistence.Task, arg2 api.TaskStatus) error { +func (m *MockTaskStateMachine) TaskStatusChange(arg0 context.Context, arg1 *sqlc.Task, arg2 api.TaskStatus) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "TaskStatusChange", arg0, arg1, arg2) ret0, _ := ret[0].(error) @@ -1390,7 +1420,7 @@ func (m *MockJobDeleter) EXPECT() *MockJobDeleterMockRecorder { } // QueueJobDeletion mocks base method. -func (m *MockJobDeleter) QueueJobDeletion(arg0 context.Context, arg1 *persistence.Job) error { +func (m *MockJobDeleter) QueueJobDeletion(arg0 context.Context, arg1 *sqlc.Job) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "QueueJobDeletion", arg0, arg1) ret0, _ := ret[0].(error) @@ -1418,7 +1448,7 @@ func (mr *MockJobDeleterMockRecorder) QueueMassJobDeletion(arg0, arg1 interface{ } // WhatWouldBeDeleted mocks base method. -func (m *MockJobDeleter) WhatWouldBeDeleted(arg0 *persistence.Job) api.JobDeletionInfo { +func (m *MockJobDeleter) WhatWouldBeDeleted(arg0 *sqlc.Job) api.JobDeletionInfo { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "WhatWouldBeDeleted", arg0) ret0, _ := ret[0].(api.JobDeletionInfo) diff --git a/internal/manager/api_impl/worker_mgt.go b/internal/manager/api_impl/worker_mgt.go index 3a525c60..8882bbd9 100644 --- a/internal/manager/api_impl/worker_mgt.go +++ b/internal/manager/api_impl/worker_mgt.go @@ -66,7 +66,7 @@ func (f *Flamenco) FetchWorker(e echo.Context, workerUUID string) error { return sendAPIError(e, http.StatusInternalServerError, "error fetching worker tags: %v", err) } - dbTask, err := f.persist.FetchWorkerTask(ctx, dbWorker) + taskJob, err := f.persist.FetchWorkerTask(ctx, dbWorker) switch { case errors.Is(err, context.Canceled): return handleConnectionClosed(e, logger, "fetching task assigned to worker") @@ -78,10 +78,10 @@ func (f *Flamenco) FetchWorker(e echo.Context, workerUUID string) error { logger.Debug().Msg("fetched worker") apiWorker := workerDBtoAPI(*dbWorker) - if dbTask != nil { + if taskJob != nil { apiWorkerTask := api.WorkerTask{ - TaskSummary: taskDBtoSummary(dbTask), - JobId: dbTask.Job.UUID, + TaskSummary: taskDBtoSummaryAPI(taskJob.Task), + JobId: taskJob.JobUUID, } apiWorker.Task = &apiWorkerTask } diff --git a/internal/manager/api_impl/worker_mgt_test.go b/internal/manager/api_impl/worker_mgt_test.go index 1384255d..bbb0d12e 100644 --- a/internal/manager/api_impl/worker_mgt_test.go +++ b/internal/manager/api_impl/worker_mgt_test.go @@ -96,17 +96,23 @@ func TestFetchWorker(t *testing.T) { // Test with worker that does NOT have a status change requested, and DOES have an assigned task. mf.persistence.EXPECT().FetchWorker(gomock.Any(), workerUUID).Return(&worker, nil) + + assignedJob := persistence.Job{UUID: "f0e25ee4-0d13-4291-afc3-e9446b555aaf"} assignedTask := persistence.Task{ UUID: "806057d5-759a-4e75-86a4-356d43f28cff", Name: "test task", - Job: &persistence.Job{UUID: "f0e25ee4-0d13-4291-afc3-e9446b555aaf"}, Status: api.TaskStatusActive, } mf.persistence.EXPECT().FetchTagsOfWorker(gomock.Any(), workerUUID).Return([]persistence.WorkerTag{ {UUID: "0e701402-c4cc-49b0-8b8c-3eb8718d463a", Name: "EEVEE"}, {UUID: "59211f0a-81cc-4148-b0b7-32b3e2dcdb8f", Name: "Cycles"}, }, nil) - mf.persistence.EXPECT().FetchWorkerTask(gomock.Any(), &worker).Return(&assignedTask, nil) + assignedTaskJob := persistence.TaskJob{ + Task: assignedTask, + JobUUID: assignedJob.UUID, + IsActive: true, + } + mf.persistence.EXPECT().FetchWorkerTask(gomock.Any(), &worker).Return(&assignedTaskJob, nil) echo = mf.prepareMockedRequest(nil) err = mf.flamenco.FetchWorker(echo, workerUUID) @@ -127,7 +133,7 @@ func TestFetchWorker(t *testing.T) { Name: assignedTask.Name, Status: assignedTask.Status, }, - JobId: assignedTask.Job.UUID, + JobId: assignedJob.UUID, }, Tags: &[]api.WorkerTag{ {Id: ptr("0e701402-c4cc-49b0-8b8c-3eb8718d463a"), Name: "EEVEE"}, diff --git a/internal/manager/api_impl/worker_task_updates.go b/internal/manager/api_impl/worker_task_updates.go index 7d132386..eac08378 100644 --- a/internal/manager/api_impl/worker_task_updates.go +++ b/internal/manager/api_impl/worker_task_updates.go @@ -29,7 +29,7 @@ func (f *Flamenco) TaskUpdate(e echo.Context, taskID string) error { // Fetch the task, to see if this worker is even allowed to send us updates. ctx := e.Request().Context() - dbTask, err := f.persist.FetchTask(ctx, taskID) + taskJobWorker, err := f.persist.FetchTask(ctx, taskID) if err != nil { logger.Warn().Err(err).Msg("cannot fetch task") if errors.Is(err, persistence.ErrTaskNotFound) { @@ -37,9 +37,6 @@ func (f *Flamenco) TaskUpdate(e echo.Context, taskID string) error { } return sendAPIError(e, http.StatusInternalServerError, "error fetching task") } - if dbTask == nil { - panic("task could not be fetched, but database gave no error either") - } // Decode the request body. var taskUpdate api.TaskUpdate @@ -47,13 +44,15 @@ func (f *Flamenco) TaskUpdate(e echo.Context, taskID string) error { logger.Warn().Err(err).Msg("bad request received") return sendAPIError(e, http.StatusBadRequest, "invalid format") } - if dbTask.WorkerID == nil { + if !taskJobWorker.Task.WorkerID.Valid { logger.Warn(). Msg("worker trying to update task that's not assigned to any worker") return sendAPIError(e, http.StatusConflict, "task %+v is not assigned to any worker, so also not to you", taskID) } - if *dbTask.WorkerID != uint(worker.ID) { - logger.Warn().Msg("worker trying to update task that's assigned to another worker") + if taskJobWorker.Task.WorkerID.Int64 != worker.ID { + logger.Warn(). + Str("assignedToWorker", taskJobWorker.WorkerUUID). + Msg("worker trying to update task that's assigned to another worker") return sendAPIError(e, http.StatusConflict, "task %+v is not assigned to you", taskID) } @@ -70,8 +69,8 @@ func (f *Flamenco) TaskUpdate(e echo.Context, taskID string) error { bgCtx, bgCtxCancel := bgContext() defer bgCtxCancel() - taskUpdateErr := f.doTaskUpdate(bgCtx, logger, worker, dbTask, taskUpdate) - workerUpdateErr := f.workerPingedTask(logger, dbTask) + taskUpdateErr := f.doTaskUpdate(bgCtx, logger, taskJobWorker.JobUUID, worker, &taskJobWorker.Task, taskUpdate) + workerUpdateErr := f.workerPingedTask(logger, taskJobWorker.Task.UUID) workerSeenErr := f.workerSeen(logger, worker) if taskUpdateErr != nil { @@ -91,14 +90,11 @@ func (f *Flamenco) TaskUpdate(e echo.Context, taskID string) error { func (f *Flamenco) doTaskUpdate( ctx context.Context, logger zerolog.Logger, + jobUUID string, w *persistence.Worker, dbTask *persistence.Task, update api.TaskUpdate, ) error { - if dbTask.Job == nil { - logger.Panic().Msg("dbTask.Job is nil, unable to continue") - } - var dbErrActivity error if update.Activity != nil { @@ -113,7 +109,7 @@ func (f *Flamenco) doTaskUpdate( // Manager in response to a status change, should be logged after that. if update.Log != nil { // Errors writing the log to disk are already logged by logStorage, and can be safely ignored here. - _ = f.logStorage.Write(logger, dbTask.Job.UUID, dbTask.UUID, *update.Log) + _ = f.logStorage.Write(logger, jobUUID, dbTask.UUID, *update.Log) } if update.TaskStatus == nil { @@ -124,7 +120,7 @@ func (f *Flamenco) doTaskUpdate( var err error if *update.TaskStatus == api.TaskStatusFailed { // Failure is more complex than just going to the failed state. - err = f.onTaskFailed(ctx, logger, w, dbTask, update) + err = f.onTaskFailed(ctx, logger, w, dbTask, jobUUID, update) } else { // Just go to the given state. err = f.stateMachine.TaskStatusChange(ctx, dbTask, *update.TaskStatus) @@ -150,6 +146,7 @@ func (f *Flamenco) onTaskFailed( logger zerolog.Logger, worker *persistence.Worker, task *persistence.Task, + jobUUID string, update api.TaskUpdate, ) error { // Sanity check. @@ -164,18 +161,18 @@ func (f *Flamenco) onTaskFailed( } logger = logger.With().Str("taskType", task.Type).Logger() - wasBlacklisted, shoudlFailJob, err := f.maybeBlocklistWorker(ctx, logger, worker, task) + wasBlacklisted, shoudlFailJob, err := f.maybeBlocklistWorker(ctx, logger, worker, jobUUID, task) if err != nil { return fmt.Errorf("block-listing worker: %w", err) } if shoudlFailJob { // There are no more workers left to finish the job. - return f.failJobAfterCatastroficTaskFailure(ctx, logger, worker, task) + return f.failJobAfterCatastroficTaskFailure(ctx, logger, worker, jobUUID, task) } if wasBlacklisted { // Requeue all tasks of this job & task type that were hard-failed before by this worker. reason := fmt.Sprintf("worker %s was blocked from tasks of type %q", worker.Name, task.Type) - err := f.stateMachine.RequeueFailedTasksOfWorkerOfJob(ctx, worker, task.Job, reason) + err := f.stateMachine.RequeueFailedTasksOfWorkerOfJob(ctx, worker, jobUUID, reason) if err != nil { return err } @@ -189,7 +186,7 @@ func (f *Flamenco) onTaskFailed( Logger() if numFailed >= threshold { - return f.hardFailTask(ctx, logger, worker, task, numFailed) + return f.hardFailTask(ctx, logger, worker, jobUUID, task, numFailed) } numWorkers, err := f.numWorkersCapableOfRunningTask(ctx, task) @@ -203,9 +200,9 @@ func (f *Flamenco) onTaskFailed( // and thus it is still counted. // In such condition we should just fail the job itself. if numWorkers <= 1 { - return f.failJobAfterCatastroficTaskFailure(ctx, logger, worker, task) + return f.failJobAfterCatastroficTaskFailure(ctx, logger, worker, jobUUID, task) } - return f.softFailTask(ctx, logger, worker, task, numFailed) + return f.softFailTask(ctx, logger, worker, jobUUID, task, numFailed) } // maybeBlocklistWorker potentially block-lists the Worker, and checks whether @@ -218,11 +215,12 @@ func (f *Flamenco) maybeBlocklistWorker( ctx context.Context, logger zerolog.Logger, worker *persistence.Worker, + jobUUID string, task *persistence.Task, ) (wasBlacklisted, shouldFailJob bool, err error) { - numFailures, err := f.persist.CountTaskFailuresOfWorker(ctx, task.Job, worker, task.Type) + numFailures, err := f.persist.CountTaskFailuresOfWorker(ctx, jobUUID, worker.ID, task.Type) if err != nil { - return false, false, fmt.Errorf("counting failures of worker on job %q, task type %q: %w", task.Job.UUID, task.Type, err) + return false, false, fmt.Errorf("counting failures of worker on job %q, task type %q: %w", jobUUID, task.Type, err) } // The received task update hasn't been persisted in the database yet, // so we should count that too. @@ -238,7 +236,7 @@ func (f *Flamenco) maybeBlocklistWorker( } // Blocklist the Worker. - if err := f.blocklistWorker(ctx, logger, worker, task); err != nil { + if err := f.blocklistWorker(ctx, logger, worker, jobUUID, task); err != nil { return true, false, err } @@ -251,12 +249,13 @@ func (f *Flamenco) blocklistWorker( ctx context.Context, logger zerolog.Logger, worker *persistence.Worker, + jobUUID string, task *persistence.Task, ) error { logger.Warn(). - Str("job", task.Job.UUID). + Str("job", jobUUID). Msg("block-listing worker") - err := f.persist.AddWorkerToJobBlocklist(ctx, task.Job, worker, task.Type) + err := f.persist.AddWorkerToJobBlocklist(ctx, task.JobID, worker.ID, task.Type) if err != nil { return fmt.Errorf("adding worker to block list: %w", err) } @@ -264,11 +263,16 @@ func (f *Flamenco) blocklistWorker( } func (f *Flamenco) numWorkersCapableOfRunningTask(ctx context.Context, task *persistence.Task) (int, error) { + job, err := f.persist.FetchJobByID(ctx, task.JobID) + if err != nil { + return 0, fmt.Errorf("fetching job of task %s: %w", task.UUID, err) + } + // See which workers are left to run tasks of this type, on this job, - workersLeft, err := f.persist.WorkersLeftToRun(ctx, task.Job, task.Type) + workersLeft, err := f.persist.WorkersLeftToRun(ctx, job, task.Type) if err != nil { return 0, fmt.Errorf("fetching workers available to run tasks of type %q on job %q: %w", - task.Job.UUID, task.Type, err) + job.UUID, task.Type, err) } // Remove (from the list of available workers) those who failed this task before. @@ -290,13 +294,14 @@ func (f *Flamenco) failJobAfterCatastroficTaskFailure( ctx context.Context, logger zerolog.Logger, worker *persistence.Worker, + jobUUID string, task *persistence.Task, ) error { taskLog := fmt.Sprintf( "Task failed by worker %s, Manager will fail the entire job as there are no more workers left for tasks of type %q.", worker.Identifier(), task.Type, ) - if err := f.logStorage.WriteTimestamped(logger, task.Job.UUID, task.UUID, taskLog); err != nil { + if err := f.logStorage.WriteTimestamped(logger, jobUUID, task.UUID, taskLog); err != nil { logger.Error().Err(err).Msg("error writing failure notice to task log") } @@ -309,17 +314,18 @@ func (f *Flamenco) failJobAfterCatastroficTaskFailure( newJobStatus := api.JobStatusFailed logger.Info(). - Str("job", task.Job.UUID). + Str("job", jobUUID). Str("newJobStatus", string(newJobStatus)). Msg("no more workers left to run tasks of this type, failing the entire job") reason := fmt.Sprintf("no more workers left to run tasks of type %q", task.Type) - return f.stateMachine.JobStatusChange(ctx, task.Job, newJobStatus, reason) + return f.stateMachine.JobStatusChange(ctx, jobUUID, newJobStatus, reason) } func (f *Flamenco) hardFailTask( ctx context.Context, logger zerolog.Logger, worker *persistence.Worker, + jobUUID string, task *persistence.Task, numFailed int, ) error { @@ -329,7 +335,7 @@ func (f *Flamenco) hardFailTask( "Task failed by %s, Manager will mark it as hard failure", pluralizer.Pluralize("worker", numFailed, true), ) - if err := f.logStorage.WriteTimestamped(logger, task.Job.UUID, task.UUID, taskLog); err != nil { + if err := f.logStorage.WriteTimestamped(logger, jobUUID, task.UUID, taskLog); err != nil { logger.Error().Err(err).Msg("error writing failure notice to task log") } @@ -343,6 +349,7 @@ func (f *Flamenco) softFailTask( ctx context.Context, logger zerolog.Logger, worker *persistence.Worker, + jobUUID string, task *persistence.Task, numFailed int, ) error { @@ -357,7 +364,7 @@ func (f *Flamenco) softFailTask( failsToThreshold, pluralizer.Pluralize("failure", failsToThreshold, false), ) - if err := f.logStorage.WriteTimestamped(logger, task.Job.UUID, task.UUID, taskLog); err != nil { + if err := f.logStorage.WriteTimestamped(logger, jobUUID, task.UUID, taskLog); err != nil { logger.Error().Err(err).Msg("error writing failure notice to task log") } diff --git a/internal/manager/api_impl/worker_task_updates_test.go b/internal/manager/api_impl/worker_task_updates_test.go index 0cc6fd2d..1d142895 100644 --- a/internal/manager/api_impl/worker_task_updates_test.go +++ b/internal/manager/api_impl/worker_task_updates_test.go @@ -4,6 +4,7 @@ package api_impl import ( "context" + "database/sql" "testing" "github.com/golang/mock/gomock" @@ -32,17 +33,21 @@ func TestTaskUpdate(t *testing.T) { // Construct the task that's supposed to be updated. taskID := "181eab68-1123-4790-93b1-94309a899411" jobID := "e4719398-7cfa-4877-9bab-97c2d6c158b5" - mockJob := persistence.Job{UUID: jobID} + mockJob := persistence.Job{ID: 1234, UUID: jobID} mockTask := persistence.Task{ UUID: taskID, - Worker: &worker, - WorkerID: ptr(uint(worker.ID)), - Job: &mockJob, + WorkerID: sql.NullInt64{Int64: worker.ID, Valid: true}, + JobID: mockJob.ID, Activity: "pre-update activity", } // Expect the task to be fetched. - mf.persistence.EXPECT().FetchTask(gomock.Any(), taskID).Return(&mockTask, nil) + taskJobWorker := persistence.TaskJobWorker{ + Task: mockTask, + JobUUID: jobID, + WorkerUUID: worker.UUID, + } + mf.persistence.EXPECT().FetchTask(gomock.Any(), taskID).Return(taskJobWorker, nil) // Expect the task status change to be handed to the state machine. var statusChangedtask persistence.Task @@ -64,10 +69,10 @@ func TestTaskUpdate(t *testing.T) { mf.logStorage.EXPECT().Write(gomock.Any(), jobID, taskID, "line1\nline2\n") // Expect a 'touch' of the task. - var touchedTask persistence.Task + var touchedTaskUUID string mf.persistence.EXPECT().TaskTouchedByWorker(gomock.Any(), gomock.Any()).DoAndReturn( - func(ctx context.Context, task *persistence.Task) error { - touchedTask = *task + func(ctx context.Context, taskUUID string) error { + touchedTaskUUID = taskUUID return nil }) mf.persistence.EXPECT().WorkerSeen(gomock.Any(), &worker) @@ -81,7 +86,7 @@ func TestTaskUpdate(t *testing.T) { require.NoError(t, err) assert.Equal(t, mockTask.UUID, statusChangedtask.UUID) assert.Equal(t, mockTask.UUID, actUpdatedTask.UUID) - assert.Equal(t, mockTask.UUID, touchedTask.UUID) + assert.Equal(t, mockTask.UUID, touchedTaskUUID) assert.Equal(t, "testing", statusChangedtask.Activity) assert.Equal(t, "testing", actUpdatedTask.Activity) } @@ -101,12 +106,11 @@ func TestTaskUpdateFailed(t *testing.T) { // Construct the task that's supposed to be updated. taskID := "181eab68-1123-4790-93b1-94309a899411" jobID := "e4719398-7cfa-4877-9bab-97c2d6c158b5" - mockJob := persistence.Job{UUID: jobID} + mockJob := persistence.Job{ID: 1234, UUID: jobID} mockTask := persistence.Task{ UUID: taskID, - Worker: &worker, - WorkerID: ptr(uint(worker.ID)), - Job: &mockJob, + WorkerID: sql.NullInt64{Int64: worker.ID, Valid: true}, + JobID: mockJob.ID, Activity: "pre-update activity", Type: "misc", } @@ -121,20 +125,26 @@ func TestTaskUpdateFailed(t *testing.T) { const numSubTests = 2 // Expect the task to be fetched for each sub-test: - mf.persistence.EXPECT().FetchTask(gomock.Any(), taskID).Return(&mockTask, nil).Times(numSubTests) + taskJobWorker := persistence.TaskJobWorker{ + Task: mockTask, + JobUUID: jobID, + WorkerUUID: worker.UUID, + } + mf.persistence.EXPECT().FetchTask(gomock.Any(), taskID).Return(taskJobWorker, nil).Times(numSubTests) // Expect a 'touch' of the task for each sub-test: - mf.persistence.EXPECT().TaskTouchedByWorker(gomock.Any(), &mockTask).Times(numSubTests) + mf.persistence.EXPECT().TaskTouchedByWorker(gomock.Any(), taskID).Times(numSubTests) mf.persistence.EXPECT().WorkerSeen(gomock.Any(), &worker).Times(numSubTests) // Mimick that this is always first failure of this worker/job/tasktype combo: - mf.persistence.EXPECT().CountTaskFailuresOfWorker(gomock.Any(), &mockJob, &worker, "misc").Return(0, nil).Times(numSubTests) + mf.persistence.EXPECT().CountTaskFailuresOfWorker(gomock.Any(), jobID, worker.ID, "misc").Return(0, nil).Times(numSubTests) { // Expect the Worker to be added to the list of failed workers. // This returns 1, which is less than the failure threshold -> soft failure expected. mf.persistence.EXPECT().AddWorkerToTaskFailedList(gomock.Any(), &mockTask, &worker).Return(1, nil) + mf.persistence.EXPECT().FetchJobByID(gomock.Any(), mockTask.JobID).Return(&mockJob, nil) mf.persistence.EXPECT().WorkersLeftToRun(gomock.Any(), &mockJob, "misc"). Return(map[string]bool{"60453eec-5a26-43e9-9da2-d00506d492cc": true, "ce312357-29cd-4389-81ab-4d43e30945f8": true}, nil) mf.persistence.EXPECT().FetchTaskFailureList(gomock.Any(), &mockTask). @@ -185,12 +195,11 @@ func TestBlockingAfterFailure(t *testing.T) { // Construct the task that's supposed to be updated. taskID := "181eab68-1123-4790-93b1-94309a899411" jobID := "e4719398-7cfa-4877-9bab-97c2d6c158b5" - mockJob := persistence.Job{UUID: jobID} + mockJob := persistence.Job{ID: 1234, UUID: jobID} mockTask := persistence.Task{ UUID: taskID, - Worker: &worker, - WorkerID: ptr(uint(worker.ID)), - Job: &mockJob, + WorkerID: sql.NullInt64{Int64: worker.ID, Valid: true}, + JobID: mockJob.ID, Activity: "pre-update activity", Type: "misc", } @@ -205,26 +214,32 @@ func TestBlockingAfterFailure(t *testing.T) { const numSubTests = 3 // Expect the task to be fetched for each sub-test: - mf.persistence.EXPECT().FetchTask(gomock.Any(), taskID).Return(&mockTask, nil).Times(numSubTests) + taskJobWorker := persistence.TaskJobWorker{ + Task: mockTask, + JobUUID: jobID, + WorkerUUID: worker.UUID, + } + mf.persistence.EXPECT().FetchTask(gomock.Any(), taskID).Return(taskJobWorker, nil).Times(numSubTests) // Expect a 'touch' of the task for each sub-test: - mf.persistence.EXPECT().TaskTouchedByWorker(gomock.Any(), &mockTask).Times(numSubTests) + mf.persistence.EXPECT().TaskTouchedByWorker(gomock.Any(), taskID).Times(numSubTests) mf.persistence.EXPECT().WorkerSeen(gomock.Any(), &worker).Times(numSubTests) // Mimick that this is the 3rd of this worker/job/tasktype combo, and thus should trigger a block. // Returns 2 because there have been 2 previous failures. mf.persistence.EXPECT(). - CountTaskFailuresOfWorker(gomock.Any(), &mockJob, &worker, "misc"). + CountTaskFailuresOfWorker(gomock.Any(), jobID, worker.ID, "misc"). Return(2, nil). Times(numSubTests) // Expect the worker to be blocked. mf.persistence.EXPECT(). - AddWorkerToJobBlocklist(gomock.Any(), &mockJob, &worker, "misc"). + AddWorkerToJobBlocklist(gomock.Any(), mockJob.ID, worker.ID, "misc"). Times(numSubTests) { // Mimick that there is another worker to work on this task, so the job should continue happily. + mf.persistence.EXPECT().FetchJobByID(gomock.Any(), mockTask.JobID).Return(&mockJob, nil).Times(2) mf.persistence.EXPECT().WorkersLeftToRun(gomock.Any(), &mockJob, "misc"). Return(map[string]bool{"60453eec-5a26-43e9-9da2-d00506d492cc": true, "ce312357-29cd-4389-81ab-4d43e30945f8": true}, nil).Times(2) mf.persistence.EXPECT().FetchTaskFailureList(gomock.Any(), &mockTask). @@ -242,7 +257,7 @@ func TestBlockingAfterFailure(t *testing.T) { // Because the job didn't fail in its entirety, the tasks previously failed // by the Worker should be requeued so they can be picked up by another. mf.stateMachine.EXPECT().RequeueFailedTasksOfWorkerOfJob( - gomock.Any(), &worker, &mockJob, + gomock.Any(), &worker, jobID, "worker дрон was blocked from tasks of type \"misc\"") // Do the call. @@ -255,6 +270,7 @@ func TestBlockingAfterFailure(t *testing.T) { { // Test without any workers left to run these tasks on this job due to blocklisting. This should fail the entire job. + mf.persistence.EXPECT().FetchJobByID(gomock.Any(), mockTask.JobID).Return(&mockJob, nil) mf.persistence.EXPECT().WorkersLeftToRun(gomock.Any(), &mockJob, "misc"). Return(map[string]bool{}, nil) mf.persistence.EXPECT().FetchTaskFailureList(gomock.Any(), &mockTask). @@ -272,7 +288,7 @@ func TestBlockingAfterFailure(t *testing.T) { // Expect failure of the job. mf.stateMachine.EXPECT(). - JobStatusChange(gomock.Any(), &mockJob, api.JobStatusFailed, "no more workers left to run tasks of type \"misc\"") + JobStatusChange(gomock.Any(), jobID, api.JobStatusFailed, "no more workers left to run tasks of type \"misc\"") // Because the job failed, there is no need to re-queue any tasks previously failed by this worker. @@ -290,6 +306,7 @@ func TestBlockingAfterFailure(t *testing.T) { theOtherFailingWorker := persistence.Worker{ UUID: "ce312357-29cd-4389-81ab-4d43e30945f8", } + mf.persistence.EXPECT().FetchJobByID(gomock.Any(), mockTask.JobID).Return(&mockJob, nil) mf.persistence.EXPECT().WorkersLeftToRun(gomock.Any(), &mockJob, "misc"). Return(map[string]bool{theOtherFailingWorker.UUID: true}, nil) mf.persistence.EXPECT().FetchTaskFailureList(gomock.Any(), &mockTask). @@ -307,7 +324,7 @@ func TestBlockingAfterFailure(t *testing.T) { // Expect failure of the job. mf.stateMachine.EXPECT(). - JobStatusChange(gomock.Any(), &mockJob, api.JobStatusFailed, "no more workers left to run tasks of type \"misc\"") + JobStatusChange(gomock.Any(), jobID, api.JobStatusFailed, "no more workers left to run tasks of type \"misc\"") // Because the job failed, there is no need to re-queue any tasks previously failed by this worker. @@ -335,12 +352,11 @@ func TestJobFailureAfterWorkerTaskFailure(t *testing.T) { // Construct the task that's supposed to be updated. taskID := "181eab68-1123-4790-93b1-94309a899411" jobID := "e4719398-7cfa-4877-9bab-97c2d6c158b5" - mockJob := persistence.Job{UUID: jobID} + mockJob := persistence.Job{ID: 1234, UUID: jobID} mockTask := persistence.Task{ UUID: taskID, - Worker: &worker, - WorkerID: ptr(uint(worker.ID)), - Job: &mockJob, + WorkerID: sql.NullInt64{Int64: worker.ID, Valid: true}, + JobID: mockJob.ID, Activity: "pre-update activity", Type: "misc", } @@ -354,15 +370,21 @@ func TestJobFailureAfterWorkerTaskFailure(t *testing.T) { mf.config.EXPECT().Get().Return(&conf).Times(2) - mf.persistence.EXPECT().FetchTask(gomock.Any(), taskID).Return(&mockTask, nil) + taskJobWorker := persistence.TaskJobWorker{ + Task: mockTask, + JobUUID: jobID, + WorkerUUID: worker.UUID, + } + mf.persistence.EXPECT().FetchTask(gomock.Any(), taskID).Return(taskJobWorker, nil) - mf.persistence.EXPECT().TaskTouchedByWorker(gomock.Any(), &mockTask) + mf.persistence.EXPECT().TaskTouchedByWorker(gomock.Any(), taskID) mf.persistence.EXPECT().WorkerSeen(gomock.Any(), &worker) - mf.persistence.EXPECT().CountTaskFailuresOfWorker(gomock.Any(), &mockJob, &worker, "misc").Return(0, nil) + mf.persistence.EXPECT().CountTaskFailuresOfWorker(gomock.Any(), jobID, worker.ID, "misc").Return(0, nil) mf.persistence.EXPECT().AddWorkerToTaskFailedList(gomock.Any(), &mockTask, &worker).Return(1, nil) + mf.persistence.EXPECT().FetchJobByID(gomock.Any(), mockTask.JobID).Return(&mockJob, nil) mf.persistence.EXPECT().WorkersLeftToRun(gomock.Any(), &mockJob, "misc"). Return(map[string]bool{"e7632d62-c3b8-4af0-9e78-01752928952c": true}, nil) mf.persistence.EXPECT().FetchTaskFailureList(gomock.Any(), &mockTask). @@ -376,7 +398,7 @@ func TestJobFailureAfterWorkerTaskFailure(t *testing.T) { // Expect failure of the job. mf.stateMachine.EXPECT(). - JobStatusChange(gomock.Any(), &mockJob, api.JobStatusFailed, "no more workers left to run tasks of type \"misc\"") + JobStatusChange(gomock.Any(), jobID, api.JobStatusFailed, "no more workers left to run tasks of type \"misc\"") // Do the call echoCtx := mf.prepareMockedJSONRequest(taskUpdate) diff --git a/internal/manager/api_impl/workers.go b/internal/manager/api_impl/workers.go index 1feaae1e..d5f37284 100644 --- a/internal/manager/api_impl/workers.go +++ b/internal/manager/api_impl/workers.go @@ -4,6 +4,7 @@ package api_impl import ( "context" + "encoding/json" "errors" "fmt" "io" @@ -12,6 +13,7 @@ import ( "github.com/labstack/echo/v4" "github.com/rs/zerolog" + "github.com/rs/zerolog/log" "projects.blender.org/studio/flamenco/internal/manager/eventbus" "projects.blender.org/studio/flamenco/internal/manager/last_rendered" @@ -331,7 +333,7 @@ func (f *Flamenco) ScheduleTask(e echo.Context) error { } // Get a task to execute: - dbTask, err := f.persist.ScheduleTask(reqCtx, worker) + scheduledTask, err := f.persist.ScheduleTask(reqCtx, worker) if err != nil { if persistence.ErrIsDBBusy(err) { logger.Warn().Msg("database busy scheduling task for worker") @@ -340,7 +342,7 @@ func (f *Flamenco) ScheduleTask(e echo.Context) error { logger.Warn().Err(err).Msg("error scheduling task for worker") return sendAPIError(e, http.StatusInternalServerError, "internal error finding a task for you: %v", err) } - if dbTask == nil { + if scheduledTask == nil { return e.NoContent(http.StatusNoContent) } @@ -351,18 +353,18 @@ func (f *Flamenco) ScheduleTask(e echo.Context) error { // Add a note to the task log about the worker assignment. msg := fmt.Sprintf("Task assigned to worker %s (%s)", worker.Name, worker.UUID) - if err := f.logStorage.WriteTimestamped(logger, dbTask.Job.UUID, dbTask.UUID, msg); err != nil { + if err := f.logStorage.WriteTimestamped(logger, scheduledTask.JobUUID, scheduledTask.Task.UUID, msg); err != nil { return sendAPIError(e, http.StatusInternalServerError, "internal error appending to task log: %v", err) } // Move the task to 'active' status so that it won't be assigned to another // worker. This also enables the task timeout monitoring. - if err := f.stateMachine.TaskStatusChange(bgCtx, dbTask, api.TaskStatusActive); err != nil { + if err := f.stateMachine.TaskStatusChange(bgCtx, &scheduledTask.Task, api.TaskStatusActive); err != nil { return sendAPIError(e, http.StatusInternalServerError, "internal error marking task as active: %v", err) } // Start timeout measurement as soon as the Worker gets the task assigned. - if err := f.workerPingedTask(logger, dbTask); err != nil { + if err := f.workerPingedTask(logger, scheduledTask.Task.UUID); err != nil { return sendAPIError(e, http.StatusInternalServerError, "internal error updating task for timeout calculation: %v", err) } @@ -371,23 +373,23 @@ func (f *Flamenco) ScheduleTask(e echo.Context) error { f.broadcaster.BroadcastWorkerUpdate(update) // Convert database objects to API objects: - apiCommands := []api.Command{} - for _, cmd := range dbTask.Commands { - apiCommands = append(apiCommands, api.Command{ - Name: cmd.Name, - Parameters: cmd.Parameters, - }) - } apiTask := api.AssignedTask{ - Uuid: dbTask.UUID, - Commands: apiCommands, - Job: dbTask.Job.UUID, - JobPriority: dbTask.Job.Priority, - JobType: dbTask.Job.JobType, - Name: dbTask.Name, - Priority: dbTask.Priority, - Status: api.TaskStatus(dbTask.Status), - TaskType: dbTask.Type, + Uuid: scheduledTask.Task.UUID, + Job: scheduledTask.JobUUID, + JobPriority: int(scheduledTask.JobPriority), + JobType: scheduledTask.JobType, + Name: scheduledTask.Task.Name, + Priority: int(scheduledTask.Task.Priority), + Status: api.TaskStatus(scheduledTask.Task.Status), + TaskType: scheduledTask.Task.Type, + } + + if err := json.Unmarshal(scheduledTask.Task.Commands, &apiTask.Commands); err != nil { + log.Error(). + Str("task", scheduledTask.Task.UUID). + AnErr("cause", err). + Msg("could not parse task commands JSON") + return sendAPIError(e, http.StatusInternalServerError, "internal error parsing task commands JSON: %v", err) } // Perform variable replacement before sending to the Worker. @@ -423,19 +425,17 @@ func (f *Flamenco) TaskOutputProduced(e echo.Context, taskID string) error { } // Fetch the task, to find its job UUID: - dbTask, err := f.persist.FetchTask(ctx, taskID) + taskJobWorker, err := f.persist.FetchTask(ctx, taskID) switch { case errors.Is(err, persistence.ErrTaskNotFound): return e.JSON(http.StatusNotFound, "Task does not exist") case err != nil: logger.Error().Err(err).Msg("TaskOutputProduced: cannot fetch task") return sendAPIError(e, http.StatusInternalServerError, "error fetching task") - case dbTask == nil: - panic("task could not be fetched, but database gave no error either") } // Include the job UUID in the logger. - jobUUID := dbTask.Job.UUID + jobUUID := taskJobWorker.JobUUID logger = logger.With().Str("job", jobUUID).Logger() // Read the image bytes into memory. @@ -459,7 +459,7 @@ func (f *Flamenco) TaskOutputProduced(e echo.Context, taskID string) error { Callback: func(ctx context.Context) { // Store this job as the last one to get a rendered image. - err := f.persist.SetLastRendered(ctx, dbTask.Job) + err := f.persist.SetLastRendered(ctx, taskJobWorker.JobUUID) if err != nil { logger.Error().Err(err).Msg("TaskOutputProduced: error marking this job as the last one to receive render output") } @@ -497,12 +497,12 @@ func (f *Flamenco) TaskOutputProduced(e echo.Context, taskID string) error { func (f *Flamenco) workerPingedTask( logger zerolog.Logger, - task *persistence.Task, + taskUUID string, ) error { bgCtx, bgCtxCancel := bgContext() defer bgCtxCancel() - err := f.persist.TaskTouchedByWorker(bgCtx, task) + err := f.persist.TaskTouchedByWorker(bgCtx, taskUUID) if err != nil { logger.Error().Err(err).Msg("error marking task as 'touched' by worker") return err @@ -549,7 +549,7 @@ func (f *Flamenco) MayWorkerRun(e echo.Context, taskID string) error { // Fetch the task, to see if this worker is allowed to run it. ctx := e.Request().Context() - dbTask, err := f.persist.FetchTask(ctx, taskID) + taskJobWorker, err := f.persist.FetchTask(ctx, taskID) if err != nil { if errors.Is(err, persistence.ErrTaskNotFound) { mkr := api.MayKeepRunning{Reason: "Task not found"} @@ -558,16 +558,13 @@ func (f *Flamenco) MayWorkerRun(e echo.Context, taskID string) error { logger.Error().Err(err).Msg("MayWorkerRun: cannot fetch task") return sendAPIError(e, http.StatusInternalServerError, "error fetching task") } - if dbTask == nil { - panic("task could not be fetched, but database gave no error either") - } - mkr := mayWorkerRun(worker, dbTask) + mkr := mayWorkerRun(worker, &taskJobWorker.Task) // Errors saving the "worker pinged task" and "worker seen" fields in the // database are just logged. It's not something to bother the worker with. if mkr.MayKeepRunning { - _ = f.workerPingedTask(logger, dbTask) + _ = f.workerPingedTask(logger, taskJobWorker.Task.UUID) } _ = f.workerSeen(logger, worker) @@ -582,7 +579,7 @@ func mayWorkerRun(worker *persistence.Worker, dbTask *persistence.Task) api.MayK StatusChangeRequested: true, } } - if dbTask.WorkerID == nil || *dbTask.WorkerID != uint(worker.ID) { + if !dbTask.WorkerID.Valid || dbTask.WorkerID.Int64 != worker.ID { return api.MayKeepRunning{Reason: "task not assigned to this worker"} } if !task_state_machine.IsRunnableTaskStatus(dbTask.Status) { diff --git a/internal/manager/api_impl/workers_test.go b/internal/manager/api_impl/workers_test.go index cee31635..14dc9d79 100644 --- a/internal/manager/api_impl/workers_test.go +++ b/internal/manager/api_impl/workers_test.go @@ -5,6 +5,7 @@ package api_impl import ( "bytes" "context" + "database/sql" "io" "net/http" "testing" @@ -32,22 +33,28 @@ func TestTaskScheduleHappy(t *testing.T) { // Expect a call into the persistence layer, which should return a scheduled task. job := persistence.Job{ - UUID: "583a7d59-887a-4c6c-b3e4-a753018f71b0", + ID: 1234, + UUID: "583a7d59-887a-4c6c-b3e4-a753018f71b0", + Priority: 47, + JobType: "simple-blender-test", } task := persistence.Task{ - UUID: "4107c7aa-e86d-4244-858b-6c4fce2af503", - Job: &job, - Commands: []persistence.Command{ - {Name: "test", Parameters: map[string]interface{}{ - "param": "prefix-{variable}-suffix", - }}, - }, + UUID: "4107c7aa-e86d-4244-858b-6c4fce2af503", + JobID: job.ID, + Priority: 327, + Commands: []byte(`[{"name": "test", "parameters": {"param": "prefix-{variable}-suffix"}}]`), + } + scheduledTask := persistence.ScheduledTask{ + Task: task, + JobUUID: job.UUID, + JobPriority: job.Priority, + JobType: job.JobType, } ctx := echo.Request().Context() bgCtx := gomock.Not(ctx) - mf.persistence.EXPECT().ScheduleTask(ctx, &worker).Return(&task, nil) - mf.persistence.EXPECT().TaskTouchedByWorker(bgCtx, &task) + mf.persistence.EXPECT().ScheduleTask(ctx, &worker).Return(&scheduledTask, nil) + mf.persistence.EXPECT().TaskTouchedByWorker(bgCtx, task.UUID) mf.persistence.EXPECT().WorkerSeen(bgCtx, &worker) mf.expectExpandVariables(t, config.VariableAudienceWorkers, @@ -66,8 +73,11 @@ func TestTaskScheduleHappy(t *testing.T) { // Check the response assignedTask := api.AssignedTask{ - Uuid: task.UUID, - Job: job.UUID, + Uuid: task.UUID, + Job: job.UUID, + JobType: "simple-blender-test", + JobPriority: 47, + Priority: 327, Commands: []api.Command{ {Name: "test", Parameters: map[string]interface{}{ "param": "prefix-value-suffix", @@ -493,17 +503,16 @@ func TestMayWorkerRun(t *testing.T) { } job := persistence.Job{ + ID: 1234, UUID: "583a7d59-887a-4c6c-b3e4-a753018f71b0", } task := persistence.Task{ UUID: "4107c7aa-e86d-4244-858b-6c4fce2af503", - Job: &job, + JobID: job.ID, Status: api.TaskStatusActive, } - mf.persistence.EXPECT().FetchTask(gomock.Any(), task.UUID).Return(&task, nil).AnyTimes() - // Expect the worker to be marked as 'seen' regardless of whether it may run // its current task or not, so equal to the number of calls to // `MayWorkerRun()` below. @@ -511,6 +520,13 @@ func TestMayWorkerRun(t *testing.T) { // Test: unhappy, task unassigned { + taskJobWorker := persistence.TaskJobWorker{ + Task: task, + JobUUID: job.UUID, + WorkerUUID: "", + } + mf.persistence.EXPECT().FetchTask(gomock.Any(), task.UUID).Return(taskJobWorker, nil) + echo := prepareRequest() err := mf.flamenco.MayWorkerRun(echo, task.UUID) require.NoError(t, err) @@ -522,11 +538,18 @@ func TestMayWorkerRun(t *testing.T) { // Test: happy, task assigned to this worker. { + task.WorkerID = sql.NullInt64{Int64: worker.ID, Valid: true} + taskJobWorker := persistence.TaskJobWorker{ + Task: task, + JobUUID: job.UUID, + WorkerUUID: worker.UUID, + } + mf.persistence.EXPECT().FetchTask(gomock.Any(), task.UUID).Return(taskJobWorker, nil) + // Expect a 'touch' of the task. - mf.persistence.EXPECT().TaskTouchedByWorker(gomock.Any(), &task).Return(nil) + mf.persistence.EXPECT().TaskTouchedByWorker(gomock.Any(), task.UUID).Return(nil) echo := prepareRequest() - task.WorkerID = ptr(uint(worker.ID)) err := mf.flamenco.MayWorkerRun(echo, task.UUID) require.NoError(t, err) assertResponseJSON(t, echo, http.StatusOK, api.MayKeepRunning{ @@ -536,9 +559,16 @@ func TestMayWorkerRun(t *testing.T) { // Test: unhappy, assigned but cancelled. { - echo := prepareRequest() - task.WorkerID = ptr(uint(worker.ID)) + task.WorkerID = sql.NullInt64{Int64: worker.ID, Valid: true} task.Status = api.TaskStatusCanceled + taskJobWorker := persistence.TaskJobWorker{ + Task: task, + JobUUID: job.UUID, + WorkerUUID: worker.UUID, + } + mf.persistence.EXPECT().FetchTask(gomock.Any(), task.UUID).Return(taskJobWorker, nil) + + echo := prepareRequest() err := mf.flamenco.MayWorkerRun(echo, task.UUID) require.NoError(t, err) assertResponseJSON(t, echo, http.StatusOK, api.MayKeepRunning{ @@ -550,9 +580,16 @@ func TestMayWorkerRun(t *testing.T) { // Test: unhappy, assigned and runnable but worker should go to bed. { worker.StatusChangeRequest(api.WorkerStatusAsleep, false) - echo := prepareRequest() - task.WorkerID = ptr(uint(worker.ID)) + task.WorkerID = sql.NullInt64{Int64: worker.ID, Valid: true} task.Status = api.TaskStatusActive + taskJobWorker := persistence.TaskJobWorker{ + Task: task, + JobUUID: job.UUID, + WorkerUUID: worker.UUID, + } + mf.persistence.EXPECT().FetchTask(gomock.Any(), task.UUID).Return(taskJobWorker, nil) + + echo := prepareRequest() err := mf.flamenco.MayWorkerRun(echo, task.UUID) require.NoError(t, err) assertResponseJSON(t, echo, http.StatusOK, api.MayKeepRunning{ @@ -564,13 +601,20 @@ func TestMayWorkerRun(t *testing.T) { // Test: happy, assigned and runnable; worker should go to bed after task is finished. { - // Expect a 'touch' of the task. - mf.persistence.EXPECT().TaskTouchedByWorker(gomock.Any(), &task).Return(nil) - worker.StatusChangeRequest(api.WorkerStatusAsleep, true) - echo := prepareRequest() - task.WorkerID = ptr(uint(worker.ID)) + task.WorkerID = sql.NullInt64{Int64: worker.ID, Valid: true} task.Status = api.TaskStatusActive + taskJobWorker := persistence.TaskJobWorker{ + Task: task, + JobUUID: job.UUID, + WorkerUUID: worker.UUID, + } + mf.persistence.EXPECT().FetchTask(gomock.Any(), task.UUID).Return(taskJobWorker, nil) + + // Expect a 'touch' of the task. + mf.persistence.EXPECT().TaskTouchedByWorker(gomock.Any(), task.UUID).Return(nil) + + echo := prepareRequest() err := mf.flamenco.MayWorkerRun(echo, task.UUID) require.NoError(t, err) assertResponseJSON(t, echo, http.StatusOK, api.MayKeepRunning{ @@ -593,13 +637,19 @@ func TestTaskOutputProduced(t *testing.T) { } job := persistence.Job{ + ID: 1234, UUID: "583a7d59-887a-4c6c-b3e4-a753018f71b0", } task := persistence.Task{ UUID: "4107c7aa-e86d-4244-858b-6c4fce2af503", - Job: &job, + JobID: job.ID, Status: api.TaskStatusActive, } + taskJobWorker := persistence.TaskJobWorker{ + Task: task, + JobUUID: job.UUID, + WorkerUUID: worker.UUID, + } // Mock body to use in the request. bodyBytes := []byte("JPEG file contents") @@ -640,7 +690,7 @@ func TestTaskOutputProduced(t *testing.T) { // Test: unhappy, wrong mime type { mf.persistence.EXPECT().WorkerSeen(gomock.Any(), &worker) - mf.persistence.EXPECT().FetchTask(gomock.Any(), task.UUID).Return(&task, nil) + mf.persistence.EXPECT().FetchTask(gomock.Any(), task.UUID).Return(taskJobWorker, nil) echo := prepareRequest(bytes.NewReader(bodyBytes)) echo.Request().Header.Set("Content-Type", "image/openexr") @@ -654,7 +704,7 @@ func TestTaskOutputProduced(t *testing.T) { // Test: unhappy, queue full { mf.persistence.EXPECT().WorkerSeen(gomock.Any(), &worker) - mf.persistence.EXPECT().FetchTask(gomock.Any(), task.UUID).Return(&task, nil) + mf.persistence.EXPECT().FetchTask(gomock.Any(), task.UUID).Return(taskJobWorker, nil) echo := prepareRequest(bytes.NewReader(bodyBytes)) mf.lastRender.EXPECT().QueueImage(gomock.Any()).Return(last_rendered.ErrQueueFull) @@ -667,7 +717,7 @@ func TestTaskOutputProduced(t *testing.T) { // Test: happy { mf.persistence.EXPECT().WorkerSeen(gomock.Any(), &worker) - mf.persistence.EXPECT().FetchTask(gomock.Any(), task.UUID).Return(&task, nil) + mf.persistence.EXPECT().FetchTask(gomock.Any(), task.UUID).Return(taskJobWorker, nil) // Don't expect persistence.SetLastRendered(...) quite yet. That should be // called after the image processing is done. @@ -691,7 +741,7 @@ func TestTaskOutputProduced(t *testing.T) { if assert.NotNil(t, actualPayload) { ctx := context.Background() - mf.persistence.EXPECT().SetLastRendered(ctx, &job) + mf.persistence.EXPECT().SetLastRendered(ctx, job.UUID) expectBroadcast := api.EventLastRenderedUpdate{ JobId: job.UUID, diff --git a/internal/manager/eventbus/events_jobs.go b/internal/manager/eventbus/events_jobs.go index b399b0ed..e4fdf8a6 100644 --- a/internal/manager/eventbus/events_jobs.go +++ b/internal/manager/eventbus/events_jobs.go @@ -17,10 +17,10 @@ func NewJobUpdate(job *persistence.Job) api.EventJobUpdate { jobUpdate := api.EventJobUpdate{ Id: job.UUID, Name: &job.Name, - Updated: job.UpdatedAt, + Updated: job.UpdatedAt.Time, Status: job.Status, Type: job.JobType, - Priority: job.Priority, + Priority: int(job.Priority), } if job.DeleteRequestedAt.Valid { @@ -34,14 +34,12 @@ func NewJobUpdate(job *persistence.Job) api.EventJobUpdate { // fills in the fields that represent the current state of the task. For // example, it omits `PreviousStatus`. The omitted fields can be filled in by // the caller. -// -// Assumes task.Job is not nil. -func NewTaskUpdate(task *persistence.Task) api.EventTaskUpdate { +func NewTaskUpdate(task persistence.Task, jobUUID string) api.EventTaskUpdate { taskUpdate := api.EventTaskUpdate{ Id: task.UUID, - JobId: task.Job.UUID, + JobId: jobUUID, Name: task.Name, - Updated: task.UpdatedAt, + Updated: task.UpdatedAt.Time, Status: task.Status, Activity: task.Activity, } diff --git a/internal/manager/job_deleter/job_deleter.go b/internal/manager/job_deleter/job_deleter.go index 42659a61..6855ded1 100644 --- a/internal/manager/job_deleter/job_deleter.go +++ b/internal/manager/job_deleter/job_deleter.go @@ -250,7 +250,7 @@ func (s *Service) canDeleteShamanCheckout(logger zerolog.Logger, job *persistenc return false } - checkoutID := job.Storage.ShamanCheckoutID + checkoutID := job.StorageShamanCheckoutID if checkoutID == "" { logger.Debug().Msg("job deleter: job was not created with Shaman (or before Flamenco v3.2), cannot delete job files") return false diff --git a/internal/manager/job_deleter/mocks/interfaces_mock.gen.go b/internal/manager/job_deleter/mocks/interfaces_mock.gen.go index c927c5b8..e5165d60 100644 --- a/internal/manager/job_deleter/mocks/interfaces_mock.gen.go +++ b/internal/manager/job_deleter/mocks/interfaces_mock.gen.go @@ -10,7 +10,7 @@ import ( time "time" gomock "github.com/golang/mock/gomock" - persistence "projects.blender.org/studio/flamenco/internal/manager/persistence" + sqlc "projects.blender.org/studio/flamenco/internal/manager/persistence/sqlc" api "projects.blender.org/studio/flamenco/pkg/api" ) @@ -52,10 +52,10 @@ func (mr *MockPersistenceServiceMockRecorder) DeleteJob(arg0, arg1 interface{}) } // FetchJob mocks base method. -func (m *MockPersistenceService) FetchJob(arg0 context.Context, arg1 string) (*persistence.Job, error) { +func (m *MockPersistenceService) FetchJob(arg0 context.Context, arg1 string) (*sqlc.Job, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "FetchJob", arg0, arg1) - ret0, _ := ret[0].(*persistence.Job) + ret0, _ := ret[0].(*sqlc.Job) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -109,7 +109,7 @@ func (mr *MockPersistenceServiceMockRecorder) RequestIntegrityCheck() *gomock.Ca } // RequestJobDeletion mocks base method. -func (m *MockPersistenceService) RequestJobDeletion(arg0 context.Context, arg1 *persistence.Job) error { +func (m *MockPersistenceService) RequestJobDeletion(arg0 context.Context, arg1 *sqlc.Job) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "RequestJobDeletion", arg0, arg1) ret0, _ := ret[0].(error) diff --git a/internal/manager/persistence/jobs.go b/internal/manager/persistence/jobs.go index d70dc515..1e19d405 100644 --- a/internal/manager/persistence/jobs.go +++ b/internal/manager/persistence/jobs.go @@ -19,69 +19,27 @@ import ( "projects.blender.org/studio/flamenco/pkg/api" ) -type Job struct { - Model - UUID string +type Job = sqlc.Job +type Task = sqlc.Task - Name string - JobType string - Priority int - Status api.JobStatus - Activity string +// TaskJobWorker represents a task, with identifieres for its job and the worker it's assigned to. +type TaskJobWorker struct { + Task Task + JobUUID string + WorkerUUID string +} - Settings StringInterfaceMap - Metadata StringStringMap - - DeleteRequestedAt sql.NullTime - - Storage JobStorageInfo - - WorkerTagID *uint - WorkerTag *WorkerTag +// TaskJob represents a task, with identifier for its job. +type TaskJob struct { + Task Task + JobUUID string + IsActive bool // Whether the worker assigned to this task is actually working on it. } type StringInterfaceMap map[string]interface{} type StringStringMap map[string]string -// DeleteRequested returns whether deletion of this job was requested. -func (j *Job) DeleteRequested() bool { - return j.DeleteRequestedAt.Valid -} - -// JobStorageInfo contains info about where the job files are stored. It is -// intended to be used when removing a job, which may include the removal of its -// files. -type JobStorageInfo struct { - // ShamanCheckoutID is only set when the job was actually using Shaman storage. - ShamanCheckoutID string -} - -type Task struct { - Model - UUID string - - Name string - Type string - JobID uint - Job *Job - JobUUID string // Fetched by SQLC, handled by GORM in Task.AfterFind() - IndexInJob int - Priority int - Status api.TaskStatus - - // Which worker is/was working on this. - WorkerID *uint - Worker *Worker - WorkerUUID string // Fetched by SQLC, handled by GORM in Task.AfterFind() - LastTouchedAt time.Time // Should contain UTC timestamps. - - // Dependencies are tasks that need to be completed before this one can run. - Dependencies []*Task - - Commands Commands - Activity string -} - +// Commands is the schema used for (un)marshalling sqlc.Task.Commands. type Commands []Command type Command struct { @@ -123,15 +81,7 @@ func (js *StringStringMap) Scan(value interface{}) error { } // TaskFailure keeps track of which Worker failed which Task. -type TaskFailure struct { - // Don't include the standard Gorm ID, UpdatedAt, or DeletedAt fields, as they're useless here. - // Entries will never be updated, and should never be soft-deleted but just purged from existence. - CreatedAt time.Time - TaskID uint - Task *Task - WorkerID uint - Worker *Worker -} +type TaskFailure = sqlc.TaskFailure // StoreJob stores an AuthoredJob and its tasks, and saves it to the database. // The job will be in 'under construction' status. It is up to the caller to transition it to its desired initial status. @@ -334,7 +284,7 @@ func (db *DB) storeAuthoredJobTaks( func (db *DB) FetchJob(ctx context.Context, jobUUID string) (*Job, error) { queries := db.queries() - sqlcJob, err := queries.FetchJob(ctx, jobUUID) + job, err := queries.FetchJob(ctx, jobUUID) switch { case errors.Is(err, sql.ErrNoRows): return nil, ErrJobNotFound @@ -342,23 +292,22 @@ func (db *DB) FetchJob(ctx context.Context, jobUUID string) (*Job, error) { return nil, jobError(err, "fetching job") } - gormJob, err := convertSqlcJob(sqlcJob) - if err != nil { - return nil, err + return &job, nil +} + +// FetchJob fetches a single job by its database ID, without fetching its tasks. +func (db *DB) FetchJobByID(ctx context.Context, jobID int64) (*Job, error) { + queries := db.queries() + + job, err := queries.FetchJobByID(ctx, jobID) + switch { + case errors.Is(err, sql.ErrNoRows): + return nil, ErrJobNotFound + case err != nil: + return nil, jobError(err, "fetching job") } - if sqlcJob.WorkerTagID.Valid { - workerTag, err := fetchWorkerTagByID(ctx, queries, sqlcJob.WorkerTagID.Int64) - switch { - case errors.Is(err, sql.ErrNoRows): - return nil, ErrWorkerTagNotFound - case err != nil: - return nil, workerTagError(err, "fetching worker tag of job") - } - gormJob.WorkerTag = &workerTag - } - - return &gormJob, nil + return &job, nil } func (db *DB) FetchJobs(ctx context.Context) ([]*Job, error) { @@ -369,28 +318,13 @@ func (db *DB) FetchJobs(ctx context.Context) ([]*Job, error) { return nil, jobError(err, "fetching all jobs") } - gormJobs := make([]*Job, len(sqlcJobs)) - for index, sqlcJob := range sqlcJobs { - gormJob, err := convertSqlcJob(sqlcJob) - if err != nil { - return nil, err - } - - if sqlcJob.WorkerTagID.Valid { - workerTag, err := fetchWorkerTagByID(ctx, queries, sqlcJob.WorkerTagID.Int64) - switch { - case errors.Is(err, sql.ErrNoRows): - return nil, ErrWorkerTagNotFound - case err != nil: - return nil, workerTagError(err, "fetching worker tag of job") - } - gormJob.WorkerTag = &workerTag - } - - gormJobs[index] = &gormJob + // TODO: just return []Job instead of converting the array. + jobPointers := make([]*Job, len(sqlcJobs)) + for index := range sqlcJobs { + jobPointers[index] = &sqlcJobs[index] } - return gormJobs, nil + return jobPointers, nil } // FetchJobShamanCheckoutID fetches the job's Shaman Checkout ID. @@ -492,21 +426,18 @@ func (db *DB) FetchJobsDeletionRequested(ctx context.Context) ([]string, error) func (db *DB) FetchJobsInStatus(ctx context.Context, jobStatuses ...api.JobStatus) ([]*Job, error) { queries := db.queries() - sqlcJobs, err := queries.FetchJobsInStatus(ctx, jobStatuses) + jobs, err := queries.FetchJobsInStatus(ctx, jobStatuses) if err != nil { return nil, jobError(err, "fetching jobs in status %q", jobStatuses) } - var jobs []*Job - for index := range sqlcJobs { - job, err := convertSqlcJob(sqlcJobs[index]) - if err != nil { - return nil, jobError(err, "converting fetched jobs in status %q", jobStatuses) - } - jobs = append(jobs, &job) + // TODO: just return []Job instead of converting the array. + pointers := make([]*Job, len(jobs)) + for index := range jobs { + pointers[index] = &jobs[index] } - return jobs, nil + return pointers, nil } // SaveJobStatus saves the job's Status and Activity fields. @@ -552,7 +483,7 @@ func (db *DB) SaveJobStorageInfo(ctx context.Context, j *Job) error { params := sqlc.SaveJobStorageInfoParams{ ID: int64(j.ID), - StorageShamanCheckoutID: j.Storage.ShamanCheckoutID, + StorageShamanCheckoutID: j.StorageShamanCheckoutID, } err := queries.SaveJobStorageInfo(ctx, params) @@ -562,68 +493,21 @@ func (db *DB) SaveJobStorageInfo(ctx context.Context, j *Job) error { return nil } -func (db *DB) FetchTask(ctx context.Context, taskUUID string) (*Task, error) { +func (db *DB) FetchTask(ctx context.Context, taskUUID string) (TaskJobWorker, error) { queries := db.queries() taskRow, err := queries.FetchTask(ctx, taskUUID) if err != nil { - return nil, taskError(err, "fetching task %s", taskUUID) + return TaskJobWorker{}, taskError(err, "fetching task %s", taskUUID) } - return convertSqlTaskWithJobAndWorker(ctx, queries, taskRow.Task) -} - -// TODO: remove this code, and let the code that calls into the persistence -// service fetch the job/worker explicitly when needed. -func convertSqlTaskWithJobAndWorker( - ctx context.Context, - queries *sqlc.Queries, - task sqlc.Task, -) (*Task, error) { - var ( - gormJob Job - worker Worker - err error - ) - - // Fetch & convert the Job. - if task.JobID > 0 { - sqlcJob, err := queries.FetchJobByID(ctx, task.JobID) - if err != nil { - return nil, jobError(err, "fetching job of task %s", task.UUID) - } - - gormJob, err = convertSqlcJob(sqlcJob) - if err != nil { - return nil, jobError(err, "converting job of task %s", task.UUID) - } + taskJobWorker := TaskJobWorker{ + Task: taskRow.Task, + JobUUID: taskRow.JobUUID.String, + WorkerUUID: taskRow.WorkerUUID.String, } - // Fetch the Worker. - if task.WorkerID.Valid && task.WorkerID.Int64 > 0 { - worker, err = queries.FetchWorkerUnconditionalByID(ctx, task.WorkerID.Int64) - if err != nil { - return nil, taskError(err, "fetching worker assigned to task %s", task.UUID) - } - } - - // Convert the Task. - gormTask, err := convertSqlcTask(task, gormJob.UUID, worker.UUID) - if err != nil { - return nil, err - } - - // Put the Job & Worker into the Task. - if gormJob.ID > 0 { - gormTask.Job = &gormJob - gormTask.JobUUID = gormJob.UUID - } - if worker.ID > 0 { - gormTask.Worker = &worker - gormTask.WorkerUUID = worker.UUID - } - - return gormTask, nil + return taskJobWorker, nil } // FetchTaskJobUUID fetches the job UUID of the given task. @@ -655,32 +539,16 @@ func (db *DB) SaveTask(ctx context.Context, t *Task) error { } param := sqlc.UpdateTaskParams{ - UpdatedAt: db.nowNullable(), - Name: t.Name, - Type: t.Type, - Priority: int64(t.Priority), - Status: t.Status, - Commands: commandsJSON, - Activity: t.Activity, - ID: int64(t.ID), - } - if t.WorkerID != nil { - param.WorkerID = sql.NullInt64{ - Int64: int64(*t.WorkerID), - Valid: true, - } - } else if t.Worker != nil && t.Worker.ID > 0 { - param.WorkerID = sql.NullInt64{ - Int64: int64(t.Worker.ID), - Valid: true, - } - } - - if !t.LastTouchedAt.IsZero() { - param.LastTouchedAt = sql.NullTime{ - Time: t.LastTouchedAt, - Valid: true, - } + UpdatedAt: db.nowNullable(), + Name: t.Name, + Type: t.Type, + Priority: t.Priority, + Status: t.Status, + Commands: commandsJSON, + Activity: t.Activity, + ID: t.ID, + WorkerID: t.WorkerID, + LastTouchedAt: t.LastTouchedAt, } err = queries.UpdateTask(ctx, param) @@ -724,26 +592,20 @@ func (db *DB) SaveTaskActivity(ctx context.Context, t *Task) error { func (db *DB) TaskAssignToWorker(ctx context.Context, t *Task, w *Worker) error { queries := db.queries() + t.WorkerID = sql.NullInt64{Int64: w.ID, Valid: true} + err := queries.TaskAssignToWorker(ctx, sqlc.TaskAssignToWorkerParams{ UpdatedAt: db.nowNullable(), - WorkerID: sql.NullInt64{ - Int64: int64(w.ID), - Valid: true, - }, - ID: int64(t.ID), + WorkerID: t.WorkerID, + ID: t.ID, }) if err != nil { return taskError(err, "assigning task %s to worker %s", t.UUID, w.UUID) } - - // Update the task itself. - t.Worker = w - t.WorkerID = ptr(uint(w.ID)) - return nil } -func (db *DB) FetchTasksOfWorkerInStatus(ctx context.Context, worker *Worker, taskStatus api.TaskStatus) ([]*Task, error) { +func (db *DB) FetchTasksOfWorkerInStatus(ctx context.Context, worker *Worker, taskStatus api.TaskStatus) ([]TaskJob, error) { queries := db.queries() rows, err := queries.FetchTasksOfWorkerInStatus(ctx, sqlc.FetchTasksOfWorkerInStatusParams{ @@ -757,38 +619,15 @@ func (db *DB) FetchTasksOfWorkerInStatus(ctx context.Context, worker *Worker, ta return nil, taskError(err, "finding tasks of worker %s in status %q", worker.UUID, taskStatus) } - jobCache := make(map[uint]*Job) - - result := make([]*Task, len(rows)) + result := make([]TaskJob, len(rows)) for i := range rows { - jobUUID := rows[i].JobUUID.String - gormTask, err := convertSqlcTask(rows[i].Task, jobUUID, worker.UUID) - if err != nil { - return nil, err - } - gormTask.Worker = worker - gormTask.WorkerID = ptr(uint(worker.ID)) - - // Fetch the job, either from the cache or from the database. This is done - // here because the task_state_machine functionality expects that task.Job - // is set. - // TODO: make that code fetch the job details it needs, rather than fetching - // the entire job here. - job := jobCache[gormTask.JobID] - if job == nil { - job, err = db.FetchJob(ctx, jobUUID) - if err != nil { - return nil, jobError(err, "finding job %s of task %s", jobUUID, gormTask.UUID) - } - } - gormTask.Job = job - - result[i] = gormTask + result[i].Task = rows[i].Task + result[i].JobUUID = rows[i].JobUUID } return result, nil } -func (db *DB) FetchTasksOfWorkerInStatusOfJob(ctx context.Context, worker *Worker, taskStatus api.TaskStatus, job *Job) ([]*Task, error) { +func (db *DB) FetchTasksOfWorkerInStatusOfJob(ctx context.Context, worker *Worker, taskStatus api.TaskStatus, jobUUID string) ([]*Task, error) { queries := db.queries() rows, err := queries.FetchTasksOfWorkerInStatusOfJob(ctx, sqlc.FetchTasksOfWorkerInStatusOfJobParams{ @@ -796,24 +635,17 @@ func (db *DB) FetchTasksOfWorkerInStatusOfJob(ctx context.Context, worker *Worke Int64: int64(worker.ID), Valid: true, }, - JobID: int64(job.ID), + JobUUID: jobUUID, TaskStatus: taskStatus, }) if err != nil { - return nil, taskError(err, "finding tasks of worker %s in status %q and job %s", worker.UUID, taskStatus, job.UUID) + return nil, taskError(err, "finding tasks of worker %s in status %q and job %s", worker.UUID, taskStatus, jobUUID) } + // TODO: just return []Task instead of creating an array of pointers. result := make([]*Task, len(rows)) for i := range rows { - gormTask, err := convertSqlcTask(rows[i].Task, job.UUID, worker.UUID) - if err != nil { - return nil, err - } - gormTask.Job = job - gormTask.JobID = job.ID - gormTask.Worker = worker - gormTask.WorkerID = ptr(uint(worker.ID)) - result[i] = gormTask + result[i] = &rows[i].Task } return result, nil } @@ -865,7 +697,7 @@ func (db *DB) CountTasksOfJobInStatus( } // FetchTaskIDsOfJob returns all tasks of the given job. -func (db *DB) FetchTasksOfJob(ctx context.Context, job *Job) ([]*Task, error) { +func (db *DB) FetchTasksOfJob(ctx context.Context, job *Job) ([]TaskJobWorker, error) { queries := db.queries() rows, err := queries.FetchTasksOfJob(ctx, int64(job.ID)) @@ -873,20 +705,17 @@ func (db *DB) FetchTasksOfJob(ctx context.Context, job *Job) ([]*Task, error) { return nil, taskError(err, "fetching tasks of job %s", job.UUID) } - result := make([]*Task, len(rows)) + result := make([]TaskJobWorker, len(rows)) for i := range rows { - gormTask, err := convertSqlcTask(rows[i].Task, job.UUID, rows[i].WorkerUUID.String) - if err != nil { - return nil, err - } - gormTask.Job = job - result[i] = gormTask + result[i].Task = rows[i].Task + result[i].JobUUID = job.UUID + result[i].WorkerUUID = rows[i].WorkerUUID.String } return result, nil } // FetchTasksOfJobInStatus returns those tasks of the given job that have any of the given statuses. -func (db *DB) FetchTasksOfJobInStatus(ctx context.Context, job *Job, taskStatuses ...api.TaskStatus) ([]*Task, error) { +func (db *DB) FetchTasksOfJobInStatus(ctx context.Context, job *Job, taskStatuses ...api.TaskStatus) ([]TaskJobWorker, error) { queries := db.queries() rows, err := queries.FetchTasksOfJobInStatus(ctx, sqlc.FetchTasksOfJobInStatusParams{ @@ -897,14 +726,11 @@ func (db *DB) FetchTasksOfJobInStatus(ctx context.Context, job *Job, taskStatuse return nil, taskError(err, "fetching tasks of job %s in status %q", job.UUID, taskStatuses) } - result := make([]*Task, len(rows)) + result := make([]TaskJobWorker, len(rows)) for i := range rows { - gormTask, err := convertSqlcTask(rows[i].Task, job.UUID, rows[i].WorkerUUID.String) - if err != nil { - return nil, err - } - gormTask.Job = job - result[i] = gormTask + result[i].Task = rows[i].Task + result[i].JobUUID = job.UUID + result[i].WorkerUUID = rows[i].WorkerUUID.String } return result, nil } @@ -958,22 +784,18 @@ func (db *DB) UpdateJobsTaskStatusesConditional(ctx context.Context, job *Job, } // TaskTouchedByWorker marks the task as 'touched' by a worker. This is used for timeout detection. -func (db *DB) TaskTouchedByWorker(ctx context.Context, t *Task) error { +func (db *DB) TaskTouchedByWorker(ctx context.Context, taskUUID string) error { queries := db.queries() now := db.nowNullable() err := queries.TaskTouchedByWorker(ctx, sqlc.TaskTouchedByWorkerParams{ UpdatedAt: now, LastTouchedAt: now, - ID: int64(t.ID), + UUID: taskUUID, }) if err != nil { return taskError(err, "saving task 'last touched at'") } - - // Also update the given task, so that it's consistent with the database. - t.LastTouchedAt = now.Time - return nil } @@ -1040,82 +862,3 @@ func (db *DB) FetchTaskFailureList(ctx context.Context, t *Task) ([]*Worker, err } return workers, nil } - -// convertSqlcJob converts a job from the SQLC-generated model to the model -// expected by the rest of the code. This is mostly in place to aid in the GORM -// to SQLC migration. It is intended that eventually the rest of the code will -// use the same SQLC-generated model. -func convertSqlcJob(job sqlc.Job) (Job, error) { - dbJob := Job{ - Model: Model{ - ID: uint(job.ID), - CreatedAt: job.CreatedAt, - UpdatedAt: job.UpdatedAt.Time, - }, - UUID: job.UUID, - Name: job.Name, - JobType: job.JobType, - Priority: int(job.Priority), - Status: api.JobStatus(job.Status), - Activity: job.Activity, - DeleteRequestedAt: job.DeleteRequestedAt, - Storage: JobStorageInfo{ - ShamanCheckoutID: job.StorageShamanCheckoutID, - }, - } - - if err := json.Unmarshal(job.Settings, &dbJob.Settings); err != nil { - return Job{}, jobError(err, fmt.Sprintf("job %s has invalid settings: %v", job.UUID, err)) - } - - if err := json.Unmarshal(job.Metadata, &dbJob.Metadata); err != nil { - return Job{}, jobError(err, fmt.Sprintf("job %s has invalid metadata: %v", job.UUID, err)) - } - - if job.WorkerTagID.Valid { - workerTagID := uint(job.WorkerTagID.Int64) - dbJob.WorkerTagID = &workerTagID - } - - return dbJob, nil -} - -// convertSqlcTask converts a FetchTaskRow from the SQLC-generated model to the -// model expected by the rest of the code. This is mostly in place to aid in the -// GORM to SQLC migration. It is intended that eventually the rest of the code -// will use the same SQLC-generated model. -func convertSqlcTask(task sqlc.Task, jobUUID string, workerUUID string) (*Task, error) { - dbTask := Task{ - Model: Model{ - ID: uint(task.ID), - CreatedAt: task.CreatedAt, - UpdatedAt: task.UpdatedAt.Time, - }, - - UUID: task.UUID, - Name: task.Name, - Type: task.Type, - IndexInJob: int(task.IndexInJob), - Priority: int(task.Priority), - Status: api.TaskStatus(task.Status), - LastTouchedAt: task.LastTouchedAt.Time, - Activity: task.Activity, - - JobID: uint(task.JobID), - JobUUID: jobUUID, - WorkerUUID: workerUUID, - } - - // TODO: convert dependencies? - - if task.WorkerID.Valid { - workerID := uint(task.WorkerID.Int64) - dbTask.WorkerID = &workerID - } - - if err := json.Unmarshal(task.Commands, &dbTask.Commands); err != nil { - return nil, taskError(err, "task %s of job %s has invalid commands: %v", task.UUID, jobUUID, err) - } - - return &dbTask, nil -} diff --git a/internal/manager/persistence/jobs_blocklist.go b/internal/manager/persistence/jobs_blocklist.go index f149b97d..75fbe285 100644 --- a/internal/manager/persistence/jobs_blocklist.go +++ b/internal/manager/persistence/jobs_blocklist.go @@ -13,11 +13,11 @@ import ( type JobBlockListEntry = sqlc.FetchJobBlocklistRow // AddWorkerToJobBlocklist prevents this Worker of getting any task, of this type, on this job, from the task scheduler. -func (db *DB) AddWorkerToJobBlocklist(ctx context.Context, job *Job, worker *Worker, taskType string) error { - if job.ID == 0 { +func (db *DB) AddWorkerToJobBlocklist(ctx context.Context, jobID int64, workerID int64, taskType string) error { + if jobID == 0 { panic("Cannot add worker to job blocklist with zero job ID") } - if worker.ID == 0 { + if workerID == 0 { panic("Cannot add worker to job blocklist with zero worker ID") } if taskType == "" { @@ -28,8 +28,8 @@ func (db *DB) AddWorkerToJobBlocklist(ctx context.Context, job *Job, worker *Wor return queries.AddWorkerToJobBlocklist(ctx, sqlc.AddWorkerToJobBlocklistParams{ CreatedAt: db.nowNullable().Time, - JobID: int64(job.ID), - WorkerID: int64(worker.ID), + JobID: jobID, + WorkerID: workerID, TaskType: taskType, }) } @@ -72,18 +72,18 @@ func (db *DB) WorkersLeftToRun(ctx context.Context, job *Job, taskType string) ( workerUUIDs []string err error ) - if job.WorkerTagID == nil { - workerUUIDs, err = queries.WorkersLeftToRun(ctx, sqlc.WorkersLeftToRunParams{ - JobID: int64(job.ID), - TaskType: taskType, - }) - } else { + if job.WorkerTagID.Valid { workerUUIDs, err = queries.WorkersLeftToRunWithWorkerTag(ctx, sqlc.WorkersLeftToRunWithWorkerTagParams{ - JobID: int64(job.ID), + JobID: job.ID, TaskType: taskType, - WorkerTagID: int64(*job.WorkerTagID), + WorkerTagID: job.WorkerTagID.Int64, }) + } else { + workerUUIDs, err = queries.WorkersLeftToRun(ctx, sqlc.WorkersLeftToRunParams{ + JobID: job.ID, + TaskType: taskType, + }) } if err != nil { return nil, err @@ -99,13 +99,13 @@ func (db *DB) WorkersLeftToRun(ctx context.Context, job *Job, taskType string) ( } // CountTaskFailuresOfWorker returns the number of task failures of this worker, on this particular job and task type. -func (db *DB) CountTaskFailuresOfWorker(ctx context.Context, job *Job, worker *Worker, taskType string) (int, error) { +func (db *DB) CountTaskFailuresOfWorker(ctx context.Context, jobUUID string, workerID int64, taskType string) (int, error) { var numFailures int64 queries := db.queries() numFailures, err := queries.CountTaskFailuresOfWorker(ctx, sqlc.CountTaskFailuresOfWorkerParams{ - JobID: int64(job.ID), - WorkerID: int64(worker.ID), + JobUUID: jobUUID, + WorkerID: workerID, TaskType: taskType, }) diff --git a/internal/manager/persistence/jobs_blocklist_test.go b/internal/manager/persistence/jobs_blocklist_test.go index 35985933..3624ffed 100644 --- a/internal/manager/persistence/jobs_blocklist_test.go +++ b/internal/manager/persistence/jobs_blocklist_test.go @@ -18,7 +18,7 @@ func TestAddWorkerToJobBlocklist(t *testing.T) { { // Add a worker to the block list. - err := db.AddWorkerToJobBlocklist(ctx, job, worker, "blender") + err := db.AddWorkerToJobBlocklist(ctx, job.ID, worker.ID, "blender") require.NoError(t, err) list, err := queries.Test_FetchJobBlocklist(ctx) @@ -33,7 +33,7 @@ func TestAddWorkerToJobBlocklist(t *testing.T) { { // Adding the same worker again should be a no-op. - err := db.AddWorkerToJobBlocklist(ctx, job, worker, "blender") + err := db.AddWorkerToJobBlocklist(ctx, job.ID, worker.ID, "blender") require.NoError(t, err) list, err := queries.Test_FetchJobBlocklist(ctx) @@ -48,7 +48,7 @@ func TestFetchJobBlocklist(t *testing.T) { // Add a worker to the block list. worker := createWorker(ctx, t, db) - err := db.AddWorkerToJobBlocklist(ctx, job, worker, "blender") + err := db.AddWorkerToJobBlocklist(ctx, job.ID, worker.ID, "blender") require.NoError(t, err) list, err := db.FetchJobBlocklist(ctx, job.UUID) @@ -68,9 +68,9 @@ func TestClearJobBlocklist(t *testing.T) { // Add a worker and some entries to the block list. worker := createWorker(ctx, t, db) - err := db.AddWorkerToJobBlocklist(ctx, job, worker, "blender") + err := db.AddWorkerToJobBlocklist(ctx, job.ID, worker.ID, "blender") require.NoError(t, err) - err = db.AddWorkerToJobBlocklist(ctx, job, worker, "ffmpeg") + err = db.AddWorkerToJobBlocklist(ctx, job.ID, worker.ID, "ffmpeg") require.NoError(t, err) // Clear the blocklist. @@ -89,9 +89,9 @@ func TestRemoveFromJobBlocklist(t *testing.T) { // Add a worker and some entries to the block list. worker := createWorker(ctx, t, db) - err := db.AddWorkerToJobBlocklist(ctx, job, worker, "blender") + err := db.AddWorkerToJobBlocklist(ctx, job.ID, worker.ID, "blender") require.NoError(t, err) - err = db.AddWorkerToJobBlocklist(ctx, job, worker, "ffmpeg") + err = db.AddWorkerToJobBlocklist(ctx, job.ID, worker.ID, "ffmpeg") require.NoError(t, err) // Remove an entry. @@ -148,20 +148,20 @@ func TestWorkersLeftToRun(t *testing.T) { assert.Equal(t, uuidMap(worker1, worker2, workerC1), left) // Two workers, one blocked. - _ = db.AddWorkerToJobBlocklist(ctx, job, worker1, "blender") + _ = db.AddWorkerToJobBlocklist(ctx, job.ID, worker1.ID, "blender") left, err = db.WorkersLeftToRun(ctx, job, "blender") require.NoError(t, err) assert.Equal(t, uuidMap(worker2, workerC1), left) // All workers blocked. - _ = db.AddWorkerToJobBlocklist(ctx, job, worker2, "blender") - _ = db.AddWorkerToJobBlocklist(ctx, job, workerC1, "blender") + _ = db.AddWorkerToJobBlocklist(ctx, job.ID, worker2.ID, "blender") + _ = db.AddWorkerToJobBlocklist(ctx, job.ID, workerC1.ID, "blender") left, err = db.WorkersLeftToRun(ctx, job, "blender") require.NoError(t, err) assert.Empty(t, left) // Two workers, unknown job. - fakeJob := Job{Model: Model{ID: 327}} + fakeJob := Job{ID: 327} left, err = db.WorkersLeftToRun(ctx, &fakeJob, "blender") require.NoError(t, err) assert.Equal(t, uuidMap(worker1, worker2, workerC1), left) @@ -222,13 +222,13 @@ func TestWorkersLeftToRunWithTags(t *testing.T) { assert.Equal(t, uuidMap(workerC13, workerC1), left) // One worker blocked, one worker remain. - _ = db.AddWorkerToJobBlocklist(ctx, job, workerC1, "blender") + _ = db.AddWorkerToJobBlocklist(ctx, job.ID, workerC1.ID, "blender") left, err = db.WorkersLeftToRun(ctx, job, "blender") require.NoError(t, err) assert.Equal(t, uuidMap(workerC13), left) // All taged workers blocked. - _ = db.AddWorkerToJobBlocklist(ctx, job, workerC13, "blender") + _ = db.AddWorkerToJobBlocklist(ctx, job.ID, workerC13.ID, "blender") left, err = db.WorkersLeftToRun(ctx, job, "blender") require.NoError(t, err) assert.Empty(t, left) @@ -238,13 +238,17 @@ func TestCountTaskFailuresOfWorker(t *testing.T) { ctx, close, db, dbJob, authoredJob := jobTasksTestFixtures(t) defer close() - task0, err := db.FetchTask(ctx, authoredJob.Tasks[0].UUID) + taskJobWorker0, err := db.FetchTask(ctx, authoredJob.Tasks[0].UUID) require.NoError(t, err) - task1, err := db.FetchTask(ctx, authoredJob.Tasks[1].UUID) + taskJobWorker1, err := db.FetchTask(ctx, authoredJob.Tasks[1].UUID) require.NoError(t, err) - task2, err := db.FetchTask(ctx, authoredJob.Tasks[2].UUID) + taskJobWorker2, err := db.FetchTask(ctx, authoredJob.Tasks[2].UUID) require.NoError(t, err) + task0 := taskJobWorker0.Task + task1 := taskJobWorker1.Task + task2 := taskJobWorker2.Task + // Sanity check on the test data. assert.Equal(t, "blender", task0.Type) assert.Equal(t, "blender", task1.Type) @@ -254,28 +258,28 @@ func TestCountTaskFailuresOfWorker(t *testing.T) { worker2 := createWorkerFrom(ctx, t, db, *worker1) // Store some failures for different tasks - _, _ = db.AddWorkerToTaskFailedList(ctx, task0, worker1) - _, _ = db.AddWorkerToTaskFailedList(ctx, task1, worker1) - _, _ = db.AddWorkerToTaskFailedList(ctx, task1, worker2) - _, _ = db.AddWorkerToTaskFailedList(ctx, task2, worker1) + _, _ = db.AddWorkerToTaskFailedList(ctx, &task0, worker1) + _, _ = db.AddWorkerToTaskFailedList(ctx, &task1, worker1) + _, _ = db.AddWorkerToTaskFailedList(ctx, &task1, worker2) + _, _ = db.AddWorkerToTaskFailedList(ctx, &task2, worker1) // Multiple failures. - numBlender1, err := db.CountTaskFailuresOfWorker(ctx, dbJob, worker1, "blender") + numBlender1, err := db.CountTaskFailuresOfWorker(ctx, dbJob.UUID, worker1.ID, "blender") require.NoError(t, err) assert.Equal(t, 2, numBlender1) // Single failure, but multiple tasks exist of this type. - numBlender2, err := db.CountTaskFailuresOfWorker(ctx, dbJob, worker2, "blender") + numBlender2, err := db.CountTaskFailuresOfWorker(ctx, dbJob.UUID, worker2.ID, "blender") require.NoError(t, err) assert.Equal(t, 1, numBlender2) // Single failure, only one task of this type exists. - numFFMpeg1, err := db.CountTaskFailuresOfWorker(ctx, dbJob, worker1, "ffmpeg") + numFFMpeg1, err := db.CountTaskFailuresOfWorker(ctx, dbJob.UUID, worker1.ID, "ffmpeg") require.NoError(t, err) assert.Equal(t, 1, numFFMpeg1) // No failure. - numFFMpeg2, err := db.CountTaskFailuresOfWorker(ctx, dbJob, worker2, "ffmpeg") + numFFMpeg2, err := db.CountTaskFailuresOfWorker(ctx, dbJob.UUID, worker2.ID, "ffmpeg") require.NoError(t, err) assert.Equal(t, 0, numFFMpeg2) } diff --git a/internal/manager/persistence/jobs_query.go b/internal/manager/persistence/jobs_query.go index 133f3c92..75e33cfd 100644 --- a/internal/manager/persistence/jobs_query.go +++ b/internal/manager/persistence/jobs_query.go @@ -5,42 +5,39 @@ import ( "context" "github.com/rs/zerolog/log" + "projects.blender.org/studio/flamenco/internal/manager/persistence/sqlc" "projects.blender.org/studio/flamenco/pkg/api" ) +type TaskSummary = sqlc.QueryJobTaskSummariesRow + // QueryJobTaskSummaries retrieves all tasks of the job, but not all fields of those tasks. // Fields are synchronised with api.TaskSummary. -func (db *DB) QueryJobTaskSummaries(ctx context.Context, jobUUID string) ([]*Task, error) { +func (db *DB) QueryJobTaskSummaries(ctx context.Context, jobUUID string) ([]TaskSummary, error) { logger := log.Ctx(ctx) logger.Debug().Str("job", jobUUID).Msg("querying task summaries") queries := db.queries() - sqlcPartialTasks, err := queries.QueryJobTaskSummaries(ctx, jobUUID) + summaries, err := queries.QueryJobTaskSummaries(ctx, jobUUID) if err != nil { return nil, err } - // Convert to partial GORM tasks. - gormTasks := make([]*Task, len(sqlcPartialTasks)) - for index, task := range sqlcPartialTasks { - gormTask := Task{ - Model: Model{ - ID: uint(task.ID), - UpdatedAt: task.UpdatedAt.Time, - }, - + result := make([]TaskSummary, len(summaries)) + for index, task := range summaries { + result[index] = TaskSummary{ + ID: task.ID, + UpdatedAt: task.UpdatedAt, UUID: task.UUID, Name: task.Name, Type: task.Type, - IndexInJob: int(task.IndexInJob), - Priority: int(task.Priority), + IndexInJob: task.IndexInJob, + Priority: task.Priority, Status: api.TaskStatus(task.Status), - JobUUID: jobUUID, } - gormTasks[index] = &gormTask } - return gormTasks, nil + return result, nil } // JobStatusCount is a mapping from job status to the number of jobs in that status. diff --git a/internal/manager/persistence/jobs_query_test.go b/internal/manager/persistence/jobs_query_test.go index 753cc1b3..e7245e55 100644 --- a/internal/manager/persistence/jobs_query_test.go +++ b/internal/manager/persistence/jobs_query_test.go @@ -49,7 +49,7 @@ func TestQueryJobTaskSummaries(t *testing.T) { assert.Len(t, summaries, len(expectTaskUUIDs)) for index, summary := range summaries { assert.True(t, expectTaskUUIDs[summary.UUID], "%q should be in %v", summary.UUID, expectTaskUUIDs) - assert.Equal(t, index+1, summary.IndexInJob) + assert.Equal(t, int64(index+1), summary.IndexInJob) } } diff --git a/internal/manager/persistence/jobs_test.go b/internal/manager/persistence/jobs_test.go index 9dde4585..2c2711ef 100644 --- a/internal/manager/persistence/jobs_test.go +++ b/internal/manager/persistence/jobs_test.go @@ -4,6 +4,8 @@ package persistence // SPDX-License-Identifier: GPL-3.0-or-later import ( + "database/sql" + "encoding/json" "fmt" "math" "testing" @@ -36,11 +38,17 @@ func TestStoreAuthoredJob(t *testing.T) { assert.Equal(t, job.JobID, fetchedJob.UUID) assert.Equal(t, job.Name, fetchedJob.Name) assert.Equal(t, job.JobType, fetchedJob.JobType) - assert.Equal(t, job.Priority, fetchedJob.Priority) + assert.Equal(t, job.Priority, int(fetchedJob.Priority)) assert.Equal(t, api.JobStatusUnderConstruction, fetchedJob.Status) - assert.EqualValues(t, map[string]interface{}(job.Settings), fetchedJob.Settings) - assert.EqualValues(t, map[string]string(job.Metadata), fetchedJob.Metadata) - assert.Equal(t, "", fetchedJob.Storage.ShamanCheckoutID) + assert.Equal(t, "", fetchedJob.StorageShamanCheckoutID) + + var parsedSettings map[string]interface{} + assert.NoError(t, json.Unmarshal(fetchedJob.Settings, &parsedSettings)) + assert.EqualValues(t, map[string]interface{}(job.Settings), parsedSettings) + + var parsedMetadata map[string]string + assert.NoError(t, json.Unmarshal(fetchedJob.Metadata, &parsedMetadata)) + assert.EqualValues(t, map[string]string(job.Metadata), parsedMetadata) // Fetch result of job. result, err := queries.FetchTasksOfJob(ctx, int64(fetchedJob.ID)) @@ -70,7 +78,7 @@ func TestStoreAuthoredJobWithShamanCheckoutID(t *testing.T) { require.NoError(t, err) require.NotNil(t, fetchedJob) - assert.Equal(t, job.Storage.ShamanCheckoutID, fetchedJob.Storage.ShamanCheckoutID) + assert.Equal(t, job.Storage.ShamanCheckoutID, fetchedJob.StorageShamanCheckoutID) } func TestStoreAuthoredJobWithWorkerTag(t *testing.T) { @@ -97,12 +105,8 @@ func TestStoreAuthoredJobWithWorkerTag(t *testing.T) { require.NotNil(t, fetchedJob) require.NotNil(t, fetchedJob.WorkerTagID) - assert.Equal(t, int64(*fetchedJob.WorkerTagID), workerTag.ID) - - require.NotNil(t, fetchedJob.WorkerTag) - assert.Equal(t, fetchedJob.WorkerTag.Name, workerTag.Name) - assert.Equal(t, fetchedJob.WorkerTag.Description, workerTag.Description) - assert.Equal(t, fetchedJob.WorkerTag.UUID, workerTagUUID) + assert.Equal(t, fetchedJob.WorkerTagID.Int64, workerTag.ID) + assert.True(t, fetchedJob.WorkerTagID.Valid) } func TestFetchTaskJobUUID(t *testing.T) { @@ -135,20 +139,20 @@ func TestSaveJobStorageInfo(t *testing.T) { dbJob, err := db.FetchJob(ctx, authoredJob.JobID) require.NoError(t, err) assert.NotNil(t, dbJob) - assert.EqualValues(t, startTime, dbJob.UpdatedAt) + assert.EqualValues(t, startTime, dbJob.UpdatedAt.Time) // Move the clock forward. updateTime := time.Date(2023, time.February, 7, 15, 10, 0, 0, time.UTC) mockNow = updateTime // Save the storage info. - dbJob.Storage.ShamanCheckoutID = "shaman/checkout/id" + dbJob.StorageShamanCheckoutID = "shaman/checkout/id" require.NoError(t, db.SaveJobStorageInfo(ctx, dbJob)) // Check that the UpdatedAt field wasn't touched. updatedJob, err := db.FetchJob(ctx, authoredJob.JobID) require.NoError(t, err) - assert.Equal(t, startTime, updatedJob.UpdatedAt, "SaveJobStorageInfo should not touch UpdatedAt") + assert.Equal(t, startTime, updatedJob.UpdatedAt.Time, "SaveJobStorageInfo should not touch UpdatedAt") } func TestSaveJobPriority(t *testing.T) { @@ -161,7 +165,7 @@ func TestSaveJobPriority(t *testing.T) { require.NoError(t, err) // Set a new priority. - newPriority := 47 + newPriority := int64(47) dbJob, err := db.FetchJob(ctx, authoredJob.JobID) require.NoError(t, err) require.NotEqual(t, newPriority, dbJob.Priority, @@ -316,7 +320,7 @@ func TestRequestJobMassDeletion(t *testing.T) { // Request that "job3 and older" gets deleted. timeOfDeleteRequest := realNowFunc() db.nowfunc = func() time.Time { return timeOfDeleteRequest } - uuids, err := db.RequestJobMassDeletion(ctx, job3.UpdatedAt) + uuids, err := db.RequestJobMassDeletion(ctx, job3.UpdatedAt.Time) require.NoError(t, err) db.nowfunc = realNowFunc @@ -428,11 +432,11 @@ func TestCountTasksOfJobInStatus(t *testing.T) { assert.Equal(t, 3, numQueued) assert.Equal(t, 3, numTotal) - // Make one task failed. - task, err := db.FetchTask(ctx, authoredJob.Tasks[0].UUID) + // Make one taskJobWorker failed. + taskJobWorker, err := db.FetchTask(ctx, authoredJob.Tasks[0].UUID) require.NoError(t, err) - task.Status = api.TaskStatusFailed - require.NoError(t, db.SaveTask(ctx, task)) + taskJobWorker.Task.Status = api.TaskStatusFailed + require.NoError(t, db.SaveTask(ctx, &taskJobWorker.Task)) numQueued, numTotal, err = db.CountTasksOfJobInStatus(ctx, job, api.TaskStatusQueued) require.NoError(t, err) @@ -515,33 +519,33 @@ func TestFetchTasksOfJobInStatus(t *testing.T) { allTasks, err := db.FetchTasksOfJob(ctx, job) require.NoError(t, err) - assert.Equal(t, job, allTasks[0].Job, "FetchTasksOfJob should set job pointer") + assert.Equal(t, job.UUID, allTasks[0].JobUUID, "FetchTasksOfJob should set job UUID") - tasks, err := db.FetchTasksOfJobInStatus(ctx, job, api.TaskStatusQueued) + tasksJobsWorkers, err := db.FetchTasksOfJobInStatus(ctx, job, api.TaskStatusQueued) require.NoError(t, err) - assert.Equal(t, allTasks, tasks) - assert.Equal(t, job, tasks[0].Job, "FetchTasksOfJobInStatus should set job pointer") + assert.Equal(t, allTasks, tasksJobsWorkers) + assert.Equal(t, job.UUID, tasksJobsWorkers[0].JobUUID, "FetchTasksOfJobInStatus should set job UUID") - // Make one task failed. - task, err := db.FetchTask(ctx, authoredJob.Tasks[0].UUID) + // Make one taskJobWorker failed. + taskJobWorker, err := db.FetchTask(ctx, authoredJob.Tasks[0].UUID) require.NoError(t, err) - task.Status = api.TaskStatusFailed - require.NoError(t, db.SaveTask(ctx, task)) + taskJobWorker.Task.Status = api.TaskStatusFailed + require.NoError(t, db.SaveTask(ctx, &taskJobWorker.Task)) - tasks, err = db.FetchTasksOfJobInStatus(ctx, job, api.TaskStatusQueued) + tasksJobsWorkers, err = db.FetchTasksOfJobInStatus(ctx, job, api.TaskStatusQueued) require.NoError(t, err) - assert.Equal(t, []*Task{allTasks[1], allTasks[2]}, tasks) + assert.Equal(t, []TaskJobWorker{allTasks[1], allTasks[2]}, tasksJobsWorkers) // Check the failed task. This cannot directly compare to `allTasks[0]` // because saving the task above changed some of its fields. - tasks, err = db.FetchTasksOfJobInStatus(ctx, job, api.TaskStatusFailed) + tasksJobsWorkers, err = db.FetchTasksOfJobInStatus(ctx, job, api.TaskStatusFailed) require.NoError(t, err) - assert.Len(t, tasks, 1) - assert.Equal(t, allTasks[0].ID, tasks[0].ID) + assert.Len(t, tasksJobsWorkers, 1) + assert.Equal(t, allTasks[0].Task.ID, tasksJobsWorkers[0].Task.ID) - tasks, err = db.FetchTasksOfJobInStatus(ctx, job, api.TaskStatusActive) + tasksJobsWorkers, err = db.FetchTasksOfJobInStatus(ctx, job, api.TaskStatusActive) require.NoError(t, err) - assert.Empty(t, tasks) + assert.Empty(t, tasksJobsWorkers) } func TestSaveTaskActivity(t *testing.T) { @@ -549,16 +553,20 @@ func TestSaveTaskActivity(t *testing.T) { defer close() taskUUID := authoredJob.Tasks[0].UUID - task, err := db.FetchTask(ctx, taskUUID) + taskJobWorker, err := db.FetchTask(ctx, taskUUID) require.NoError(t, err) + + task := taskJobWorker.Task require.Equal(t, api.TaskStatusQueued, task.Status) task.Activity = "Somebody ran a ünit test" task.Status = api.TaskStatusPaused // Should not be saved. - require.NoError(t, db.SaveTaskActivity(ctx, task)) + require.NoError(t, db.SaveTaskActivity(ctx, &task)) - dbTask, err := db.FetchTask(ctx, taskUUID) + dbTaskJobWorker, err := db.FetchTask(ctx, taskUUID) require.NoError(t, err) + + dbTask := dbTaskJobWorker.Task require.Equal(t, "Somebody ran a ünit test", dbTask.Activity) require.Equal(t, api.TaskStatusQueued, dbTask.Status, "SaveTaskActivity() should not save the task status") @@ -568,44 +576,40 @@ func TestTaskAssignToWorker(t *testing.T) { ctx, close, db, _, authoredJob := jobTasksTestFixtures(t) defer close() - task, err := db.FetchTask(ctx, authoredJob.Tasks[1].UUID) + taskJobWorker, err := db.FetchTask(ctx, authoredJob.Tasks[1].UUID) require.NoError(t, err) + assert.Zero(t, taskJobWorker.WorkerUUID) + assert.Equal(t, authoredJob.JobID, taskJobWorker.JobUUID) w := createWorker(ctx, t, db) - require.NoError(t, db.TaskAssignToWorker(ctx, task, w)) - - if task.Worker == nil { - t.Error("task.Worker == nil") - } else { - assert.Equal(t, w, task.Worker) - } - if task.WorkerID == nil { - t.Error("task.WorkerID == nil") - } else { - assert.Equal(t, w.ID, int64(*task.WorkerID)) - } + require.NoError(t, db.TaskAssignToWorker(ctx, &taskJobWorker.Task, w)) + assert.Equal(t, + sql.NullInt64{Int64: w.ID, Valid: true}, + taskJobWorker.Task.WorkerID) } func TestFetchTasksOfWorkerInStatus(t *testing.T) { ctx, close, db, _, authoredJob := jobTasksTestFixtures(t) defer close() - task, err := db.FetchTask(ctx, authoredJob.Tasks[1].UUID) + taskJobWorker, err := db.FetchTask(ctx, authoredJob.Tasks[1].UUID) + task := taskJobWorker.Task require.NoError(t, err) w := createWorker(ctx, t, db) - require.NoError(t, db.TaskAssignToWorker(ctx, task, w)) + require.NoError(t, db.TaskAssignToWorker(ctx, &task, w)) - tasks, err := db.FetchTasksOfWorkerInStatus(ctx, w, task.Status) + tasksJobsWorkers, err := db.FetchTasksOfWorkerInStatus(ctx, w, task.Status) require.NoError(t, err) - assert.Len(t, tasks, 1, "worker should have one task in status %q", task.Status) - assert.Equal(t, task.ID, tasks[0].ID) - assert.Equal(t, task.UUID, tasks[0].UUID) + assert.Len(t, tasksJobsWorkers, 1, "worker should have one task in status %q", task.Status) + assert.Equal(t, task.ID, tasksJobsWorkers[0].Task.ID) + assert.Equal(t, task.UUID, tasksJobsWorkers[0].Task.UUID) + assert.Equal(t, authoredJob.JobID, tasksJobsWorkers[0].JobUUID) - assert.NotEqual(t, api.TaskStatusCanceled, task.Status) - tasks, err = db.FetchTasksOfWorkerInStatus(ctx, w, api.TaskStatusCanceled) + require.NotEqual(t, api.TaskStatusCanceled, task.Status) + tasksJobsWorkers, err = db.FetchTasksOfWorkerInStatus(ctx, w, api.TaskStatusCanceled) require.NoError(t, err) - assert.Empty(t, tasks, "worker should have no task in status %q", w) + assert.Empty(t, tasksJobsWorkers, "worker should have no task in status %q", api.TaskStatusCanceled) } func TestFetchTasksOfWorkerInStatusOfJob(t *testing.T) { @@ -630,42 +634,42 @@ func TestFetchTasksOfWorkerInStatusOfJob(t *testing.T) { // Assign a task from each job to each Worker. // Also double-check the test precondition that all tasks have the same status. { // Job / Worker. - task1, err := db.FetchTask(ctx, authoredJob.Tasks[1].UUID) + taskJobWorker1, err := db.FetchTask(ctx, authoredJob.Tasks[1].UUID) require.NoError(t, err) - require.NoError(t, db.TaskAssignToWorker(ctx, task1, worker)) - require.Equal(t, task1.Status, api.TaskStatusQueued) + require.NoError(t, db.TaskAssignToWorker(ctx, &taskJobWorker1.Task, worker)) + require.Equal(t, taskJobWorker1.Task.Status, api.TaskStatusQueued) - task2, err := db.FetchTask(ctx, authoredJob.Tasks[0].UUID) + taskJobWorker2, err := db.FetchTask(ctx, authoredJob.Tasks[0].UUID) require.NoError(t, err) - require.NoError(t, db.TaskAssignToWorker(ctx, task2, worker)) - require.Equal(t, task2.Status, api.TaskStatusQueued) + require.NoError(t, db.TaskAssignToWorker(ctx, &taskJobWorker2.Task, worker)) + require.Equal(t, taskJobWorker2.Task.Status, api.TaskStatusQueued) } { // Job / Other Worker. - task, err := db.FetchTask(ctx, authoredJob.Tasks[2].UUID) + taskJobWorker, err := db.FetchTask(ctx, authoredJob.Tasks[2].UUID) require.NoError(t, err) - require.NoError(t, db.TaskAssignToWorker(ctx, task, otherWorker)) - require.Equal(t, task.Status, api.TaskStatusQueued) + require.NoError(t, db.TaskAssignToWorker(ctx, &taskJobWorker.Task, otherWorker)) + require.Equal(t, taskJobWorker.Task.Status, api.TaskStatusQueued) } { // Other Job / Worker. - task, err := db.FetchTask(ctx, otherJob.Tasks[1].UUID) + taskJobWorker, err := db.FetchTask(ctx, otherJob.Tasks[1].UUID) require.NoError(t, err) - require.NoError(t, db.TaskAssignToWorker(ctx, task, worker)) - require.Equal(t, task.Status, api.TaskStatusQueued) + require.NoError(t, db.TaskAssignToWorker(ctx, &taskJobWorker.Task, worker)) + require.Equal(t, taskJobWorker.Task.Status, api.TaskStatusQueued) } { // Other Job / Other Worker. - task, err := db.FetchTask(ctx, otherJob.Tasks[2].UUID) + taskJobWorker, err := db.FetchTask(ctx, otherJob.Tasks[2].UUID) require.NoError(t, err) - require.NoError(t, db.TaskAssignToWorker(ctx, task, otherWorker)) - require.Equal(t, task.Status, api.TaskStatusQueued) + require.NoError(t, db.TaskAssignToWorker(ctx, &taskJobWorker.Task, otherWorker)) + require.Equal(t, taskJobWorker.Task.Status, api.TaskStatusQueued) } { // Test active tasks, should be none. - tasks, err := db.FetchTasksOfWorkerInStatusOfJob(ctx, worker, api.TaskStatusActive, dbJob) + tasks, err := db.FetchTasksOfWorkerInStatusOfJob(ctx, worker, api.TaskStatusActive, dbJob.UUID) require.NoError(t, err) require.Len(t, tasks, 0) } { // Test queued tasks, should be two. - tasks, err := db.FetchTasksOfWorkerInStatusOfJob(ctx, worker, api.TaskStatusQueued, dbJob) + tasks, err := db.FetchTasksOfWorkerInStatusOfJob(ctx, worker, api.TaskStatusQueued, dbJob.UUID) require.NoError(t, err) require.Len(t, tasks, 2) assert.Equal(t, authoredJob.Tasks[0].UUID, tasks[0].UUID) @@ -675,7 +679,7 @@ func TestFetchTasksOfWorkerInStatusOfJob(t *testing.T) { worker := createWorker(ctx, t, db, func(worker *Worker) { worker.UUID = "6534a1d4-f58e-4f2c-8925-4b2cd6caac22" }) - tasks, err := db.FetchTasksOfWorkerInStatusOfJob(ctx, worker, api.TaskStatusQueued, dbJob) + tasks, err := db.FetchTasksOfWorkerInStatusOfJob(ctx, worker, api.TaskStatusQueued, dbJob.UUID) require.NoError(t, err) require.Len(t, tasks, 0) } @@ -685,27 +689,29 @@ func TestTaskTouchedByWorker(t *testing.T) { ctx, close, db, _, authoredJob := jobTasksTestFixtures(t) defer close() - task, err := db.FetchTask(ctx, authoredJob.Tasks[1].UUID) + taskJobWorker, err := db.FetchTask(ctx, authoredJob.Tasks[1].UUID) + task := taskJobWorker.Task require.NoError(t, err) - assert.True(t, task.LastTouchedAt.IsZero()) + assert.Zero(t, task.LastTouchedAt) now := db.now() - err = db.TaskTouchedByWorker(ctx, task) + err = db.TaskTouchedByWorker(ctx, task.UUID) require.NoError(t, err) - // Test the task instance as well as the database entry. - dbTask, err := db.FetchTask(ctx, task.UUID) + // Test the task as it is in the database. + dbTaskJobWorker, err := db.FetchTask(ctx, task.UUID) require.NoError(t, err) - assert.WithinDuration(t, now, task.LastTouchedAt, time.Second) - assert.WithinDuration(t, now, dbTask.LastTouchedAt, time.Second) + assert.True(t, dbTaskJobWorker.Task.LastTouchedAt.Valid) + assert.WithinDuration(t, now, dbTaskJobWorker.Task.LastTouchedAt.Time, time.Second) } func TestAddWorkerToTaskFailedList(t *testing.T) { ctx, close, db, _, authoredJob := jobTasksTestFixtures(t) defer close() - task, err := db.FetchTask(ctx, authoredJob.Tasks[1].UUID) + taskJobWorker, err := db.FetchTask(ctx, authoredJob.Tasks[1].UUID) require.NoError(t, err) + task := taskJobWorker.Task worker1 := createWorker(ctx, t, db) @@ -719,17 +725,17 @@ func TestAddWorkerToTaskFailedList(t *testing.T) { require.NoError(t, err) // First failure should be registered just fine. - numFailed, err := db.AddWorkerToTaskFailedList(ctx, task, worker1) + numFailed, err := db.AddWorkerToTaskFailedList(ctx, &task, worker1) require.NoError(t, err) assert.Equal(t, 1, numFailed) // Calling again should be a no-op and not cause any errors. - numFailed, err = db.AddWorkerToTaskFailedList(ctx, task, worker1) + numFailed, err = db.AddWorkerToTaskFailedList(ctx, &task, worker1) require.NoError(t, err) assert.Equal(t, 1, numFailed) // Another worker should be able to fail this task as well. - numFailed, err = db.AddWorkerToTaskFailedList(ctx, task, worker2) + numFailed, err = db.AddWorkerToTaskFailedList(ctx, &task, worker2) require.NoError(t, err) assert.Equal(t, 2, numFailed) @@ -743,8 +749,10 @@ func TestClearFailureListOfTask(t *testing.T) { defer close() queries := db.queries() - task1, _ := db.FetchTask(ctx, authoredJob.Tasks[1].UUID) - task2, _ := db.FetchTask(ctx, authoredJob.Tasks[2].UUID) + taskJobWorker1, _ := db.FetchTask(ctx, authoredJob.Tasks[1].UUID) + taskJobWorker2, _ := db.FetchTask(ctx, authoredJob.Tasks[2].UUID) + task1 := taskJobWorker1.Task + task2 := taskJobWorker2.Task worker1 := createWorker(ctx, t, db) @@ -758,12 +766,12 @@ func TestClearFailureListOfTask(t *testing.T) { require.NoError(t, err) // Store some failures for different tasks. - _, _ = db.AddWorkerToTaskFailedList(ctx, task1, worker1) - _, _ = db.AddWorkerToTaskFailedList(ctx, task1, worker2) - _, _ = db.AddWorkerToTaskFailedList(ctx, task2, worker1) + _, _ = db.AddWorkerToTaskFailedList(ctx, &task1, worker1) + _, _ = db.AddWorkerToTaskFailedList(ctx, &task1, worker2) + _, _ = db.AddWorkerToTaskFailedList(ctx, &task2, worker1) // Clearing should just update this one task. - require.NoError(t, db.ClearFailureListOfTask(ctx, task1)) + require.NoError(t, db.ClearFailureListOfTask(ctx, &task1)) failures, err := queries.Test_FetchTaskFailures(ctx) require.NoError(t, err) if assert.Len(t, failures, 1) { @@ -781,19 +789,22 @@ func TestClearFailureListOfJob(t *testing.T) { authoredJob2 := duplicateJobAndTasks(authoredJob1) persistAuthoredJob(t, ctx, db, authoredJob2) - task1_1, _ := db.FetchTask(ctx, authoredJob1.Tasks[1].UUID) - task1_2, _ := db.FetchTask(ctx, authoredJob1.Tasks[2].UUID) - task2_1, _ := db.FetchTask(ctx, authoredJob2.Tasks[1].UUID) + taskJobWorker1_1, _ := db.FetchTask(ctx, authoredJob1.Tasks[1].UUID) + taskJobWorker1_2, _ := db.FetchTask(ctx, authoredJob1.Tasks[2].UUID) + taskJobWorker2_1, _ := db.FetchTask(ctx, authoredJob2.Tasks[1].UUID) + task1_1 := taskJobWorker1_1.Task + task1_2 := taskJobWorker1_2.Task + task2_1 := taskJobWorker2_1.Task worker1 := createWorker(ctx, t, db) worker2 := createWorkerFrom(ctx, t, db, *worker1) // Store some failures for different tasks and jobs - _, _ = db.AddWorkerToTaskFailedList(ctx, task1_1, worker1) - _, _ = db.AddWorkerToTaskFailedList(ctx, task1_1, worker2) - _, _ = db.AddWorkerToTaskFailedList(ctx, task1_2, worker1) - _, _ = db.AddWorkerToTaskFailedList(ctx, task2_1, worker1) - _, _ = db.AddWorkerToTaskFailedList(ctx, task2_1, worker2) + _, _ = db.AddWorkerToTaskFailedList(ctx, &task1_1, worker1) + _, _ = db.AddWorkerToTaskFailedList(ctx, &task1_1, worker2) + _, _ = db.AddWorkerToTaskFailedList(ctx, &task1_2, worker1) + _, _ = db.AddWorkerToTaskFailedList(ctx, &task2_1, worker1) + _, _ = db.AddWorkerToTaskFailedList(ctx, &task2_1, worker2) // Sanity check: there should be 5 failures registered now. assert.Equal(t, 5, countTaskFailures(ctx, db)) @@ -815,16 +826,18 @@ func TestFetchTaskFailureList(t *testing.T) { defer close() // Test with non-existing task. - fakeTask := Task{Model: Model{ID: 327}} + fakeTask := Task{ID: 327} failures, err := db.FetchTaskFailureList(ctx, &fakeTask) require.NoError(t, err) assert.Empty(t, failures) - task1_1, _ := db.FetchTask(ctx, authoredJob1.Tasks[1].UUID) - task1_2, _ := db.FetchTask(ctx, authoredJob1.Tasks[2].UUID) + taskJobWorker1_1, _ := db.FetchTask(ctx, authoredJob1.Tasks[1].UUID) + taskJobWorker1_2, _ := db.FetchTask(ctx, authoredJob1.Tasks[2].UUID) + task1_1 := taskJobWorker1_1.Task + task1_2 := taskJobWorker1_2.Task // Test without failures. - failures, err = db.FetchTaskFailureList(ctx, task1_1) + failures, err = db.FetchTaskFailureList(ctx, &task1_1) require.NoError(t, err) assert.Empty(t, failures) @@ -832,12 +845,12 @@ func TestFetchTaskFailureList(t *testing.T) { worker2 := createWorkerFrom(ctx, t, db, *worker1) // Store some failures for different tasks and jobs - _, _ = db.AddWorkerToTaskFailedList(ctx, task1_1, worker1) - _, _ = db.AddWorkerToTaskFailedList(ctx, task1_1, worker2) - _, _ = db.AddWorkerToTaskFailedList(ctx, task1_2, worker1) + _, _ = db.AddWorkerToTaskFailedList(ctx, &task1_1, worker1) + _, _ = db.AddWorkerToTaskFailedList(ctx, &task1_1, worker2) + _, _ = db.AddWorkerToTaskFailedList(ctx, &task1_2, worker1) // Fetch one task's failure list. - failures, err = db.FetchTaskFailureList(ctx, task1_1) + failures, err = db.FetchTaskFailureList(ctx, &task1_1) require.NoError(t, err) if assert.Len(t, failures, 2) { diff --git a/internal/manager/persistence/last_rendered.go b/internal/manager/persistence/last_rendered.go index f837ed2a..ccc36f46 100644 --- a/internal/manager/persistence/last_rendered.go +++ b/internal/manager/persistence/last_rendered.go @@ -15,14 +15,19 @@ import ( // This is used to show the global last-rendered image in the web interface. // SetLastRendered sets this job as the one with the most recent rendered image. -func (db *DB) SetLastRendered(ctx context.Context, j *Job) error { +func (db *DB) SetLastRendered(ctx context.Context, jobUUID string) error { queries := db.queries() + jobID, err := queries.FetchJobIDFromUUID(ctx, jobUUID) + if err != nil { + return jobError(err, "finding job with UUID %q", jobUUID) + } + now := db.nowNullable() return queries.SetLastRendered(ctx, sqlc.SetLastRenderedParams{ CreatedAt: now.Time, UpdatedAt: now, - JobID: int64(j.ID), + JobID: jobID, }) } diff --git a/internal/manager/persistence/last_rendered_test.go b/internal/manager/persistence/last_rendered_test.go index 7b997f20..b9647827 100644 --- a/internal/manager/persistence/last_rendered_test.go +++ b/internal/manager/persistence/last_rendered_test.go @@ -17,7 +17,7 @@ func TestSetLastRendered(t *testing.T) { authoredJob2 := authorTestJob("1295757b-e668-4c49-8b89-f73db8270e42", "just-a-job") job2 := persistAuthoredJob(t, ctx, db, authoredJob2) - require.NoError(t, db.SetLastRendered(ctx, job1)) + require.NoError(t, db.SetLastRendered(ctx, job1.UUID)) { entries, err := queries.Test_FetchLastRendered(ctx) require.NoError(t, err) @@ -26,7 +26,7 @@ func TestSetLastRendered(t *testing.T) { } } - require.NoError(t, db.SetLastRendered(ctx, job2)) + require.NoError(t, db.SetLastRendered(ctx, job2.UUID)) { entries, err := queries.Test_FetchLastRendered(ctx) require.NoError(t, err) @@ -49,7 +49,7 @@ func TestGetLastRenderedJobUUID(t *testing.T) { { // Test with first render. - require.NoError(t, db.SetLastRendered(ctx, job1)) + require.NoError(t, db.SetLastRendered(ctx, job1.UUID)) lastUUID, err := db.GetLastRenderedJobUUID(ctx) require.NoError(t, err) assert.Equal(t, job1.UUID, lastUUID) @@ -60,7 +60,7 @@ func TestGetLastRenderedJobUUID(t *testing.T) { authoredJob2 := authorTestJob("1295757b-e668-4c49-8b89-f73db8270e42", "just-a-job") job2 := persistAuthoredJob(t, ctx, db, authoredJob2) - require.NoError(t, db.SetLastRendered(ctx, job2)) + require.NoError(t, db.SetLastRendered(ctx, job2.UUID)) lastUUID, err := db.GetLastRenderedJobUUID(ctx) require.NoError(t, err) assert.Equal(t, job2.UUID, lastUUID) diff --git a/internal/manager/persistence/sqlc/methods.go b/internal/manager/persistence/sqlc/methods.go index 445af587..7b461737 100644 --- a/internal/manager/persistence/sqlc/methods.go +++ b/internal/manager/persistence/sqlc/methods.go @@ -20,6 +20,11 @@ func (ss *SleepSchedule) SetNextCheck(nextCheck time.Time) { } } +// DeleteRequested returns whether deletion of this job was requested. +func (j *Job) DeleteRequested() bool { + return j.DeleteRequestedAt.Valid +} + func (w *Worker) Identifier() string { // Avoid a panic when worker.Identifier() is called on a nil pointer. if w == nil { diff --git a/internal/manager/persistence/sqlc/query_jobs.sql b/internal/manager/persistence/sqlc/query_jobs.sql index 968af2e1..4e89e838 100644 --- a/internal/manager/persistence/sqlc/query_jobs.sql +++ b/internal/manager/persistence/sqlc/query_jobs.sql @@ -117,18 +117,19 @@ LEFT JOIN workers ON (tasks.worker_id = workers.id) WHERE tasks.uuid = @uuid; -- name: FetchTasksOfWorkerInStatus :many -SELECT sqlc.embed(tasks), jobs.UUID as jobUUID +SELECT sqlc.embed(tasks), jobs.uuid as jobuuid FROM tasks -LEFT JOIN jobs ON (tasks.job_id = jobs.id) +INNER JOIN jobs ON (tasks.job_id = jobs.id) WHERE tasks.worker_id = @worker_id AND tasks.status = @task_status; -- name: FetchTasksOfWorkerInStatusOfJob :many SELECT sqlc.embed(tasks) FROM tasks +LEFT JOIN jobs ON (tasks.job_id = jobs.id) WHERE tasks.worker_id = @worker_id - AND tasks.job_id = @job_id - AND tasks.status = @task_status; + AND tasks.status = @task_status + AND jobs.uuid = @jobuuid; -- name: FetchTasksOfJob :many SELECT sqlc.embed(tasks), workers.UUID as workerUUID @@ -199,7 +200,7 @@ WHERE id=@id; UPDATE tasks SET updated_at = @updated_at, last_touched_at = @last_touched_at -WHERE id=@id; +WHERE uuid=@uuid; -- name: JobCountTasksInStatus :one -- Fetch number of tasks in the given status, of the given job. @@ -235,6 +236,13 @@ SELECT sqlc.embed(workers) FROM workers INNER JOIN task_failures TF on TF.worker_id=workers.id WHERE TF.task_id=@task_id; +-- name: FetchJobIDFromUUID :one +-- Fetch the job's database ID by its UUID. +-- +-- This query is here to keep the SetLastRendered query below simpler, +-- mostly because that query is alread hitting a limitation of sqlc. +SELECT id FROM jobs WHERE uuid=@jobuuid; + -- name: SetLastRendered :exec -- Set the 'last rendered' job info. -- @@ -309,9 +317,10 @@ AND WTM.worker_tag_id = @worker_tag_id; -- name: CountTaskFailuresOfWorker :one SELECT count(TF.task_id) FROM task_failures TF INNER JOIN tasks T ON TF.task_id = T.id +INNER JOIN jobs J ON T.job_id = J.id WHERE TF.worker_id = @worker_id -AND T.job_id = @job_id +AND J.uuid = @jobuuid AND T.type = @task_type; @@ -326,11 +335,17 @@ SELECT status, count(id) as status_count FROM jobs GROUP BY status; -- name: FetchTimedOutTasks :many -SELECT * +SELECT sqlc.embed(tasks), + -- Cast to remove nullability from the generated structs. + CAST(jobs.uuid AS VARCHAR(36)) as jobuuid, + CAST(workers.name AS VARCHAR(64)) as worker_name, + CAST(workers.uuid AS VARCHAR(36)) as workeruuid FROM tasks +LEFT JOIN jobs ON jobs.id = tasks.job_id +LEFT JOIN workers ON workers.id = tasks.worker_id WHERE - status = @task_status -AND last_touched_at <= @untouched_since; + tasks.status = @task_status +AND tasks.last_touched_at <= @untouched_since; -- name: Test_CountJobs :one -- Count the number of jobs in the database. Only used in unit tests. diff --git a/internal/manager/persistence/sqlc/query_jobs.sql.go b/internal/manager/persistence/sqlc/query_jobs.sql.go index 8960398e..87ef4770 100644 --- a/internal/manager/persistence/sqlc/query_jobs.sql.go +++ b/internal/manager/persistence/sqlc/query_jobs.sql.go @@ -89,20 +89,21 @@ func (q *Queries) ClearJobBlocklist(ctx context.Context, jobuuid string) error { const countTaskFailuresOfWorker = `-- name: CountTaskFailuresOfWorker :one SELECT count(TF.task_id) FROM task_failures TF INNER JOIN tasks T ON TF.task_id = T.id +INNER JOIN jobs J ON T.job_id = J.id WHERE TF.worker_id = ?1 -AND T.job_id = ?2 +AND J.uuid = ?2 AND T.type = ?3 ` type CountTaskFailuresOfWorkerParams struct { WorkerID int64 - JobID int64 + JobUUID string TaskType string } func (q *Queries) CountTaskFailuresOfWorker(ctx context.Context, arg CountTaskFailuresOfWorkerParams) (int64, error) { - row := q.db.QueryRowContext(ctx, countTaskFailuresOfWorker, arg.WorkerID, arg.JobID, arg.TaskType) + row := q.db.QueryRowContext(ctx, countTaskFailuresOfWorker, arg.WorkerID, arg.JobUUID, arg.TaskType) var count int64 err := row.Scan(&count) return count, err @@ -352,6 +353,21 @@ func (q *Queries) FetchJobByID(ctx context.Context, id int64) (Job, error) { return i, err } +const fetchJobIDFromUUID = `-- name: FetchJobIDFromUUID :one +SELECT id FROM jobs WHERE uuid=?1 +` + +// Fetch the job's database ID by its UUID. +// +// This query is here to keep the SetLastRendered query below simpler, +// mostly because that query is alread hitting a limitation of sqlc. +func (q *Queries) FetchJobIDFromUUID(ctx context.Context, jobuuid string) (int64, error) { + row := q.db.QueryRowContext(ctx, fetchJobIDFromUUID, jobuuid) + var id int64 + err := row.Scan(&id) + return id, err +} + const fetchJobShamanCheckoutID = `-- name: FetchJobShamanCheckoutID :one SELECT storage_shaman_checkout_id FROM jobs WHERE uuid=?1 ` @@ -736,9 +752,9 @@ func (q *Queries) FetchTasksOfJobInStatus(ctx context.Context, arg FetchTasksOfJ } const fetchTasksOfWorkerInStatus = `-- name: FetchTasksOfWorkerInStatus :many -SELECT tasks.id, tasks.created_at, tasks.updated_at, tasks.uuid, tasks.name, tasks.type, tasks.job_id, tasks.index_in_job, tasks.priority, tasks.status, tasks.worker_id, tasks.last_touched_at, tasks.commands, tasks.activity, jobs.UUID as jobUUID +SELECT tasks.id, tasks.created_at, tasks.updated_at, tasks.uuid, tasks.name, tasks.type, tasks.job_id, tasks.index_in_job, tasks.priority, tasks.status, tasks.worker_id, tasks.last_touched_at, tasks.commands, tasks.activity, jobs.uuid as jobuuid FROM tasks -LEFT JOIN jobs ON (tasks.job_id = jobs.id) +INNER JOIN jobs ON (tasks.job_id = jobs.id) WHERE tasks.worker_id = ?1 AND tasks.status = ?2 ` @@ -750,7 +766,7 @@ type FetchTasksOfWorkerInStatusParams struct { type FetchTasksOfWorkerInStatusRow struct { Task Task - JobUUID sql.NullString + JobUUID string } func (q *Queries) FetchTasksOfWorkerInStatus(ctx context.Context, arg FetchTasksOfWorkerInStatusParams) ([]FetchTasksOfWorkerInStatusRow, error) { @@ -795,15 +811,16 @@ func (q *Queries) FetchTasksOfWorkerInStatus(ctx context.Context, arg FetchTasks const fetchTasksOfWorkerInStatusOfJob = `-- name: FetchTasksOfWorkerInStatusOfJob :many SELECT tasks.id, tasks.created_at, tasks.updated_at, tasks.uuid, tasks.name, tasks.type, tasks.job_id, tasks.index_in_job, tasks.priority, tasks.status, tasks.worker_id, tasks.last_touched_at, tasks.commands, tasks.activity FROM tasks +LEFT JOIN jobs ON (tasks.job_id = jobs.id) WHERE tasks.worker_id = ?1 - AND tasks.job_id = ?2 - AND tasks.status = ?3 + AND tasks.status = ?2 + AND jobs.uuid = ?3 ` type FetchTasksOfWorkerInStatusOfJobParams struct { WorkerID sql.NullInt64 - JobID int64 TaskStatus api.TaskStatus + JobUUID string } type FetchTasksOfWorkerInStatusOfJobRow struct { @@ -811,7 +828,7 @@ type FetchTasksOfWorkerInStatusOfJobRow struct { } func (q *Queries) FetchTasksOfWorkerInStatusOfJob(ctx context.Context, arg FetchTasksOfWorkerInStatusOfJobParams) ([]FetchTasksOfWorkerInStatusOfJobRow, error) { - rows, err := q.db.QueryContext(ctx, fetchTasksOfWorkerInStatusOfJob, arg.WorkerID, arg.JobID, arg.TaskStatus) + rows, err := q.db.QueryContext(ctx, fetchTasksOfWorkerInStatusOfJob, arg.WorkerID, arg.TaskStatus, arg.JobUUID) if err != nil { return nil, err } @@ -849,11 +866,17 @@ func (q *Queries) FetchTasksOfWorkerInStatusOfJob(ctx context.Context, arg Fetch } const fetchTimedOutTasks = `-- name: FetchTimedOutTasks :many -SELECT id, created_at, updated_at, uuid, name, type, job_id, index_in_job, priority, status, worker_id, last_touched_at, commands, activity +SELECT tasks.id, tasks.created_at, tasks.updated_at, tasks.uuid, tasks.name, tasks.type, tasks.job_id, tasks.index_in_job, tasks.priority, tasks.status, tasks.worker_id, tasks.last_touched_at, tasks.commands, tasks.activity, + -- Cast to remove nullability from the generated structs. + CAST(jobs.uuid AS VARCHAR(36)) as jobuuid, + CAST(workers.name AS VARCHAR(64)) as worker_name, + CAST(workers.uuid AS VARCHAR(36)) as workeruuid FROM tasks +LEFT JOIN jobs ON jobs.id = tasks.job_id +LEFT JOIN workers ON workers.id = tasks.worker_id WHERE - status = ?1 -AND last_touched_at <= ?2 + tasks.status = ?1 +AND tasks.last_touched_at <= ?2 ` type FetchTimedOutTasksParams struct { @@ -861,30 +884,40 @@ type FetchTimedOutTasksParams struct { UntouchedSince sql.NullTime } -func (q *Queries) FetchTimedOutTasks(ctx context.Context, arg FetchTimedOutTasksParams) ([]Task, error) { +type FetchTimedOutTasksRow struct { + Task Task + JobUUID string + WorkerName string + WorkerUUID string +} + +func (q *Queries) FetchTimedOutTasks(ctx context.Context, arg FetchTimedOutTasksParams) ([]FetchTimedOutTasksRow, error) { rows, err := q.db.QueryContext(ctx, fetchTimedOutTasks, arg.TaskStatus, arg.UntouchedSince) if err != nil { return nil, err } defer rows.Close() - var items []Task + var items []FetchTimedOutTasksRow for rows.Next() { - var i Task + var i FetchTimedOutTasksRow if err := rows.Scan( - &i.ID, - &i.CreatedAt, - &i.UpdatedAt, - &i.UUID, - &i.Name, - &i.Type, - &i.JobID, - &i.IndexInJob, - &i.Priority, - &i.Status, - &i.WorkerID, - &i.LastTouchedAt, - &i.Commands, - &i.Activity, + &i.Task.ID, + &i.Task.CreatedAt, + &i.Task.UpdatedAt, + &i.Task.UUID, + &i.Task.Name, + &i.Task.Type, + &i.Task.JobID, + &i.Task.IndexInJob, + &i.Task.Priority, + &i.Task.Status, + &i.Task.WorkerID, + &i.Task.LastTouchedAt, + &i.Task.Commands, + &i.Task.Activity, + &i.JobUUID, + &i.WorkerName, + &i.WorkerUUID, ); err != nil { return nil, err } @@ -1221,17 +1254,17 @@ const taskTouchedByWorker = `-- name: TaskTouchedByWorker :exec UPDATE tasks SET updated_at = ?1, last_touched_at = ?2 -WHERE id=?3 +WHERE uuid=?3 ` type TaskTouchedByWorkerParams struct { UpdatedAt sql.NullTime LastTouchedAt sql.NullTime - ID int64 + UUID string } func (q *Queries) TaskTouchedByWorker(ctx context.Context, arg TaskTouchedByWorkerParams) error { - _, err := q.db.ExecContext(ctx, taskTouchedByWorker, arg.UpdatedAt, arg.LastTouchedAt, arg.ID) + _, err := q.db.ExecContext(ctx, taskTouchedByWorker, arg.UpdatedAt, arg.LastTouchedAt, arg.UUID) return err } diff --git a/internal/manager/persistence/sqlc/query_task_scheduler.sql b/internal/manager/persistence/sqlc/query_task_scheduler.sql index f88d7f3c..f80b7e91 100644 --- a/internal/manager/persistence/sqlc/query_task_scheduler.sql +++ b/internal/manager/persistence/sqlc/query_task_scheduler.sql @@ -1,7 +1,7 @@ -- name: FetchAssignedAndRunnableTaskOfWorker :one -- Fetch a task that's assigned to this worker, and is in a runnable state. -SELECT sqlc.embed(tasks) +SELECT sqlc.embed(tasks), jobs.uuid as jobuuid, jobs.priority as job_priority, jobs.job_type as job_type FROM tasks INNER JOIN jobs ON tasks.job_id = jobs.id WHERE tasks.status=@active_task_status @@ -20,7 +20,7 @@ LIMIT 1; -- -- The order in the WHERE clause is important, slices should come last. See -- https://github.com/sqlc-dev/sqlc/issues/2452 for more info. -SELECT sqlc.embed(tasks) +SELECT sqlc.embed(tasks), jobs.uuid as jobuuid, jobs.priority as job_priority, jobs.job_type as job_type FROM tasks INNER JOIN jobs ON tasks.job_id = jobs.id LEFT JOIN task_failures TF ON tasks.id = TF.task_id AND TF.worker_id=@worker_id @@ -51,8 +51,8 @@ ORDER BY jobs.priority DESC, tasks.priority DESC; -- Find the currently-active task assigned to a Worker. If not found, find the last task this Worker worked on. SELECT sqlc.embed(tasks), - sqlc.embed(jobs), - (tasks.status = @task_status_active AND jobs.status = @job_status_active) as is_active + jobs.uuid as jobuuid, + CAST(tasks.status = @task_status_active AND jobs.status = @job_status_active AS BOOLEAN) as is_active FROM tasks INNER JOIN jobs ON tasks.job_id = jobs.id WHERE diff --git a/internal/manager/persistence/sqlc/query_task_scheduler.sql.go b/internal/manager/persistence/sqlc/query_task_scheduler.sql.go index 7573ecb7..f83b5757 100644 --- a/internal/manager/persistence/sqlc/query_task_scheduler.sql.go +++ b/internal/manager/persistence/sqlc/query_task_scheduler.sql.go @@ -31,7 +31,7 @@ func (q *Queries) AssignTaskToWorker(ctx context.Context, arg AssignTaskToWorker } const fetchAssignedAndRunnableTaskOfWorker = `-- name: FetchAssignedAndRunnableTaskOfWorker :one -SELECT tasks.id, tasks.created_at, tasks.updated_at, tasks.uuid, tasks.name, tasks.type, tasks.job_id, tasks.index_in_job, tasks.priority, tasks.status, tasks.worker_id, tasks.last_touched_at, tasks.commands, tasks.activity +SELECT tasks.id, tasks.created_at, tasks.updated_at, tasks.uuid, tasks.name, tasks.type, tasks.job_id, tasks.index_in_job, tasks.priority, tasks.status, tasks.worker_id, tasks.last_touched_at, tasks.commands, tasks.activity, jobs.uuid as jobuuid, jobs.priority as job_priority, jobs.job_type as job_type FROM tasks INNER JOIN jobs ON tasks.job_id = jobs.id WHERE tasks.status=?1 @@ -47,7 +47,10 @@ type FetchAssignedAndRunnableTaskOfWorkerParams struct { } type FetchAssignedAndRunnableTaskOfWorkerRow struct { - Task Task + Task Task + JobUUID string + JobPriority int64 + JobType string } // Fetch a task that's assigned to this worker, and is in a runnable state. @@ -81,6 +84,9 @@ func (q *Queries) FetchAssignedAndRunnableTaskOfWorker(ctx context.Context, arg &i.Task.LastTouchedAt, &i.Task.Commands, &i.Task.Activity, + &i.JobUUID, + &i.JobPriority, + &i.JobType, ) return i, err } @@ -88,8 +94,8 @@ func (q *Queries) FetchAssignedAndRunnableTaskOfWorker(ctx context.Context, arg const fetchWorkerTask = `-- name: FetchWorkerTask :one SELECT tasks.id, tasks.created_at, tasks.updated_at, tasks.uuid, tasks.name, tasks.type, tasks.job_id, tasks.index_in_job, tasks.priority, tasks.status, tasks.worker_id, tasks.last_touched_at, tasks.commands, tasks.activity, - jobs.id, jobs.created_at, jobs.updated_at, jobs.uuid, jobs.name, jobs.job_type, jobs.priority, jobs.status, jobs.activity, jobs.settings, jobs.metadata, jobs.delete_requested_at, jobs.storage_shaman_checkout_id, jobs.worker_tag_id, - (tasks.status = ?1 AND jobs.status = ?2) as is_active + jobs.uuid as jobuuid, + CAST(tasks.status = ?1 AND jobs.status = ?2 AS BOOLEAN) as is_active FROM tasks INNER JOIN jobs ON tasks.job_id = jobs.id WHERE @@ -108,8 +114,8 @@ type FetchWorkerTaskParams struct { type FetchWorkerTaskRow struct { Task Task - Job Job - IsActive interface{} + JobUUID string + IsActive bool } // Find the currently-active task assigned to a Worker. If not found, find the last task this Worker worked on. @@ -131,27 +137,14 @@ func (q *Queries) FetchWorkerTask(ctx context.Context, arg FetchWorkerTaskParams &i.Task.LastTouchedAt, &i.Task.Commands, &i.Task.Activity, - &i.Job.ID, - &i.Job.CreatedAt, - &i.Job.UpdatedAt, - &i.Job.UUID, - &i.Job.Name, - &i.Job.JobType, - &i.Job.Priority, - &i.Job.Status, - &i.Job.Activity, - &i.Job.Settings, - &i.Job.Metadata, - &i.Job.DeleteRequestedAt, - &i.Job.StorageShamanCheckoutID, - &i.Job.WorkerTagID, + &i.JobUUID, &i.IsActive, ) return i, err } const findRunnableTask = `-- name: FindRunnableTask :one -SELECT tasks.id, tasks.created_at, tasks.updated_at, tasks.uuid, tasks.name, tasks.type, tasks.job_id, tasks.index_in_job, tasks.priority, tasks.status, tasks.worker_id, tasks.last_touched_at, tasks.commands, tasks.activity +SELECT tasks.id, tasks.created_at, tasks.updated_at, tasks.uuid, tasks.name, tasks.type, tasks.job_id, tasks.index_in_job, tasks.priority, tasks.status, tasks.worker_id, tasks.last_touched_at, tasks.commands, tasks.activity, jobs.uuid as jobuuid, jobs.priority as job_priority, jobs.job_type as job_type FROM tasks INNER JOIN jobs ON tasks.job_id = jobs.id LEFT JOIN task_failures TF ON tasks.id = TF.task_id AND TF.worker_id=?1 @@ -189,7 +182,10 @@ type FindRunnableTaskParams struct { } type FindRunnableTaskRow struct { - Task Task + Task Task + JobUUID string + JobPriority int64 + JobType string } // Find a task to be run by a worker. This is the core of the task scheduler. @@ -255,6 +251,9 @@ func (q *Queries) FindRunnableTask(ctx context.Context, arg FindRunnableTaskPara &i.Task.LastTouchedAt, &i.Task.Commands, &i.Task.Activity, + &i.JobUUID, + &i.JobPriority, + &i.JobType, ) return i, err } diff --git a/internal/manager/persistence/task_scheduler.go b/internal/manager/persistence/task_scheduler.go index 39903b0f..5179b65e 100644 --- a/internal/manager/persistence/task_scheduler.go +++ b/internal/manager/persistence/task_scheduler.go @@ -22,10 +22,22 @@ var ( // completedTaskStatuses = []api.TaskStatus{api.TaskStatusCompleted} ) +// ScheduledTask contains a Task and some info about its job. +// +// This structure is returned from different points in the code below, and +// filled from different sqlc-generated structs. That's why it has to be an +// explicit struct here, rather than an alias for some sqlc struct. +type ScheduledTask struct { + Task Task + JobUUID string + JobPriority int64 + JobType string +} + // ScheduleTask finds a task to execute by the given worker. // If no task is available, (nil, nil) is returned, as this is not an error situation. // NOTE: this does not also fetch returnedTask.Worker, but returnedTask.WorkerID is set. -func (db *DB) ScheduleTask(ctx context.Context, w *Worker) (*Task, error) { +func (db *DB) ScheduleTask(ctx context.Context, w *Worker) (*ScheduledTask, error) { logger := log.With().Str("worker", w.UUID).Logger() logger.Trace().Msg("finding task for worker") @@ -41,32 +53,27 @@ func (db *DB) ScheduleTask(ctx context.Context, w *Worker) (*Task, error) { defer qtx.rollback() - task, err := db.scheduleTask(ctx, qtx.queries, w, logger) + scheduledTask, err := db.scheduleTask(ctx, qtx.queries, w, logger) if err != nil { return nil, err } - if task == nil { + if scheduledTask == nil { // No task means no changes to the database. // It's fine to just roll back the transaction. return nil, nil } - gormTask, err := convertSqlTaskWithJobAndWorker(ctx, qtx.queries, *task) - if err != nil { - return nil, err - } - if err := qtx.commit(); err != nil { return nil, fmt.Errorf( "could not commit database transaction after scheduling task %s for worker %s: %w", - task.UUID, w.UUID, err) + scheduledTask.Task.UUID, w.UUID, err) } - return gormTask, nil + return scheduledTask, nil } -func (db *DB) scheduleTask(ctx context.Context, queries *sqlc.Queries, w *Worker, logger zerolog.Logger) (*sqlc.Task, error) { +func (db *DB) scheduleTask(ctx context.Context, queries *sqlc.Queries, w *Worker, logger zerolog.Logger) (*ScheduledTask, error) { if w.ID == 0 { panic("worker should be in database, but has zero ID") } @@ -76,11 +83,12 @@ func (db *DB) scheduleTask(ctx context.Context, queries *sqlc.Queries, w *Worker // Note that this task type could be blocklisted or no longer supported by the // Worker, but since it's active that is unlikely. { - row, err := queries.FetchAssignedAndRunnableTaskOfWorker(ctx, sqlc.FetchAssignedAndRunnableTaskOfWorkerParams{ - ActiveTaskStatus: api.TaskStatusActive, - ActiveJobStatuses: schedulableJobStatuses, - WorkerID: workerID, - }) + row, err := queries.FetchAssignedAndRunnableTaskOfWorker( + ctx, sqlc.FetchAssignedAndRunnableTaskOfWorkerParams{ + ActiveTaskStatus: api.TaskStatusActive, + ActiveJobStatuses: schedulableJobStatuses, + WorkerID: workerID, + }) switch { case errors.Is(err, sql.ErrNoRows): @@ -88,11 +96,18 @@ func (db *DB) scheduleTask(ctx context.Context, queries *sqlc.Queries, w *Worker case err != nil: return nil, err case row.Task.ID > 0: - return &row.Task, nil + // Task was previously assigned, just go for it again. + scheduledTask := ScheduledTask{ + Task: row.Task, + JobUUID: row.JobUUID, + JobPriority: row.JobPriority, + JobType: row.JobType, + } + return &scheduledTask, nil } } - task, err := findTaskForWorker(ctx, queries, w) + scheduledTask, err := findTaskForWorker(ctx, queries, w) switch { case errors.Is(err, sql.ErrNoRows): @@ -107,10 +122,11 @@ func (db *DB) scheduleTask(ctx context.Context, queries *sqlc.Queries, w *Worker } // Assign the task to the worker. + assignmentTimestamp := db.nowNullable() err = queries.AssignTaskToWorker(ctx, sqlc.AssignTaskToWorkerParams{ WorkerID: workerID, - Now: db.nowNullable(), - TaskID: task.ID, + TaskID: scheduledTask.Task.ID, + Now: assignmentTimestamp, }) switch { @@ -119,32 +135,33 @@ func (db *DB) scheduleTask(ctx context.Context, queries *sqlc.Queries, w *Worker return nil, errDatabaseBusy case err != nil: logger.Warn(). - Str("taskID", task.UUID). + Str("taskID", scheduledTask.Task.UUID). Err(err). Msg("assigning task to worker") return nil, fmt.Errorf("assigning task to worker: %w", err) } // Make sure the returned task matches the database. - task.WorkerID = workerID + scheduledTask.Task.WorkerID = workerID + scheduledTask.Task.UpdatedAt = assignmentTimestamp logger.Info(). - Str("taskID", task.UUID). + Str("taskID", scheduledTask.Task.UUID). Msg("assigned task to worker") - return &task, nil + return scheduledTask, nil } func findTaskForWorker( ctx context.Context, queries *sqlc.Queries, w *Worker, -) (sqlc.Task, error) { +) (*ScheduledTask, error) { // Construct the list of worker tag IDs to check. tags, err := queries.FetchTagsOfWorker(ctx, w.UUID) if err != nil { - return sqlc.Task{}, err + return nil, err } workerTags := make([]sql.NullInt64, len(tags)) for index, tag := range tags { @@ -160,10 +177,17 @@ func findTaskForWorker( WorkerTags: workerTags, }) if err != nil { - return sqlc.Task{}, err + return nil, err } if row.Task.ID == 0 { - return sqlc.Task{}, nil + return nil, nil } - return row.Task, nil + + scheduledTask := ScheduledTask{ + Task: row.Task, + JobUUID: row.JobUUID, + JobPriority: row.JobPriority, + JobType: row.JobType, + } + return &scheduledTask, nil } diff --git a/internal/manager/persistence/task_scheduler_test.go b/internal/manager/persistence/task_scheduler_test.go index 3e1b4035..04f0d9f3 100644 --- a/internal/manager/persistence/task_scheduler_test.go +++ b/internal/manager/persistence/task_scheduler_test.go @@ -17,7 +17,7 @@ import ( "projects.blender.org/studio/flamenco/pkg/api" ) -const schedulerTestTimeout = 100 * time.Millisecond +const schedulerTestTimeout = 100 * time.Hour const schedulerTestTimeoutlong = 5000 * time.Millisecond func TestNoTasks(t *testing.T) { @@ -26,8 +26,8 @@ func TestNoTasks(t *testing.T) { w := linuxWorker(t, db) - task, err := db.ScheduleTask(ctx, &w) - assert.Nil(t, task) + scheduledTask, err := db.ScheduleTask(ctx, &w) + assert.Nil(t, scheduledTask) require.NoError(t, err) } @@ -41,24 +41,25 @@ func TestOneJobOneTask(t *testing.T) { atj := authorTestJob("b6a1d859-122f-4791-8b78-b943329a9989", "simple-blender-render", authTask) job := constructTestJob(ctx, t, db, atj) - task, err := db.ScheduleTask(ctx, &w) + scheduledTask, err := db.ScheduleTask(ctx, &w) require.NoError(t, err) // Check the returned task. - require.NotNil(t, task) - assert.Equal(t, job.ID, task.JobID) - require.NotNil(t, task.WorkerID, "no worker assigned to returned task") - assert.Equal(t, w.ID, int64(*task.WorkerID), "task must be assigned to the requesting worker") + require.NotNil(t, scheduledTask) + assert.Equal(t, job.ID, scheduledTask.Task.JobID) + require.True(t, scheduledTask.Task.WorkerID.Valid, "no worker assigned to returned task") + assert.Equal(t, w.ID, scheduledTask.Task.WorkerID.Int64, "task must be assigned to the requesting worker") // Check the task in the database. now := db.now() - dbTask, err := db.FetchTask(context.Background(), authTask.UUID) + taskJobWorker, err := db.FetchTask(context.Background(), authTask.UUID) require.NoError(t, err) - require.NotNil(t, dbTask) - require.NotNil(t, dbTask.WorkerID, "no worker assigned to task in database") + require.NotNil(t, taskJobWorker) + require.True(t, taskJobWorker.Task.WorkerID.Valid, "no worker assigned to task in database") + require.NotZero(t, taskJobWorker.WorkerUUID, "no worker fetched from database even though assigned assigned to task") - assert.Equal(t, w.ID, int64(*dbTask.WorkerID), "task must be assigned to the requesting worker") - assert.WithinDuration(t, now, dbTask.LastTouchedAt, time.Second, "task must be 'touched' by the worker after scheduling") + assert.Equal(t, w.ID, taskJobWorker.Task.WorkerID.Int64, "task must be assigned to the requesting worker") + assert.WithinDuration(t, now, taskJobWorker.Task.LastTouchedAt.Time, time.Second, "task must be 'touched' by the worker after scheduling") } func TestOneJobThreeTasksByPrio(t *testing.T) { @@ -78,14 +79,14 @@ func TestOneJobThreeTasksByPrio(t *testing.T) { job := constructTestJob(ctx, t, db, atj) - task, err := db.ScheduleTask(ctx, &w) + scheduledTask, err := db.ScheduleTask(ctx, &w) require.NoError(t, err) - require.NotNil(t, task) + require.NotNil(t, scheduledTask) - assert.Equal(t, job.ID, task.JobID) - assert.NotNil(t, task.Job) + assert.Equal(t, job.ID, scheduledTask.Task.JobID) + assert.Equal(t, job.UUID, scheduledTask.JobUUID) - assert.Equal(t, att2.Name, task.Name, "the high-prio task should have been chosen") + assert.Equal(t, att2.Name, scheduledTask.Task.Name, "the high-prio task should have been chosen") } func TestOneJobThreeTasksByDependencies(t *testing.T) { @@ -105,11 +106,12 @@ func TestOneJobThreeTasksByDependencies(t *testing.T) { att1, att2, att3) job := constructTestJob(ctx, t, db, atj) - task, err := db.ScheduleTask(ctx, &w) + scheduledTask, err := db.ScheduleTask(ctx, &w) require.NoError(t, err) - require.NotNil(t, task) - assert.Equal(t, job.ID, task.JobID) - assert.Equal(t, att1.Name, task.Name, "the first task should have been chosen") + require.NotNil(t, scheduledTask) + assert.Equal(t, job.ID, scheduledTask.Task.JobID) + assert.Equal(t, job.UUID, scheduledTask.JobUUID) + assert.Equal(t, att1.Name, scheduledTask.Task.Name, "the first task should have been chosen") } func TestTwoJobsThreeTasks(t *testing.T) { @@ -143,11 +145,11 @@ func TestTwoJobsThreeTasks(t *testing.T) { constructTestJob(ctx, t, db, atj1) job2 := constructTestJob(ctx, t, db, atj2) - task, err := db.ScheduleTask(ctx, &w) + scheduledTask, err := db.ScheduleTask(ctx, &w) require.NoError(t, err) - require.NotNil(t, task) - assert.Equal(t, job2.ID, task.JobID) - assert.Equal(t, att2_3.Name, task.Name, "the 3rd task of the 2nd job should have been chosen") + require.NotNil(t, scheduledTask) + assert.Equal(t, job2.ID, scheduledTask.Task.JobID) + assert.Equal(t, att2_3.Name, scheduledTask.Task.Name, "the 3rd task of the 2nd job should have been chosen") } // TestFanOutFanIn tests one starting task, then multiple tasks that depend on @@ -190,14 +192,14 @@ func TestFanOutFanIn(t *testing.T) { // Check the order in which tasks are handed out. executionOrder := []string{} // Slice of task names. for index := range 6 { - task, err := db.ScheduleTask(ctx, &w) + scheduledTask, err := db.ScheduleTask(ctx, &w) require.NoError(t, err) - require.NotNil(t, task, "task #%d is nil", index) - executionOrder = append(executionOrder, task.Name) + require.NotNil(t, scheduledTask, "task #%d is nil", index) + executionOrder = append(executionOrder, scheduledTask.Task.Name) // Fake that the task has been completed by the worker. - task.Status = api.TaskStatusCompleted - require.NoError(t, db.SaveTaskStatus(ctx, task)) + scheduledTask.Task.Status = api.TaskStatusCompleted + require.NoError(t, db.SaveTaskStatus(ctx, &scheduledTask.Task)) } expectedOrder := []string{ @@ -230,10 +232,10 @@ func TestSomeButNotAllDependenciesCompleted(t *testing.T) { setTaskStatus(t, db, att1.UUID, api.TaskStatusCompleted) w := linuxWorker(t, db) - task, err := db.ScheduleTask(ctx, &w) + scheduledTask, err := db.ScheduleTask(ctx, &w) require.NoError(t, err) - if task != nil { - t.Fatalf("there should not be any task assigned, but received %q", task.Name) + if scheduledTask != nil { + t.Fatalf("there should not be any task assigned, but received %q", scheduledTask.Task.Name) } } @@ -259,16 +261,16 @@ func TestAlreadyAssigned(t *testing.T) { // another, higher-prio task to be done. dbTask3, err := db.FetchTask(ctx, att3.UUID) require.NoError(t, err) - dbTask3.WorkerID = ptr(uint(w.ID)) - dbTask3.Status = api.TaskStatusActive - err = db.SaveTask(ctx, dbTask3) + dbTask3.Task.WorkerID = sql.NullInt64{Int64: w.ID, Valid: true} + dbTask3.Task.Status = api.TaskStatusActive + err = db.SaveTask(ctx, &dbTask3.Task) require.NoError(t, err) - task, err := db.ScheduleTask(ctx, &w) + scheduledTask, err := db.ScheduleTask(ctx, &w) require.NoError(t, err) - require.NotNil(t, task) + require.NotNil(t, scheduledTask) - assert.Equal(t, att3.Name, task.Name, "the already-assigned task should have been chosen") + assert.Equal(t, att3.Name, scheduledTask.Task.Name, "the already-assigned task should have been chosen") } func TestAssignedToOtherWorker(t *testing.T) { @@ -292,17 +294,17 @@ func TestAssignedToOtherWorker(t *testing.T) { // it shouldn't matter which worker it's assigned to. dbTask2, err := db.FetchTask(ctx, att2.UUID) require.NoError(t, err) - dbTask2.WorkerID = ptr(uint(w2.ID)) - dbTask2.Status = api.TaskStatusQueued - err = db.SaveTask(ctx, dbTask2) + dbTask2.Task.WorkerID = sql.NullInt64{Int64: w2.ID, Valid: true} + dbTask2.Task.Status = api.TaskStatusQueued + err = db.SaveTask(ctx, &dbTask2.Task) require.NoError(t, err) - task, err := db.ScheduleTask(ctx, &w) + scheduledTask, err := db.ScheduleTask(ctx, &w) require.NoError(t, err) - require.NotNil(t, task) + require.NotNil(t, scheduledTask) - assert.Equal(t, att2.Name, task.Name, "the high-prio task should have been chosen") - assert.Equal(t, int64(*task.WorkerID), w.ID, "the task should now be assigned to the worker it was scheduled for") + assert.Equal(t, att2.Name, scheduledTask.Task.Name, "the high-prio task should have been chosen") + assert.Equal(t, w.ID, scheduledTask.Task.WorkerID.Int64, "the task should now be assigned to the worker it was scheduled for") } func TestPreviouslyFailed(t *testing.T) { @@ -320,17 +322,17 @@ func TestPreviouslyFailed(t *testing.T) { job := constructTestJob(ctx, t, db, atj) // Mimick that this worker already failed the first task. - tasks, err := db.FetchTasksOfJob(ctx, job) + taskJobWorkers, err := db.FetchTasksOfJob(ctx, job) require.NoError(t, err) - numFailed, err := db.AddWorkerToTaskFailedList(ctx, tasks[0], &w) + numFailed, err := db.AddWorkerToTaskFailedList(ctx, &taskJobWorkers[0].Task, &w) require.NoError(t, err) assert.Equal(t, 1, numFailed) - // This should assign the 2nd task. - task, err := db.ScheduleTask(ctx, &w) + // This should assign the 2nd scheduledTask. + scheduledTask, err := db.ScheduleTask(ctx, &w) require.NoError(t, err) - require.NotNil(t, task) - assert.Equal(t, att2.Name, task.Name, "the second task should have been chosen") + require.NotNil(t, scheduledTask) + assert.Equal(t, att2.Name, scheduledTask.Task.Name, "the second task should have been chosen") } func TestWorkerTagJobWithTag(t *testing.T) { @@ -369,14 +371,14 @@ func TestWorkerTagJobWithTag(t *testing.T) { job.WorkerTagUUID = tag1.UUID constructTestJob(ctx, t, db, job) - task, err := db.ScheduleTask(ctx, &workerC) + scheduledTask, err := db.ScheduleTask(ctx, &workerC) require.NoError(t, err) - require.NotNil(t, task, "job with matching tag should be scheduled") - assert.Equal(t, authTask.UUID, task.UUID) + require.NotNil(t, scheduledTask, "job with matching tag should be scheduled") + assert.Equal(t, authTask.UUID, scheduledTask.Task.UUID) - task, err = db.ScheduleTask(ctx, &workerNC) + scheduledTask, err = db.ScheduleTask(ctx, &workerNC) require.NoError(t, err) - assert.Nil(t, task, "job with tag should not be scheduled for worker without tag") + assert.Nil(t, scheduledTask, "job with tag should not be scheduled for worker without tag") } } @@ -402,15 +404,15 @@ func TestWorkerTagJobWithoutTag(t *testing.T) { job := authorTestJob("b6a1d859-122f-4791-8b78-b943329a9989", "simple-blender-render", authTask) constructTestJob(ctx, t, db, job) - task, err := db.ScheduleTask(ctx, &workerC) + scheduledTask, err := db.ScheduleTask(ctx, &workerC) require.NoError(t, err) - require.NotNil(t, task, "job without tag should always be scheduled to worker in some tag") - assert.Equal(t, authTask.UUID, task.UUID) + require.NotNil(t, scheduledTask, "job without tag should always be scheduled to worker in some tag") + assert.Equal(t, authTask.UUID, scheduledTask.Task.UUID) - task, err = db.ScheduleTask(ctx, &workerNC) + scheduledTask, err = db.ScheduleTask(ctx, &workerNC) require.NoError(t, err) - require.NotNil(t, task, "job without tag should always be scheduled to worker without tag") - assert.Equal(t, authTask.UUID, task.UUID) + require.NotNil(t, scheduledTask, "job without tag should always be scheduled to worker without tag") + assert.Equal(t, authTask.UUID, scheduledTask.Task.UUID) } func TestBlocklisted(t *testing.T) { @@ -428,14 +430,14 @@ func TestBlocklisted(t *testing.T) { job := constructTestJob(ctx, t, db, atj) // Mimick that this worker was already blocked for 'blender' tasks of this job. - err := db.AddWorkerToJobBlocklist(ctx, job, &w, "blender") + err := db.AddWorkerToJobBlocklist(ctx, job.ID, w.ID, "blender") require.NoError(t, err) - // This should assign the 2nd task. - task, err := db.ScheduleTask(ctx, &w) + // This should assign the 2nd scheduledTask. + scheduledTask, err := db.ScheduleTask(ctx, &w) require.NoError(t, err) - require.NotNil(t, task) - assert.Equal(t, att2.Name, task.Name, "the second task should have been chosen") + require.NotNil(t, scheduledTask) + assert.Equal(t, att2.Name, scheduledTask.Task.Name, "the second task should have been chosen") } // To test: blocklists @@ -486,12 +488,12 @@ func authorTestTask(name, taskType string, dependencies ...*job_compilers.Author func setTaskStatus(t *testing.T, db *DB, taskUUID string, status api.TaskStatus) { ctx := context.Background() - task, err := db.FetchTask(ctx, taskUUID) + taskJobWorker, err := db.FetchTask(ctx, taskUUID) require.NoError(t, err) - task.Status = status + taskJobWorker.Task.Status = status - require.NoError(t, db.SaveTask(ctx, task)) + require.NoError(t, db.SaveTask(ctx, &taskJobWorker.Task)) } func linuxWorker(t *testing.T, db *DB, updaters ...func(worker *Worker)) Worker { diff --git a/internal/manager/persistence/timeout.go b/internal/manager/persistence/timeout.go index af74d4a0..b31ba88a 100644 --- a/internal/manager/persistence/timeout.go +++ b/internal/manager/persistence/timeout.go @@ -21,16 +21,16 @@ var workerStatusNoTimeout = []api.WorkerStatus{ api.WorkerStatusOffline, } +type TimedOutTaskInfo = sqlc.FetchTimedOutTasksRow + // FetchTimedOutTasks returns a slice of tasks that have timed out. // // In order to time out, a task must be in status `active` and not touched by a // Worker since `untouchedSince`. -// -// The returned tasks also have their `Job` and `Worker` fields set. -func (db *DB) FetchTimedOutTasks(ctx context.Context, untouchedSince time.Time) ([]*Task, error) { +func (db *DB) FetchTimedOutTasks(ctx context.Context, untouchedSince time.Time) ([]TimedOutTaskInfo, error) { queries := db.queries() - sqlcTasks, err := queries.FetchTimedOutTasks(ctx, sqlc.FetchTimedOutTasksParams{ + timedOut, err := queries.FetchTimedOutTasks(ctx, sqlc.FetchTimedOutTasksParams{ TaskStatus: api.TaskStatusActive, UntouchedSince: sql.NullTime{Time: untouchedSince, Valid: true}, }) @@ -38,17 +38,7 @@ func (db *DB) FetchTimedOutTasks(ctx context.Context, untouchedSince time.Time) if err != nil { return nil, taskError(err, "finding timed out tasks (untouched since %s)", untouchedSince.String()) } - - result := make([]*Task, len(sqlcTasks)) - for index, task := range sqlcTasks { - gormTask, err := convertSqlTaskWithJobAndWorker(ctx, queries, task) - if err != nil { - return nil, err - } - result[index] = gormTask - } - - return result, nil + return timedOut, nil } func (db *DB) FetchTimedOutWorkers(ctx context.Context, lastSeenBefore time.Time) ([]*Worker, error) { diff --git a/internal/manager/persistence/timeout_test.go b/internal/manager/persistence/timeout_test.go index 344dbc4f..ba2ecc27 100644 --- a/internal/manager/persistence/timeout_test.go +++ b/internal/manager/persistence/timeout_test.go @@ -16,15 +16,15 @@ func TestFetchTimedOutTasks(t *testing.T) { ctx, close, db, job, _ := jobTasksTestFixtures(t) defer close() - tasks, err := db.FetchTasksOfJob(ctx, job) + tasksOfJob, err := db.FetchTasksOfJob(ctx, job) require.NoError(t, err) now := db.now() deadline := now.Add(-5 * time.Minute) // Mark the task as last touched before the deadline, i.e. old enough for a timeout. - task := tasks[0] - task.LastTouchedAt = deadline.Add(-1 * time.Minute) + task := &tasksOfJob[0].Task + task.LastTouchedAt = sql.NullTime{Time: deadline.Add(-1 * time.Minute), Valid: true} require.NoError(t, db.SaveTask(ctx, task)) w := createWorker(ctx, t, db) @@ -45,9 +45,10 @@ func TestFetchTimedOutTasks(t *testing.T) { if assert.Len(t, timedout, 1) { // Other fields will be different, like the 'UpdatedAt' field -- this just // tests that the expected task is returned. - assert.Equal(t, task.UUID, timedout[0].UUID) - assert.Equal(t, job, timedout[0].Job, "the job should be included in the result as well") - assert.Equal(t, w.UUID, timedout[0].Worker.UUID, "the worker should be included in the result as well") + assert.Equal(t, task.UUID, timedout[0].Task.UUID) + assert.Equal(t, job.UUID, timedout[0].JobUUID, "the job should be included in the result as well") + assert.Equal(t, w.UUID, timedout[0].WorkerUUID, "the worker UUID should be included in the result as well") + assert.Equal(t, w.Name, timedout[0].WorkerName, "the worker Name should be included in the result as well") } } diff --git a/internal/manager/persistence/worker_sleep_schedule.go b/internal/manager/persistence/worker_sleep_schedule.go index 798a766f..125238cc 100644 --- a/internal/manager/persistence/worker_sleep_schedule.go +++ b/internal/manager/persistence/worker_sleep_schedule.go @@ -16,6 +16,7 @@ import ( // sent to the 'asleep' and 'awake' states. type SleepSchedule = sqlc.SleepSchedule +// SleepScheduleOwned represents a sleep schedule + info about the worker that's controlled by it. type SleepScheduleOwned struct { SleepSchedule SleepSchedule WorkerName string diff --git a/internal/manager/persistence/worker_tag.go b/internal/manager/persistence/worker_tag.go index 26e0efbd..67563c1b 100644 --- a/internal/manager/persistence/worker_tag.go +++ b/internal/manager/persistence/worker_tag.go @@ -54,6 +54,11 @@ func (db *DB) FetchWorkerTag(ctx context.Context, uuid string) (WorkerTag, error return workerTag, nil } +func (db *DB) FetchWorkerTagByID(ctx context.Context, id int64) (WorkerTag, error) { + queries := db.queries() + return fetchWorkerTagByID(ctx, queries, id) +} + // fetchWorkerTagByID fetches the worker tag using the given database instance. func fetchWorkerTagByID(ctx context.Context, queries *sqlc.Queries, id int64) (WorkerTag, error) { workerTag, err := queries.FetchWorkerTagByID(ctx, id) diff --git a/internal/manager/persistence/workers.go b/internal/manager/persistence/workers.go index c633cf04..d20ca06d 100644 --- a/internal/manager/persistence/workers.go +++ b/internal/manager/persistence/workers.go @@ -99,7 +99,7 @@ func (db *DB) FetchWorkers(ctx context.Context) ([]*Worker, error) { } // FetchWorkerTask returns the most recent task assigned to the given Worker. -func (db *DB) FetchWorkerTask(ctx context.Context, worker *Worker) (*Task, error) { +func (db *DB) FetchWorkerTask(ctx context.Context, worker *Worker) (*TaskJob, error) { queries := db.queries() // Convert the WorkerID to a NullInt64. As task.worker_id can be NULL, this is @@ -119,27 +119,12 @@ func (db *DB) FetchWorkerTask(ctx context.Context, worker *Worker) (*Task, error return nil, taskError(err, "fetching task assigned to Worker %s", worker.UUID) } - // Found a task! - if row.Job.ID == 0 { - panic(fmt.Sprintf("task found but with no job: %#v", row)) + taskJob := TaskJob{ + Task: row.Task, + JobUUID: row.JobUUID, + IsActive: row.IsActive, } - if row.Task.ID == 0 { - panic(fmt.Sprintf("task found but with zero ID: %#v", row)) - } - - // Convert the task & job to gorm data types. - gormTask, err := convertSqlcTask(row.Task, row.Job.UUID, worker.UUID) - if err != nil { - return nil, err - } - gormJob, err := convertSqlcJob(row.Job) - if err != nil { - return nil, err - } - gormTask.Job = &gormJob - gormTask.Worker = worker - - return gormTask, nil + return &taskJob, nil } func (db *DB) SaveWorkerStatus(ctx context.Context, w *Worker) error { diff --git a/internal/manager/persistence/workers_test.go b/internal/manager/persistence/workers_test.go index 4a3cf52d..949889d3 100644 --- a/internal/manager/persistence/workers_test.go +++ b/internal/manager/persistence/workers_test.go @@ -95,26 +95,29 @@ func TestFetchWorkerTask(t *testing.T) { assignedTask, err := db.ScheduleTask(ctx, &w) require.NoError(t, err) - require.Equal(t, assignedTask.UUID, authTask1.UUID) + require.Equal(t, authTask1.UUID, assignedTask.Task.UUID) + require.Equal(t, jobUUID, assignedTask.JobUUID) + require.Equal(t, atj.JobType, assignedTask.JobType) + require.Equal(t, atj.Priority, int(assignedTask.JobPriority)) { // Assigned task should be returned. foundTask, err := db.FetchWorkerTask(ctx, &w) require.NoError(t, err) require.NotNil(t, foundTask) - assert.Equal(t, assignedTask.UUID, foundTask.UUID) - assert.Equal(t, jobUUID, foundTask.Job.UUID, "the job UUID should be returned as well") + assert.Equal(t, assignedTask.Task.UUID, foundTask.Task.UUID) + assert.Equal(t, assignedTask.JobUUID, foundTask.JobUUID, "the job UUID should be returned as well") } // Set the task to 'completed'. - assignedTask.Status = api.TaskStatusCompleted - require.NoError(t, db.SaveTaskStatus(ctx, assignedTask)) + assignedTask.Task.Status = api.TaskStatusCompleted + require.NoError(t, db.SaveTaskStatus(ctx, &assignedTask.Task)) { // Completed-but-last-assigned task should be returned. foundTask, err := db.FetchWorkerTask(ctx, &w) require.NoError(t, err) require.NotNil(t, foundTask) - assert.Equal(t, assignedTask.UUID, foundTask.UUID) - assert.Equal(t, jobUUID, foundTask.Job.UUID, "the job UUID should be returned as well") + assert.Equal(t, assignedTask.Task.UUID, foundTask.Task.UUID) + assert.Equal(t, jobUUID, foundTask.JobUUID, "the job UUID should be returned as well") } // Assign another task. Since the remainder of this test depends on the order @@ -125,26 +128,26 @@ func TestFetchWorkerTask(t *testing.T) { newlyAssignedTask, err := db.ScheduleTask(ctx, &w) require.NoError(t, err) require.NotNil(t, newlyAssignedTask) - require.Equal(t, newlyAssignedTask.UUID, authTask2.UUID) + require.Equal(t, newlyAssignedTask.Task.UUID, authTask2.UUID) { // Newly assigned task should be returned. foundTask, err := db.FetchWorkerTask(ctx, &w) require.NoError(t, err) require.NotNil(t, foundTask) - assert.Equal(t, newlyAssignedTask.UUID, foundTask.UUID) - assert.Equal(t, jobUUID, foundTask.Job.UUID, "the job UUID should be returned as well") + assert.Equal(t, newlyAssignedTask.Task.UUID, foundTask.Task.UUID) + assert.Equal(t, jobUUID, foundTask.JobUUID, "the job UUID should be returned as well") } // Set the new task to 'completed'. - newlyAssignedTask.Status = api.TaskStatusCompleted - require.NoError(t, db.SaveTaskStatus(ctx, newlyAssignedTask)) + newlyAssignedTask.Task.Status = api.TaskStatusCompleted + require.NoError(t, db.SaveTaskStatus(ctx, &newlyAssignedTask.Task)) { // Completed-but-last-assigned task should be returned. foundTask, err := db.FetchWorkerTask(ctx, &w) require.NoError(t, err) require.NotNil(t, foundTask) - assert.Equal(t, newlyAssignedTask.UUID, foundTask.UUID) - assert.Equal(t, jobUUID, foundTask.Job.UUID, "the job UUID should be returned as well") + assert.Equal(t, newlyAssignedTask.Task.UUID, foundTask.Task.UUID) + assert.Equal(t, jobUUID, foundTask.JobUUID, "the job UUID should be returned as well") } } @@ -309,10 +312,9 @@ func TestDeleteWorker(t *testing.T) { persistAuthoredJob(t, ctx, db, authJob) taskUUID := authJob.Tasks[0].UUID { - task, err := db.FetchTask(ctx, taskUUID) + taskJobWorker, err := db.FetchTask(ctx, taskUUID) require.NoError(t, err) - task.Worker = &w1 - require.NoError(t, db.SaveTask(ctx, task)) + require.NoError(t, db.TaskAssignToWorker(ctx, &taskJobWorker.Task, &w1)) } // Delete the worker. @@ -320,12 +322,11 @@ func TestDeleteWorker(t *testing.T) { // Check the task after deletion of the Worker. { - fetchedTask, err := db.FetchTask(ctx, taskUUID) + taskJobWorker, err := db.FetchTask(ctx, taskUUID) require.NoError(t, err) - assert.Equal(t, taskUUID, fetchedTask.UUID) - assert.Equal(t, w1.UUID, fetchedTask.Worker.UUID) - assert.NotZero(t, fetchedTask.Worker.DeletedAt.Time) - assert.True(t, fetchedTask.Worker.DeletedAt.Valid) + assert.Equal(t, taskUUID, taskJobWorker.Task.UUID) + assert.Equal(t, w1.UUID, taskJobWorker.WorkerUUID) + assert.Equal(t, authJob.JobID, taskJobWorker.JobUUID) } } diff --git a/internal/manager/sleep_scheduler/mocks/interfaces_mock.gen.go b/internal/manager/sleep_scheduler/mocks/interfaces_mock.gen.go index 82811c3e..e33025b6 100644 --- a/internal/manager/sleep_scheduler/mocks/interfaces_mock.gen.go +++ b/internal/manager/sleep_scheduler/mocks/interfaces_mock.gen.go @@ -37,20 +37,6 @@ func (m *MockPersistenceService) EXPECT() *MockPersistenceServiceMockRecorder { return m.recorder } -// CreateWorker mocks base method. -func (m *MockPersistenceService) CreateWorker(arg0 context.Context, arg1 *sqlc.Worker) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateWorker", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// CreateWorker indicates an expected call of CreateWorker. -func (mr *MockPersistenceServiceMockRecorder) CreateWorker(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateWorker", reflect.TypeOf((*MockPersistenceService)(nil).CreateWorker), arg0, arg1) -} - // FetchSleepScheduleWorker mocks base method. func (m *MockPersistenceService) FetchSleepScheduleWorker(arg0 context.Context, arg1 sqlc.SleepSchedule) (*sqlc.Worker, error) { m.ctrl.T.Helper() diff --git a/internal/manager/task_state_machine/interfaces.go b/internal/manager/task_state_machine/interfaces.go index 77c2db9f..f36af4b7 100644 --- a/internal/manager/task_state_machine/interfaces.go +++ b/internal/manager/task_state_machine/interfaces.go @@ -33,9 +33,11 @@ type PersistenceService interface { UpdateJobsTaskStatusesConditional(ctx context.Context, job *persistence.Job, statusesToUpdate []api.TaskStatus, taskStatus api.TaskStatus, activity string) error + FetchJob(ctx context.Context, jobUUID string) (*persistence.Job, error) + FetchJobByID(ctx context.Context, jobID int64) (*persistence.Job, error) FetchJobsInStatus(ctx context.Context, jobStatuses ...api.JobStatus) ([]*persistence.Job, error) - FetchTasksOfWorkerInStatus(context.Context, *persistence.Worker, api.TaskStatus) ([]*persistence.Task, error) - FetchTasksOfWorkerInStatusOfJob(context.Context, *persistence.Worker, api.TaskStatus, *persistence.Job) ([]*persistence.Task, error) + FetchTasksOfWorkerInStatus(context.Context, *persistence.Worker, api.TaskStatus) ([]persistence.TaskJob, error) + FetchTasksOfWorkerInStatusOfJob(ctx context.Context, worker *persistence.Worker, status api.TaskStatus, jobUUID string) ([]*persistence.Task, error) } // PersistenceService should be a subset of persistence.DB diff --git a/internal/manager/task_state_machine/mocks/interfaces_mock.gen.go b/internal/manager/task_state_machine/mocks/interfaces_mock.gen.go index ef004fba..05ffefa4 100644 --- a/internal/manager/task_state_machine/mocks/interfaces_mock.gen.go +++ b/internal/manager/task_state_machine/mocks/interfaces_mock.gen.go @@ -39,7 +39,7 @@ func (m *MockPersistenceService) EXPECT() *MockPersistenceServiceMockRecorder { } // CountTasksOfJobInStatus mocks base method. -func (m *MockPersistenceService) CountTasksOfJobInStatus(arg0 context.Context, arg1 *persistence.Job, arg2 ...api.TaskStatus) (int, int, error) { +func (m *MockPersistenceService) CountTasksOfJobInStatus(arg0 context.Context, arg1 *sqlc.Job, arg2 ...api.TaskStatus) (int, int, error) { m.ctrl.T.Helper() varargs := []interface{}{arg0, arg1} for _, a := range arg2 { @@ -59,15 +59,45 @@ func (mr *MockPersistenceServiceMockRecorder) CountTasksOfJobInStatus(arg0, arg1 return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountTasksOfJobInStatus", reflect.TypeOf((*MockPersistenceService)(nil).CountTasksOfJobInStatus), varargs...) } +// FetchJob mocks base method. +func (m *MockPersistenceService) FetchJob(arg0 context.Context, arg1 string) (*sqlc.Job, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FetchJob", arg0, arg1) + ret0, _ := ret[0].(*sqlc.Job) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FetchJob indicates an expected call of FetchJob. +func (mr *MockPersistenceServiceMockRecorder) FetchJob(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchJob", reflect.TypeOf((*MockPersistenceService)(nil).FetchJob), arg0, arg1) +} + +// FetchJobByID mocks base method. +func (m *MockPersistenceService) FetchJobByID(arg0 context.Context, arg1 int64) (*sqlc.Job, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FetchJobByID", arg0, arg1) + ret0, _ := ret[0].(*sqlc.Job) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FetchJobByID indicates an expected call of FetchJobByID. +func (mr *MockPersistenceServiceMockRecorder) FetchJobByID(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchJobByID", reflect.TypeOf((*MockPersistenceService)(nil).FetchJobByID), arg0, arg1) +} + // FetchJobsInStatus mocks base method. -func (m *MockPersistenceService) FetchJobsInStatus(arg0 context.Context, arg1 ...api.JobStatus) ([]*persistence.Job, error) { +func (m *MockPersistenceService) FetchJobsInStatus(arg0 context.Context, arg1 ...api.JobStatus) ([]*sqlc.Job, error) { m.ctrl.T.Helper() varargs := []interface{}{arg0} for _, a := range arg1 { varargs = append(varargs, a) } ret := m.ctrl.Call(m, "FetchJobsInStatus", varargs...) - ret0, _ := ret[0].([]*persistence.Job) + ret0, _ := ret[0].([]*sqlc.Job) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -80,10 +110,10 @@ func (mr *MockPersistenceServiceMockRecorder) FetchJobsInStatus(arg0 interface{} } // FetchTasksOfWorkerInStatus mocks base method. -func (m *MockPersistenceService) FetchTasksOfWorkerInStatus(arg0 context.Context, arg1 *sqlc.Worker, arg2 api.TaskStatus) ([]*persistence.Task, error) { +func (m *MockPersistenceService) FetchTasksOfWorkerInStatus(arg0 context.Context, arg1 *sqlc.Worker, arg2 api.TaskStatus) ([]persistence.TaskJob, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "FetchTasksOfWorkerInStatus", arg0, arg1, arg2) - ret0, _ := ret[0].([]*persistence.Task) + ret0, _ := ret[0].([]persistence.TaskJob) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -95,10 +125,10 @@ func (mr *MockPersistenceServiceMockRecorder) FetchTasksOfWorkerInStatus(arg0, a } // FetchTasksOfWorkerInStatusOfJob mocks base method. -func (m *MockPersistenceService) FetchTasksOfWorkerInStatusOfJob(arg0 context.Context, arg1 *sqlc.Worker, arg2 api.TaskStatus, arg3 *persistence.Job) ([]*persistence.Task, error) { +func (m *MockPersistenceService) FetchTasksOfWorkerInStatusOfJob(arg0 context.Context, arg1 *sqlc.Worker, arg2 api.TaskStatus, arg3 string) ([]*sqlc.Task, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "FetchTasksOfWorkerInStatusOfJob", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].([]*persistence.Task) + ret0, _ := ret[0].([]*sqlc.Task) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -110,7 +140,7 @@ func (mr *MockPersistenceServiceMockRecorder) FetchTasksOfWorkerInStatusOfJob(ar } // JobHasTasksInStatus mocks base method. -func (m *MockPersistenceService) JobHasTasksInStatus(arg0 context.Context, arg1 *persistence.Job, arg2 api.TaskStatus) (bool, error) { +func (m *MockPersistenceService) JobHasTasksInStatus(arg0 context.Context, arg1 *sqlc.Job, arg2 api.TaskStatus) (bool, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "JobHasTasksInStatus", arg0, arg1, arg2) ret0, _ := ret[0].(bool) @@ -125,7 +155,7 @@ func (mr *MockPersistenceServiceMockRecorder) JobHasTasksInStatus(arg0, arg1, ar } // SaveJobStatus mocks base method. -func (m *MockPersistenceService) SaveJobStatus(arg0 context.Context, arg1 *persistence.Job) error { +func (m *MockPersistenceService) SaveJobStatus(arg0 context.Context, arg1 *sqlc.Job) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SaveJobStatus", arg0, arg1) ret0, _ := ret[0].(error) @@ -139,7 +169,7 @@ func (mr *MockPersistenceServiceMockRecorder) SaveJobStatus(arg0, arg1 interface } // SaveTask mocks base method. -func (m *MockPersistenceService) SaveTask(arg0 context.Context, arg1 *persistence.Task) error { +func (m *MockPersistenceService) SaveTask(arg0 context.Context, arg1 *sqlc.Task) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SaveTask", arg0, arg1) ret0, _ := ret[0].(error) @@ -153,7 +183,7 @@ func (mr *MockPersistenceServiceMockRecorder) SaveTask(arg0, arg1 interface{}) * } // SaveTaskActivity mocks base method. -func (m *MockPersistenceService) SaveTaskActivity(arg0 context.Context, arg1 *persistence.Task) error { +func (m *MockPersistenceService) SaveTaskActivity(arg0 context.Context, arg1 *sqlc.Task) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SaveTaskActivity", arg0, arg1) ret0, _ := ret[0].(error) @@ -167,7 +197,7 @@ func (mr *MockPersistenceServiceMockRecorder) SaveTaskActivity(arg0, arg1 interf } // SaveTaskStatus mocks base method. -func (m *MockPersistenceService) SaveTaskStatus(arg0 context.Context, arg1 *persistence.Task) error { +func (m *MockPersistenceService) SaveTaskStatus(arg0 context.Context, arg1 *sqlc.Task) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SaveTaskStatus", arg0, arg1) ret0, _ := ret[0].(error) @@ -181,7 +211,7 @@ func (mr *MockPersistenceServiceMockRecorder) SaveTaskStatus(arg0, arg1 interfac } // UpdateJobsTaskStatuses mocks base method. -func (m *MockPersistenceService) UpdateJobsTaskStatuses(arg0 context.Context, arg1 *persistence.Job, arg2 api.TaskStatus, arg3 string) error { +func (m *MockPersistenceService) UpdateJobsTaskStatuses(arg0 context.Context, arg1 *sqlc.Job, arg2 api.TaskStatus, arg3 string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "UpdateJobsTaskStatuses", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(error) @@ -195,7 +225,7 @@ func (mr *MockPersistenceServiceMockRecorder) UpdateJobsTaskStatuses(arg0, arg1, } // UpdateJobsTaskStatusesConditional mocks base method. -func (m *MockPersistenceService) UpdateJobsTaskStatusesConditional(arg0 context.Context, arg1 *persistence.Job, arg2 []api.TaskStatus, arg3 api.TaskStatus, arg4 string) error { +func (m *MockPersistenceService) UpdateJobsTaskStatusesConditional(arg0 context.Context, arg1 *sqlc.Job, arg2 []api.TaskStatus, arg3 api.TaskStatus, arg4 string) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "UpdateJobsTaskStatusesConditional", arg0, arg1, arg2, arg3, arg4) ret0, _ := ret[0].(error) diff --git a/internal/manager/task_state_machine/task_state_machine.go b/internal/manager/task_state_machine/task_state_machine.go index 1ab687ea..b73e2149 100644 --- a/internal/manager/task_state_machine/task_state_machine.go +++ b/internal/manager/task_state_machine/task_state_machine.go @@ -40,13 +40,23 @@ func (sm *StateMachine) TaskStatusChange( task *persistence.Task, newTaskStatus api.TaskStatus, ) error { + if task.JobID == 0 { + log.Panic().Str("task", task.UUID).Msg("task without job ID, cannot handle this") + return nil // Will not run because of the panic. + } + + job, err := sm.persist.FetchJobByID(ctx, task.JobID) + if err != nil { + return fmt.Errorf("cannot fetch the job of task %s: %w", task.UUID, err) + } + oldTaskStatus := task.Status - if err := sm.taskStatusChangeOnly(ctx, task, newTaskStatus); err != nil { + if err := sm.taskStatusChangeOnly(ctx, task, job, newTaskStatus); err != nil { return err } - if err := sm.updateJobAfterTaskStatusChange(ctx, task, oldTaskStatus); err != nil { + if err := sm.updateJobAfterTaskStatusChange(ctx, task, job, oldTaskStatus); err != nil { return fmt.Errorf("updating job after task status change: %w", err) } return nil @@ -57,19 +67,16 @@ func (sm *StateMachine) TaskStatusChange( func (sm *StateMachine) taskStatusChangeOnly( ctx context.Context, task *persistence.Task, + job *persistence.Job, newTaskStatus api.TaskStatus, ) error { - if task.JobUUID == "" { - log.Panic().Str("task", task.UUID).Msg("task without job UUID, cannot handle this") - return nil // Will not run because of the panic. - } oldTaskStatus := task.Status task.Status = newTaskStatus logger := log.With(). Str("task", task.UUID). - Str("job", task.JobUUID). + Str("job", job.UUID). Str("taskStatusOld", string(oldTaskStatus)). Str("taskStatusNew", string(newTaskStatus)). Logger() @@ -82,12 +89,12 @@ func (sm *StateMachine) taskStatusChangeOnly( if oldTaskStatus != newTaskStatus { // logStorage already logs any error, and an error here shouldn't block the // rest of the function. - _ = sm.logStorage.WriteTimestamped(logger, task.JobUUID, task.UUID, + _ = sm.logStorage.WriteTimestamped(logger, job.UUID, task.UUID, fmt.Sprintf("task changed status %s -> %s", oldTaskStatus, newTaskStatus)) } // Broadcast this change to the SocketIO clients. - taskUpdate := eventbus.NewTaskUpdate(task) + taskUpdate := eventbus.NewTaskUpdate(*task, job.UUID) taskUpdate.PreviousStatus = &oldTaskStatus sm.broadcaster.BroadcastTaskUpdate(taskUpdate) @@ -97,16 +104,18 @@ func (sm *StateMachine) taskStatusChangeOnly( // updateJobAfterTaskStatusChange updates the job status based on the status of // this task and other tasks in the job. func (sm *StateMachine) updateJobAfterTaskStatusChange( - ctx context.Context, task *persistence.Task, oldTaskStatus api.TaskStatus, + ctx context.Context, + task *persistence.Task, + job *persistence.Job, + oldTaskStatus api.TaskStatus, ) error { - job := task.Job if job == nil { log.Panic().Str("task", task.UUID).Msg("task without job, cannot handle this") return nil // Will not run because of the panic. } logger := log.With(). - Str("job", task.JobUUID). + Str("job", job.UUID). Str("task", task.UUID). Str("taskStatusOld", string(oldTaskStatus)). Str("taskStatusNew", string(task.Status)). @@ -136,7 +145,7 @@ func (sm *StateMachine) updateJobAfterTaskStatusChange( default: logger.Info().Msg("job became active because one of its task changed status") reason := fmt.Sprintf("task became %s", task.Status) - return sm.JobStatusChange(ctx, job, api.JobStatusActive, reason) + return sm.jobStatusChange(ctx, job, api.JobStatusActive, reason) } case api.TaskStatusCompleted: @@ -163,7 +172,7 @@ func (sm *StateMachine) jobStatusIfAThenB( Str("jobStatusOld", string(ifStatus)). Str("jobStatusNew", string(thenStatus)). Msg("Job will change status because one of its task changed status") - return sm.JobStatusChange(ctx, job, thenStatus, reason) + return sm.jobStatusChange(ctx, job, thenStatus, reason) } // isJobPausingComplete returns true when the job status is pause-requested and there are no more active tasks. @@ -189,7 +198,7 @@ func (sm *StateMachine) updateJobOnTaskStatusCanceled(ctx context.Context, logge if numRunnable == 0 { // NOTE: this does NOT cancel any non-runnable (paused/failed) tasks. If that's desired, just cancel the job as a whole. logger.Info().Msg("canceled task was last runnable task of job, canceling job") - return sm.JobStatusChange(ctx, job, api.JobStatusCanceled, "canceled task was last runnable task of job, canceling job") + return sm.jobStatusChange(ctx, job, api.JobStatusCanceled, "canceled task was last runnable task of job, canceling job") } // Deal with the special case when the job is in pause-requested status. @@ -198,7 +207,7 @@ func (sm *StateMachine) updateJobOnTaskStatusCanceled(ctx context.Context, logge return err } if toBePaused { - return sm.JobStatusChange(ctx, job, api.JobStatusPaused, "no more active tasks after task cancellation") + return sm.jobStatusChange(ctx, job, api.JobStatusPaused, "no more active tasks after task cancellation") } return nil @@ -221,7 +230,7 @@ func (sm *StateMachine) updateJobOnTaskStatusFailed(ctx context.Context, logger if failedPercentage >= taskFailJobPercentage { failLogger.Info().Msg("failing job because too many of its tasks failed") - return sm.JobStatusChange(ctx, job, api.JobStatusFailed, "too many tasks failed") + return sm.jobStatusChange(ctx, job, api.JobStatusFailed, "too many tasks failed") } // If the job didn't fail, this failure indicates that at least the job is active. failLogger.Info().Msg("task failed, but not enough to fail the job") @@ -232,7 +241,7 @@ func (sm *StateMachine) updateJobOnTaskStatusFailed(ctx context.Context, logger return err } if toBePaused { - return sm.JobStatusChange(ctx, job, api.JobStatusPaused, "no more active tasks after task failure") + return sm.jobStatusChange(ctx, job, api.JobStatusPaused, "no more active tasks after task failure") } return sm.jobStatusIfAThenB(ctx, logger, job, api.JobStatusQueued, api.JobStatusActive, @@ -247,7 +256,7 @@ func (sm *StateMachine) updateJobOnTaskStatusCompleted(ctx context.Context, logg } if numComplete == numTotal { logger.Info().Msg("all tasks of job are completed, job is completed") - return sm.JobStatusChange(ctx, job, api.JobStatusCompleted, "all tasks completed") + return sm.jobStatusChange(ctx, job, api.JobStatusCompleted, "all tasks completed") } // Deal with the special case when the job is in pause-requested status. @@ -256,7 +265,7 @@ func (sm *StateMachine) updateJobOnTaskStatusCompleted(ctx context.Context, logg return err } if toBePaused { - return sm.JobStatusChange(ctx, job, api.JobStatusPaused, "no more active tasks after task completion") + return sm.jobStatusChange(ctx, job, api.JobStatusPaused, "no more active tasks after task completion") } logger.Info(). @@ -268,6 +277,23 @@ func (sm *StateMachine) updateJobOnTaskStatusCompleted(ctx context.Context, logg // JobStatusChange gives a Job a new status, and handles the resulting status changes on its tasks. func (sm *StateMachine) JobStatusChange( + ctx context.Context, + jobUUID string, + newJobStatus api.JobStatus, + reason string, +) error { + job, err := sm.persist.FetchJob(ctx, jobUUID) + if err != nil { + return err + } + return sm.jobStatusChange(ctx, job, newJobStatus, reason) +} + +// jobStatusChange gives a Job a new status, and handles the resulting status changes on its tasks. +// +// This is the private implementation, which takes the job as an argument. The +// public function (above) takes a job's UUID instead, so that it's easier to call. +func (sm *StateMachine) jobStatusChange( ctx context.Context, job *persistence.Job, newJobStatus api.JobStatus, @@ -369,7 +395,7 @@ func (sm *StateMachine) jobStatusSet(ctx context.Context, } // Handle the status change. - result, err := sm.updateTasksAfterJobStatusChange(ctx, logger, job, oldJobStatus) + result, err := sm.updateTasksAfterjobStatusChange(ctx, logger, job, oldJobStatus) if err != nil { return "", fmt.Errorf("updating job's tasks after job status change: %w", err) } @@ -383,7 +409,7 @@ func (sm *StateMachine) jobStatusSet(ctx context.Context, return result.followingJobStatus, nil } -// tasksUpdateResult is returned by `updateTasksAfterJobStatusChange`. +// tasksUpdateResult is returned by `updateTasksAfterjobStatusChange`. type tasksUpdateResult struct { // FollowingJobStatus is set when the task updates should trigger another job status update. followingJobStatus api.JobStatus @@ -394,14 +420,14 @@ type tasksUpdateResult struct { massTaskUpdate bool } -// updateTasksAfterJobStatusChange updates the status of its tasks based on the +// updateTasksAfterjobStatusChange updates the status of its tasks based on the // new status of this job. // // NOTE: this function assumes that the job already has its new status. // // Returns the new state the job should go into after this change, or an empty // string if there is no subsequent change necessary. -func (sm *StateMachine) updateTasksAfterJobStatusChange( +func (sm *StateMachine) updateTasksAfterjobStatusChange( ctx context.Context, logger zerolog.Logger, job *persistence.Job, diff --git a/internal/manager/task_state_machine/task_state_machine_test.go b/internal/manager/task_state_machine/task_state_machine_test.go index 531bd7ca..eeedad26 100644 --- a/internal/manager/task_state_machine/task_state_machine_test.go +++ b/internal/manager/task_state_machine/task_state_machine_test.go @@ -31,13 +31,15 @@ func TestTaskStatusChangeQueuedToActive(t *testing.T) { defer mockCtrl.Finish() // T: queued > active --> J: queued > active - task := taskWithStatus(api.JobStatusQueued, api.TaskStatusQueued) - mocks.expectSaveTaskWithStatus(t, task, api.TaskStatusActive) - mocks.expectWriteTaskLogTimestamped(t, task, "task changed status queued -> active") - mocks.expectSaveJobWithStatus(t, task.Job, api.JobStatusActive) - mocks.expectBroadcastJobChange(task.Job, api.JobStatusQueued, api.JobStatusActive) - mocks.expectBroadcastTaskChange(task, api.TaskStatusQueued, api.TaskStatusActive) + task, job := taskWithStatus(api.JobStatusQueued, api.TaskStatusQueued) + mocks.expectSaveTaskWithStatus(t, task, api.TaskStatusActive) + mocks.expectWriteTaskLogTimestamped(t, task, job.UUID, "task changed status queued -> active") + mocks.expectSaveJobWithStatus(t, job, api.JobStatusActive) + mocks.expectBroadcastJobChange(job, api.JobStatusQueued, api.JobStatusActive) + mocks.expectBroadcastTaskChange(task, job.UUID, api.TaskStatusQueued, api.TaskStatusActive) + + mocks.expectFetchJobOfTask(task, job) require.NoError(t, sm.TaskStatusChange(ctx, task, api.TaskStatusActive)) } @@ -46,17 +48,18 @@ func TestTaskStatusChangeSaveTaskAfterJobChangeFailure(t *testing.T) { defer mockCtrl.Finish() // A task status change should be saved, even when triggering the job change errors somehow. - task := taskWithStatus(api.JobStatusQueued, api.TaskStatusQueued) + task, job := taskWithStatus(api.JobStatusQueued, api.TaskStatusQueued) jobSaveErr := errors.New("hypothetical job save error") mocks.persist.EXPECT(). - SaveJobStatus(gomock.Any(), task.Job). + SaveJobStatus(gomock.Any(), job). Return(jobSaveErr) // Expect a call to save the task in the persistence layer, regardless of the above error. mocks.expectSaveTaskWithStatus(t, task, api.TaskStatusActive) - mocks.expectWriteTaskLogTimestamped(t, task, "task changed status queued -> active") - mocks.expectBroadcastTaskChange(task, api.TaskStatusQueued, api.TaskStatusActive) + mocks.expectWriteTaskLogTimestamped(t, task, job.UUID, "task changed status queued -> active") + mocks.expectBroadcastTaskChange(task, job.UUID, api.TaskStatusQueued, api.TaskStatusActive) + mocks.expectFetchJobOfTask(task, job) returnedErr := sm.TaskStatusChange(ctx, task, api.TaskStatusActive) assert.ErrorIs(t, returnedErr, jobSaveErr, "the returned error should wrap the persistence layer error") @@ -67,37 +70,42 @@ func TestTaskStatusChangeActiveToCompleted(t *testing.T) { defer mockCtrl.Finish() // Job has three tasks. - task := taskWithStatus(api.JobStatusActive, api.TaskStatusActive) - task2 := taskOfSameJob(task, api.TaskStatusActive) - task3 := taskOfSameJob(task, api.TaskStatusActive) + task1, job := taskWithStatus(api.JobStatusActive, api.TaskStatusActive) + + task2 := taskOfSameJob(task1, api.TaskStatusActive) + task3 := taskOfSameJob(task1, api.TaskStatusActive) // First task completing: T: active > completed --> J: active > active - mocks.expectSaveTaskWithStatus(t, task, api.TaskStatusCompleted) - mocks.expectWriteTaskLogTimestamped(t, task, "task changed status active -> completed") - mocks.expectBroadcastTaskChange(task, api.TaskStatusActive, api.TaskStatusCompleted) - mocks.persist.EXPECT().CountTasksOfJobInStatus(ctx, task.Job, api.TaskStatusCompleted).Return(1, 3, nil) // 1 of 3 complete. - require.NoError(t, sm.TaskStatusChange(ctx, task, api.TaskStatusCompleted)) + mocks.expectSaveTaskWithStatus(t, task1, api.TaskStatusCompleted) + mocks.expectWriteTaskLogTimestamped(t, task1, job.UUID, "task changed status active -> completed") + mocks.expectBroadcastTaskChange(task1, job.UUID, api.TaskStatusActive, api.TaskStatusCompleted) + mocks.persist.EXPECT().CountTasksOfJobInStatus(ctx, job, api.TaskStatusCompleted).Return(1, 3, nil) // 1 of 3 complete. + mocks.expectFetchJobOfTask(task1, job) + require.NoError(t, sm.TaskStatusChange(ctx, task1, api.TaskStatusCompleted)) // Second task hickup: T: active > soft-failed --> J: active > active mocks.expectSaveTaskWithStatus(t, task2, api.TaskStatusSoftFailed) - mocks.expectWriteTaskLogTimestamped(t, task2, "task changed status active -> soft-failed") - mocks.expectBroadcastTaskChange(task2, api.TaskStatusActive, api.TaskStatusSoftFailed) + mocks.expectWriteTaskLogTimestamped(t, task2, job.UUID, "task changed status active -> soft-failed") + mocks.expectBroadcastTaskChange(task2, job.UUID, api.TaskStatusActive, api.TaskStatusSoftFailed) + mocks.expectFetchJobOfTask(task2, job) require.NoError(t, sm.TaskStatusChange(ctx, task2, api.TaskStatusSoftFailed)) // Second task completing: T: soft-failed > completed --> J: active > active mocks.expectSaveTaskWithStatus(t, task2, api.TaskStatusCompleted) - mocks.expectWriteTaskLogTimestamped(t, task2, "task changed status soft-failed -> completed") - mocks.expectBroadcastTaskChange(task2, api.TaskStatusSoftFailed, api.TaskStatusCompleted) - mocks.persist.EXPECT().CountTasksOfJobInStatus(ctx, task.Job, api.TaskStatusCompleted).Return(2, 3, nil) // 2 of 3 complete. + mocks.expectWriteTaskLogTimestamped(t, task2, job.UUID, "task changed status soft-failed -> completed") + mocks.expectBroadcastTaskChange(task2, job.UUID, api.TaskStatusSoftFailed, api.TaskStatusCompleted) + mocks.persist.EXPECT().CountTasksOfJobInStatus(ctx, job, api.TaskStatusCompleted).Return(2, 3, nil) // 2 of 3 complete. + mocks.expectFetchJobOfTask(task2, job) require.NoError(t, sm.TaskStatusChange(ctx, task2, api.TaskStatusCompleted)) // Third task completing: T: active > completed --> J: active > completed mocks.expectSaveTaskWithStatus(t, task3, api.TaskStatusCompleted) - mocks.expectWriteTaskLogTimestamped(t, task3, "task changed status active -> completed") - mocks.expectBroadcastTaskChange(task3, api.TaskStatusActive, api.TaskStatusCompleted) - mocks.persist.EXPECT().CountTasksOfJobInStatus(ctx, task.Job, api.TaskStatusCompleted).Return(3, 3, nil) // 3 of 3 complete. - mocks.expectSaveJobWithStatus(t, task.Job, api.JobStatusCompleted) - mocks.expectBroadcastJobChange(task.Job, api.JobStatusActive, api.JobStatusCompleted) + mocks.expectWriteTaskLogTimestamped(t, task3, job.UUID, "task changed status active -> completed") + mocks.expectBroadcastTaskChange(task3, job.UUID, api.TaskStatusActive, api.TaskStatusCompleted) + mocks.persist.EXPECT().CountTasksOfJobInStatus(ctx, job, api.TaskStatusCompleted).Return(3, 3, nil) // 3 of 3 complete. + mocks.expectSaveJobWithStatus(t, job, api.JobStatusCompleted) + mocks.expectBroadcastJobChange(job, api.JobStatusActive, api.JobStatusCompleted) + mocks.expectFetchJobOfTask(task3, job) require.NoError(t, sm.TaskStatusChange(ctx, task3, api.TaskStatusCompleted)) } @@ -107,13 +115,15 @@ func TestTaskStatusChangeQueuedToFailed(t *testing.T) { defer mockCtrl.Finish() // T: queued > failed (1% task failure) --> J: queued > active - task := taskWithStatus(api.JobStatusQueued, api.TaskStatusQueued) + task, job := taskWithStatus(api.JobStatusQueued, api.TaskStatusQueued) + mocks.expectSaveTaskWithStatus(t, task, api.TaskStatusFailed) - mocks.expectWriteTaskLogTimestamped(t, task, "task changed status queued -> failed") - mocks.expectBroadcastTaskChange(task, api.TaskStatusQueued, api.TaskStatusFailed) - mocks.expectSaveJobWithStatus(t, task.Job, api.JobStatusActive) - mocks.persist.EXPECT().CountTasksOfJobInStatus(ctx, task.Job, api.TaskStatusFailed).Return(1, 100, nil) // 1 out of 100 failed. - mocks.expectBroadcastJobChange(task.Job, api.JobStatusQueued, api.JobStatusActive) + mocks.expectWriteTaskLogTimestamped(t, task, job.UUID, "task changed status queued -> failed") + mocks.expectBroadcastTaskChange(task, job.UUID, api.TaskStatusQueued, api.TaskStatusFailed) + mocks.expectSaveJobWithStatus(t, job, api.JobStatusActive) + mocks.persist.EXPECT().CountTasksOfJobInStatus(ctx, job, api.TaskStatusFailed).Return(1, 100, nil) // 1 out of 100 failed. + mocks.expectBroadcastJobChange(job, api.JobStatusQueued, api.JobStatusActive) + mocks.expectFetchJobOfTask(task, job) require.NoError(t, sm.TaskStatusChange(ctx, task, api.TaskStatusFailed)) } @@ -123,16 +133,16 @@ func TestTaskStatusChangeActiveToFailedFailJob(t *testing.T) { defer mockCtrl.Finish() // T: active > failed (10% task1 failure) --> J: active > failed + cancellation of any runnable tasks. - task1 := taskWithStatus(api.JobStatusActive, api.TaskStatusActive) + task1, job := taskWithStatus(api.JobStatusActive, api.TaskStatusActive) mocks.expectSaveTaskWithStatus(t, task1, api.TaskStatusFailed) - mocks.expectWriteTaskLogTimestamped(t, task1, "task changed status active -> failed") + mocks.expectWriteTaskLogTimestamped(t, task1, job.UUID, "task changed status active -> failed") // The change to the failed task should be broadcast. - mocks.expectBroadcastTaskChange(task1, api.TaskStatusActive, api.TaskStatusFailed) - mocks.expectSaveJobWithStatus(t, task1.Job, api.JobStatusFailed) + mocks.expectBroadcastTaskChange(task1, job.UUID, api.TaskStatusActive, api.TaskStatusFailed) + mocks.expectSaveJobWithStatus(t, job, api.JobStatusFailed) // The resulting cancellation of the other tasks should be communicated as mass-task-update in the job update broadcast. - mocks.expectBroadcastJobChangeWithTaskRefresh(task1.Job, api.JobStatusActive, api.JobStatusFailed) + mocks.expectBroadcastJobChangeWithTaskRefresh(job, api.JobStatusActive, api.JobStatusFailed) - mocks.persist.EXPECT().CountTasksOfJobInStatus(ctx, task1.Job, api.TaskStatusFailed).Return(10, 100, nil) // 10 out of 100 failed. + mocks.persist.EXPECT().CountTasksOfJobInStatus(ctx, job, api.TaskStatusFailed).Return(10, 100, nil) // 10 out of 100 failed. // Expect failure of the job to trigger cancellation of remaining tasks. taskStatusesToCancel := []api.TaskStatus{ @@ -141,9 +151,10 @@ func TestTaskStatusChangeActiveToFailedFailJob(t *testing.T) { api.TaskStatusSoftFailed, } - mocks.persist.EXPECT().UpdateJobsTaskStatusesConditional(ctx, task1.Job, taskStatusesToCancel, api.TaskStatusCanceled, + mocks.persist.EXPECT().UpdateJobsTaskStatusesConditional(ctx, job, taskStatusesToCancel, api.TaskStatusCanceled, "Manager cancelled this task because the job got status \"failed\".", ) + mocks.expectFetchJobOfTask(task1, job) require.NoError(t, sm.TaskStatusChange(ctx, task1, api.TaskStatusFailed)) } @@ -153,21 +164,23 @@ func TestTaskStatusChangeRequeueOnCompletedJob(t *testing.T) { defer mockCtrl.Finish() // T: completed > queued --> J: completed > requeueing > queued - task := taskWithStatus(api.JobStatusCompleted, api.TaskStatusCompleted) + task, job := taskWithStatus(api.JobStatusCompleted, api.TaskStatusCompleted) + mocks.expectSaveTaskWithStatus(t, task, api.TaskStatusQueued) - mocks.expectWriteTaskLogTimestamped(t, task, "task changed status completed -> queued") - mocks.expectBroadcastTaskChange(task, api.TaskStatusCompleted, api.TaskStatusQueued) - mocks.expectSaveJobWithStatus(t, task.Job, api.JobStatusRequeueing) - mocks.expectBroadcastJobChangeWithTaskRefresh(task.Job, api.JobStatusCompleted, api.JobStatusRequeueing) - mocks.expectBroadcastJobChangeWithTaskRefresh(task.Job, api.JobStatusRequeueing, api.JobStatusQueued) + mocks.expectWriteTaskLogTimestamped(t, task, job.UUID, "task changed status completed -> queued") + mocks.expectBroadcastTaskChange(task, job.UUID, api.TaskStatusCompleted, api.TaskStatusQueued) + mocks.expectSaveJobWithStatus(t, job, api.JobStatusRequeueing) + mocks.expectBroadcastJobChangeWithTaskRefresh(job, api.JobStatusCompleted, api.JobStatusRequeueing) + mocks.expectBroadcastJobChangeWithTaskRefresh(job, api.JobStatusRequeueing, api.JobStatusQueued) // Expect queueing of the job to trigger queueing of all its tasks, if those tasks were all completed before. // 2 out of 3 completed, because one was just queued. - mocks.persist.EXPECT().CountTasksOfJobInStatus(ctx, task.Job, api.TaskStatusCompleted).Return(2, 3, nil) - mocks.persist.EXPECT().UpdateJobsTaskStatuses(ctx, task.Job, api.TaskStatusQueued, + mocks.persist.EXPECT().CountTasksOfJobInStatus(ctx, job, api.TaskStatusCompleted).Return(2, 3, nil) + mocks.persist.EXPECT().UpdateJobsTaskStatuses(ctx, job, api.TaskStatusQueued, "Queued because job transitioned status from \"completed\" to \"requeueing\"", ) - mocks.expectSaveJobWithStatus(t, task.Job, api.JobStatusQueued) + mocks.expectSaveJobWithStatus(t, job, api.JobStatusQueued) + mocks.expectFetchJobOfTask(task, job) require.NoError(t, sm.TaskStatusChange(ctx, task, api.TaskStatusQueued)) } @@ -176,28 +189,29 @@ func TestTaskStatusChangeCancelSingleTask(t *testing.T) { mockCtrl, ctx, sm, mocks := taskStateMachineTestFixtures(t) defer mockCtrl.Finish() - task := taskWithStatus(api.JobStatusCancelRequested, api.TaskStatusActive) + task, job := taskWithStatus(api.JobStatusCancelRequested, api.TaskStatusActive) task2 := taskOfSameJob(task, api.TaskStatusQueued) - job := task.Job // T1: active > cancelled --> J: cancel-requested > cancel-requested mocks.expectSaveTaskWithStatus(t, task, api.TaskStatusCanceled) - mocks.expectWriteTaskLogTimestamped(t, task, "task changed status active -> canceled") - mocks.expectBroadcastTaskChange(task, api.TaskStatusActive, api.TaskStatusCanceled) + mocks.expectWriteTaskLogTimestamped(t, task, job.UUID, "task changed status active -> canceled") + mocks.expectBroadcastTaskChange(task, job.UUID, api.TaskStatusActive, api.TaskStatusCanceled) mocks.persist.EXPECT().CountTasksOfJobInStatus(ctx, job, api.TaskStatusActive, api.TaskStatusQueued, api.TaskStatusSoftFailed). Return(1, 2, nil) + mocks.expectFetchJobOfTask(task, job) require.NoError(t, sm.TaskStatusChange(ctx, task, api.TaskStatusCanceled)) // T2: queued > cancelled --> J: cancel-requested > canceled mocks.expectSaveTaskWithStatus(t, task2, api.TaskStatusCanceled) - mocks.expectWriteTaskLogTimestamped(t, task2, "task changed status queued -> canceled") - mocks.expectBroadcastTaskChange(task2, api.TaskStatusQueued, api.TaskStatusCanceled) + mocks.expectWriteTaskLogTimestamped(t, task2, job.UUID, "task changed status queued -> canceled") + mocks.expectBroadcastTaskChange(task2, job.UUID, api.TaskStatusQueued, api.TaskStatusCanceled) mocks.persist.EXPECT().CountTasksOfJobInStatus(ctx, job, api.TaskStatusActive, api.TaskStatusQueued, api.TaskStatusSoftFailed). Return(0, 2, nil) mocks.expectSaveJobWithStatus(t, job, api.JobStatusCanceled) - mocks.expectBroadcastJobChange(task.Job, api.JobStatusCancelRequested, api.JobStatusCanceled) + mocks.expectBroadcastJobChange(job, api.JobStatusCancelRequested, api.JobStatusCanceled) + mocks.expectFetchJobOfTask(task2, job) require.NoError(t, sm.TaskStatusChange(ctx, task2, api.TaskStatusCanceled)) } @@ -206,20 +220,20 @@ func TestTaskStatusChangeCancelSingleTaskWithOtherFailed(t *testing.T) { mockCtrl, ctx, sm, mocks := taskStateMachineTestFixtures(t) defer mockCtrl.Finish() - task1 := taskWithStatus(api.JobStatusCancelRequested, api.TaskStatusActive) + task1, job := taskWithStatus(api.JobStatusCancelRequested, api.TaskStatusActive) task2 := taskOfSameJob(task1, api.TaskStatusFailed) taskOfSameJob(task2, api.TaskStatusPaused) - job := task1.Job // T1: active > cancelled --> J: cancel-requested > canceled because T2 already failed and cannot run anyway. mocks.expectSaveTaskWithStatus(t, task1, api.TaskStatusCanceled) - mocks.expectWriteTaskLogTimestamped(t, task1, "task changed status active -> canceled") - mocks.expectBroadcastTaskChange(task1, api.TaskStatusActive, api.TaskStatusCanceled) + mocks.expectWriteTaskLogTimestamped(t, task1, job.UUID, "task changed status active -> canceled") + mocks.expectBroadcastTaskChange(task1, job.UUID, api.TaskStatusActive, api.TaskStatusCanceled) mocks.persist.EXPECT().CountTasksOfJobInStatus(ctx, job, api.TaskStatusActive, api.TaskStatusQueued, api.TaskStatusSoftFailed). Return(0, 3, nil) mocks.expectSaveJobWithStatus(t, job, api.JobStatusCanceled) - mocks.expectBroadcastJobChange(task1.Job, api.JobStatusCancelRequested, api.JobStatusCanceled) + mocks.expectBroadcastJobChange(job, api.JobStatusCancelRequested, api.JobStatusCanceled) + mocks.expectFetchJobOfTask(task1, job) // The canceled task just stays canceled, so don't expectBroadcastTaskChange(task3). @@ -231,10 +245,12 @@ func TestTaskStatusChangeUnknownStatus(t *testing.T) { defer mockCtrl.Finish() // T: queued > borked --> saved to DB but otherwise ignored w.r.t. job status changes. - task := taskWithStatus(api.JobStatusQueued, api.TaskStatusQueued) + task, job := taskWithStatus(api.JobStatusQueued, api.TaskStatusQueued) + mocks.expectSaveTaskWithStatus(t, task, api.TaskStatus("borked")) - mocks.expectWriteTaskLogTimestamped(t, task, "task changed status queued -> borked") - mocks.expectBroadcastTaskChange(task, api.TaskStatusQueued, api.TaskStatus("borked")) + mocks.expectWriteTaskLogTimestamped(t, task, job.UUID, "task changed status queued -> borked") + mocks.expectBroadcastTaskChange(task, job.UUID, api.TaskStatusQueued, api.TaskStatus("borked")) + mocks.expectFetchJobOfTask(task, job) require.NoError(t, sm.TaskStatusChange(ctx, task, api.TaskStatus("borked"))) } @@ -243,12 +259,11 @@ func TestJobRequeueWithSomeCompletedTasks(t *testing.T) { mockCtrl, ctx, sm, mocks := taskStateMachineTestFixtures(t) defer mockCtrl.Finish() - task1 := taskWithStatus(api.JobStatusActive, api.TaskStatusCompleted) + _, job := taskWithStatus(api.JobStatusActive, api.TaskStatusCompleted) // These are not necessary to create for this test, but just imagine these tasks are there too. // This is mimicked by returning (1, 3, nil) when counting the tasks (1 of 3 completed). // task2 := taskOfSameJob(task1, api.TaskStatusFailed) // task3 := taskOfSameJob(task2, api.TaskStatusSoftFailed) - job := task1.Job mocks.expectSaveJobWithStatus(t, job, api.JobStatusRequeueing) @@ -270,19 +285,18 @@ func TestJobRequeueWithSomeCompletedTasks(t *testing.T) { mocks.expectBroadcastJobChangeWithTaskRefresh(job, api.JobStatusActive, api.JobStatusRequeueing) mocks.expectBroadcastJobChangeWithTaskRefresh(job, api.JobStatusRequeueing, api.JobStatusQueued) - require.NoError(t, sm.JobStatusChange(ctx, job, api.JobStatusRequeueing, "someone wrote a unittest")) + require.NoError(t, sm.jobStatusChange(ctx, job, api.JobStatusRequeueing, "someone wrote a unittest")) } func TestJobRequeueWithAllCompletedTasks(t *testing.T) { mockCtrl, ctx, sm, mocks := taskStateMachineTestFixtures(t) defer mockCtrl.Finish() - task1 := taskWithStatus(api.JobStatusCompleted, api.TaskStatusCompleted) + _, job := taskWithStatus(api.JobStatusCompleted, api.TaskStatusCompleted) // These are not necessary to create for this test, but just imagine these tasks are there too. // This is mimicked by returning (3, 3, nil) when counting the tasks (3 of 3 completed). // task2 := taskOfSameJob(task1, api.TaskStatusCompleted) // task3 := taskOfSameJob(task2, api.TaskStatusCompleted) - job := task1.Job call1 := mocks.expectSaveJobWithStatus(t, job, api.JobStatusRequeueing) @@ -302,17 +316,16 @@ func TestJobRequeueWithAllCompletedTasks(t *testing.T) { mocks.expectBroadcastJobChangeWithTaskRefresh(job, api.JobStatusCompleted, api.JobStatusRequeueing) mocks.expectBroadcastJobChangeWithTaskRefresh(job, api.JobStatusRequeueing, api.JobStatusQueued) - require.NoError(t, sm.JobStatusChange(ctx, job, api.JobStatusRequeueing, "someone wrote a unit test")) + require.NoError(t, sm.jobStatusChange(ctx, job, api.JobStatusRequeueing, "someone wrote a unit test")) } func TestJobCancelWithSomeCompletedTasks(t *testing.T) { mockCtrl, ctx, sm, mocks := taskStateMachineTestFixtures(t) defer mockCtrl.Finish() - task1 := taskWithStatus(api.JobStatusActive, api.TaskStatusCompleted) + _, job := taskWithStatus(api.JobStatusActive, api.TaskStatusCompleted) // task2 := taskOfSameJob(task1, api.TaskStatusFailed) // task3 := taskOfSameJob(task2, api.TaskStatusSoftFailed) - job := task1.Job mocks.expectSaveJobWithStatus(t, job, api.JobStatusCancelRequested) @@ -333,17 +346,17 @@ func TestJobCancelWithSomeCompletedTasks(t *testing.T) { mocks.expectBroadcastJobChangeWithTaskRefresh(job, api.JobStatusActive, api.JobStatusCancelRequested) mocks.expectBroadcastJobChange(job, api.JobStatusCancelRequested, api.JobStatusCanceled) - require.NoError(t, sm.JobStatusChange(ctx, job, api.JobStatusCancelRequested, "someone wrote a unittest")) + require.NoError(t, sm.jobStatusChange(ctx, job, api.JobStatusCancelRequested, "someone wrote a unittest")) } func TestJobPauseWithAllQueuedTasks(t *testing.T) { mockCtrl, ctx, sm, mocks := taskStateMachineTestFixtures(t) defer mockCtrl.Finish() - task1 := taskWithStatus(api.JobStatusQueued, api.TaskStatusQueued) + task1, job := taskWithStatus(api.JobStatusQueued, api.TaskStatusQueued) task2 := taskOfSameJob(task1, api.TaskStatusQueued) task3 := taskOfSameJob(task2, api.TaskStatusQueued) - job := task3.Job + _ = task3 mocks.expectSaveJobWithStatus(t, job, api.JobStatusPauseRequested) @@ -363,17 +376,17 @@ func TestJobPauseWithAllQueuedTasks(t *testing.T) { mocks.expectBroadcastJobChangeWithTaskRefresh(job, api.JobStatusQueued, api.JobStatusPauseRequested) mocks.expectBroadcastJobChange(job, api.JobStatusPauseRequested, api.JobStatusPaused) - require.NoError(t, sm.JobStatusChange(ctx, job, api.JobStatusPauseRequested, "someone wrote a unittest")) + require.NoError(t, sm.jobStatusChange(ctx, job, api.JobStatusPauseRequested, "someone wrote a unittest")) } func TestJobPauseWithSomeCompletedTasks(t *testing.T) { mockCtrl, ctx, sm, mocks := taskStateMachineTestFixtures(t) defer mockCtrl.Finish() - task1 := taskWithStatus(api.JobStatusQueued, api.TaskStatusCompleted) + task1, job := taskWithStatus(api.JobStatusQueued, api.TaskStatusCompleted) task2 := taskOfSameJob(task1, api.TaskStatusQueued) task3 := taskOfSameJob(task2, api.TaskStatusQueued) - job := task3.Job + _ = task3 mocks.expectSaveJobWithStatus(t, job, api.JobStatusPauseRequested) @@ -393,17 +406,17 @@ func TestJobPauseWithSomeCompletedTasks(t *testing.T) { mocks.expectBroadcastJobChangeWithTaskRefresh(job, api.JobStatusQueued, api.JobStatusPauseRequested) mocks.expectBroadcastJobChange(job, api.JobStatusPauseRequested, api.JobStatusPaused) - require.NoError(t, sm.JobStatusChange(ctx, job, api.JobStatusPauseRequested, "someone wrote a unittest")) + require.NoError(t, sm.jobStatusChange(ctx, job, api.JobStatusPauseRequested, "someone wrote a unittest")) } func TestJobPauseWithSomeActiveTasks(t *testing.T) { mockCtrl, ctx, sm, mocks := taskStateMachineTestFixtures(t) defer mockCtrl.Finish() - task1 := taskWithStatus(api.JobStatusActive, api.TaskStatusActive) + task1, job := taskWithStatus(api.JobStatusActive, api.TaskStatusActive) task2 := taskOfSameJob(task1, api.TaskStatusCompleted) task3 := taskOfSameJob(task2, api.TaskStatusQueued) - job := task3.Job + _ = task3 mocks.expectSaveJobWithStatus(t, job, api.JobStatusPauseRequested) @@ -421,17 +434,17 @@ func TestJobPauseWithSomeActiveTasks(t *testing.T) { Return(1, 3, nil) mocks.expectBroadcastJobChangeWithTaskRefresh(job, api.JobStatusActive, api.JobStatusPauseRequested) - require.NoError(t, sm.JobStatusChange(ctx, job, api.JobStatusPauseRequested, "someone wrote a unittest")) + require.NoError(t, sm.jobStatusChange(ctx, job, api.JobStatusPauseRequested, "someone wrote a unittest")) } func TestCheckStuck(t *testing.T) { mockCtrl, ctx, sm, mocks := taskStateMachineTestFixtures(t) defer mockCtrl.Finish() - task1 := taskWithStatus(api.JobStatusActive, api.TaskStatusCompleted) + task1, job := taskWithStatus(api.JobStatusActive, api.TaskStatusCompleted) + _ = task1 // task2 := taskOfSameJob(task1, api.TaskStatusFailed) // task3 := taskOfSameJob(task2, api.TaskStatusSoftFailed) - job := task1.Job job.Status = api.JobStatusRequeueing mocks.persist.EXPECT().FetchJobsInStatus(ctx, api.JobStatusCancelRequested, api.JobStatusRequeueing). @@ -469,6 +482,13 @@ func mockedTaskStateMachine(mockCtrl *gomock.Controller) (*StateMachine, *StateM return sm, &mocks } +func (m *StateMachineMocks) expectFetchJobOfTask( + task *persistence.Task, + jobToReturn *persistence.Job, +) *gomock.Call { + return m.persist.EXPECT().FetchJobByID(gomock.Any(), task.JobID).Return(jobToReturn, nil) +} + func (m *StateMachineMocks) expectSaveTaskWithStatus( t *testing.T, task *persistence.Task, @@ -484,10 +504,11 @@ func (m *StateMachineMocks) expectSaveTaskWithStatus( func (m *StateMachineMocks) expectWriteTaskLogTimestamped( t *testing.T, task *persistence.Task, + jobUUID string, logtext string, ) *gomock.Call { return m.logStorage.EXPECT().WriteTimestamped( - gomock.Any(), task.Job.UUID, task.UUID, logtext, + gomock.Any(), jobUUID, task.UUID, logtext, ) } @@ -514,7 +535,7 @@ func (m *StateMachineMocks) expectBroadcastJobChange( PreviousStatus: &fromStatus, RefreshTasks: false, Status: toStatus, - Updated: job.UpdatedAt, + Updated: job.UpdatedAt.Time, } return m.broadcaster.EXPECT().BroadcastJobUpdate(expectUpdate) } @@ -529,20 +550,21 @@ func (m *StateMachineMocks) expectBroadcastJobChangeWithTaskRefresh( PreviousStatus: &fromStatus, RefreshTasks: true, Status: toStatus, - Updated: job.UpdatedAt, + Updated: job.UpdatedAt.Time, } return m.broadcaster.EXPECT().BroadcastJobUpdate(expectUpdate) } func (m *StateMachineMocks) expectBroadcastTaskChange( task *persistence.Task, + jobUUID string, fromStatus, toStatus api.TaskStatus, ) *gomock.Call { expectUpdate := api.EventTaskUpdate{ Id: task.UUID, - JobId: task.Job.UUID, + JobId: jobUUID, Name: task.Name, - Updated: task.UpdatedAt, + Updated: task.UpdatedAt.Time, PreviousStatus: &fromStatus, Status: toStatus, } @@ -550,37 +572,31 @@ func (m *StateMachineMocks) expectBroadcastTaskChange( } /* taskWithStatus() creates a task of a certain status, with a job of a certain status. */ -func taskWithStatus(jobStatus api.JobStatus, taskStatus api.TaskStatus) *persistence.Task { +func taskWithStatus(jobStatus api.JobStatus, taskStatus api.TaskStatus) (*persistence.Task, *persistence.Job) { job := persistence.Job{ - Model: persistence.Model{ID: 47}, - UUID: "test-job-f3f5-4cef-9cd7-e67eb28eaf3e", + ID: 47, + UUID: "test-job-f3f5-4cef-9cd7-e67eb28eaf3e", Status: jobStatus, } task := persistence.Task{ - Model: persistence.Model{ID: 327}, - UUID: "testtask-0001-4e28-aeea-8cbaf2fc96a5", - - JobUUID: job.UUID, - JobID: job.ID, - Job: &job, - + ID: 327, + JobID: job.ID, + UUID: "testtask-0001-4e28-aeea-8cbaf2fc96a5", Status: taskStatus, } - return &task + return &task, &job } /* taskOfSameJob() creates a task of a certain status, on the same job as the given task. */ func taskOfSameJob(task *persistence.Task, taskStatus api.TaskStatus) *persistence.Task { newTaskID := task.ID + 1 return &persistence.Task{ - Model: persistence.Model{ID: newTaskID}, - UUID: fmt.Sprintf("testtask-%04d-4e28-aeea-8cbaf2fc96a5", newTaskID), - JobUUID: task.JobUUID, - JobID: task.JobID, - Job: task.Job, - Status: taskStatus, + ID: newTaskID, + UUID: fmt.Sprintf("testtask-%04d-4e28-aeea-8cbaf2fc96a5", newTaskID), + JobID: task.JobID, + Status: taskStatus, } } diff --git a/internal/manager/task_state_machine/worker_requeue.go b/internal/manager/task_state_machine/worker_requeue.go index 0c14c1a8..81392290 100644 --- a/internal/manager/task_state_machine/worker_requeue.go +++ b/internal/manager/task_state_machine/worker_requeue.go @@ -19,13 +19,19 @@ func (sm *StateMachine) RequeueActiveTasksOfWorker( reason string, ) error { // Fetch the tasks to update. - tasks, err := sm.persist.FetchTasksOfWorkerInStatus( + tasksJobs, err := sm.persist.FetchTasksOfWorkerInStatus( ctx, worker, api.TaskStatusActive) if err != nil { return err } - return sm.requeueTasksOfWorker(ctx, tasks, worker, reason) + // Run each task change through the task state machine. + var lastErr error + for _, taskJobWorker := range tasksJobs { + lastErr = sm.requeueTaskOfWorker(ctx, &taskJobWorker.Task, taskJobWorker.JobUUID, worker, reason) + } + + return lastErr } // RequeueFailedTasksOfWorkerOfJob re-queues all failed tasks of this worker on this job. @@ -34,22 +40,29 @@ func (sm *StateMachine) RequeueActiveTasksOfWorker( func (sm *StateMachine) RequeueFailedTasksOfWorkerOfJob( ctx context.Context, worker *persistence.Worker, - job *persistence.Job, + jobUUID string, reason string, ) error { // Fetch the tasks to update. tasks, err := sm.persist.FetchTasksOfWorkerInStatusOfJob( - ctx, worker, api.TaskStatusFailed, job) + ctx, worker, api.TaskStatusFailed, jobUUID) if err != nil { return err } - return sm.requeueTasksOfWorker(ctx, tasks, worker, reason) + // Run each task change through the task state machine. + var lastErr error + for _, task := range tasks { + lastErr = sm.requeueTaskOfWorker(ctx, task, jobUUID, worker, reason) + } + + return lastErr } -func (sm *StateMachine) requeueTasksOfWorker( +func (sm *StateMachine) requeueTaskOfWorker( ctx context.Context, - tasks []*persistence.Task, + task *persistence.Task, + jobUUID string, worker *persistence.Worker, reason string, ) error { @@ -58,35 +71,30 @@ func (sm *StateMachine) requeueTasksOfWorker( Str("reason", reason). Logger() - // Run each task change through the task state machine. - var lastErr error - for _, task := range tasks { - logger.Info(). + logger.Info(). + Str("task", task.UUID). + Msg("re-queueing task") + + // Write to task activity that it got requeued because of worker sign-off. + task.Activity = "Task was requeued by Manager because " + reason + if err := sm.persist.SaveTaskActivity(ctx, task); err != nil { + logger.Warn().Err(err). Str("task", task.UUID). - Msg("re-queueing task") - - // Write to task activity that it got requeued because of worker sign-off. - task.Activity = "Task was requeued by Manager because " + reason - if err := sm.persist.SaveTaskActivity(ctx, task); err != nil { - logger.Warn().Err(err). - Str("task", task.UUID). - Str("reason", reason). - Str("activity", task.Activity). - Msg("error saving task activity to database") - lastErr = err - } - - if err := sm.TaskStatusChange(ctx, task, api.TaskStatusQueued); err != nil { - logger.Warn().Err(err). - Str("task", task.UUID). - Str("reason", reason). - Msg("error queueing task") - lastErr = err - } - - // The error is already logged by the log storage. - _ = sm.logStorage.WriteTimestamped(logger, task.Job.UUID, task.UUID, task.Activity) + Str("reason", reason). + Str("activity", task.Activity). + Msg("error saving task activity to database") } - return lastErr + err := sm.TaskStatusChange(ctx, task, api.TaskStatusQueued) + if err != nil { + logger.Warn().Err(err). + Str("task", task.UUID). + Str("reason", reason). + Msg("error queueing task") + } + + // The error is already logged by the log storage. + _ = sm.logStorage.WriteTimestamped(logger, jobUUID, task.UUID, task.Activity) + + return err } diff --git a/internal/manager/task_state_machine/worker_requeue_test.go b/internal/manager/task_state_machine/worker_requeue_test.go index 2cef414b..02b16b04 100644 --- a/internal/manager/task_state_machine/worker_requeue_test.go +++ b/internal/manager/task_state_machine/worker_requeue_test.go @@ -22,9 +22,12 @@ func TestRequeueActiveTasksOfWorker(t *testing.T) { // Mock that the worker has two active tasks. It shouldn't happen, but even // when it does, both should be requeued when the worker signs off. - task1 := taskWithStatus(api.JobStatusActive, api.TaskStatusActive) + task1, job := taskWithStatus(api.JobStatusActive, api.TaskStatusActive) task2 := taskOfSameJob(task1, api.TaskStatusActive) - workerTasks := []*persistence.Task{task1, task2} + workerTasks := []persistence.TaskJob{ + {Task: *task1, JobUUID: job.UUID}, + {Task: *task2, JobUUID: job.UUID}, + } task1PrevStatus := task1.Status task2PrevStatus := task2.Status @@ -32,39 +35,50 @@ func TestRequeueActiveTasksOfWorker(t *testing.T) { mocks.persist.EXPECT().FetchTasksOfWorkerInStatus(ctx, &worker, api.TaskStatusActive).Return(workerTasks, nil) // Expect this re-queueing to end up in the task's log and activity. - mocks.persist.EXPECT().SaveTaskActivity(ctx, task1) // TODO: test saved activity value - mocks.persist.EXPECT().SaveTaskActivity(ctx, task2) // TODO: test saved activity value - mocks.persist.EXPECT().SaveTaskStatus(ctx, task1) // TODO: test saved task status - mocks.persist.EXPECT().SaveTaskStatus(ctx, task2) // TODO: test saved task status - logMsg1 := "task changed status active -> queued" - mocks.logStorage.EXPECT().WriteTimestamped(gomock.Any(), task1.Job.UUID, task1.UUID, logMsg1) - mocks.logStorage.EXPECT().WriteTimestamped(gomock.Any(), task2.Job.UUID, task2.UUID, logMsg1) - logMsg2 := "Task was requeued by Manager because worker had to test" - mocks.logStorage.EXPECT().WriteTimestamped(gomock.Any(), task1.Job.UUID, task1.UUID, logMsg2) - mocks.logStorage.EXPECT().WriteTimestamped(gomock.Any(), task2.Job.UUID, task2.UUID, logMsg2) + task1WithActivity := *task1 + task1WithActivity.Activity = logMsg2 + task2WithActivity := *task2 + task2WithActivity.Activity = logMsg2 + task1WithActivityAndStatus := task1WithActivity + task1WithActivityAndStatus.Status = api.TaskStatusQueued + task2WithActivityAndStatus := task2WithActivity + task2WithActivityAndStatus.Status = api.TaskStatusQueued + mocks.persist.EXPECT().SaveTaskActivity(ctx, &task1WithActivity) + mocks.persist.EXPECT().SaveTaskActivity(ctx, &task2WithActivity) + mocks.persist.EXPECT().SaveTaskStatus(ctx, &task1WithActivityAndStatus) + mocks.persist.EXPECT().SaveTaskStatus(ctx, &task2WithActivityAndStatus) + + mocks.logStorage.EXPECT().WriteTimestamped(gomock.Any(), job.UUID, task1.UUID, logMsg1) + mocks.logStorage.EXPECT().WriteTimestamped(gomock.Any(), job.UUID, task2.UUID, logMsg1) + + mocks.logStorage.EXPECT().WriteTimestamped(gomock.Any(), job.UUID, task1.UUID, logMsg2) + mocks.logStorage.EXPECT().WriteTimestamped(gomock.Any(), job.UUID, task2.UUID, logMsg2) mocks.broadcaster.EXPECT().BroadcastTaskUpdate(api.EventTaskUpdate{ Activity: logMsg2, Id: task1.UUID, - JobId: task1.Job.UUID, + JobId: job.UUID, Name: task1.Name, PreviousStatus: &task1PrevStatus, Status: api.TaskStatusQueued, - Updated: task1.UpdatedAt, + Updated: task1.UpdatedAt.Time, }) mocks.broadcaster.EXPECT().BroadcastTaskUpdate(api.EventTaskUpdate{ Activity: logMsg2, Id: task2.UUID, - JobId: task2.Job.UUID, + JobId: job.UUID, Name: task2.Name, PreviousStatus: &task2PrevStatus, Status: api.TaskStatusQueued, - Updated: task2.UpdatedAt, + Updated: task2.UpdatedAt.Time, }) + mocks.expectFetchJobOfTask(task1, job) + mocks.expectFetchJobOfTask(task2, job) + err := sm.RequeueActiveTasksOfWorker(ctx, &worker, "worker had to test") require.NoError(t, err) } diff --git a/internal/manager/timeout_checker/interfaces.go b/internal/manager/timeout_checker/interfaces.go index dd51fb4f..a9631114 100644 --- a/internal/manager/timeout_checker/interfaces.go +++ b/internal/manager/timeout_checker/interfaces.go @@ -9,7 +9,6 @@ import ( "github.com/rs/zerolog" "projects.blender.org/studio/flamenco/internal/manager/eventbus" "projects.blender.org/studio/flamenco/internal/manager/persistence" - "projects.blender.org/studio/flamenco/internal/manager/persistence/sqlc" "projects.blender.org/studio/flamenco/internal/manager/task_state_machine" "projects.blender.org/studio/flamenco/pkg/api" ) @@ -18,9 +17,11 @@ import ( //go:generate go run github.com/golang/mock/mockgen -destination mocks/interfaces_mock.gen.go -package mocks projects.blender.org/studio/flamenco/internal/manager/timeout_checker PersistenceService,TaskStateMachine,LogStorage,ChangeBroadcaster type PersistenceService interface { - FetchTimedOutTasks(ctx context.Context, untouchedSince time.Time) ([]*persistence.Task, error) + FetchTimedOutTasks(ctx context.Context, untouchedSince time.Time) ([]persistence.TimedOutTaskInfo, error) FetchTimedOutWorkers(ctx context.Context, lastSeenBefore time.Time) ([]*persistence.Worker, error) - SaveWorker(ctx context.Context, w *sqlc.Worker) error + FetchWorker(ctx context.Context, workerUUID string) (*persistence.Worker, error) + SaveWorker(ctx context.Context, w *persistence.Worker) error + FetchJob(ctx context.Context, jobUUID string) (*persistence.Job, error) } var _ PersistenceService = (*persistence.DB)(nil) diff --git a/internal/manager/timeout_checker/mocks/interfaces_mock.gen.go b/internal/manager/timeout_checker/mocks/interfaces_mock.gen.go index 509af5bd..e3b1d7a6 100644 --- a/internal/manager/timeout_checker/mocks/interfaces_mock.gen.go +++ b/internal/manager/timeout_checker/mocks/interfaces_mock.gen.go @@ -11,7 +11,6 @@ import ( gomock "github.com/golang/mock/gomock" zerolog "github.com/rs/zerolog" - persistence "projects.blender.org/studio/flamenco/internal/manager/persistence" sqlc "projects.blender.org/studio/flamenco/internal/manager/persistence/sqlc" api "projects.blender.org/studio/flamenco/pkg/api" ) @@ -39,11 +38,26 @@ func (m *MockPersistenceService) EXPECT() *MockPersistenceServiceMockRecorder { return m.recorder } +// FetchJob mocks base method. +func (m *MockPersistenceService) FetchJob(arg0 context.Context, arg1 string) (*sqlc.Job, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FetchJob", arg0, arg1) + ret0, _ := ret[0].(*sqlc.Job) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FetchJob indicates an expected call of FetchJob. +func (mr *MockPersistenceServiceMockRecorder) FetchJob(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchJob", reflect.TypeOf((*MockPersistenceService)(nil).FetchJob), arg0, arg1) +} + // FetchTimedOutTasks mocks base method. -func (m *MockPersistenceService) FetchTimedOutTasks(arg0 context.Context, arg1 time.Time) ([]*persistence.Task, error) { +func (m *MockPersistenceService) FetchTimedOutTasks(arg0 context.Context, arg1 time.Time) ([]sqlc.FetchTimedOutTasksRow, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "FetchTimedOutTasks", arg0, arg1) - ret0, _ := ret[0].([]*persistence.Task) + ret0, _ := ret[0].([]sqlc.FetchTimedOutTasksRow) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -69,6 +83,21 @@ func (mr *MockPersistenceServiceMockRecorder) FetchTimedOutWorkers(arg0, arg1 in return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchTimedOutWorkers", reflect.TypeOf((*MockPersistenceService)(nil).FetchTimedOutWorkers), arg0, arg1) } +// FetchWorker mocks base method. +func (m *MockPersistenceService) FetchWorker(arg0 context.Context, arg1 string) (*sqlc.Worker, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FetchWorker", arg0, arg1) + ret0, _ := ret[0].(*sqlc.Worker) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FetchWorker indicates an expected call of FetchWorker. +func (mr *MockPersistenceServiceMockRecorder) FetchWorker(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchWorker", reflect.TypeOf((*MockPersistenceService)(nil).FetchWorker), arg0, arg1) +} + // SaveWorker mocks base method. func (m *MockPersistenceService) SaveWorker(arg0 context.Context, arg1 *sqlc.Worker) error { m.ctrl.T.Helper() @@ -121,7 +150,7 @@ func (mr *MockTaskStateMachineMockRecorder) RequeueActiveTasksOfWorker(arg0, arg } // TaskStatusChange mocks base method. -func (m *MockTaskStateMachine) TaskStatusChange(arg0 context.Context, arg1 *persistence.Task, arg2 api.TaskStatus) error { +func (m *MockTaskStateMachine) TaskStatusChange(arg0 context.Context, arg1 *sqlc.Task, arg2 api.TaskStatus) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "TaskStatusChange", arg0, arg1, arg2) ret0, _ := ret[0].(error) diff --git a/internal/manager/timeout_checker/tasks.go b/internal/manager/timeout_checker/tasks.go index afee2337..b95538a9 100644 --- a/internal/manager/timeout_checker/tasks.go +++ b/internal/manager/timeout_checker/tasks.go @@ -21,38 +21,44 @@ func (ttc *TimeoutChecker) checkTasks(ctx context.Context) { Logger() logger.Trace().Msg("TimeoutChecker: finding active tasks that have not been touched since threshold") - tasks, err := ttc.persist.FetchTimedOutTasks(ctx, timeoutThreshold) + timeoutTaskInfo, err := ttc.persist.FetchTimedOutTasks(ctx, timeoutThreshold) if err != nil { log.Error().Err(err).Msg("TimeoutChecker: error fetching timed-out tasks from database") return } - if len(tasks) == 0 { + if len(timeoutTaskInfo) == 0 { logger.Trace().Msg("TimeoutChecker: no timed-out tasks") return } logger.Debug(). - Int("numTasks", len(tasks)). + Int("numTasks", len(timeoutTaskInfo)). Msg("TimeoutChecker: failing all active tasks that have not been touched since threshold") - for _, task := range tasks { - ttc.timeoutTask(ctx, task) + for _, taskInfo := range timeoutTaskInfo { + ttc.timeoutTask(ctx, taskInfo) } } // timeoutTask marks a task as 'failed' due to a timeout. -func (ttc *TimeoutChecker) timeoutTask(ctx context.Context, task *persistence.Task) { - workerIdent, logger := ttc.assignedWorker(task) +func (ttc *TimeoutChecker) timeoutTask(ctx context.Context, taskInfo persistence.TimedOutTaskInfo) { + task := taskInfo.Task + workerIdent, logger := ttc.assignedWorker(taskInfo) task.Activity = fmt.Sprintf("Task timed out on worker %s", workerIdent) - err := ttc.taskStateMachine.TaskStatusChange(ctx, task, api.TaskStatusFailed) + err := ttc.taskStateMachine.TaskStatusChange(ctx, &task, api.TaskStatusFailed) if err != nil { logger.Error().Err(err).Msg("TimeoutChecker: error saving timed-out task to database") } - err = ttc.logStorage.WriteTimestamped(logger, task.Job.UUID, task.UUID, + lastTouchedAt := "forever" + if task.LastTouchedAt.Valid { + lastTouchedAt = task.LastTouchedAt.Time.Format(time.RFC3339) + } + + err = ttc.logStorage.WriteTimestamped(logger, taskInfo.JobUUID, task.UUID, fmt.Sprintf("Task timed out. It was assigned to worker %s, but untouched since %s", - workerIdent, task.LastTouchedAt.Format(time.RFC3339))) + workerIdent, lastTouchedAt)) if err != nil { logger.Error().Err(err).Msg("TimeoutChecker: error writing timeout info to the task log") } @@ -60,27 +66,20 @@ func (ttc *TimeoutChecker) timeoutTask(ctx context.Context, task *persistence.Ta // assignedWorker returns a description of the worker assigned to this task, // and a logger configured for it. -func (ttc *TimeoutChecker) assignedWorker(task *persistence.Task) (string, zerolog.Logger) { - logCtx := log.With().Str("task", task.UUID) +func (ttc *TimeoutChecker) assignedWorker(taskInfo persistence.TimedOutTaskInfo) (string, zerolog.Logger) { + logCtx := log.With().Str("task", taskInfo.Task.UUID) - if task.WorkerID == nil { + if taskInfo.WorkerUUID == "" { logger := logCtx.Logger() logger.Warn().Msg("TimeoutChecker: task timed out, but was not assigned to any worker") return "-unassigned-", logger } - if task.Worker == nil { - logger := logCtx.Logger() - logger.Warn().Uint("workerDBID", *task.WorkerID). - Msg("TimeoutChecker: task is assigned to worker that no longer exists") - return "-unknown-", logger - } - logCtx = logCtx. - Str("worker", task.Worker.UUID). - Str("workerName", task.Worker.Name) + Str("worker", taskInfo.WorkerUUID). + Str("workerName", taskInfo.WorkerName) logger := logCtx.Logger() logger.Warn().Msg("TimeoutChecker: task timed out") - return task.Worker.Identifier(), logger + return fmt.Sprintf("%s (%s)", taskInfo.WorkerName, taskInfo.WorkerUUID), logger } diff --git a/internal/manager/timeout_checker/tasks_test.go b/internal/manager/timeout_checker/tasks_test.go index 0c4747c8..6267838c 100644 --- a/internal/manager/timeout_checker/tasks_test.go +++ b/internal/manager/timeout_checker/tasks_test.go @@ -4,6 +4,7 @@ package timeout_checker import ( "context" + "database/sql" "errors" "testing" "time" @@ -109,7 +110,7 @@ func TestTaskTimeout(t *testing.T) { lastTime := mocks.clock.Now().UTC().Add(-1 * time.Hour) - job := persistence.Job{UUID: "JOB-UUID"} + job := persistence.Job{ID: 327, UUID: "JOB-UUID"} worker := persistence.Worker{ UUID: "WORKER-UUID", Name: "Tester", @@ -117,36 +118,35 @@ func TestTaskTimeout(t *testing.T) { } taskUnassigned := persistence.Task{ UUID: "TASK-UUID-UNASSIGNED", - Job: &job, - LastTouchedAt: lastTime, - } - taskUnknownWorker := persistence.Task{ - UUID: "TASK-UUID-UNKNOWN", - Job: &job, - LastTouchedAt: lastTime, - WorkerID: ptr(uint(worker.ID)), + JobID: job.ID, + LastTouchedAt: sql.NullTime{Time: lastTime, Valid: true}, } taskAssigned := persistence.Task{ UUID: "TASK-UUID-ASSIGNED", - Job: &job, - LastTouchedAt: lastTime, - WorkerID: ptr(uint(worker.ID)), - Worker: &worker, + JobID: job.ID, + LastTouchedAt: sql.NullTime{Time: lastTime, Valid: true}, + WorkerID: sql.NullInt64{Int64: worker.ID, Valid: true}, } mocks.persist.EXPECT().FetchTimedOutWorkers(mocks.ctx, gomock.Any()).AnyTimes().Return(nil, nil) + timedoutTaskInfo := []persistence.TimedOutTaskInfo{ + {Task: taskUnassigned, JobUUID: job.UUID, WorkerName: "", WorkerUUID: ""}, + {Task: taskAssigned, JobUUID: job.UUID, WorkerName: worker.Name, WorkerUUID: worker.UUID}, + } mocks.persist.EXPECT().FetchTimedOutTasks(mocks.ctx, gomock.Any()). - Return([]*persistence.Task{&taskUnassigned, &taskUnknownWorker, &taskAssigned}, nil) + Return(timedoutTaskInfo, nil) - mocks.taskStateMachine.EXPECT().TaskStatusChange(mocks.ctx, &taskUnassigned, api.TaskStatusFailed) - mocks.taskStateMachine.EXPECT().TaskStatusChange(mocks.ctx, &taskUnknownWorker, api.TaskStatusFailed) - mocks.taskStateMachine.EXPECT().TaskStatusChange(mocks.ctx, &taskAssigned, api.TaskStatusFailed) + taskUnassignedWithActivity := taskUnassigned + taskUnassignedWithActivity.Activity = "Task timed out on worker -unassigned-" + taskAssignedWithActivity := taskAssigned + taskAssignedWithActivity.Activity = "Task timed out on worker Tester (WORKER-UUID)" + + mocks.taskStateMachine.EXPECT().TaskStatusChange(mocks.ctx, &taskUnassignedWithActivity, api.TaskStatusFailed) + mocks.taskStateMachine.EXPECT().TaskStatusChange(mocks.ctx, &taskAssignedWithActivity, api.TaskStatusFailed) mocks.logStorage.EXPECT().WriteTimestamped(gomock.Any(), job.UUID, taskUnassigned.UUID, "Task timed out. It was assigned to worker -unassigned-, but untouched since 2022-06-09T11:00:00Z") - mocks.logStorage.EXPECT().WriteTimestamped(gomock.Any(), job.UUID, taskUnknownWorker.UUID, - "Task timed out. It was assigned to worker -unknown-, but untouched since 2022-06-09T11:00:00Z") mocks.logStorage.EXPECT().WriteTimestamped(gomock.Any(), job.UUID, taskAssigned.UUID, "Task timed out. It was assigned to worker Tester (WORKER-UUID), but untouched since 2022-06-09T11:00:00Z") diff --git a/internal/manager/timeout_checker/workers_test.go b/internal/manager/timeout_checker/workers_test.go index 79f43512..57a30cf8 100644 --- a/internal/manager/timeout_checker/workers_test.go +++ b/internal/manager/timeout_checker/workers_test.go @@ -38,7 +38,7 @@ func TestWorkerTimeout(t *testing.T) { } // No tasks are timing out in this test. - mocks.persist.EXPECT().FetchTimedOutTasks(mocks.ctx, gomock.Any()).Return([]*persistence.Task{}, nil) + mocks.persist.EXPECT().FetchTimedOutTasks(mocks.ctx, gomock.Any()).Return([]persistence.TimedOutTaskInfo{}, nil) mocks.persist.EXPECT().FetchTimedOutWorkers(mocks.ctx, gomock.Any()). Return([]*persistence.Worker{&worker}, nil)