diff --git a/internal/manager/persistence/db.go b/internal/manager/persistence/db.go index a515d6cc..2a0f45c8 100644 --- a/internal/manager/persistence/db.go +++ b/internal/manager/persistence/db.go @@ -179,14 +179,19 @@ func (db *DB) Close() error { // queries returns the SQLC Queries struct, connected to this database. // It is intended that all GORM queries will be migrated to use this interface // instead. -func (db *DB) queries() (*sqlc.Queries, error) { +// +// Note that this function does not return an error. Instead it just panics when +// it cannot obtain the low-level GORM database interface. I have no idea when +// this will ever fail, so I'm opting to simplify the use of this function +// instead. +func (db *DB) queries() *sqlc.Queries { sqldb, err := db.gormDB.DB() if err != nil { - return nil, fmt.Errorf("could not get low-level database driver: %w", err) + panic(fmt.Sprintf("could not get low-level database driver: %v", err)) } loggingWrapper := LoggingDBConn{sqldb} - return sqlc.New(&loggingWrapper), nil + return sqlc.New(&loggingWrapper) } type queriesTX struct { @@ -205,7 +210,7 @@ type queriesTX struct { func (db *DB) queriesWithTX() (*queriesTX, error) { sqldb, err := db.gormDB.DB() if err != nil { - return nil, fmt.Errorf("could not get low-level database driver: %w", err) + panic(fmt.Sprintf("could not get low-level database driver: %v", err)) } tx, err := sqldb.Begin() diff --git a/internal/manager/persistence/integrity.go b/internal/manager/persistence/integrity.go index 04f60560..710b05b7 100644 --- a/internal/manager/persistence/integrity.go +++ b/internal/manager/persistence/integrity.go @@ -89,12 +89,7 @@ func (db *DB) performIntegrityCheck(ctx context.Context) (ok bool) { // // See https: //www.sqlite.org/pragma.html#pragma_integrity_check func (db *DB) pragmaIntegrityCheck(ctx context.Context) (ok bool) { - queries, err := db.queries() - if err != nil { - log.Error().Err(err).Msg("database: could not obtain queries object") - return false - } - + queries := db.queries() issues, err := queries.PragmaIntegrityCheck(ctx) if err != nil { log.Error().Err(err).Msg("database: error checking integrity") diff --git a/internal/manager/persistence/jobs.go b/internal/manager/persistence/jobs.go index 7941da4a..31a9b8cc 100644 --- a/internal/manager/persistence/jobs.go +++ b/internal/manager/persistence/jobs.go @@ -342,10 +342,7 @@ func (db *DB) storeAuthoredJobTaks( // FetchJob fetches a single job, without fetching its tasks. func (db *DB) FetchJob(ctx context.Context, jobUUID string) (*Job, error) { - queries, err := db.queries() - if err != nil { - return nil, err - } + queries := db.queries() sqlcJob, err := queries.FetchJob(ctx, jobUUID) switch { @@ -376,10 +373,7 @@ func (db *DB) FetchJob(ctx context.Context, jobUUID string) (*Job, error) { // FetchJobShamanCheckoutID fetches the job's Shaman Checkout ID. func (db *DB) FetchJobShamanCheckoutID(ctx context.Context, jobUUID string) (string, error) { - queries, err := db.queries() - if err != nil { - return "", err - } + queries := db.queries() checkoutID, err := queries.FetchJobShamanCheckoutID(ctx, jobUUID) switch { @@ -403,10 +397,7 @@ func (db *DB) DeleteJob(ctx context.Context, jobUUID string) error { return ErrDeletingWithoutFK } - queries, err := db.queries() - if err != nil { - return err - } + queries := db.queries() if err := queries.DeleteJob(ctx, jobUUID); err != nil { return jobError(err, "deleting job") @@ -416,10 +407,7 @@ func (db *DB) DeleteJob(ctx context.Context, jobUUID string) error { // RequestJobDeletion sets the job's "DeletionRequestedAt" field to "now". func (db *DB) RequestJobDeletion(ctx context.Context, j *Job) error { - queries, err := db.queries() - if err != nil { - return err - } + queries := db.queries() // Update the given job itself, so we don't have to re-fetch it from the database. j.DeleteRequestedAt = db.now() @@ -442,10 +430,7 @@ func (db *DB) RequestJobDeletion(ctx context.Context, j *Job) error { // RequestJobMassDeletion sets multiple job's "DeletionRequestedAt" field to "now". // The list of affected job UUIDs is returned. func (db *DB) RequestJobMassDeletion(ctx context.Context, lastUpdatedMax time.Time) ([]string, error) { - queries, err := db.queries() - if err != nil { - return nil, err - } + queries := db.queries() // In order to be able to report which jobs were affected, first fetch the // list of jobs, then update them. @@ -473,10 +458,7 @@ func (db *DB) RequestJobMassDeletion(ctx context.Context, lastUpdatedMax time.Ti } func (db *DB) FetchJobsDeletionRequested(ctx context.Context) ([]string, error) { - queries, err := db.queries() - if err != nil { - return nil, err - } + queries := db.queries() uuids, err := queries.FetchJobsDeletionRequested(ctx) if err != nil { @@ -486,10 +468,7 @@ func (db *DB) FetchJobsDeletionRequested(ctx context.Context) ([]string, error) } func (db *DB) FetchJobsInStatus(ctx context.Context, jobStatuses ...api.JobStatus) ([]*Job, error) { - queries, err := db.queries() - if err != nil { - return nil, err - } + queries := db.queries() statuses := []string{} for _, status := range jobStatuses { @@ -515,10 +494,7 @@ func (db *DB) FetchJobsInStatus(ctx context.Context, jobStatuses ...api.JobStatu // SaveJobStatus saves the job's Status and Activity fields. func (db *DB) SaveJobStatus(ctx context.Context, j *Job) error { - queries, err := db.queries() - if err != nil { - return err - } + queries := db.queries() params := sqlc.SaveJobStatusParams{ Now: db.now(), @@ -527,7 +503,7 @@ func (db *DB) SaveJobStatus(ctx context.Context, j *Job) error { Activity: j.Activity, } - err = queries.SaveJobStatus(ctx, params) + err := queries.SaveJobStatus(ctx, params) if err != nil { return jobError(err, "saving job status") } @@ -536,10 +512,7 @@ func (db *DB) SaveJobStatus(ctx context.Context, j *Job) error { // SaveJobPriority saves the job's Priority field. func (db *DB) SaveJobPriority(ctx context.Context, j *Job) error { - queries, err := db.queries() - if err != nil { - return err - } + queries := db.queries() params := sqlc.SaveJobPriorityParams{ Now: db.now(), @@ -547,7 +520,7 @@ func (db *DB) SaveJobPriority(ctx context.Context, j *Job) error { Priority: int64(j.Priority), } - err = queries.SaveJobPriority(ctx, params) + err := queries.SaveJobPriority(ctx, params) if err != nil { return jobError(err, "saving job priority") } @@ -558,17 +531,14 @@ func (db *DB) SaveJobPriority(ctx context.Context, j *Job) error { // NOTE: this function does NOT update the job's `UpdatedAt` field. This is // necessary for `cmd/shaman-checkout-id-setter` to do its work quietly. func (db *DB) SaveJobStorageInfo(ctx context.Context, j *Job) error { - queries, err := db.queries() - if err != nil { - return err - } + queries := db.queries() params := sqlc.SaveJobStorageInfoParams{ ID: int64(j.ID), StorageShamanCheckoutID: j.Storage.ShamanCheckoutID, } - err = queries.SaveJobStorageInfo(ctx, params) + err := queries.SaveJobStorageInfo(ctx, params) if err != nil { return jobError(err, "saving job storage") } @@ -576,10 +546,7 @@ func (db *DB) SaveJobStorageInfo(ctx context.Context, j *Job) error { } func (db *DB) FetchTask(ctx context.Context, taskUUID string) (*Task, error) { - queries, err := db.queries() - if err != nil { - return nil, err - } + queries := db.queries() taskRow, err := queries.FetchTask(ctx, taskUUID) if err != nil { @@ -644,10 +611,7 @@ func convertSqlTaskWithJobAndWorker( // FetchTaskJobUUID fetches the job UUID of the given task. func (db *DB) FetchTaskJobUUID(ctx context.Context, taskUUID string) (string, error) { - queries, err := db.queries() - if err != nil { - return "", err - } + queries := db.queries() jobUUID, err := queries.FetchTaskJobUUID(ctx, taskUUID) if err != nil { @@ -666,10 +630,7 @@ func (db *DB) SaveTask(ctx context.Context, t *Task) error { panic(fmt.Errorf("cannot use this function to insert a task")) } - queries, err := db.queries() - if err != nil { - return err - } + queries := db.queries() commandsJSON, err := json.Marshal(t.Commands) if err != nil { @@ -713,12 +674,9 @@ func (db *DB) SaveTask(ctx context.Context, t *Task) error { } func (db *DB) SaveTaskStatus(ctx context.Context, t *Task) error { - queries, err := db.queries() - if err != nil { - return err - } + queries := db.queries() - err = queries.UpdateTaskStatus(ctx, sqlc.UpdateTaskStatusParams{ + err := queries.UpdateTaskStatus(ctx, sqlc.UpdateTaskStatusParams{ UpdatedAt: db.now(), Status: string(t.Status), ID: int64(t.ID), @@ -730,12 +688,9 @@ func (db *DB) SaveTaskStatus(ctx context.Context, t *Task) error { } func (db *DB) SaveTaskActivity(ctx context.Context, t *Task) error { - queries, err := db.queries() - if err != nil { - return err - } + queries := db.queries() - err = queries.UpdateTaskActivity(ctx, sqlc.UpdateTaskActivityParams{ + err := queries.UpdateTaskActivity(ctx, sqlc.UpdateTaskActivityParams{ UpdatedAt: db.now(), Activity: t.Activity, ID: int64(t.ID), @@ -750,12 +705,9 @@ func (db *DB) SaveTaskActivity(ctx context.Context, t *Task) error { // This function is only used by unit tests. During normal operation, Flamenco // uses the code in task_scheduler.go to assign tasks to workers. func (db *DB) TaskAssignToWorker(ctx context.Context, t *Task, w *Worker) error { - queries, err := db.queries() - if err != nil { - return err - } + queries := db.queries() - err = queries.TaskAssignToWorker(ctx, sqlc.TaskAssignToWorkerParams{ + err := queries.TaskAssignToWorker(ctx, sqlc.TaskAssignToWorkerParams{ UpdatedAt: db.now(), WorkerID: sql.NullInt64{ Int64: int64(w.ID), @@ -775,10 +727,7 @@ func (db *DB) TaskAssignToWorker(ctx context.Context, t *Task, w *Worker) error } func (db *DB) FetchTasksOfWorkerInStatus(ctx context.Context, worker *Worker, taskStatus api.TaskStatus) ([]*Task, error) { - queries, err := db.queries() - if err != nil { - return nil, err - } + queries := db.queries() rows, err := queries.FetchTasksOfWorkerInStatus(ctx, sqlc.FetchTasksOfWorkerInStatusParams{ WorkerID: sql.NullInt64{ @@ -823,10 +772,7 @@ func (db *DB) FetchTasksOfWorkerInStatus(ctx context.Context, worker *Worker, ta } func (db *DB) FetchTasksOfWorkerInStatusOfJob(ctx context.Context, worker *Worker, taskStatus api.TaskStatus, job *Job) ([]*Task, error) { - queries, err := db.queries() - if err != nil { - return nil, err - } + queries := db.queries() rows, err := queries.FetchTasksOfWorkerInStatusOfJob(ctx, sqlc.FetchTasksOfWorkerInStatusOfJobParams{ WorkerID: sql.NullInt64{ @@ -856,10 +802,7 @@ func (db *DB) FetchTasksOfWorkerInStatusOfJob(ctx context.Context, worker *Worke } func (db *DB) JobHasTasksInStatus(ctx context.Context, job *Job, taskStatus api.TaskStatus) (bool, error) { - queries, err := db.queries() - if err != nil { - return false, err - } + queries := db.queries() count, err := queries.JobCountTasksInStatus(ctx, sqlc.JobCountTasksInStatusParams{ JobID: int64(job.ID), @@ -880,10 +823,7 @@ func (db *DB) CountTasksOfJobInStatus( job *Job, taskStatuses ...api.TaskStatus, ) (numInStatus, numTotal int, err error) { - queries, err := db.queries() - if err != nil { - return 0, 0, err - } + queries := db.queries() results, err := queries.JobCountTaskStatuses(ctx, int64(job.ID)) if err != nil { @@ -909,10 +849,7 @@ func (db *DB) CountTasksOfJobInStatus( // FetchTaskIDsOfJob returns all tasks of the given job. func (db *DB) FetchTasksOfJob(ctx context.Context, job *Job) ([]*Task, error) { - queries, err := db.queries() - if err != nil { - return nil, err - } + queries := db.queries() rows, err := queries.FetchTasksOfJob(ctx, int64(job.ID)) if err != nil { @@ -933,10 +870,7 @@ func (db *DB) FetchTasksOfJob(ctx context.Context, job *Job) ([]*Task, error) { // FetchTasksOfJobInStatus returns those tasks of the given job that have any of the given statuses. func (db *DB) FetchTasksOfJobInStatus(ctx context.Context, job *Job, taskStatuses ...api.TaskStatus) ([]*Task, error) { - queries, err := db.queries() - if err != nil { - return nil, err - } + queries := db.queries() rows, err := queries.FetchTasksOfJobInStatus(ctx, sqlc.FetchTasksOfJobInStatusParams{ JobID: int64(job.ID), @@ -966,12 +900,9 @@ func (db *DB) UpdateJobsTaskStatuses(ctx context.Context, job *Job, return taskError(nil, "empty status not allowed") } - queries, err := db.queries() - if err != nil { - return err - } + queries := db.queries() - err = queries.UpdateJobsTaskStatuses(ctx, sqlc.UpdateJobsTaskStatusesParams{ + err := queries.UpdateJobsTaskStatuses(ctx, sqlc.UpdateJobsTaskStatusesParams{ UpdatedAt: db.now(), Status: string(taskStatus), Activity: activity, @@ -993,12 +924,9 @@ func (db *DB) UpdateJobsTaskStatusesConditional(ctx context.Context, job *Job, return taskError(nil, "empty status not allowed") } - queries, err := db.queries() - if err != nil { - return err - } + queries := db.queries() - err = queries.UpdateJobsTaskStatusesConditional(ctx, sqlc.UpdateJobsTaskStatusesConditionalParams{ + err := queries.UpdateJobsTaskStatusesConditional(ctx, sqlc.UpdateJobsTaskStatusesConditionalParams{ UpdatedAt: db.now(), Status: string(taskStatus), Activity: activity, @@ -1014,13 +942,10 @@ func (db *DB) UpdateJobsTaskStatusesConditional(ctx context.Context, job *Job, // TaskTouchedByWorker marks the task as 'touched' by a worker. This is used for timeout detection. func (db *DB) TaskTouchedByWorker(ctx context.Context, t *Task) error { - queries, err := db.queries() - if err != nil { - return err - } + queries := db.queries() now := db.now() - err = queries.TaskTouchedByWorker(ctx, sqlc.TaskTouchedByWorkerParams{ + err := queries.TaskTouchedByWorker(ctx, sqlc.TaskTouchedByWorkerParams{ UpdatedAt: now, LastTouchedAt: now, ID: int64(t.ID), @@ -1044,10 +969,7 @@ func (db *DB) TaskTouchedByWorker(ctx context.Context, t *Task) error { // // Returns the new number of workers that failed this task. func (db *DB) AddWorkerToTaskFailedList(ctx context.Context, t *Task, w *Worker) (numFailed int, err error) { - queries, err := db.queries() - if err != nil { - return 0, err - } + queries := db.queries() err = queries.AddWorkerToTaskFailedList(ctx, sqlc.AddWorkerToTaskFailedListParams{ CreatedAt: db.now().Time, @@ -1074,10 +996,7 @@ func (db *DB) AddWorkerToTaskFailedList(ctx context.Context, t *Task, w *Worker) // ClearFailureListOfTask clears the list of workers that failed this task. func (db *DB) ClearFailureListOfTask(ctx context.Context, t *Task) error { - queries, err := db.queries() - if err != nil { - return err - } + queries := db.queries() return queries.ClearFailureListOfTask(ctx, int64(t.ID)) } @@ -1085,19 +1004,13 @@ func (db *DB) ClearFailureListOfTask(ctx context.Context, t *Task) error { // ClearFailureListOfJob en-mass, for all tasks of this job, clears the list of // workers that failed those tasks. func (db *DB) ClearFailureListOfJob(ctx context.Context, j *Job) error { - queries, err := db.queries() - if err != nil { - return err - } + queries := db.queries() return queries.ClearFailureListOfJob(ctx, int64(j.ID)) } func (db *DB) FetchTaskFailureList(ctx context.Context, t *Task) ([]*Worker, error) { - queries, err := db.queries() - if err != nil { - return nil, err - } + queries := db.queries() failureList, err := queries.FetchTaskFailureList(ctx, int64(t.ID)) if err != nil { diff --git a/internal/manager/persistence/jobs_blocklist.go b/internal/manager/persistence/jobs_blocklist.go index 1121c523..40403689 100644 --- a/internal/manager/persistence/jobs_blocklist.go +++ b/internal/manager/persistence/jobs_blocklist.go @@ -38,10 +38,7 @@ func (db *DB) AddWorkerToJobBlocklist(ctx context.Context, job *Job, worker *Wor panic("Cannot add worker to job blocklist with empty task type") } - queries, err := db.queries() - if err != nil { - return err - } + queries := db.queries() return queries.AddWorkerToJobBlocklist(ctx, sqlc.AddWorkerToJobBlocklistParams{ CreatedAt: db.now().Time, @@ -54,10 +51,7 @@ func (db *DB) AddWorkerToJobBlocklist(ctx context.Context, job *Job, worker *Wor // FetchJobBlocklist fetches the blocklist for the given job. // Workers are fetched too, and embedded in the returned list. func (db *DB) FetchJobBlocklist(ctx context.Context, jobUUID string) ([]JobBlock, error) { - queries, err := db.queries() - if err != nil { - return nil, err - } + queries := db.queries() rows, err := queries.FetchJobBlocklist(ctx, jobUUID) if err != nil { @@ -81,18 +75,12 @@ func (db *DB) FetchJobBlocklist(ctx context.Context, jobUUID string) ([]JobBlock // ClearJobBlocklist removes the entire blocklist of this job. func (db *DB) ClearJobBlocklist(ctx context.Context, job *Job) error { - queries, err := db.queries() - if err != nil { - return err - } + queries := db.queries() return queries.ClearJobBlocklist(ctx, job.UUID) } func (db *DB) RemoveFromJobBlocklist(ctx context.Context, jobUUID, workerUUID, taskType string) error { - queries, err := db.queries() - if err != nil { - return err - } + queries := db.queries() return queries.RemoveFromJobBlocklist(ctx, sqlc.RemoveFromJobBlocklistParams{ JobUUID: jobUUID, WorkerUUID: workerUUID, diff --git a/internal/manager/persistence/last_rendered.go b/internal/manager/persistence/last_rendered.go index 19478aaa..a5ec02aa 100644 --- a/internal/manager/persistence/last_rendered.go +++ b/internal/manager/persistence/last_rendered.go @@ -21,10 +21,7 @@ type LastRendered struct { // SetLastRendered sets this job as the one with the most recent rendered image. func (db *DB) SetLastRendered(ctx context.Context, j *Job) error { - queries, err := db.queries() - if err != nil { - return err - } + queries := db.queries() now := db.now() return queries.SetLastRendered(ctx, sqlc.SetLastRenderedParams{ @@ -36,10 +33,7 @@ func (db *DB) SetLastRendered(ctx context.Context, j *Job) error { // GetLastRendered returns the UUID of the job with the most recent rendered image. func (db *DB) GetLastRenderedJobUUID(ctx context.Context) (string, error) { - queries, err := db.queries() - if err != nil { - return "", err - } + queries := db.queries() jobUUID, err := queries.GetLastRenderedJobUUID(ctx) if errors.Is(err, sql.ErrNoRows) { diff --git a/internal/manager/persistence/workers.go b/internal/manager/persistence/workers.go index 3fd1f7e0..5621087d 100644 --- a/internal/manager/persistence/workers.go +++ b/internal/manager/persistence/workers.go @@ -69,10 +69,7 @@ func (w *Worker) StatusChangeClear() { } func (db *DB) CreateWorker(ctx context.Context, w *Worker) error { - queries, err := db.queries() - if err != nil { - return err - } + queries := db.queries() now := db.now().Time workerID, err := queries.CreateWorker(ctx, sqlc.CreateWorkerParams{ @@ -117,10 +114,7 @@ func (db *DB) CreateWorker(ctx context.Context, w *Worker) error { } func (db *DB) FetchWorker(ctx context.Context, uuid string) (*Worker, error) { - queries, err := db.queries() - if err != nil { - return nil, err - } + queries := db.queries() worker, err := queries.FetchWorker(ctx, uuid) if err != nil { @@ -153,10 +147,7 @@ func (db *DB) DeleteWorker(ctx context.Context, uuid string) error { return ErrDeletingWithoutFK } - queries, err := db.queries() - if err != nil { - return err - } + queries := db.queries() rowsAffected, err := queries.SoftDeleteWorker(ctx, sqlc.SoftDeleteWorkerParams{ DeletedAt: db.now(), @@ -172,10 +163,7 @@ func (db *DB) DeleteWorker(ctx context.Context, uuid string) error { } func (db *DB) FetchWorkers(ctx context.Context) ([]*Worker, error) { - queries, err := db.queries() - if err != nil { - return nil, err - } + queries := db.queries() workers, err := queries.FetchWorkers(ctx) if err != nil { @@ -192,10 +180,7 @@ func (db *DB) FetchWorkers(ctx context.Context) ([]*Worker, error) { // FetchWorkerTask returns the most recent task assigned to the given Worker. func (db *DB) FetchWorkerTask(ctx context.Context, worker *Worker) (*Task, error) { - queries, err := db.queries() - if err != nil { - return nil, err - } + queries := db.queries() // Convert the WorkerID to a NullInt64. As task.worker_id can be NULL, this is // what sqlc expects us to pass in. @@ -238,12 +223,9 @@ func (db *DB) FetchWorkerTask(ctx context.Context, worker *Worker) (*Task, error } func (db *DB) SaveWorkerStatus(ctx context.Context, w *Worker) error { - queries, err := db.queries() - if err != nil { - return err - } + queries := db.queries() - err = queries.SaveWorkerStatus(ctx, sqlc.SaveWorkerStatusParams{ + err := queries.SaveWorkerStatus(ctx, sqlc.SaveWorkerStatusParams{ UpdatedAt: db.now(), Status: string(w.Status), StatusRequested: string(w.StatusRequested), @@ -262,12 +244,9 @@ func (db *DB) SaveWorker(ctx context.Context, w *Worker) error { return db.CreateWorker(ctx, w) } - queries, err := db.queries() - if err != nil { - return err - } + queries := db.queries() - err = queries.SaveWorker(ctx, sqlc.SaveWorkerParams{ + err := queries.SaveWorker(ctx, sqlc.SaveWorkerParams{ UpdatedAt: db.now(), UUID: w.UUID, Secret: w.Secret, @@ -291,13 +270,10 @@ func (db *DB) SaveWorker(ctx context.Context, w *Worker) error { // WorkerSeen marks the worker as 'seen' by this Manager. This is used for timeout detection. func (db *DB) WorkerSeen(ctx context.Context, w *Worker) error { - queries, err := db.queries() - if err != nil { - return err - } + queries := db.queries() now := db.now() - err = queries.WorkerSeen(ctx, sqlc.WorkerSeenParams{ + err := queries.WorkerSeen(ctx, sqlc.WorkerSeenParams{ UpdatedAt: now, LastSeenAt: now, ID: int64(w.ID), @@ -315,10 +291,7 @@ func (db *DB) SummarizeWorkerStatuses(ctx context.Context) (WorkerStatusCount, e logger := log.Ctx(ctx) logger.Debug().Msg("database: summarizing worker statuses") - queries, err := db.queries() - if err != nil { - return nil, err - } + queries := db.queries() rows, err := queries.SummarizeWorkerStatuses(ctx) if err != nil {