diff --git a/internal/manager/api_impl/api_impl.go b/internal/manager/api_impl/api_impl.go index d1925449..a03fc3c7 100644 --- a/internal/manager/api_impl/api_impl.go +++ b/internal/manager/api_impl/api_impl.go @@ -39,6 +39,7 @@ type Flamenco struct { type PersistenceService interface { StoreAuthoredJob(ctx context.Context, authoredJob job_compilers.AuthoredJob) error FetchJob(ctx context.Context, jobID string) (*persistence.Job, error) + FetchTask(ctx context.Context, taskID string) (*persistence.Task, error) CreateWorker(ctx context.Context, w *persistence.Worker) error FetchWorker(ctx context.Context, uuid string) (*persistence.Worker, error) diff --git a/internal/manager/api_impl/jobs.go b/internal/manager/api_impl/jobs.go index 03e36c84..e96e97a5 100644 --- a/internal/manager/api_impl/jobs.go +++ b/internal/manager/api_impl/jobs.go @@ -69,6 +69,9 @@ func (f *Flamenco) SubmitJob(e echo.Context) error { logger = logger.With().Str("job_id", authoredJob.JobID).Logger() + // TODO: check whether this job should be queued immediately or start paused. + authoredJob.Status = api.JobStatusQueued + if err := f.persist.StoreAuthoredJob(ctx, *authoredJob); err != nil { logger.Error().Err(err).Msg("error persisting job in database") return sendAPIError(e, http.StatusInternalServerError, "error persisting job in database") @@ -84,8 +87,7 @@ func (f *Flamenco) SubmitJob(e echo.Context) error { func (f *Flamenco) FetchJob(e echo.Context, jobId string) error { // TODO: move this into some middleware. - logger := log.With(). - Str("ip", e.RealIP()). + logger := requestLogger(e).With(). Str("job_id", jobId). Logger() @@ -121,3 +123,39 @@ func (f *Flamenco) FetchJob(e echo.Context, jobId string) error { return e.JSON(http.StatusOK, apiJob) } + +func (f *Flamenco) TaskUpdate(e echo.Context, taskID string) error { + logger := requestLogger(e) + + if _, err := uuid.Parse(taskID); err != nil { + logger.Debug().Msg("invalid task ID received") + return sendAPIError(e, http.StatusBadRequest, "task ID not valid") + } + logger = logger.With().Str("taskID", taskID).Logger() + + // Fetch the task, to see if this worker is even allowed to send us updates. + ctx := e.Request().Context() + dbTask, err := f.persist.FetchTask(ctx, taskID) + if err != nil { + logger.Warn().Err(err).Msg("cannot fetch task") + return sendAPIError(e, http.StatusNotFound, fmt.Sprintf("task %+v not found", taskID)) + } + + worker := requestWorker(e) + if dbTask.Worker == nil { + logger.Warn(). + Str("requestingWorkerID", worker.UUID). + Msg("worker trying to update task that's not assigned to any worker") + return sendAPIError(e, http.StatusConflict, fmt.Sprintf("task %+v is not assigned to any worker, so also not to you", taskID)) + } + if dbTask.Worker.UUID != worker.UUID { + logger.Warn(). + Str("requestingWorkerID", worker.UUID). + Str("assignedWorkerID", dbTask.Worker.UUID). + Msg("worker trying to update task that's assigned to another worker") + return sendAPIError(e, http.StatusConflict, fmt.Sprintf("task %+v is not assigned to you, but to worker %v", taskID, dbTask.Worker.UUID)) + } + + // TODO: actually handle the task update. + return e.String(http.StatusNoContent, "") +} diff --git a/internal/manager/api_impl/workers.go b/internal/manager/api_impl/workers.go index b2748ed6..1b8bd163 100644 --- a/internal/manager/api_impl/workers.go +++ b/internal/manager/api_impl/workers.go @@ -174,19 +174,42 @@ func (f *Flamenco) ScheduleTask(e echo.Context) error { logger := requestLogger(e) logger.Info().Msg("worker requesting task") - return e.JSON(http.StatusOK, &api.AssignedTask{ - Uuid: uuid.New().String(), - Commands: []api.Command{ - {Name: "echo", Settings: echo.Map{"payload": "Simon says \"Shaders!\""}}, - {Name: "blender", Settings: echo.Map{"blender_cmd": "/shared/bin/blender"}}, - }, - Job: uuid.New().String(), - JobPriority: 50, - JobType: "blender-render", - Name: "A1032", - Priority: 50, - Status: "active", - TaskType: "blender-render", - User: "", - }) + // Figure out which worker is requesting a task: + worker := requestWorker(e) + if worker == nil { + logger.Warn().Msg("task requested by non-worker") + return sendAPIError(e, http.StatusBadRequest, "not authenticated as Worker") + } + + // Get a task to execute: + dbTask, err := f.persist.ScheduleTask(worker) + if err != nil { + logger.Warn().Err(err).Msg("error scheduling task for worker") + return sendAPIError(e, http.StatusInternalServerError, "internal error finding a task for you: %v", err) + } + if dbTask == nil { + return e.String(http.StatusNoContent, "") + } + + // Convert database objects to API objects: + apiCommands := []api.Command{} + for _, cmd := range dbTask.Commands { + apiCommands = append(apiCommands, api.Command{ + Name: cmd.Type, + Settings: cmd.Parameters, + }) + } + apiTask := api.AssignedTask{ + Uuid: dbTask.UUID, + Commands: apiCommands, + Job: dbTask.Job.UUID, + JobPriority: dbTask.Job.Priority, + JobType: dbTask.Job.JobType, + Name: dbTask.Name, + Priority: dbTask.Priority, + Status: api.TaskStatus(dbTask.Status), + TaskType: dbTask.Type, + } + + return e.JSON(http.StatusOK, apiTask) } diff --git a/internal/manager/persistence/db_migration.go b/internal/manager/persistence/db_migration.go index 5793baad..bed83b99 100644 --- a/internal/manager/persistence/db_migration.go +++ b/internal/manager/persistence/db_migration.go @@ -22,6 +22,8 @@ package persistence import ( "fmt" + + "github.com/rs/zerolog/log" ) func (db *DB) migrate() error { @@ -29,5 +31,6 @@ func (db *DB) migrate() error { if err != nil { return fmt.Errorf("failed to automigrate database: %v", err) } + log.Debug().Msg("database automigration succesful") return nil } diff --git a/internal/manager/persistence/jobs.go b/internal/manager/persistence/jobs.go index 0d32e7d0..cbb98966 100644 --- a/internal/manager/persistence/jobs.go +++ b/internal/manager/persistence/jobs.go @@ -59,7 +59,9 @@ type Task struct { Priority int `gorm:"type:smallint;not null"` Status string `gorm:"type:varchar(16);not null"` - // TODO: include info about which worker is/was working on this. + // Which worker is/was working on this. + WorkerID *uint + Worker *Worker `gorm:"foreignkey:WorkerID;references:ID;constraint:OnDelete:CASCADE"` // Dependencies are tasks that need to be completed before this one can run. Dependencies []*Task `gorm:"many2many:task_dependencies;constraint:OnDelete:CASCADE"` @@ -199,3 +201,13 @@ func (db *DB) SaveJobStatus(ctx context.Context, j *Job) error { } return nil } + +func (db *DB) FetchTask(ctx context.Context, taskUUID string) (*Task, error) { + dbTask := Task{} + findResult := db.gormDB.First(&dbTask, "uuid = ?", taskUUID) + if findResult.Error != nil { + return nil, findResult.Error + } + + return &dbTask, nil +} diff --git a/internal/manager/persistence/task_scheduler.go b/internal/manager/persistence/task_scheduler.go index 23b453cd..261c9597 100644 --- a/internal/manager/persistence/task_scheduler.go +++ b/internal/manager/persistence/task_scheduler.go @@ -22,6 +22,7 @@ package persistence import ( "errors" + "fmt" "github.com/rs/zerolog/log" "gitlab.com/blender/flamenco-ng-poc/pkg/api" @@ -51,31 +52,57 @@ func (db *DB) findTaskForWorker(w *Worker) (*Task, error) { logger.Debug().Msg("finding task for worker") task := Task{} - gormDB := db.GormDB() - tx := gormDB.Debug(). - Model(&task). - Joins("left join jobs on tasks.job_id = jobs.id"). - Joins("left join task_dependencies on tasks.id = task_dependencies.task_id"). - Joins("left join tasks as tdeps on tdeps.id = task_dependencies.dependency_id"). - Where("tasks.status in ?", schedulableTaskStatuses). // Schedulable task statuses - Where("tdeps.status in ? or tdeps.status is NULL", completedTaskStatuses). // Dependencies completed - Where("jobs.status in ?", schedulableJobStatuses). // Schedulable job statuses - // TODO: Supported task types - // TODO: Non-blacklisted - Order("jobs.priority desc"). // Highest job priority - Order("priority desc"). // Highest task priority - Limit(1). - Preload("Job"). - First(&task) - if tx.Error != nil { - if errors.Is(tx.Error, gorm.ErrRecordNotFound) { + // Run two queries in one transaction: + // 1. find task, and + // 2. assign the task to the worker. + err := db.gormDB.Transaction(func(tx *gorm.DB) error { + findTaskResult := tx.Debug(). + Model(&task). + Joins("left join jobs on tasks.job_id = jobs.id"). + Joins("left join task_dependencies on tasks.id = task_dependencies.task_id"). + Joins("left join tasks as tdeps on tdeps.id = task_dependencies.dependency_id"). + Where("tasks.status in ?", schedulableTaskStatuses). // Schedulable task statuses + Where("tdeps.status in ? or tdeps.status is NULL", completedTaskStatuses). // Dependencies completed + Where("jobs.status in ?", schedulableJobStatuses). // Schedulable job statuses + // TODO: Supported task types + // TODO: assigned to this worker or not assigned at all + // TODO: Non-blacklisted + Order("jobs.priority desc"). // Highest job priority + Order("priority desc"). // Highest task priority + Limit(1). + Preload("Job"). + First(&task) + + if findTaskResult.Error != nil { + return findTaskResult.Error + } + + // Found a task, now assign it to the requesting worker. + // Without the Select() call, Gorm will try and also store task.Job in the jobs database, which is not what we want. + if err := tx.Debug().Model(&task).Select("worker_id").Updates(Task{WorkerID: &w.ID}).Error; err != nil { + logger.Warn(). + Str("taskID", task.UUID). + Err(err). + Msg("error assigning task to worker") + return fmt.Errorf("error assigning task to worker: %v", err) + } + + return nil + }) + + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { logger.Debug().Msg("no task for worker") return nil, nil } - logger.Error().Err(tx.Error).Msg("error finding task for worker") - return nil, tx.Error + logger.Error().Err(err).Msg("error finding task for worker") + return nil, fmt.Errorf("error finding task for worker: %w", err) } + logger.Info(). + Str("taskID", task.UUID). + Msg("assigned task to worker") + return &task, nil } diff --git a/internal/manager/persistence/task_scheduler_test.go b/internal/manager/persistence/task_scheduler_test.go index a853b58e..2158304d 100644 --- a/internal/manager/persistence/task_scheduler_test.go +++ b/internal/manager/persistence/task_scheduler_test.go @@ -33,7 +33,7 @@ import ( func TestNoTasks(t *testing.T) { db := CreateTestDB(t) - w := linuxWorker() + w := linuxWorker(t, db) task, err := db.ScheduleTask(&w) assert.Nil(t, task) @@ -42,7 +42,7 @@ func TestNoTasks(t *testing.T) { func TestOneJobOneTask(t *testing.T) { db := CreateTestDB(t) - w := linuxWorker() + w := linuxWorker(t, db) authTask := authorTestTask("the task", "blender-render") atj := authorTestJob("b6a1d859-122f-4791-8b78-b943329a9989", "simple-blender-render", authTask) @@ -54,11 +54,22 @@ func TestOneJobOneTask(t *testing.T) { t.Fatal("task is nil") } assert.Equal(t, job.ID, task.JobID) + + // Test that the task has been assigned to this worker. + dbTask, err := db.FetchTask(context.Background(), authTask.UUID) + assert.NoError(t, err) + if dbTask == nil { + t.Fatal("task cannot be fetched from database") + } + if dbTask.WorkerID == nil { + t.Fatal("no worker assigned to task") + } + assert.Equal(t, w.ID, *dbTask.WorkerID) } func TestOneJobThreeTasksByPrio(t *testing.T) { db := CreateTestDB(t) - w := linuxWorker() + w := linuxWorker(t, db) att1 := authorTestTask("1 low-prio task", "blender-render") att2 := authorTestTask("2 high-prio task", "render-preview") @@ -87,7 +98,7 @@ func TestOneJobThreeTasksByPrio(t *testing.T) { func TestOneJobThreeTasksByDependencies(t *testing.T) { db := CreateTestDB(t) - w := linuxWorker() + w := linuxWorker(t, db) att1 := authorTestTask("1 low-prio task", "blender-render") att2 := authorTestTask("2 high-prio task", "render-preview") @@ -111,7 +122,7 @@ func TestOneJobThreeTasksByDependencies(t *testing.T) { func TestTwoJobsThreeTasks(t *testing.T) { db := CreateTestDB(t) - w := linuxWorker() + w := linuxWorker(t, db) att1_1 := authorTestTask("1.1 low-prio task", "blender-render") att1_2 := authorTestTask("1.2 high-prio task", "render-preview") @@ -201,7 +212,7 @@ func authorTestTask(name, taskType string, dependencies ...*job_compilers.Author return task } -func linuxWorker() Worker { +func linuxWorker(t *testing.T, db *DB) Worker { w := Worker{ UUID: "b13b8322-3e96-41c3-940a-3d581008a5f8", Name: "Linux", @@ -209,5 +220,12 @@ func linuxWorker() Worker { Status: api.WorkerStatusAwake, SupportedTaskTypes: "blender,ffmpeg,file-management", } + + err := db.gormDB.Save(&w).Error + if err != nil { + t.Logf("cannot save Linux worker: %v", err) + t.FailNow() + } + return w }