From aa964ac2053df2b28f7d72c26cdf0db8360b0560 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sybren=20A=2E=20St=C3=BCvel?= Date: Thu, 26 Sep 2024 23:30:28 +0200 Subject: [PATCH] Worker: replace GORM with sqlc Ref: #104305 --- internal/worker/persistence/db.go | 238 ++++++++++-------- internal/worker/persistence/db_migration.go | 12 +- internal/worker/persistence/integrity.go | 168 +++++++++++++ internal/worker/persistence/logger.go | 133 ++-------- .../0002_sqlc_compat_notnullable.sql | 34 +++ internal/worker/persistence/sqlc/db.go | 31 +++ internal/worker/persistence/sqlc/models.go | 16 ++ internal/worker/persistence/sqlc/pragma.go | 131 ++++++++++ internal/worker/persistence/sqlc/query.sql | 19 ++ internal/worker/persistence/sqlc/query.sql.go | 70 ++++++ internal/worker/persistence/sqlc/schema.sql | 7 + internal/worker/persistence/sqlite_busy.go | 20 +- .../worker/persistence/upstream_buffer.go | 56 ++--- sqlc.yaml | 17 ++ 14 files changed, 693 insertions(+), 259 deletions(-) create mode 100644 internal/worker/persistence/integrity.go create mode 100644 internal/worker/persistence/migrations/0002_sqlc_compat_notnullable.sql create mode 100644 internal/worker/persistence/sqlc/db.go create mode 100644 internal/worker/persistence/sqlc/models.go create mode 100644 internal/worker/persistence/sqlc/pragma.go create mode 100644 internal/worker/persistence/sqlc/query.sql create mode 100644 internal/worker/persistence/sqlc/query.sql.go create mode 100644 internal/worker/persistence/sqlc/schema.sql diff --git a/internal/worker/persistence/db.go b/internal/worker/persistence/db.go index ff033875..453712d3 100644 --- a/internal/worker/persistence/db.go +++ b/internal/worker/persistence/db.go @@ -5,28 +5,23 @@ package persistence import ( "context" + "database/sql" "fmt" "time" "github.com/rs/zerolog/log" - "gorm.io/gorm" + _ "modernc.org/sqlite" - "github.com/glebarez/sqlite" + "projects.blender.org/studio/flamenco/internal/worker/persistence/sqlc" ) // DB provides the database interface. type DB struct { - gormDB *gorm.DB -} + sqlDB *sql.DB + nowfunc func() time.Time -// Model contains the common database fields for most model structs. It is a -// copy of the gorm.Model struct, but without the `DeletedAt` or `UpdatedAt` -// fields. Soft deletion is not used by Flamenco, and the upstream buffer will -// not be updated. If it ever becomes necessary to support soft-deletion, see -// https://gorm.io/docs/delete.html#Soft-Delete -type Model struct { - ID uint `gorm:"primarykey"` - CreatedAt time.Time + // See PeriodicIntegrityCheck(). + consistencyCheckRequests chan struct{} } func OpenDB(ctx context.Context, dsn string) (*DB, error) { @@ -37,44 +32,6 @@ func OpenDB(ctx context.Context, dsn string) (*DB, error) { return nil, err } - if err := setBusyTimeout(db.gormDB, 5*time.Second); err != nil { - return nil, err - } - - // Perfom some maintenance at startup. - db.vacuum() - - if err := db.migrate(ctx); err != nil { - return nil, err - } - log.Debug().Msg("database automigration succesful") - - return db, nil -} - -func openDB(ctx context.Context, dsn string) (*DB, error) { - globalLogLevel := log.Logger.GetLevel() - dblogger := NewDBLogger(log.Level(globalLogLevel)) - - config := gorm.Config{ - Logger: dblogger, - NowFunc: nowFunc, - } - - return openDBWithConfig(dsn, &config) -} - -func openDBWithConfig(dsn string, config *gorm.Config) (*DB, error) { - dialector := sqlite.Open(dsn) - gormDB, err := gorm.Open(dialector, config) - if err != nil { - return nil, err - } - - db := &DB{ - gormDB: gormDB, - } - // Close the database connection if there was some error. This prevents // leaking database connections & should remove any write-ahead-log files. closeConnOnReturn := true @@ -87,87 +44,171 @@ func openDBWithConfig(dsn string, config *gorm.Config) (*DB, error) { } }() - // Use the generic sql.DB interface to set some connection pool options. - sqlDB, err := gormDB.DB() + if err := db.setBusyTimeout(ctx, 5*time.Second); err != nil { + return nil, err + } + + // Perfom some maintenance at startup, before trying to migrate the database. + if !db.performIntegrityCheck(ctx) { + return nil, ErrIntegrity + } + + db.vacuum(ctx) + + if err := db.migrate(ctx); err != nil { + return nil, err + } + log.Debug().Msg("database automigration succesful") + + // Perfom post-migration integrity check, just to be sure. + if !db.performIntegrityCheck(ctx) { + return nil, ErrIntegrity + } + + // Perform another vacuum after database migration, as that may have copied a + // lot of data and then dropped another lot of data. + db.vacuum(ctx) + + closeConnOnReturn = false + return db, nil +} + +func openDB(ctx context.Context, dsn string) (*DB, error) { + // Connect to the database. + sqlDB, err := sql.Open("sqlite", dsn) if err != nil { return nil, err } + // Close the database connection if there was some error. This prevents + // leaking database connections & should remove any write-ahead-log files. + closeConnOnReturn := true + defer func() { + if !closeConnOnReturn { + return + } + if err := sqlDB.Close(); err != nil { + log.Debug().AnErr("cause", err).Msg("cannot close database connection") + } + }() + // Only allow a single database connection, to avoid SQLITE_BUSY errors. // It's not certain that this'll improve the situation, but it's worth a try. sqlDB.SetMaxIdleConns(1) // Max num of connections in the idle connection pool. sqlDB.SetMaxOpenConns(1) // Max num of open connections to the database. + db := DB{ + sqlDB: sqlDB, + nowfunc: func() time.Time { return time.Now().UTC() }, + + // Buffer one request, so that even when a consistency check is already + // running, another can be queued without blocking. Queueing more than one + // doesn't make sense, though. + consistencyCheckRequests: make(chan struct{}, 1), + } + // Always enable foreign key checks, to make SQLite behave like a real database. - log.Trace().Msg("enabling SQLite foreign key checks") - if err := db.pragmaForeignKeys(true); err != nil { + pragmaCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + if err := db.pragmaForeignKeys(pragmaCtx, true); err != nil { return nil, err } + queries := db.queries() + // Write-ahead-log journal may improve writing speed. log.Trace().Msg("enabling SQLite write-ahead-log journal mode") - if tx := gormDB.Exec("PRAGMA journal_mode = WAL"); tx.Error != nil { - return nil, fmt.Errorf("enabling SQLite write-ahead-log journal mode: %w", tx.Error) + if err := queries.PragmaJournalModeWAL(pragmaCtx); err != nil { + return nil, fmt.Errorf("enabling SQLite write-ahead-log journal mode: %w", err) } // Switching from 'full' (default) to 'normal' sync may improve writing speed. log.Trace().Msg("enabling SQLite 'normal' synchronisation") - if tx := gormDB.Exec("PRAGMA synchronous = normal"); tx.Error != nil { - return nil, fmt.Errorf("enabling SQLite 'normal' sync mode: %w", tx.Error) + if err := queries.PragmaSynchronousNormal(pragmaCtx); err != nil { + return nil, fmt.Errorf("enabling SQLite 'normal' sync mode: %w", err) } closeConnOnReturn = false - return db, nil -} - -// nowFunc returns 'now' in UTC, so that GORM-managed times (createdAt, -// deletedAt, updatedAt) are stored in UTC. -func nowFunc() time.Time { - return time.Now().UTC() + return &db, nil } // vacuum executes the SQL "VACUUM" command, and logs any errors. -func (db *DB) vacuum() { - tx := db.gormDB.Exec("vacuum") - if tx.Error != nil { - log.Error().Err(tx.Error).Msg("error vacuuming database") - } -} - -func (db *DB) Close() error { - sqlDB, err := db.gormDB.DB() +func (db *DB) vacuum(ctx context.Context) { + err := db.queries().Vacuum(ctx) if err != nil { - return fmt.Errorf("getting generic database interface: %w", err) + log.Error().Err(err).Msg("error vacuuming database") } - - if err := sqlDB.Close(); err != nil { - return fmt.Errorf("closing database: %w", err) - } - return nil } -func (db *DB) pragmaForeignKeys(enabled bool) error { - var ( - value int - noun string - ) +// Close closes the connection to the database. +func (db *DB) Close() error { + return db.sqlDB.Close() +} + +// queries returns the SQLC Queries struct, connected to this database. +func (db *DB) queries() *sqlc.Queries { + loggingWrapper := LoggingDBConn{db.sqlDB} + return sqlc.New(&loggingWrapper) +} + +type queriesTX struct { + queries *sqlc.Queries + commit func() error + rollback func() error +} + +// queries returns the SQLC Queries struct, connected to this database. +// +// After calling this function, all queries should use this transaction until it +// is closed (either committed or rolled back). Otherwise SQLite will deadlock, +// as it will make any other query wait until this transaction is done. +func (db *DB) queriesWithTX() (*queriesTX, error) { + tx, err := db.sqlDB.Begin() + if err != nil { + return nil, fmt.Errorf("could not begin database transaction: %w", err) + } + + loggingWrapper := LoggingDBConn{tx} + + qtx := queriesTX{ + queries: sqlc.New(&loggingWrapper), + commit: tx.Commit, + rollback: tx.Rollback, + } + + return &qtx, nil +} + +// now returns 'now' as reported by db.nowfunc. +// It always converts the timestamp to UTC. +func (db *DB) now() time.Time { + return db.nowfunc() +} + +// nowNullable returns the result of `now()` wrapped in a sql.NullTime. +// It is nullable just for ease of use, it will never actually be null. +func (db *DB) nowNullable() sql.NullTime { + return sql.NullTime{ + Time: db.now(), + Valid: true, + } +} + +func (db *DB) pragmaForeignKeys(ctx context.Context, enabled bool) error { + var noun string switch enabled { case false: - value = 0 noun = "disabl" case true: - value = 1 noun = "enabl" } log.Trace().Msgf("%sing SQLite foreign key checks", noun) - // SQLite doesn't seem to like SQL parameters for `PRAGMA`, so `PRAGMA foreign_keys = ?` doesn't work. - sql := fmt.Sprintf("PRAGMA foreign_keys = %d", value) - - if tx := db.gormDB.Exec(sql); tx.Error != nil { - return fmt.Errorf("%sing foreign keys: %w", noun, tx.Error) + queries := db.queries() + if err := queries.PragmaForeignKeysSet(ctx, enabled); err != nil { + return fmt.Errorf("%sing foreign keys: %w", noun, err) } - fkEnabled, err := db.areForeignKeysEnabled() + fkEnabled, err := db.areForeignKeysEnabled(ctx) if err != nil { return err } @@ -178,12 +219,13 @@ func (db *DB) pragmaForeignKeys(enabled bool) error { return nil } -func (db *DB) areForeignKeysEnabled() (bool, error) { +func (db *DB) areForeignKeysEnabled(ctx context.Context) (bool, error) { log.Trace().Msg("checking whether SQLite foreign key checks are enabled") - var fkEnabled int - if tx := db.gormDB.Raw("PRAGMA foreign_keys").Scan(&fkEnabled); tx.Error != nil { - return false, fmt.Errorf("checking whether the database has foreign key checks are enabled: %w", tx.Error) + queries := db.queries() + fkEnabled, err := queries.PragmaForeignKeysGet(ctx) + if err != nil { + return false, fmt.Errorf("checking whether the database has foreign key checks are enabled: %w", err) } - return fkEnabled != 0, nil + return fkEnabled, nil } diff --git a/internal/worker/persistence/db_migration.go b/internal/worker/persistence/db_migration.go index 416c8a8f..2d34b18f 100644 --- a/internal/worker/persistence/db_migration.go +++ b/internal/worker/persistence/db_migration.go @@ -25,12 +25,6 @@ func (db *DB) migrate(ctx context.Context) error { log.Fatal().AnErr("cause", err).Msg("could not tell Goose to use sqlite3") } - // Hook up Goose to the database. - lowLevelDB, err := db.gormDB.DB() - if err != nil { - log.Fatal().AnErr("cause", err).Msg("GORM would not give us its low-level interface") - } - // Disable foreign key constraints during the migrations. This is necessary // for SQLite to do column renames / drops, as that requires creating a new // table with the new schema, copying the data, dropping the old table, and @@ -41,18 +35,18 @@ func (db *DB) migrate(ctx context.Context) error { // of data, foreign keys are disabled here instead of in the migration SQL // files, so that it can't be forgotten. - if err := db.pragmaForeignKeys(false); err != nil { + if err := db.pragmaForeignKeys(ctx, false); err != nil { log.Fatal().AnErr("cause", err).Msgf("could not disable foreign key constraints before performing database migrations, please report a bug at %s", website.BugReportURL) } // Run Goose. log.Debug().Msg("migrating database with Goose") - if err := goose.UpContext(ctx, lowLevelDB, "migrations"); err != nil { + if err := goose.UpContext(ctx, db.sqlDB, "migrations"); err != nil { log.Fatal().AnErr("cause", err).Msg("could not migrate database to the latest version") } // Re-enable foreign key checks. - if err := db.pragmaForeignKeys(true); err != nil { + if err := db.pragmaForeignKeys(ctx, true); err != nil { log.Fatal().AnErr("cause", err).Msgf("could not re-enable foreign key constraints after performing database migrations, please report a bug at %s", website.BugReportURL) } diff --git a/internal/worker/persistence/integrity.go b/internal/worker/persistence/integrity.go new file mode 100644 index 00000000..946797bc --- /dev/null +++ b/internal/worker/persistence/integrity.go @@ -0,0 +1,168 @@ +package persistence + +// SPDX-License-Identifier: GPL-3.0-or-later + +import ( + "context" + "errors" + "time" + + "github.com/rs/zerolog/log" +) + +var ErrIntegrity = errors.New("database integrity check failed") + +const ( + integrityCheckTimeout = 10 * time.Second +) + +// PeriodicIntegrityCheck periodically checks the database integrity. +// This function only returns when the context is done. +func (db *DB) PeriodicIntegrityCheck( + ctx context.Context, + period time.Duration, + onErrorCallback func(), +) { + if period == 0 { + log.Info().Msg("database: periodic integrity check disabled") + return + } + + log.Info(). + Stringer("period", period). + Msg("database: periodic integrity check starting") + defer log.Debug().Msg("database: periodic integrity check stopping") + + for { + select { + case <-ctx.Done(): + return + case <-time.After(period): + case <-db.consistencyCheckRequests: + } + + ok := db.performIntegrityCheck(ctx) + if !ok { + log.Error().Msg("database: periodic integrity check failed") + onErrorCallback() + } + } +} + +// RequestIntegrityCheck triggers a check of the database persistency. +func (db *DB) RequestIntegrityCheck() { + select { + case db.consistencyCheckRequests <- struct{}{}: + // Don't do anything, the work is done. + default: + log.Debug().Msg("database: could not trigger integrity check, another check might already be queued.") + } +} + +// performIntegrityCheck uses a few 'pragma' SQL statements to do some integrity checking. +// Returns true on OK, false if there was an issue. Issues are always logged. +func (db *DB) performIntegrityCheck(ctx context.Context) (ok bool) { + checkCtx, cancel := context.WithTimeout(ctx, integrityCheckTimeout) + defer cancel() + + log.Debug().Msg("database: performing integrity check") + + db.ensureForeignKeysEnabled(checkCtx) + + if !db.pragmaIntegrityCheck(checkCtx) { + return false + } + return db.pragmaForeignKeyCheck(checkCtx) +} + +// pragmaIntegrityCheck checks database file integrity. This does not include +// foreign key checks. +// +// Returns true on OK, false if there was an issue. Issues are always logged. +// +// See https: //www.sqlite.org/pragma.html#pragma_integrity_check +func (db *DB) pragmaIntegrityCheck(ctx context.Context) (ok bool) { + queries := db.queries() + issues, err := queries.PragmaIntegrityCheck(ctx) + if err != nil { + log.Error().Err(err).Msg("database: error checking integrity") + return false + } + + switch len(issues) { + case 0: + log.Warn().Msg("database: integrity check returned nothing, expected explicit 'ok'; treating as an implicit 'ok'") + return true + case 1: + if issues[0].Description == "ok" { + log.Debug().Msg("database: integrity check ok") + return true + } + } + + log.Error().Int("num_issues", len(issues)).Msg("database: integrity check failed") + for _, issue := range issues { + log.Error(). + Str("description", issue.Description). + Msg("database: integrity check failure") + } + + return false +} + +// pragmaForeignKeyCheck checks whether all foreign key constraints are still valid. +// +// SQLite has optional foreign key relations, so even though Flamenco Manager +// always enables these on startup, at some point there could be some issue +// causing these checks to be skipped. +// +// Returns true on OK, false if there was an issue. Issues are always logged. +// +// See https: //www.sqlite.org/pragma.html#pragma_foreign_key_check +func (db *DB) pragmaForeignKeyCheck(ctx context.Context) (ok bool) { + queries := db.queries() + + issues, err := queries.PragmaForeignKeyCheck(ctx) + if err != nil { + log.Error().Err(err).Msg("database: error checking foreign keys") + return false + } + + if len(issues) == 0 { + log.Debug().Msg("database: foreign key check ok") + return true + } + + log.Error().Int("num_issues", len(issues)).Msg("database: foreign key check failed") + for _, issue := range issues { + log.Error(). + Str("table", issue.Table). + Int("rowid", issue.RowID). + Str("parent", issue.Parent). + Int("fkid", issue.FKID). + Msg("database: foreign key relation missing") + } + + return false +} + +// ensureForeignKeysEnabled checks whether foreign keys are enabled, and if not, +// tries to enable them. +func (db *DB) ensureForeignKeysEnabled(ctx context.Context) { + fkEnabled, err := db.areForeignKeysEnabled(ctx) + + if err != nil { + log.Error().AnErr("cause", err).Msg("database: could not check whether foreign keys are enabled") + return + } + + if fkEnabled { + return + } + + log.Warn().Msg("database: foreign keys are disabled, re-enabling them") + if err := db.pragmaForeignKeys(ctx, true); err != nil { + log.Error().AnErr("cause", err).Msg("database: error re-enabling foreign keys") + return + } +} diff --git a/internal/worker/persistence/logger.go b/internal/worker/persistence/logger.go index 2135a006..0f45d7ec 100644 --- a/internal/worker/persistence/logger.go +++ b/internal/worker/persistence/logger.go @@ -4,125 +4,34 @@ package persistence import ( "context" - "errors" - "fmt" - "time" + "database/sql" - "github.com/rs/zerolog" - "gorm.io/gorm" - gormlogger "gorm.io/gorm/logger" + "github.com/rs/zerolog/log" + + "projects.blender.org/studio/flamenco/internal/worker/persistence/sqlc" ) -// dbLogger implements the behaviour of Gorm's default logger on top of Zerolog. -// See https://github.com/go-gorm/gorm/blob/master/logger/logger.go -type dbLogger struct { - zlog *zerolog.Logger - - IgnoreRecordNotFoundError bool - SlowThreshold time.Duration +// LoggingDBConn wraps a database/sql.DB connection, so that it can be used with +// sqlc and log all the queries. +type LoggingDBConn struct { + wrappedConn sqlc.DBTX } -var _ gormlogger.Interface = (*dbLogger)(nil) +var _ sqlc.DBTX = (*LoggingDBConn)(nil) -var logLevelMap = map[gormlogger.LogLevel]zerolog.Level{ - gormlogger.Silent: zerolog.Disabled, - gormlogger.Error: zerolog.ErrorLevel, - gormlogger.Warn: zerolog.WarnLevel, - gormlogger.Info: zerolog.InfoLevel, +func (ldbc *LoggingDBConn) ExecContext(ctx context.Context, sql string, args ...interface{}) (sql.Result, error) { + log.Trace().Str("sql", sql).Interface("args", args).Msg("database: query Exec") + return ldbc.wrappedConn.ExecContext(ctx, sql, args...) } - -func gormToZlogLevel(logLevel gormlogger.LogLevel) zerolog.Level { - zlogLevel, ok := logLevelMap[logLevel] - if !ok { - // Just a default value that seemed sensible at the time of writing. - return zerolog.DebugLevel - } - return zlogLevel +func (ldbc *LoggingDBConn) PrepareContext(ctx context.Context, sql string) (*sql.Stmt, error) { + log.Trace().Str("sql", sql).Msg("database: query Prepare") + return ldbc.wrappedConn.PrepareContext(ctx, sql) } - -// NewDBLogger wraps a zerolog logger to implement a Gorm logger interface. -func NewDBLogger(zlog zerolog.Logger) *dbLogger { - return &dbLogger{ - zlog: &zlog, - // Remaining properties default to their zero value. - } +func (ldbc *LoggingDBConn) QueryContext(ctx context.Context, sql string, args ...interface{}) (*sql.Rows, error) { + log.Trace().Str("sql", sql).Interface("args", args).Msg("database: query Query") + return ldbc.wrappedConn.QueryContext(ctx, sql, args...) } - -// LogMode returns a child logger at the given log level. -func (l *dbLogger) LogMode(logLevel gormlogger.LogLevel) gormlogger.Interface { - childLogger := l.zlog.Level(gormToZlogLevel(logLevel)) - newlogger := *l - newlogger.zlog = &childLogger - return &newlogger -} - -func (l *dbLogger) Info(ctx context.Context, msg string, args ...interface{}) { - l.logEvent(zerolog.InfoLevel, msg, args) -} - -func (l *dbLogger) Warn(ctx context.Context, msg string, args ...interface{}) { - l.logEvent(zerolog.WarnLevel, msg, args) -} - -func (l *dbLogger) Error(ctx context.Context, msg string, args ...interface{}) { - l.logEvent(zerolog.ErrorLevel, msg, args) -} - -// Trace traces the execution of SQL and potentially logs errors, warnings, and infos. -// Note that it doesn't mean "trace-level logging". -func (l *dbLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { - zlogLevel := l.zlog.GetLevel() - if zlogLevel == zerolog.Disabled { - return - } - - elapsed := time.Since(begin) - logCtx := l.zlog.With().CallerWithSkipFrameCount(5) - - // Function to lazily get the SQL, affected rows, and logger. - buildLogger := func() (loggerPtr *zerolog.Logger, sql string) { - sql, rows := fc() - logCtx = logCtx.Err(err) - if rows >= 0 { - logCtx = logCtx.Int64("rowsAffected", rows) - } - logger := logCtx.Logger() - return &logger, sql - } - - switch { - case err != nil && zlogLevel <= zerolog.ErrorLevel && (!errors.Is(err, gorm.ErrRecordNotFound) || !l.IgnoreRecordNotFoundError): - logger, sql := buildLogger() - logger.Error().Msg(sql) - - case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && zlogLevel <= zerolog.WarnLevel: - logger, sql := buildLogger() - logger.Warn(). - Str("sql", sql). - Dur("elapsed", elapsed). - Dur("slowThreshold", l.SlowThreshold). - Msg("slow database query") - - case zlogLevel <= zerolog.TraceLevel: - logger, sql := buildLogger() - logger.Trace().Msg(sql) - } -} - -// logEvent logs an even at the given level. -func (l dbLogger) logEvent(level zerolog.Level, msg string, args ...interface{}) { - if l.zlog.GetLevel() > level { - return - } - logger := l.logger(args) - logger.WithLevel(level).Msg(msg) -} - -// logger constructs a zerolog logger. The given arguments are added via reflection. -func (l dbLogger) logger(args ...interface{}) zerolog.Logger { - logCtx := l.zlog.With() - for idx, arg := range args { - logCtx.Interface(fmt.Sprintf("arg%d", idx), arg) - } - return logCtx.Logger() +func (ldbc *LoggingDBConn) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *sql.Row { + log.Trace().Str("sql", sql).Interface("args", args).Msg("database: query QueryRow") + return ldbc.wrappedConn.QueryRowContext(ctx, sql, args...) } diff --git a/internal/worker/persistence/migrations/0002_sqlc_compat_notnullable.sql b/internal/worker/persistence/migrations/0002_sqlc_compat_notnullable.sql new file mode 100644 index 00000000..a83137bb --- /dev/null +++ b/internal/worker/persistence/migrations/0002_sqlc_compat_notnullable.sql @@ -0,0 +1,34 @@ +-- GORM automigration wasn't smart, and thus the database had more nullable +-- columns than necessary. This migration makes columns that should never be +-- NULL actually NOT NULL. +-- +-- Since this migration recreates all tables in the database, this is now also +-- done in a way that makes the schema more compatible with sqlc (which is +-- mostly removing various quotes and backticks, and replacing char(N) with +-- varchar(N)). sqlc is the tool that replaced GORM. +-- +-- +goose Up +CREATE TABLE temp_task_updates ( + id integer NOT NULL, + created_at datetime NOT NULL, + task_id varchar(36) DEFAULT '' NOT NULL, + payload BLOB, + PRIMARY KEY (id) +); +INSERT INTO temp_task_updates + SELECT id, created_at, task_id, payload FROM task_updates; +DROP TABLE task_updates; +ALTER TABLE temp_task_updates RENAME TO task_updates; + +-- +goose Down +CREATE TABLE IF NOT EXISTS `temp_task_updates` ( + `id` integer, + `created_at` datetime, + `task_id` varchar(36) DEFAULT "", + `payload` BLOB, + PRIMARY KEY (`id`) +); +INSERT INTO temp_task_updates + SELECT id, created_at, task_id, payload FROM task_updates; +DROP TABLE task_updates; +ALTER TABLE temp_task_updates RENAME TO task_updates; diff --git a/internal/worker/persistence/sqlc/db.go b/internal/worker/persistence/sqlc/db.go new file mode 100644 index 00000000..c5852e06 --- /dev/null +++ b/internal/worker/persistence/sqlc/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.26.0 + +package sqlc + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/worker/persistence/sqlc/models.go b/internal/worker/persistence/sqlc/models.go new file mode 100644 index 00000000..e0aa694d --- /dev/null +++ b/internal/worker/persistence/sqlc/models.go @@ -0,0 +1,16 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.26.0 + +package sqlc + +import ( + "time" +) + +type TaskUpdate struct { + ID int64 + CreatedAt time.Time + TaskID string + Payload []byte +} diff --git a/internal/worker/persistence/sqlc/pragma.go b/internal/worker/persistence/sqlc/pragma.go new file mode 100644 index 00000000..134c5a57 --- /dev/null +++ b/internal/worker/persistence/sqlc/pragma.go @@ -0,0 +1,131 @@ +// Code MANUALLY written to extend the SQLC interface with some extra functions. +// +// This is to work around https://github.com/sqlc-dev/sqlc/issues/3237 + +package sqlc + +import ( + "context" + "fmt" + "time" +) + +const pragmaIntegrityCheck = `PRAGMA integrity_check` + +type PragmaIntegrityCheckResult struct { + Description string +} + +func (q *Queries) PragmaIntegrityCheck(ctx context.Context) ([]PragmaIntegrityCheckResult, error) { + rows, err := q.db.QueryContext(ctx, pragmaIntegrityCheck) + if err != nil { + return nil, err + } + defer rows.Close() + var items []PragmaIntegrityCheckResult + for rows.Next() { + var i PragmaIntegrityCheckResult + if err := rows.Scan( + &i.Description, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +// SQLite doesn't seem to like SQL parameters for `PRAGMA`, so `PRAGMA foreign_keys = ?` doesn't work. +const pragmaForeignKeysEnable = `PRAGMA foreign_keys = 1` +const pragmaForeignKeysDisable = `PRAGMA foreign_keys = 0` + +func (q *Queries) PragmaForeignKeysSet(ctx context.Context, enable bool) error { + var sql string + if enable { + sql = pragmaForeignKeysEnable + } else { + sql = pragmaForeignKeysDisable + } + + _, err := q.db.ExecContext(ctx, sql) + return err +} + +const pragmaForeignKeys = `PRAGMA foreign_keys` + +func (q *Queries) PragmaForeignKeysGet(ctx context.Context) (bool, error) { + row := q.db.QueryRowContext(ctx, pragmaForeignKeys) + var fkEnabled bool + err := row.Scan(&fkEnabled) + return fkEnabled, err +} + +const pragmaForeignKeyCheck = `PRAGMA foreign_key_check` + +type PragmaForeignKeyCheckResult struct { + Table string + RowID int + Parent string + FKID int +} + +func (q *Queries) PragmaForeignKeyCheck(ctx context.Context) ([]PragmaForeignKeyCheckResult, error) { + rows, err := q.db.QueryContext(ctx, pragmaForeignKeyCheck) + if err != nil { + return nil, err + } + defer rows.Close() + var items []PragmaForeignKeyCheckResult + for rows.Next() { + var i PragmaForeignKeyCheckResult + if err := rows.Scan( + &i.Table, + &i.RowID, + &i.Parent, + &i.FKID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +func (q *Queries) PragmaBusyTimeout(ctx context.Context, busyTimeout time.Duration) error { + sql := fmt.Sprintf("PRAGMA busy_timeout = %d", busyTimeout.Milliseconds()) + _, err := q.db.ExecContext(ctx, sql) + return err +} + +const pragmaJournalModeWAL = `PRAGMA journal_mode = WAL` + +func (q *Queries) PragmaJournalModeWAL(ctx context.Context) error { + _, err := q.db.ExecContext(ctx, pragmaJournalModeWAL) + return err +} + +const pragmaSynchronousNormal = `PRAGMA synchronous = normal` + +func (q *Queries) PragmaSynchronousNormal(ctx context.Context) error { + _, err := q.db.ExecContext(ctx, pragmaSynchronousNormal) + return err +} + +const vacuum = `VACUUM` + +func (q *Queries) Vacuum(ctx context.Context) error { + _, err := q.db.ExecContext(ctx, vacuum) + return err +} diff --git a/internal/worker/persistence/sqlc/query.sql b/internal/worker/persistence/sqlc/query.sql new file mode 100644 index 00000000..d2230a0d --- /dev/null +++ b/internal/worker/persistence/sqlc/query.sql @@ -0,0 +1,19 @@ +-- name: CountTaskUpdates :one +SELECT count(*) as count from task_updates; + +-- name: InsertTaskUpdate :exec +INSERT INTO task_updates ( + created_at, + task_id, + payload +) VALUES ( + @created_at, + @task_id, + @payload +); + +-- name: FirstTaskUpdate :one +SELECT * FROM task_updates ORDER BY id LIMIT 1; + +-- name: DeleteTaskUpdate :exec +DELETE FROM task_updates WHERE id=@task_update_id; diff --git a/internal/worker/persistence/sqlc/query.sql.go b/internal/worker/persistence/sqlc/query.sql.go new file mode 100644 index 00000000..0fb31e87 --- /dev/null +++ b/internal/worker/persistence/sqlc/query.sql.go @@ -0,0 +1,70 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.26.0 +// source: query.sql + +package sqlc + +import ( + "context" + "time" +) + +const countTaskUpdates = `-- name: CountTaskUpdates :one +SELECT count(*) as count from task_updates +` + +func (q *Queries) CountTaskUpdates(ctx context.Context) (int64, error) { + row := q.db.QueryRowContext(ctx, countTaskUpdates) + var count int64 + err := row.Scan(&count) + return count, err +} + +const deleteTaskUpdate = `-- name: DeleteTaskUpdate :exec +DELETE FROM task_updates WHERE id=?1 +` + +func (q *Queries) DeleteTaskUpdate(ctx context.Context, taskUpdateID int64) error { + _, err := q.db.ExecContext(ctx, deleteTaskUpdate, taskUpdateID) + return err +} + +const firstTaskUpdate = `-- name: FirstTaskUpdate :one +SELECT id, created_at, task_id, payload FROM task_updates ORDER BY id LIMIT 1 +` + +func (q *Queries) FirstTaskUpdate(ctx context.Context) (TaskUpdate, error) { + row := q.db.QueryRowContext(ctx, firstTaskUpdate) + var i TaskUpdate + err := row.Scan( + &i.ID, + &i.CreatedAt, + &i.TaskID, + &i.Payload, + ) + return i, err +} + +const insertTaskUpdate = `-- name: InsertTaskUpdate :exec +INSERT INTO task_updates ( + created_at, + task_id, + payload +) VALUES ( + ?1, + ?2, + ?3 +) +` + +type InsertTaskUpdateParams struct { + CreatedAt time.Time + TaskID string + Payload []byte +} + +func (q *Queries) InsertTaskUpdate(ctx context.Context, arg InsertTaskUpdateParams) error { + _, err := q.db.ExecContext(ctx, insertTaskUpdate, arg.CreatedAt, arg.TaskID, arg.Payload) + return err +} diff --git a/internal/worker/persistence/sqlc/schema.sql b/internal/worker/persistence/sqlc/schema.sql new file mode 100644 index 00000000..9d591e9e --- /dev/null +++ b/internal/worker/persistence/sqlc/schema.sql @@ -0,0 +1,7 @@ +CREATE TABLE IF NOT EXISTS task_updates ( + id integer NOT NULL, + created_at datetime NOT NULL, + task_id varchar(36) DEFAULT '' NOT NULL, + payload BLOB NOT NULL, + PRIMARY KEY (id) +); diff --git a/internal/worker/persistence/sqlite_busy.go b/internal/worker/persistence/sqlite_busy.go index e8e814b6..78301eea 100644 --- a/internal/worker/persistence/sqlite_busy.go +++ b/internal/worker/persistence/sqlite_busy.go @@ -2,13 +2,11 @@ package persistence import ( + "context" "errors" "fmt" + "strings" "time" - - "github.com/glebarez/go-sqlite" - "gorm.io/gorm" - sqlite3 "modernc.org/sqlite/lib" ) var ( @@ -25,15 +23,19 @@ func ErrIsDBBusy(err error) bool { // isDatabaseBusyError returns true when the error returned by GORM is a // SQLITE_BUSY error. func isDatabaseBusyError(err error) bool { - sqlErr, ok := err.(*sqlite.Error) - return ok && sqlErr.Code() == sqlite3.SQLITE_BUSY + if err == nil { + return false + } + return strings.Contains(err.Error(), "SQLITE_BUSY") } // setBusyTimeout sets the SQLite busy_timeout busy handler. // See https://sqlite.org/pragma.html#pragma_busy_timeout -func setBusyTimeout(gormDB *gorm.DB, busyTimeout time.Duration) error { - if tx := gormDB.Exec(fmt.Sprintf("PRAGMA busy_timeout = %d", busyTimeout.Milliseconds())); tx.Error != nil { - return fmt.Errorf("setting busy_timeout: %w", tx.Error) +func (db *DB) setBusyTimeout(ctx context.Context, busyTimeout time.Duration) error { + queries := db.queries() + err := queries.PragmaBusyTimeout(ctx, busyTimeout) + if err != nil { + return fmt.Errorf("setting busy_timeout: %w", err) } return nil } diff --git a/internal/worker/persistence/upstream_buffer.go b/internal/worker/persistence/upstream_buffer.go index 46953dcb..033b86b4 100644 --- a/internal/worker/persistence/upstream_buffer.go +++ b/internal/worker/persistence/upstream_buffer.go @@ -2,9 +2,12 @@ package persistence import ( "context" + "database/sql" "encoding/json" + "errors" "fmt" + "projects.blender.org/studio/flamenco/internal/worker/persistence/sqlc" "projects.blender.org/studio/flamenco/pkg/api" ) @@ -12,10 +15,7 @@ import ( // TaskUpdate is a queued task update. type TaskUpdate struct { - Model - - TaskID string `gorm:"type:varchar(36);default:''"` - Payload []byte `gorm:"type:BLOB"` + sqlc.TaskUpdate } func (t *TaskUpdate) Unmarshal() (*api.TaskUpdateJSONRequestBody, error) { @@ -28,12 +28,11 @@ func (t *TaskUpdate) Unmarshal() (*api.TaskUpdateJSONRequestBody, error) { // UpstreamBufferQueueSize returns how many task updates are queued in the upstream buffer. func (db *DB) UpstreamBufferQueueSize(ctx context.Context) (int, error) { - var queueSize int64 - tx := db.gormDB.WithContext(ctx). - Model(&TaskUpdate{}). - Count(&queueSize) - if tx.Error != nil { - return 0, fmt.Errorf("counting queued task updates: %w", tx.Error) + queries := db.queries() + + queueSize, err := queries.CountTaskUpdates(ctx) + if err != nil { + return 0, fmt.Errorf("counting queued task updates: %w", err) } return int(queueSize), nil } @@ -45,36 +44,31 @@ func (db *DB) UpstreamBufferQueue(ctx context.Context, taskID string, apiTaskUpd return fmt.Errorf("converting task update to JSON: %w", err) } - taskUpdate := TaskUpdate{ - TaskID: taskID, - Payload: blob, - } + queries := db.queries() + err = queries.InsertTaskUpdate(ctx, sqlc.InsertTaskUpdateParams{ + CreatedAt: db.now(), + TaskID: taskID, + Payload: blob, + }) - tx := db.gormDB.WithContext(ctx).Create(&taskUpdate) - return tx.Error + return err } // UpstreamBufferFrontItem returns the first-queued item. The item remains queued. func (db *DB) UpstreamBufferFrontItem(ctx context.Context) (*TaskUpdate, error) { - taskUpdate := TaskUpdate{} - - findResult := db.gormDB.WithContext(ctx). - Order("ID"). - Limit(1). - Find(&taskUpdate) - if findResult.Error != nil { - return nil, findResult.Error - } - if taskUpdate.ID == 0 { - // No update fetched, which doesn't result in an error with Limt(1).Find(&task). + queries := db.queries() + result, err := queries.FirstTaskUpdate(ctx) + switch { + case errors.Is(err, sql.ErrNoRows): return nil, nil + case err != nil: + return nil, err } - - return &taskUpdate, nil + return &TaskUpdate{result}, err } // UpstreamBufferDiscard discards the queued task update with the given row ID. func (db *DB) UpstreamBufferDiscard(ctx context.Context, queuedTaskUpdate *TaskUpdate) error { - tx := db.gormDB.WithContext(ctx).Delete(queuedTaskUpdate) - return tx.Error + queries := db.queries() + return queries.DeleteTaskUpdate(ctx, queuedTaskUpdate.ID) } diff --git a/sqlc.yaml b/sqlc.yaml index 3a712aca..9025df8b 100644 --- a/sqlc.yaml +++ b/sqlc.yaml @@ -51,3 +51,20 @@ sql: jobuuid: "JobUUID" taskUUID: "TaskUUID" workeruuid: "WorkerUUID" + - engine: "sqlite" + schema: "internal/worker/persistence/sqlc/schema.sql" + queries: "internal/worker/persistence/sqlc/query.sql" + gen: + go: + out: "internal/worker/persistence/sqlc" + overrides: + - db_type: "jsonb" + go_type: + import: "encoding/json" + type: "RawMessage" + rename: + uuid: "UUID" + uuids: "UUIDs" + jobuuid: "JobUUID" + taskUUID: "TaskUUID" + workeruuid: "WorkerUUID"