diff --git a/internal/manager/persistence/db.go b/internal/manager/persistence/db.go index 2a0f45c8..fb3325b8 100644 --- a/internal/manager/persistence/db.go +++ b/internal/manager/persistence/db.go @@ -134,7 +134,9 @@ func openDBWithConfig(dsn string, config *gorm.Config) (*DB, error) { sqlDB.SetMaxOpenConns(1) // Max num of open connections to the database. // Always enable foreign key checks, to make SQLite behave like a real database. - if err := db.pragmaForeignKeys(true); err != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := db.pragmaForeignKeys(ctx, true); err != nil { return nil, err } @@ -237,29 +239,22 @@ func (db *DB) now() sql.NullTime { } } -func (db *DB) pragmaForeignKeys(enabled bool) error { - var ( - value int - noun string - ) +func (db *DB) pragmaForeignKeys(ctx context.Context, enabled bool) error { + var noun string switch enabled { case false: - value = 0 noun = "disabl" case true: - value = 1 noun = "enabl" } log.Trace().Msgf("%sing SQLite foreign key checks", noun) - // SQLite doesn't seem to like SQL parameters for `PRAGMA`, so `PRAGMA foreign_keys = ?` doesn't work. - sql := fmt.Sprintf("PRAGMA foreign_keys = %d", value) - - if tx := db.gormDB.Exec(sql); tx.Error != nil { - return fmt.Errorf("%sing foreign keys: %w", noun, tx.Error) + queries := db.queries() + if err := queries.PragmaForeignKeysSet(ctx, enabled); err != nil { + return fmt.Errorf("%sing foreign keys: %w", noun, err) } - fkEnabled, err := db.areForeignKeysEnabled() + fkEnabled, err := db.areForeignKeysEnabled(ctx) if err != nil { return err } @@ -270,12 +265,13 @@ func (db *DB) pragmaForeignKeys(enabled bool) error { return nil } -func (db *DB) areForeignKeysEnabled() (bool, error) { +func (db *DB) areForeignKeysEnabled(ctx context.Context) (bool, error) { log.Trace().Msg("checking whether SQLite foreign key checks are enabled") - var fkEnabled int - if tx := db.gormDB.Raw("PRAGMA foreign_keys").Scan(&fkEnabled); tx.Error != nil { - return false, fmt.Errorf("checking whether the database has foreign key checks are enabled: %w", tx.Error) + queries := db.queries() + fkEnabled, err := queries.PragmaForeignKeysGet(ctx) + if err != nil { + return false, fmt.Errorf("checking whether the database has foreign key checks are enabled: %w", err) } - return fkEnabled != 0, nil + return fkEnabled, nil } diff --git a/internal/manager/persistence/db_migration.go b/internal/manager/persistence/db_migration.go index 416c8a8f..4c0bacef 100644 --- a/internal/manager/persistence/db_migration.go +++ b/internal/manager/persistence/db_migration.go @@ -41,7 +41,7 @@ func (db *DB) migrate(ctx context.Context) error { // of data, foreign keys are disabled here instead of in the migration SQL // files, so that it can't be forgotten. - if err := db.pragmaForeignKeys(false); err != nil { + if err := db.pragmaForeignKeys(ctx, false); err != nil { log.Fatal().AnErr("cause", err).Msgf("could not disable foreign key constraints before performing database migrations, please report a bug at %s", website.BugReportURL) } @@ -52,7 +52,7 @@ func (db *DB) migrate(ctx context.Context) error { } // Re-enable foreign key checks. - if err := db.pragmaForeignKeys(true); err != nil { + if err := db.pragmaForeignKeys(ctx, true); err != nil { log.Fatal().AnErr("cause", err).Msgf("could not re-enable foreign key constraints after performing database migrations, please report a bug at %s", website.BugReportURL) } diff --git a/internal/manager/persistence/integrity.go b/internal/manager/persistence/integrity.go index 710b05b7..61dd336c 100644 --- a/internal/manager/persistence/integrity.go +++ b/internal/manager/persistence/integrity.go @@ -74,7 +74,7 @@ func (db *DB) performIntegrityCheck(ctx context.Context) (ok bool) { log.Debug().Msg("database: performing integrity check") - db.ensureForeignKeysEnabled() + db.ensureForeignKeysEnabled(checkCtx) if !db.pragmaIntegrityCheck(checkCtx) { return false @@ -162,8 +162,8 @@ func (db *DB) pragmaForeignKeyCheck(ctx context.Context) (ok bool) { // connection to the low-level SQLite driver. Unfortunately the GORM-embedded // SQLite doesn't have an 'on-connect' callback function to always enable // foreign keys. -func (db *DB) ensureForeignKeysEnabled() { - fkEnabled, err := db.areForeignKeysEnabled() +func (db *DB) ensureForeignKeysEnabled(ctx context.Context) { + fkEnabled, err := db.areForeignKeysEnabled(ctx) if err != nil { log.Error().AnErr("cause", err).Msg("database: could not check whether foreign keys are enabled") @@ -175,7 +175,7 @@ func (db *DB) ensureForeignKeysEnabled() { } log.Warn().Msg("database: foreign keys are disabled, re-enabling them") - if err := db.pragmaForeignKeys(true); err != nil { + if err := db.pragmaForeignKeys(ctx, true); err != nil { log.Error().AnErr("cause", err).Msg("database: error re-enabling foreign keys") return } diff --git a/internal/manager/persistence/jobs.go b/internal/manager/persistence/jobs.go index 31a9b8cc..352e488e 100644 --- a/internal/manager/persistence/jobs.go +++ b/internal/manager/persistence/jobs.go @@ -389,11 +389,11 @@ func (db *DB) FetchJobShamanCheckoutID(ctx context.Context, jobUUID string) (str // The deletion cascades to its tasks and other job-related tables. func (db *DB) DeleteJob(ctx context.Context, jobUUID string) error { // As a safety measure, refuse to delete jobs unless foreign key constraints are active. - fkEnabled, err := db.areForeignKeysEnabled() - if err != nil { - return fmt.Errorf("checking whether foreign keys are enabled: %w", err) - } - if !fkEnabled { + fkEnabled, err := db.areForeignKeysEnabled(ctx) + switch { + case err != nil: + return err + case !fkEnabled: return ErrDeletingWithoutFK } diff --git a/internal/manager/persistence/jobs_test.go b/internal/manager/persistence/jobs_test.go index 7a2cb85e..cd220759 100644 --- a/internal/manager/persistence/jobs_test.go +++ b/internal/manager/persistence/jobs_test.go @@ -255,7 +255,7 @@ func TestDeleteJobWithoutFK(t *testing.T) { authJob.Name = "Job to delete" persistAuthoredJob(t, ctx, db, authJob) - require.NoError(t, db.pragmaForeignKeys(false)) + require.NoError(t, db.pragmaForeignKeys(ctx, false)) err := db.DeleteJob(ctx, authJob.JobID) require.ErrorIs(t, err, ErrDeletingWithoutFK) diff --git a/internal/manager/persistence/sqlc/integrity.go b/internal/manager/persistence/sqlc/integrity.go index 99ae7831..e412a2d7 100644 --- a/internal/manager/persistence/sqlc/integrity.go +++ b/internal/manager/persistence/sqlc/integrity.go @@ -38,3 +38,28 @@ func (q *Queries) PragmaIntegrityCheck(ctx context.Context) ([]PragmaIntegrityCh } return items, nil } + +// SQLite doesn't seem to like SQL parameters for `PRAGMA`, so `PRAGMA foreign_keys = ?` doesn't work. +const pragmaForeignKeysEnable = `PRAGMA foreign_keys = 1` +const pragmaForeignKeysDisable = `PRAGMA foreign_keys = 0` + +func (q *Queries) PragmaForeignKeysSet(ctx context.Context, enable bool) error { + var sql string + if enable { + sql = pragmaForeignKeysEnable + } else { + sql = pragmaForeignKeysDisable + } + + _, err := q.db.ExecContext(ctx, sql) + return err +} + +const pragmaForeignKeys = `PRAGMA foreign_keys` + +func (q *Queries) PragmaForeignKeysGet(ctx context.Context) (bool, error) { + row := q.db.QueryRowContext(ctx, pragmaForeignKeys) + var fkEnabled bool + err := row.Scan(&fkEnabled) + return fkEnabled, err +} diff --git a/internal/manager/persistence/worker_tag.go b/internal/manager/persistence/worker_tag.go index 0fdebeec..0ddb4f9f 100644 --- a/internal/manager/persistence/worker_tag.go +++ b/internal/manager/persistence/worker_tag.go @@ -73,11 +73,11 @@ func (db *DB) SaveWorkerTag(ctx context.Context, tag *WorkerTag) error { // DeleteWorkerTag deletes the given tag, after unassigning all workers from it. func (db *DB) DeleteWorkerTag(ctx context.Context, uuid string) error { // As a safety measure, refuse to delete unless foreign key constraints are active. - fkEnabled, err := db.areForeignKeysEnabled() - if err != nil { - return fmt.Errorf("checking whether foreign keys are enabled: %w", err) - } - if !fkEnabled { + fkEnabled, err := db.areForeignKeysEnabled(ctx) + switch { + case err != nil: + return err + case !fkEnabled: return ErrDeletingWithoutFK } diff --git a/internal/manager/persistence/worker_tag_test.go b/internal/manager/persistence/worker_tag_test.go index f754cfa6..955da9cd 100644 --- a/internal/manager/persistence/worker_tag_test.go +++ b/internal/manager/persistence/worker_tag_test.go @@ -85,7 +85,7 @@ func TestDeleteTagsWithoutFK(t *testing.T) { require.NoError(t, f.db.CreateWorkerTag(f.ctx, &secondTag)) // Try deleting with foreign key constraints disabled. - require.NoError(t, f.db.pragmaForeignKeys(false)) + require.NoError(t, f.db.pragmaForeignKeys(f.ctx, false)) err = f.db.DeleteWorkerTag(f.ctx, f.tag.UUID) require.ErrorIs(t, err, ErrDeletingWithoutFK) diff --git a/internal/manager/persistence/workers.go b/internal/manager/persistence/workers.go index 5621087d..6c51aa48 100644 --- a/internal/manager/persistence/workers.go +++ b/internal/manager/persistence/workers.go @@ -139,11 +139,11 @@ func (db *DB) FetchWorker(ctx context.Context, uuid string) (*Worker, error) { func (db *DB) DeleteWorker(ctx context.Context, uuid string) error { // As a safety measure, refuse to delete unless foreign key constraints are active. - fkEnabled, err := db.areForeignKeysEnabled() - if err != nil { - return fmt.Errorf("checking whether foreign keys are enabled: %w", err) - } - if !fkEnabled { + fkEnabled, err := db.areForeignKeysEnabled(ctx) + switch { + case err != nil: + return err + case !fkEnabled: return ErrDeletingWithoutFK } diff --git a/internal/manager/persistence/workers_test.go b/internal/manager/persistence/workers_test.go index 976f6ca7..77f6119c 100644 --- a/internal/manager/persistence/workers_test.go +++ b/internal/manager/persistence/workers_test.go @@ -342,7 +342,7 @@ func TestDeleteWorkerNoForeignKeys(t *testing.T) { require.NoError(t, db.CreateWorker(ctx, &w1)) // Try deleting with foreign key constraints disabled. - require.NoError(t, db.pragmaForeignKeys(false)) + require.NoError(t, db.pragmaForeignKeys(ctx, false)) require.ErrorIs(t, ErrDeletingWithoutFK, db.DeleteWorker(ctx, w1.UUID)) // The worker should still exist.