diff --git a/internal/manager/persistence/errors.go b/internal/manager/persistence/errors.go new file mode 100644 index 00000000..72fefcd7 --- /dev/null +++ b/internal/manager/persistence/errors.go @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: GPL-3.0-or-later +package persistence + +import ( + "errors" + "fmt" + + "gorm.io/gorm" +) + +var ( + ErrJobNotFound = PersistenceError{Message: "job not found", Err: gorm.ErrRecordNotFound} + ErrTaskNotFound = PersistenceError{Message: "task not found", Err: gorm.ErrRecordNotFound} +) + +type PersistenceError struct { + Message string // The error message. + Err error // Any wrapped error. +} + +func (e PersistenceError) Error() string { + return fmt.Sprintf("%s: %v", e.Message, e.Err) +} + +func (e PersistenceError) Is(err error) bool { + return err == e.Err +} + +func jobError(errorToWrap error, message string, msgArgs ...interface{}) error { + return wrapError(translateGormJobError(errorToWrap), message, msgArgs...) +} + +func taskError(errorToWrap error, message string, msgArgs ...interface{}) error { + return wrapError(translateGormTaskError(errorToWrap), message, msgArgs...) +} + +func wrapError(errorToWrap error, message string, format ...interface{}) error { + // Only format if there are arguments for formatting. + var formattedMsg string + if len(format) > 0 { + formattedMsg = fmt.Sprintf(message, format...) + } else { + formattedMsg = message + } + + return PersistenceError{ + Message: formattedMsg, + Err: errorToWrap, + } +} + +// translateGormJobError translates a Gorm error to a persistence layer error. +// This helps to keep Gorm as "implementation detail" of the persistence layer. +func translateGormJobError(gormError error) error { + if errors.Is(gormError, gorm.ErrRecordNotFound) { + return ErrJobNotFound + } + return gormError +} + +// translateGormTaskError translates a Gorm error to a persistence layer error. +// This helps to keep Gorm as "implementation detail" of the persistence layer. +func translateGormTaskError(gormError error) error { + if errors.Is(gormError, gorm.ErrRecordNotFound) { + return ErrTaskNotFound + } + return gormError +} diff --git a/internal/manager/persistence/errors_test.go b/internal/manager/persistence/errors_test.go new file mode 100644 index 00000000..4a3f7d70 --- /dev/null +++ b/internal/manager/persistence/errors_test.go @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: GPL-3.0-or-later +package persistence + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "gorm.io/gorm" +) + +func TestNotFoundErrors(t *testing.T) { + assert.ErrorIs(t, ErrJobNotFound, gorm.ErrRecordNotFound) + assert.ErrorIs(t, ErrTaskNotFound, gorm.ErrRecordNotFound) + + assert.Contains(t, ErrJobNotFound.Error(), "job") + assert.Contains(t, ErrTaskNotFound.Error(), "task") +} + +func TestTranslateGormJobError(t *testing.T) { + assert.Nil(t, translateGormJobError(nil)) + assert.Equal(t, ErrJobNotFound, translateGormJobError(gorm.ErrRecordNotFound)) + + otherError := errors.New("this error is not special for this function") + assert.Equal(t, otherError, translateGormJobError(otherError)) +} + +func TestTranslateGormTaskError(t *testing.T) { + assert.Nil(t, translateGormTaskError(nil)) + assert.Equal(t, ErrTaskNotFound, translateGormTaskError(gorm.ErrRecordNotFound)) + + otherError := errors.New("this error is not special for this function") + assert.Equal(t, otherError, translateGormTaskError(otherError)) +} diff --git a/internal/manager/persistence/jobs.go b/internal/manager/persistence/jobs.go index a7517f54..ac6fcc5b 100644 --- a/internal/manager/persistence/jobs.go +++ b/internal/manager/persistence/jobs.go @@ -7,7 +7,6 @@ import ( "database/sql/driver" "encoding/json" "errors" - "fmt" "gorm.io/gorm" @@ -109,7 +108,7 @@ func (db *DB) StoreAuthoredJob(ctx context.Context, authoredJob job_compilers.Au } if err := tx.Create(&dbJob).Error; err != nil { - return fmt.Errorf("storing job: %v", err) + return jobError(err, "storing job") } uuidToTask := make(map[string]*Task) @@ -133,7 +132,7 @@ func (db *DB) StoreAuthoredJob(ctx context.Context, authoredJob job_compilers.Au // dependencies are stored below. } if err := tx.Create(&dbTask).Error; err != nil { - return fmt.Errorf("storing task: %v", err) + return taskError(err, "storing task: %v", err) } uuidToTask[authoredTask.UUID] = &dbTask @@ -147,21 +146,21 @@ func (db *DB) StoreAuthoredJob(ctx context.Context, authoredJob job_compilers.Au dbTask, ok := uuidToTask[authoredTask.UUID] if !ok { - return fmt.Errorf("unable to find task %q in the database, even though it was just authored", authoredTask.UUID) + return taskError(nil, "unable to find task %q in the database, even though it was just authored", authoredTask.UUID) } deps := make([]*Task, len(authoredTask.Dependencies)) for i, t := range authoredTask.Dependencies { depTask, ok := uuidToTask[t.UUID] if !ok { - return fmt.Errorf("finding task with UUID %q; a task depends on a task that is not part of this job", t.UUID) + return taskError(nil, "finding task with UUID %q; a task depends on a task that is not part of this job", t.UUID) } deps[i] = depTask } dbTask.Dependencies = deps if err := tx.Save(dbTask).Error; err != nil { - return fmt.Errorf("unable to store dependencies of task %q: %w", authoredTask.UUID, err) + return taskError(err, "unable to store dependencies of task %q", authoredTask.UUID) } } @@ -173,7 +172,7 @@ func (db *DB) FetchJob(ctx context.Context, jobUUID string) (*Job, error) { dbJob := Job{} findResult := db.gormDB.WithContext(ctx).First(&dbJob, "uuid = ?", jobUUID) if findResult.Error != nil { - return nil, findResult.Error + return nil, jobError(findResult.Error, "fetching job") } return &dbJob, nil @@ -184,7 +183,7 @@ func (db *DB) SaveJobStatus(ctx context.Context, j *Job) error { Model(j). Updates(Job{Status: j.Status}) if tx.Error != nil { - return fmt.Errorf("saving job status: %w", tx.Error) + return jobError(tx.Error, "saving job status") } return nil } @@ -195,21 +194,21 @@ func (db *DB) FetchTask(ctx context.Context, taskUUID string) (*Task, error) { Joins("Job"). First(&dbTask, "tasks.uuid = ?", taskUUID) if tx.Error != nil { - return nil, tx.Error + return nil, taskError(tx.Error, "fetching task") } return &dbTask, nil } func (db *DB) SaveTask(ctx context.Context, t *Task) error { if err := db.gormDB.WithContext(ctx).Save(t).Error; err != nil { - return fmt.Errorf("saving task: %w", err) + return taskError(err, "saving task") } return nil } func (db *DB) SaveTaskActivity(ctx context.Context, t *Task) error { if err := db.gormDB.Model(t).Updates(Task{Activity: t.Activity}).Error; err != nil { - return fmt.Errorf("saving task activity: %w", err) + return taskError(err, "saving task activity") } return nil } @@ -218,7 +217,7 @@ func (db *DB) TaskAssignToWorker(ctx context.Context, t *Task, w *Worker) error tx := db.gormDB.WithContext(ctx). Model(t).Updates(Task{WorkerID: &w.ID}) if tx.Error != nil { - return fmt.Errorf("assigning task %s to worker %s: %w", t.UUID, w.UUID, tx.Error) + return taskError(tx.Error, "assigning task %s to worker %s", t.UUID, w.UUID) } // Gorm updates t.WorkerID itself, but not t.Worker (even when it's added to @@ -237,7 +236,7 @@ func (db *DB) FetchTasksOfWorkerInStatus(ctx context.Context, worker *Worker, ta Where("tasks.status = ?", taskStatus). Scan(&result) if tx.Error != nil { - return nil, fmt.Errorf("finding tasks of worker %s in status %q: %w", worker.UUID, taskStatus, tx.Error) + return nil, taskError(tx.Error, "finding tasks of worker %s in status %q", worker.UUID, taskStatus) } return result, nil } @@ -250,7 +249,7 @@ func (db *DB) JobHasTasksInStatus(ctx context.Context, job *Job, taskStatus api. Where("status", taskStatus). Count(&numTasksInStatus) if tx.Error != nil { - return false, tx.Error + return false, taskError(tx.Error, "counting tasks of job %s in status %q", job.UUID, taskStatus) } return numTasksInStatus > 0, nil } @@ -270,7 +269,7 @@ func (db *DB) CountTasksOfJobInStatus(ctx context.Context, job *Job, taskStatus Scan(&results) if tx.Error != nil { - return 0, 0, fmt.Errorf("count tasks of job %s in status %q: %w", job.UUID, taskStatus, tx.Error) + return 0, 0, jobError(tx.Error, "count tasks of job %s in status %q", job.UUID, taskStatus) } for _, result := range results { @@ -283,12 +282,12 @@ func (db *DB) CountTasksOfJobInStatus(ctx context.Context, job *Job, taskStatus return } -// UpdateJobsTaskStatuses updates the status & activity of the tasks of `job`. +// UpdateJobsTaskStatuses updates the status & activity of all tasks of `job`. func (db *DB) UpdateJobsTaskStatuses(ctx context.Context, job *Job, taskStatus api.TaskStatus, activity string) error { if taskStatus == "" { - return errors.New("empty status not allowed") + return taskError(nil, "empty status not allowed") } tx := db.gormDB.WithContext(ctx). @@ -297,7 +296,7 @@ func (db *DB) UpdateJobsTaskStatuses(ctx context.Context, job *Job, Updates(Task{Status: taskStatus, Activity: activity}) if tx.Error != nil { - return tx.Error + return taskError(tx.Error, "updating status of all tasks of job %s", job.UUID) } return nil } @@ -308,7 +307,7 @@ func (db *DB) UpdateJobsTaskStatusesConditional(ctx context.Context, job *Job, statusesToUpdate []api.TaskStatus, taskStatus api.TaskStatus, activity string) error { if taskStatus == "" { - return errors.New("empty status not allowed") + return taskError(nil, "empty status not allowed") } tx := db.gormDB.WithContext(ctx). @@ -316,5 +315,8 @@ func (db *DB) UpdateJobsTaskStatusesConditional(ctx context.Context, job *Job, Where("job_Id = ?", job.ID). Where("status in ?", statusesToUpdate). Updates(Task{Status: taskStatus, Activity: activity}) - return tx.Error + if tx.Error != nil { + return taskError(tx.Error, "updating status of all tasks in status %v of job %s", statusesToUpdate, job.UUID) + } + return nil }