Manager: implement persistence layer interface for task status machine
Implement the functions used by the task status machine in the DB persistence layer.
This commit is contained in:
parent
7279f2e35f
commit
9a5bbb4131
@ -219,3 +219,80 @@ func (db *DB) SaveTask(ctx context.Context, t *Task) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) JobHasTasksInStatus(ctx context.Context, job *Job, taskStatus api.TaskStatus) (bool, error) {
|
||||
var numTasksInStatus int64
|
||||
tx := db.gormDB.Model(&Task{}).
|
||||
Where("job_id", job.ID).
|
||||
Where("status", taskStatus).
|
||||
Count(&numTasksInStatus)
|
||||
if tx.Error != nil {
|
||||
return false, tx.Error
|
||||
}
|
||||
return numTasksInStatus > 0, nil
|
||||
}
|
||||
|
||||
func (db *DB) CountTasksOfJobInStatus(ctx context.Context, job *Job, taskStatus api.TaskStatus) (numInStatus, numTotal int, err error) {
|
||||
type Result struct {
|
||||
Status api.TaskStatus
|
||||
NumTasks int
|
||||
}
|
||||
var results []Result
|
||||
|
||||
tx := db.gormDB.Debug().Model(&Task{}).
|
||||
Select("status, count(*) as num_tasks").
|
||||
Where("job_id", job.ID).
|
||||
Group("status").
|
||||
Scan(&results)
|
||||
|
||||
if tx.Error != nil {
|
||||
return 0, 0, fmt.Errorf("count tasks of job %s in status %q: %w", job.UUID, taskStatus, tx.Error)
|
||||
}
|
||||
|
||||
for _, result := range results {
|
||||
if result.Status == taskStatus {
|
||||
numInStatus += result.NumTasks
|
||||
}
|
||||
numTotal += result.NumTasks
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// UpdateJobsTaskStatuses updates the status & activity of the tasks of `job`.
|
||||
func (db *DB) UpdateJobsTaskStatuses(ctx context.Context, job *Job,
|
||||
taskStatus api.TaskStatus, activity string) error {
|
||||
|
||||
if taskStatus == "" {
|
||||
return errors.New("empty status not allowed")
|
||||
}
|
||||
|
||||
tx := db.gormDB.Model(Task{}).
|
||||
Where("job_Id = ?", job.ID).
|
||||
Updates(Task{Status: taskStatus, Activity: activity})
|
||||
|
||||
if tx.Error != nil {
|
||||
return tx.Error
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateJobsTaskStatusesConditional updates the status & activity of the tasks of `job`,
|
||||
// limited to those tasks with status in `statusesToUpdate`.
|
||||
func (db *DB) UpdateJobsTaskStatusesConditional(ctx context.Context, job *Job,
|
||||
statusesToUpdate []api.TaskStatus, taskStatus api.TaskStatus, activity string) error {
|
||||
|
||||
if taskStatus == "" {
|
||||
return errors.New("empty status not allowed")
|
||||
}
|
||||
|
||||
tx := db.gormDB.Debug().Model(Task{}).
|
||||
Where("job_Id = ?", job.ID).
|
||||
Where("status in ?", statusesToUpdate).
|
||||
Updates(Task{Status: taskStatus, Activity: activity})
|
||||
|
||||
if tx.Error != nil {
|
||||
return tx.Error
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -37,6 +37,157 @@ func TestStoreAuthoredJob(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
job := createTestAuthoredJobWithTasks()
|
||||
err := db.StoreAuthoredJob(ctx, job)
|
||||
assert.NoError(t, err)
|
||||
|
||||
fetchedJob, err := db.FetchJob(ctx, job.JobID)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, fetchedJob)
|
||||
|
||||
// Test contents of fetched job
|
||||
assert.Equal(t, job.JobID, fetchedJob.UUID)
|
||||
assert.Equal(t, job.Name, fetchedJob.Name)
|
||||
assert.Equal(t, job.JobType, fetchedJob.JobType)
|
||||
assert.Equal(t, job.Priority, fetchedJob.Priority)
|
||||
assert.Equal(t, api.JobStatusUnderConstruction, fetchedJob.Status)
|
||||
assert.EqualValues(t, map[string]interface{}(job.Settings), fetchedJob.Settings)
|
||||
assert.EqualValues(t, map[string]string(job.Metadata), fetchedJob.Metadata)
|
||||
|
||||
// Fetch tasks of job.
|
||||
var dbJob Job
|
||||
tx := db.gormDB.Where(&Job{UUID: job.JobID}).Find(&dbJob)
|
||||
assert.NoError(t, tx.Error)
|
||||
var tasks []Task
|
||||
tx = db.gormDB.Where("job_id = ?", dbJob.ID).Find(&tasks)
|
||||
assert.NoError(t, tx.Error)
|
||||
|
||||
if len(tasks) != 3 {
|
||||
t.Fatalf("expected 3 tasks, got %d", len(tasks))
|
||||
}
|
||||
|
||||
// TODO: test task contents.
|
||||
assert.Equal(t, api.TaskStatusQueued, tasks[0].Status)
|
||||
assert.Equal(t, api.TaskStatusQueued, tasks[1].Status)
|
||||
assert.Equal(t, api.TaskStatusQueued, tasks[2].Status)
|
||||
}
|
||||
|
||||
func TestJobHasTasksInStatus(t *testing.T) {
|
||||
ctx, db, job, _ := jobTasksTestFixtures(t)
|
||||
|
||||
hasTasks, err := db.JobHasTasksInStatus(ctx, job, api.TaskStatusQueued)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, hasTasks, "expected freshly-created job to have queued tasks")
|
||||
|
||||
hasTasks, err = db.JobHasTasksInStatus(ctx, job, api.TaskStatusActive)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, hasTasks, "expected freshly-created job to have no active tasks")
|
||||
}
|
||||
|
||||
func TestCountTasksOfJobInStatus(t *testing.T) {
|
||||
ctx, db, job, authoredJob := jobTasksTestFixtures(t)
|
||||
|
||||
numQueued, numTotal, err := db.CountTasksOfJobInStatus(ctx, job, api.TaskStatusQueued)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, numQueued)
|
||||
assert.Equal(t, 3, numTotal)
|
||||
|
||||
// Make one task failed.
|
||||
task, err := db.FetchTask(ctx, authoredJob.Tasks[0].UUID)
|
||||
assert.NoError(t, err)
|
||||
task.Status = api.TaskStatusFailed
|
||||
assert.NoError(t, db.SaveTask(ctx, task))
|
||||
|
||||
numQueued, numTotal, err = db.CountTasksOfJobInStatus(ctx, job, api.TaskStatusQueued)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, numQueued)
|
||||
assert.Equal(t, 3, numTotal)
|
||||
|
||||
numFailed, numTotal, err := db.CountTasksOfJobInStatus(ctx, job, api.TaskStatusFailed)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, numFailed)
|
||||
assert.Equal(t, 3, numTotal)
|
||||
|
||||
numActive, numTotal, err := db.CountTasksOfJobInStatus(ctx, job, api.TaskStatusActive)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, numActive)
|
||||
assert.Equal(t, 3, numTotal)
|
||||
}
|
||||
|
||||
func TestUpdateJobsTaskStatuses(t *testing.T) {
|
||||
ctx, db, job, authoredJob := jobTasksTestFixtures(t)
|
||||
|
||||
err := db.UpdateJobsTaskStatuses(ctx, job, api.TaskStatusSoftFailed, "testing æctivity")
|
||||
assert.NoError(t, err)
|
||||
|
||||
numSoftFailed, numTotal, err := db.CountTasksOfJobInStatus(ctx, job, api.TaskStatusSoftFailed)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, numSoftFailed, "all tasks should have had their status changed")
|
||||
assert.Equal(t, 3, numTotal)
|
||||
|
||||
task, err := db.FetchTask(ctx, authoredJob.Tasks[0].UUID)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "testing æctivity", task.Activity)
|
||||
|
||||
// Empty status should be rejected.
|
||||
err = db.UpdateJobsTaskStatuses(ctx, job, "", "testing empty status")
|
||||
assert.Error(t, err)
|
||||
|
||||
numEmpty, _, err := db.CountTasksOfJobInStatus(ctx, job, "")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, numEmpty, "tasks should not have their status changed")
|
||||
|
||||
numSoftFailed, _, err = db.CountTasksOfJobInStatus(ctx, job, api.TaskStatusSoftFailed)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, numSoftFailed, "all tasks should still be soft-failed")
|
||||
}
|
||||
|
||||
func TestUpdateJobsTaskStatusesConditional(t *testing.T) {
|
||||
ctx, db, job, authoredJob := jobTasksTestFixtures(t)
|
||||
|
||||
getTask := func(taskIndex int) *Task {
|
||||
task, err := db.FetchTask(ctx, authoredJob.Tasks[taskIndex].UUID)
|
||||
if err != nil {
|
||||
t.Fatalf("Fetching task %d: %v", taskIndex, err)
|
||||
}
|
||||
return task
|
||||
}
|
||||
|
||||
setTaskStatus := func(taskIndex int, taskStatus api.TaskStatus) {
|
||||
task := getTask(taskIndex)
|
||||
task.Status = taskStatus
|
||||
if err := db.SaveTask(ctx, task); err != nil {
|
||||
t.Fatalf("Setting task %d to status %s: %v", taskIndex, taskStatus, err)
|
||||
}
|
||||
}
|
||||
|
||||
setTaskStatus(0, api.TaskStatusFailed)
|
||||
setTaskStatus(1, api.TaskStatusCompleted)
|
||||
setTaskStatus(2, api.TaskStatusActive)
|
||||
|
||||
err := db.UpdateJobsTaskStatusesConditional(ctx, job,
|
||||
[]api.TaskStatus{api.TaskStatusFailed, api.TaskStatusActive},
|
||||
api.TaskStatusCancelRequested, "some activity")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Task statuses should have updated for tasks 0 and 2.
|
||||
assert.Equal(t, api.TaskStatusCancelRequested, getTask(0).Status)
|
||||
assert.Equal(t, api.TaskStatusCompleted, getTask(1).Status)
|
||||
assert.Equal(t, api.TaskStatusCancelRequested, getTask(2).Status)
|
||||
|
||||
err = db.UpdateJobsTaskStatusesConditional(ctx, job,
|
||||
[]api.TaskStatus{api.TaskStatusFailed, api.TaskStatusActive},
|
||||
"", "empty task status should be disallowed")
|
||||
assert.Error(t, err)
|
||||
|
||||
// Task statuses should remain unchanged.
|
||||
assert.Equal(t, api.TaskStatusCancelRequested, getTask(0).Status)
|
||||
assert.Equal(t, api.TaskStatusCompleted, getTask(1).Status)
|
||||
assert.Equal(t, api.TaskStatusCancelRequested, getTask(2).Status)
|
||||
|
||||
}
|
||||
|
||||
func createTestAuthoredJobWithTasks() job_compilers.AuthoredJob {
|
||||
task1 := job_compilers.AuthoredTask{
|
||||
Name: "render-1-3",
|
||||
Type: "blender",
|
||||
@ -93,36 +244,28 @@ func TestStoreAuthoredJob(t *testing.T) {
|
||||
Tasks: []job_compilers.AuthoredTask{task1, task2, task3},
|
||||
}
|
||||
|
||||
err := db.StoreAuthoredJob(ctx, job)
|
||||
assert.NoError(t, err)
|
||||
|
||||
fetchedJob, err := db.FetchJob(ctx, job.JobID)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, fetchedJob)
|
||||
|
||||
// Test contents of fetched job
|
||||
assert.Equal(t, job.JobID, fetchedJob.UUID)
|
||||
assert.Equal(t, job.Name, fetchedJob.Name)
|
||||
assert.Equal(t, job.JobType, fetchedJob.JobType)
|
||||
assert.Equal(t, job.Priority, fetchedJob.Priority)
|
||||
assert.Equal(t, api.JobStatusUnderConstruction, fetchedJob.Status)
|
||||
assert.EqualValues(t, map[string]interface{}(job.Settings), fetchedJob.Settings)
|
||||
assert.EqualValues(t, map[string]string(job.Metadata), fetchedJob.Metadata)
|
||||
|
||||
// Fetch tasks of job.
|
||||
var dbJob Job
|
||||
tx := db.gormDB.Where(&Job{UUID: job.JobID}).Find(&dbJob)
|
||||
assert.NoError(t, tx.Error)
|
||||
var tasks []Task
|
||||
tx = db.gormDB.Where("job_id = ?", dbJob.ID).Find(&tasks)
|
||||
assert.NoError(t, tx.Error)
|
||||
|
||||
if len(tasks) != 3 {
|
||||
t.Fatalf("expected 3 tasks, got %d", len(tasks))
|
||||
return job
|
||||
}
|
||||
|
||||
// TODO: test task contents.
|
||||
assert.Equal(t, api.TaskStatusQueued, tasks[0].Status)
|
||||
assert.Equal(t, api.TaskStatusQueued, tasks[1].Status)
|
||||
assert.Equal(t, api.TaskStatusQueued, tasks[2].Status)
|
||||
func jobTasksTestFixtures(t *testing.T) (context.Context, *DB, *Job, job_compilers.AuthoredJob) {
|
||||
db := CreateTestDB(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
authoredJob := createTestAuthoredJobWithTasks()
|
||||
err := db.StoreAuthoredJob(ctx, authoredJob)
|
||||
if err != nil {
|
||||
t.Fatalf("error storing authored job in DB: %v", err)
|
||||
}
|
||||
|
||||
dbJob, err := db.FetchJob(ctx, authoredJob.JobID)
|
||||
if err != nil {
|
||||
t.Fatalf("error fetching job from DB: %v", err)
|
||||
}
|
||||
if dbJob == nil {
|
||||
t.Fatalf("nil job obtained from DB but with no error!")
|
||||
}
|
||||
|
||||
return ctx, db, dbJob, authoredJob
|
||||
}
|
||||
|
@ -42,7 +42,7 @@ type StateMachine struct {
|
||||
// Generate mock implementations of these interfaces.
|
||||
//go:generate go run github.com/golang/mock/mockgen -destination mocks/interfaces_mock.gen.go -package mocks gitlab.com/blender/flamenco-ng-poc/internal/manager/task_state_machine PersistenceService
|
||||
|
||||
type PersistenceService interface { // Subset of persistence.DB
|
||||
type PersistenceService interface {
|
||||
SaveTask(ctx context.Context, task *persistence.Task) error
|
||||
SaveJobStatus(ctx context.Context, j *persistence.Job) error
|
||||
|
||||
@ -59,6 +59,9 @@ type PersistenceService interface { // Subset of persistence.DB
|
||||
statusesToUpdate []api.TaskStatus, taskStatus api.TaskStatus, activity string) error
|
||||
}
|
||||
|
||||
// PersistenceService should be a subset of persistence.DB
|
||||
var _ PersistenceService = (*persistence.DB)(nil)
|
||||
|
||||
func NewStateMachine(persist PersistenceService) *StateMachine {
|
||||
return &StateMachine{
|
||||
persist: persist,
|
||||
|
Loading…
x
Reference in New Issue
Block a user