From d43947898dbe70a0a21adde893ccf29081db1863 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sybren=20A=2E=20St=C3=BCvel?= Date: Thu, 26 Sep 2024 21:20:01 +0200 Subject: [PATCH] Manager: replace final job-related queries with sqlc Ref: #104305 --- .../persistence/jobs_blocklist_test.go | 15 ++- .../manager/persistence/jobs_query_test.go | 6 +- internal/manager/persistence/jobs_test.go | 79 +++++++-------- .../manager/persistence/sqlc/query_jobs.sql | 20 ++++ .../persistence/sqlc/query_jobs.sql.go | 98 +++++++++++++++++++ 5 files changed, 164 insertions(+), 54 deletions(-) diff --git a/internal/manager/persistence/jobs_blocklist_test.go b/internal/manager/persistence/jobs_blocklist_test.go index da9e69bb..8894e681 100644 --- a/internal/manager/persistence/jobs_blocklist_test.go +++ b/internal/manager/persistence/jobs_blocklist_test.go @@ -14,19 +14,19 @@ func TestAddWorkerToJobBlocklist(t *testing.T) { defer close() worker := createWorker(ctx, t, db) + queries := db.queries() { // Add a worker to the block list. err := db.AddWorkerToJobBlocklist(ctx, job, worker, "blender") require.NoError(t, err) - list := []JobBlock{} - tx := db.gormDB.Model(&JobBlock{}).Scan(&list) - require.NoError(t, tx.Error) + list, err := queries.Test_FetchJobBlocklist(ctx) + require.NoError(t, err) if assert.Len(t, list, 1) { entry := list[0] - assert.Equal(t, entry.JobID, job.ID) - assert.Equal(t, entry.WorkerID, worker.ID) + assert.Equal(t, entry.JobID, int64(job.ID)) + assert.Equal(t, entry.WorkerID, int64(worker.ID)) assert.Equal(t, entry.TaskType, "blender") } } @@ -36,9 +36,8 @@ func TestAddWorkerToJobBlocklist(t *testing.T) { err := db.AddWorkerToJobBlocklist(ctx, job, worker, "blender") require.NoError(t, err) - list := []JobBlock{} - tx := db.gormDB.Model(&JobBlock{}).Scan(&list) - require.NoError(t, tx.Error) + list, err := queries.Test_FetchJobBlocklist(ctx) + require.NoError(t, err) assert.Len(t, list, 1, "No new entry should have been created") } } diff --git a/internal/manager/persistence/jobs_query_test.go b/internal/manager/persistence/jobs_query_test.go index 64e399ba..b2b18040 100644 --- a/internal/manager/persistence/jobs_query_test.go +++ b/internal/manager/persistence/jobs_query_test.go @@ -19,6 +19,7 @@ import ( func TestQueryJobTaskSummaries(t *testing.T) { ctx, close, db, job, authoredJob := jobTasksTestFixtures(t) defer close() + queries := db.queries() expectTaskUUIDs := map[string]bool{} for _, task := range authoredJob.Tasks { @@ -37,9 +38,8 @@ func TestQueryJobTaskSummaries(t *testing.T) { persistAuthoredJob(t, ctx, db, otherAuthoredJob) // Sanity check for the above code, there should be 6 tasks overall, 3 per job. - var numTasks int64 - tx := db.gormDB.Model(&Task{}).Count(&numTasks) - require.NoError(t, tx.Error) + numTasks, err := queries.Test_CountTasks(ctx) + require.NoError(t, err) assert.Equal(t, int64(6), numTasks) // Get the task summaries of a particular job. diff --git a/internal/manager/persistence/jobs_test.go b/internal/manager/persistence/jobs_test.go index cd220759..389b9554 100644 --- a/internal/manager/persistence/jobs_test.go +++ b/internal/manager/persistence/jobs_test.go @@ -21,6 +21,7 @@ import ( func TestStoreAuthoredJob(t *testing.T) { ctx, cancel, db := persistenceTestFixtures(1 * time.Second) defer cancel() + queries := db.queries() job := createTestAuthoredJobWithTasks() err := db.StoreAuthoredJob(ctx, job) @@ -40,22 +41,18 @@ func TestStoreAuthoredJob(t *testing.T) { assert.EqualValues(t, map[string]string(job.Metadata), fetchedJob.Metadata) assert.Equal(t, "", fetchedJob.Storage.ShamanCheckoutID) - // Fetch tasks of job. - var dbJob Job - tx := db.gormDB.Where(&Job{UUID: job.JobID}).Find(&dbJob) - require.NoError(t, tx.Error) - var tasks []Task - tx = db.gormDB.Where("job_id = ?", dbJob.ID).Find(&tasks) - require.NoError(t, tx.Error) + // Fetch result of job. + result, err := queries.FetchTasksOfJob(ctx, int64(fetchedJob.ID)) + require.NoError(t, err) - if len(tasks) != 3 { - t.Fatalf("expected 3 tasks, got %d", len(tasks)) + if len(result) != 3 { + t.Fatalf("expected 3 tasks, got %d", len(result)) } // TODO: test task contents. - assert.Equal(t, api.TaskStatusQueued, tasks[0].Status) - assert.Equal(t, api.TaskStatusQueued, tasks[1].Status) - assert.Equal(t, api.TaskStatusQueued, tasks[2].Status) + assert.Equal(t, api.TaskStatusQueued, api.TaskStatus(result[0].Task.Status)) + assert.Equal(t, api.TaskStatusQueued, api.TaskStatus(result[1].Task.Status)) + assert.Equal(t, api.TaskStatusQueued, api.TaskStatus(result[2].Task.Status)) } func TestStoreAuthoredJobWithShamanCheckoutID(t *testing.T) { @@ -180,6 +177,7 @@ func TestSaveJobPriority(t *testing.T) { func TestDeleteJob(t *testing.T) { ctx, cancel, db := persistenceTestFixtures(1 * time.Second) defer cancel() + queries := db.queries() authJob := createTestAuthoredJobWithTasks() authJob.Name = "Job to delete" @@ -199,16 +197,14 @@ func TestDeleteJob(t *testing.T) { assert.ErrorIs(t, err, ErrJobNotFound, "deleted jobs should not be found") // Test that the job is really gone. - var numJobs int64 - tx := db.gormDB.Model(&Job{}).Count(&numJobs) - require.NoError(t, tx.Error) + numJobs, err := queries.Test_CountJobs(ctx) + require.NoError(t, err) assert.Equal(t, int64(1), numJobs, "the job should have been deleted, and the other one should still be there") // Test that the tasks are gone too. - var numTasks int64 - tx = db.gormDB.Model(&Task{}).Count(&numTasks) - require.NoError(t, tx.Error) + numTasks, err := queries.Test_CountTasks(ctx) + require.NoError(t, err) assert.Equal(t, otherJobTaskCount, numTasks, "tasks should have been deleted along with their job, and the other job's tasks should still be there") @@ -218,9 +214,9 @@ func TestDeleteJob(t *testing.T) { assert.Equal(t, otherJob.Name, dbOtherJob.Name) // Test that all the remaining tasks belong to that particular job. - tx = db.gormDB.Model(&Task{}).Where(Task{JobID: dbOtherJob.ID}).Count(&numTasks) - require.NoError(t, tx.Error) - assert.Equal(t, otherJobTaskCount, numTasks, + tasksOfJob, err := queries.FetchTasksOfJob(ctx, int64(dbOtherJob.ID)) + require.NoError(t, err) + assert.Equal(t, len(tasksOfJob), int(numTasks), "all remaining tasks should belong to the other job") } @@ -738,15 +734,13 @@ func TestAddWorkerToTaskFailedList(t *testing.T) { // Deleting the task should also delete the failures. require.NoError(t, db.DeleteJob(ctx, authoredJob.JobID)) - var num int64 - tx := db.gormDB.Model(&TaskFailure{}).Count(&num) - require.NoError(t, tx.Error) - assert.Zero(t, num) + assert.Zero(t, countTaskFailures(ctx, db)) } func TestClearFailureListOfTask(t *testing.T) { ctx, close, db, _, authoredJob := jobTasksTestFixtures(t) defer close() + queries := db.queries() task1, _ := db.FetchTask(ctx, authoredJob.Tasks[1].UUID) task2, _ := db.FetchTask(ctx, authoredJob.Tasks[2].UUID) @@ -769,18 +763,18 @@ func TestClearFailureListOfTask(t *testing.T) { // Clearing should just update this one task. require.NoError(t, db.ClearFailureListOfTask(ctx, task1)) - var failures = []TaskFailure{} - tx := db.gormDB.Model(&TaskFailure{}).Scan(&failures) - require.NoError(t, tx.Error) + failures, err := queries.Test_FetchTaskFailures(ctx) + require.NoError(t, err) if assert.Len(t, failures, 1) { - assert.Equal(t, task2.ID, failures[0].TaskID) - assert.Equal(t, worker1.ID, failures[0].WorkerID) + assert.Equal(t, int64(task2.ID), failures[0].TaskID) + assert.Equal(t, int64(worker1.ID), failures[0].WorkerID) } } func TestClearFailureListOfJob(t *testing.T) { ctx, close, db, dbJob1, authoredJob1 := jobTasksTestFixtures(t) defer close() + queries := db.queries() // Construct a cloned version of the job. authoredJob2 := duplicateJobAndTasks(authoredJob1) @@ -801,18 +795,17 @@ func TestClearFailureListOfJob(t *testing.T) { _, _ = db.AddWorkerToTaskFailedList(ctx, task2_1, worker2) // Sanity check: there should be 5 failures registered now. - assert.Equal(t, 5, countTaskFailures(db)) + assert.Equal(t, 5, countTaskFailures(ctx, db)) // Clearing should be limited to the given job. require.NoError(t, db.ClearFailureListOfJob(ctx, dbJob1)) - var failures = []TaskFailure{} - tx := db.gormDB.Model(&TaskFailure{}).Scan(&failures) - require.NoError(t, tx.Error) + failures, err := queries.Test_FetchTaskFailures(ctx) + require.NoError(t, err) if assert.Len(t, failures, 2) { - assert.Equal(t, task2_1.ID, failures[0].TaskID) - assert.Equal(t, worker1.ID, failures[0].WorkerID) - assert.Equal(t, task2_1.ID, failures[1].TaskID) - assert.Equal(t, worker2.ID, failures[1].WorkerID) + assert.Equal(t, int64(task2_1.ID), failures[0].TaskID) + assert.Equal(t, int64(worker1.ID), failures[0].WorkerID) + assert.Equal(t, int64(task2_1.ID), failures[1].TaskID) + assert.Equal(t, int64(worker2.ID), failures[1].WorkerID) } } @@ -1059,11 +1052,11 @@ func createWorkerFrom(ctx context.Context, t *testing.T, db *DB, worker Worker) return dbWorker } -func countTaskFailures(db *DB) int { - var numFailures int64 - tx := db.gormDB.Model(&TaskFailure{}).Count(&numFailures) - if tx.Error != nil { - panic(tx.Error) +func countTaskFailures(ctx context.Context, db *DB) int { + queries := db.queries() + numFailures, err := queries.Test_CountTaskFailures(ctx) + if err != nil { + panic(err) } if numFailures > math.MaxInt { diff --git a/internal/manager/persistence/sqlc/query_jobs.sql b/internal/manager/persistence/sqlc/query_jobs.sql index 3928b737..63153939 100644 --- a/internal/manager/persistence/sqlc/query_jobs.sql +++ b/internal/manager/persistence/sqlc/query_jobs.sql @@ -274,6 +274,10 @@ WHERE AND job_blocks.worker_id in (SELECT workers.id FROM workers WHERE workers.uuid=@workeruuid) AND job_blocks.task_type = @task_type; +-- name: Test_FetchJobBlocklist :many +-- Fetch all job block list entries. Used only in unit tests. +SELECT * FROM job_blocks; + -- name: WorkersLeftToRun :many SELECT workers.uuid FROM workers WHERE id NOT IN ( @@ -324,3 +328,19 @@ FROM tasks WHERE status = @task_status AND last_touched_at <= @untouched_since; + +-- name: Test_CountJobs :one +-- Count the number of jobs in the database. Only used in unit tests. +SELECT count(*) AS count FROM jobs; + +-- name: Test_CountTasks :one +-- Count the number of tasks in the database. Only used in unit tests. +SELECT count(*) AS count FROM tasks; + +-- name: Test_CountTaskFailures :one +-- Count the number of task failures in the database. Only used in unit tests. +SELECT count(*) AS count FROM task_failures; + +-- name: Test_FetchTaskFailures :many +-- Fetch all task failures in the database. Only used in unit tests. +SELECT * FROM task_failures; diff --git a/internal/manager/persistence/sqlc/query_jobs.sql.go b/internal/manager/persistence/sqlc/query_jobs.sql.go index 80c98666..9ae0ab52 100644 --- a/internal/manager/persistence/sqlc/query_jobs.sql.go +++ b/internal/manager/persistence/sqlc/query_jobs.sql.go @@ -1235,6 +1235,104 @@ func (q *Queries) TaskTouchedByWorker(ctx context.Context, arg TaskTouchedByWork return err } +const test_CountJobs = `-- name: Test_CountJobs :one +SELECT count(*) AS count FROM jobs +` + +// Count the number of jobs in the database. Only used in unit tests. +func (q *Queries) Test_CountJobs(ctx context.Context) (int64, error) { + row := q.db.QueryRowContext(ctx, test_CountJobs) + var count int64 + err := row.Scan(&count) + return count, err +} + +const test_CountTaskFailures = `-- name: Test_CountTaskFailures :one +SELECT count(*) AS count FROM task_failures +` + +// Count the number of task failures in the database. Only used in unit tests. +func (q *Queries) Test_CountTaskFailures(ctx context.Context) (int64, error) { + row := q.db.QueryRowContext(ctx, test_CountTaskFailures) + var count int64 + err := row.Scan(&count) + return count, err +} + +const test_CountTasks = `-- name: Test_CountTasks :one +SELECT count(*) AS count FROM tasks +` + +// Count the number of tasks in the database. Only used in unit tests. +func (q *Queries) Test_CountTasks(ctx context.Context) (int64, error) { + row := q.db.QueryRowContext(ctx, test_CountTasks) + var count int64 + err := row.Scan(&count) + return count, err +} + +const test_FetchJobBlocklist = `-- name: Test_FetchJobBlocklist :many +SELECT id, created_at, job_id, worker_id, task_type FROM job_blocks +` + +// Fetch all job block list entries. Used only in unit tests. +func (q *Queries) Test_FetchJobBlocklist(ctx context.Context) ([]JobBlock, error) { + rows, err := q.db.QueryContext(ctx, test_FetchJobBlocklist) + if err != nil { + return nil, err + } + defer rows.Close() + var items []JobBlock + for rows.Next() { + var i JobBlock + if err := rows.Scan( + &i.ID, + &i.CreatedAt, + &i.JobID, + &i.WorkerID, + &i.TaskType, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const test_FetchTaskFailures = `-- name: Test_FetchTaskFailures :many +SELECT created_at, task_id, worker_id FROM task_failures +` + +// Fetch all task failures in the database. Only used in unit tests. +func (q *Queries) Test_FetchTaskFailures(ctx context.Context) ([]TaskFailure, error) { + rows, err := q.db.QueryContext(ctx, test_FetchTaskFailures) + if err != nil { + return nil, err + } + defer rows.Close() + var items []TaskFailure + for rows.Next() { + var i TaskFailure + if err := rows.Scan(&i.CreatedAt, &i.TaskID, &i.WorkerID); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const updateJobsTaskStatuses = `-- name: UpdateJobsTaskStatuses :exec UPDATE tasks SET updated_at = ?1,