From 350f4f60cb905bad625913d0a14a48a8b5ee7411 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sybren=20A=2E=20St=C3=BCvel?= Date: Fri, 15 Jul 2022 16:05:59 +0200 Subject: [PATCH] Worker: convert database interface to GORM Convert the database interface from the stdlib `database/sql` package to the GORM object relational mapper. GORM is also used by the Manager, and thus with this change both Worker and Manager have a uniform way of accessing their databases. --- internal/worker/persistence/db.go | 128 ++++++++++++++ internal/worker/persistence/db_migration.go | 17 ++ internal/worker/persistence/logger.go | 128 ++++++++++++++ internal/worker/persistence/sqlite_busy.go | 39 +++++ .../worker/persistence/upstream_buffer.go | 80 +++++++++ internal/worker/upstream_buffer.go | 165 ++++-------------- internal/worker/upstream_buffer_test.go | 58 +++++- 7 files changed, 479 insertions(+), 136 deletions(-) create mode 100644 internal/worker/persistence/db.go create mode 100644 internal/worker/persistence/db_migration.go create mode 100644 internal/worker/persistence/logger.go create mode 100644 internal/worker/persistence/sqlite_busy.go create mode 100644 internal/worker/persistence/upstream_buffer.go diff --git a/internal/worker/persistence/db.go b/internal/worker/persistence/db.go new file mode 100644 index 00000000..ecf8838c --- /dev/null +++ b/internal/worker/persistence/db.go @@ -0,0 +1,128 @@ +// Package persistence provides the database interface for Flamenco Manager. +package persistence + +// SPDX-License-Identifier: GPL-3.0-or-later + +import ( + "context" + "fmt" + "time" + + "github.com/rs/zerolog/log" + "gorm.io/gorm" + + "github.com/glebarez/sqlite" +) + +// DB provides the database interface. +type DB struct { + gormDB *gorm.DB +} + +// Model contains the common database fields for most model structs. It is a +// copy of the gorm.Model struct, but without the `DeletedAt` or `UpdatedAt` +// fields. Soft deletion is not used by Flamenco, and the upstream buffer will +// not be updated. If it ever becomes necessary to support soft-deletion, see +// https://gorm.io/docs/delete.html#Soft-Delete +type Model struct { + ID uint `gorm:"primarykey"` + CreatedAt time.Time +} + +func OpenDB(ctx context.Context, dsn string) (*DB, error) { + log.Info().Str("dsn", dsn).Msg("opening database") + + db, err := openDB(ctx, dsn) + if err != nil { + return nil, err + } + + if err := setBusyTimeout(db.gormDB, 5*time.Second); err != nil { + return nil, err + } + + // Perfom some maintenance at startup. + db.vacuum() + + if err := db.migrate(); err != nil { + return nil, err + } + log.Debug().Msg("database automigration succesful") + + return db, nil +} + +func openDB(ctx context.Context, dsn string) (*DB, error) { + globalLogLevel := log.Logger.GetLevel() + dblogger := NewDBLogger(log.Level(globalLogLevel)) + + config := gorm.Config{ + Logger: dblogger, + NowFunc: nowFunc, + } + + return openDBWithConfig(dsn, &config) +} + +func openDBWithConfig(dsn string, config *gorm.Config) (*DB, error) { + dialector := sqlite.Open(dsn) + gormDB, err := gorm.Open(dialector, config) + if err != nil { + return nil, err + } + + // Use the generic sql.DB interface to set some connection pool options. + sqlDB, err := gormDB.DB() + if err != nil { + return nil, err + } + // Only allow a single database connection, to avoid SQLITE_BUSY errors. + // It's not certain that this'll improve the situation, but it's worth a try. + sqlDB.SetMaxIdleConns(1) // Max num of connections in the idle connection pool. + sqlDB.SetMaxOpenConns(1) // Max num of open connections to the database. + + // Enable foreign key checks. + log.Trace().Msg("enabling SQLite foreign key checks") + if tx := gormDB.Exec("PRAGMA foreign_keys = 1"); tx.Error != nil { + return nil, fmt.Errorf("enabling foreign keys: %w", tx.Error) + } + var fkEnabled int + if tx := gormDB.Raw("PRAGMA foreign_keys").Scan(&fkEnabled); tx.Error != nil { + return nil, fmt.Errorf("checking whether the database has foreign key checks enabled: %w", tx.Error) + } + if fkEnabled == 0 { + log.Error().Msg("SQLite database does not want to enable foreign keys, this may cause data loss") + } + + db := DB{ + gormDB: gormDB, + } + + return &db, nil +} + +// nowFunc returns 'now' in UTC, so that GORM-managed times (createdAt, +// deletedAt, updatedAt) are stored in UTC. +func nowFunc() time.Time { + return time.Now().UTC() +} + +// vacuum executes the SQL "VACUUM" command, and logs any errors. +func (db *DB) vacuum() { + tx := db.gormDB.Exec("vacuum") + if tx.Error != nil { + log.Error().Err(tx.Error).Msg("error vacuuming database") + } +} + +func (db *DB) Close() error { + sqlDB, err := db.gormDB.DB() + if err != nil { + return fmt.Errorf("getting generic database interface: %w", err) + } + + if err := sqlDB.Close(); err != nil { + return fmt.Errorf("closing database: %w", err) + } + return nil +} diff --git a/internal/worker/persistence/db_migration.go b/internal/worker/persistence/db_migration.go new file mode 100644 index 00000000..fc8b5e02 --- /dev/null +++ b/internal/worker/persistence/db_migration.go @@ -0,0 +1,17 @@ +package persistence + +// SPDX-License-Identifier: GPL-3.0-or-later + +import ( + "fmt" +) + +func (db *DB) migrate() error { + err := db.gormDB.AutoMigrate( + &TaskUpdate{}, + ) + if err != nil { + return fmt.Errorf("failed to automigrate database: %v", err) + } + return nil +} diff --git a/internal/worker/persistence/logger.go b/internal/worker/persistence/logger.go new file mode 100644 index 00000000..2135a006 --- /dev/null +++ b/internal/worker/persistence/logger.go @@ -0,0 +1,128 @@ +package persistence + +// SPDX-License-Identifier: GPL-3.0-or-later + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/rs/zerolog" + "gorm.io/gorm" + gormlogger "gorm.io/gorm/logger" +) + +// dbLogger implements the behaviour of Gorm's default logger on top of Zerolog. +// See https://github.com/go-gorm/gorm/blob/master/logger/logger.go +type dbLogger struct { + zlog *zerolog.Logger + + IgnoreRecordNotFoundError bool + SlowThreshold time.Duration +} + +var _ gormlogger.Interface = (*dbLogger)(nil) + +var logLevelMap = map[gormlogger.LogLevel]zerolog.Level{ + gormlogger.Silent: zerolog.Disabled, + gormlogger.Error: zerolog.ErrorLevel, + gormlogger.Warn: zerolog.WarnLevel, + gormlogger.Info: zerolog.InfoLevel, +} + +func gormToZlogLevel(logLevel gormlogger.LogLevel) zerolog.Level { + zlogLevel, ok := logLevelMap[logLevel] + if !ok { + // Just a default value that seemed sensible at the time of writing. + return zerolog.DebugLevel + } + return zlogLevel +} + +// NewDBLogger wraps a zerolog logger to implement a Gorm logger interface. +func NewDBLogger(zlog zerolog.Logger) *dbLogger { + return &dbLogger{ + zlog: &zlog, + // Remaining properties default to their zero value. + } +} + +// LogMode returns a child logger at the given log level. +func (l *dbLogger) LogMode(logLevel gormlogger.LogLevel) gormlogger.Interface { + childLogger := l.zlog.Level(gormToZlogLevel(logLevel)) + newlogger := *l + newlogger.zlog = &childLogger + return &newlogger +} + +func (l *dbLogger) Info(ctx context.Context, msg string, args ...interface{}) { + l.logEvent(zerolog.InfoLevel, msg, args) +} + +func (l *dbLogger) Warn(ctx context.Context, msg string, args ...interface{}) { + l.logEvent(zerolog.WarnLevel, msg, args) +} + +func (l *dbLogger) Error(ctx context.Context, msg string, args ...interface{}) { + l.logEvent(zerolog.ErrorLevel, msg, args) +} + +// Trace traces the execution of SQL and potentially logs errors, warnings, and infos. +// Note that it doesn't mean "trace-level logging". +func (l *dbLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + zlogLevel := l.zlog.GetLevel() + if zlogLevel == zerolog.Disabled { + return + } + + elapsed := time.Since(begin) + logCtx := l.zlog.With().CallerWithSkipFrameCount(5) + + // Function to lazily get the SQL, affected rows, and logger. + buildLogger := func() (loggerPtr *zerolog.Logger, sql string) { + sql, rows := fc() + logCtx = logCtx.Err(err) + if rows >= 0 { + logCtx = logCtx.Int64("rowsAffected", rows) + } + logger := logCtx.Logger() + return &logger, sql + } + + switch { + case err != nil && zlogLevel <= zerolog.ErrorLevel && (!errors.Is(err, gorm.ErrRecordNotFound) || !l.IgnoreRecordNotFoundError): + logger, sql := buildLogger() + logger.Error().Msg(sql) + + case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && zlogLevel <= zerolog.WarnLevel: + logger, sql := buildLogger() + logger.Warn(). + Str("sql", sql). + Dur("elapsed", elapsed). + Dur("slowThreshold", l.SlowThreshold). + Msg("slow database query") + + case zlogLevel <= zerolog.TraceLevel: + logger, sql := buildLogger() + logger.Trace().Msg(sql) + } +} + +// logEvent logs an even at the given level. +func (l dbLogger) logEvent(level zerolog.Level, msg string, args ...interface{}) { + if l.zlog.GetLevel() > level { + return + } + logger := l.logger(args) + logger.WithLevel(level).Msg(msg) +} + +// logger constructs a zerolog logger. The given arguments are added via reflection. +func (l dbLogger) logger(args ...interface{}) zerolog.Logger { + logCtx := l.zlog.With() + for idx, arg := range args { + logCtx.Interface(fmt.Sprintf("arg%d", idx), arg) + } + return logCtx.Logger() +} diff --git a/internal/worker/persistence/sqlite_busy.go b/internal/worker/persistence/sqlite_busy.go new file mode 100644 index 00000000..e8e814b6 --- /dev/null +++ b/internal/worker/persistence/sqlite_busy.go @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: GPL-3.0-or-later +package persistence + +import ( + "errors" + "fmt" + "time" + + "github.com/glebarez/go-sqlite" + "gorm.io/gorm" + sqlite3 "modernc.org/sqlite/lib" +) + +var ( + // errDatabaseBusy is returned by this package when the operation could not be + // performed due to SQLite being busy. + errDatabaseBusy = errors.New("database busy") +) + +// ErrIsDBBusy returns true when the error is a "database busy" error. +func ErrIsDBBusy(err error) bool { + return errors.Is(err, errDatabaseBusy) || isDatabaseBusyError(err) +} + +// isDatabaseBusyError returns true when the error returned by GORM is a +// SQLITE_BUSY error. +func isDatabaseBusyError(err error) bool { + sqlErr, ok := err.(*sqlite.Error) + return ok && sqlErr.Code() == sqlite3.SQLITE_BUSY +} + +// setBusyTimeout sets the SQLite busy_timeout busy handler. +// See https://sqlite.org/pragma.html#pragma_busy_timeout +func setBusyTimeout(gormDB *gorm.DB, busyTimeout time.Duration) error { + if tx := gormDB.Exec(fmt.Sprintf("PRAGMA busy_timeout = %d", busyTimeout.Milliseconds())); tx.Error != nil { + return fmt.Errorf("setting busy_timeout: %w", tx.Error) + } + return nil +} diff --git a/internal/worker/persistence/upstream_buffer.go b/internal/worker/persistence/upstream_buffer.go new file mode 100644 index 00000000..6fb0493b --- /dev/null +++ b/internal/worker/persistence/upstream_buffer.go @@ -0,0 +1,80 @@ +package persistence + +import ( + "context" + "encoding/json" + "fmt" + + "git.blender.org/flamenco/pkg/api" +) + +// SPDX-License-Identifier: GPL-3.0-or-later + +// TaskUpdate is a queued task update. +type TaskUpdate struct { + Model + + TaskID string `gorm:"type:varchar(36);default:''"` + Payload []byte `gorm:"type:BLOB"` +} + +func (t *TaskUpdate) Unmarshal() (*api.TaskUpdateJSONRequestBody, error) { + var apiTaskUpdate api.TaskUpdateJSONRequestBody + if err := json.Unmarshal(t.Payload, &apiTaskUpdate); err != nil { + return nil, err + } + return &apiTaskUpdate, nil +} + +// UpstreamBufferQueueSize returns how many task updates are queued in the upstream buffer. +func (db *DB) UpstreamBufferQueueSize(ctx context.Context) (int, error) { + var queueSize int64 + tx := db.gormDB.WithContext(ctx). + Model(&TaskUpdate{}). + Count(&queueSize) + if tx.Error != nil { + return 0, fmt.Errorf("counting queued task updates: %w", tx.Error) + } + return int(queueSize), nil +} + +// UpstreamBufferQueue queues a task update in the upstrema buffer. +func (db *DB) UpstreamBufferQueue(ctx context.Context, taskID string, apiTaskUpdate api.TaskUpdateJSONRequestBody) error { + blob, err := json.Marshal(apiTaskUpdate) + if err != nil { + return fmt.Errorf("converting task update to JSON: %w", err) + } + + taskUpdate := TaskUpdate{ + TaskID: taskID, + Payload: blob, + } + + tx := db.gormDB.WithContext(ctx).Create(&taskUpdate) + return tx.Error +} + +// UpstreamBufferFrontItem returns the first-queued item. The item remains queued. +func (db *DB) UpstreamBufferFrontItem(ctx context.Context) (*TaskUpdate, error) { + taskUpdate := TaskUpdate{} + + findResult := db.gormDB.WithContext(ctx). + Order("ID"). + Limit(1). + Find(&taskUpdate) + if findResult.Error != nil { + return nil, findResult.Error + } + if taskUpdate.ID == 0 { + // No update fetched, which doesn't result in an error with Limt(1).Find(&task). + return nil, nil + } + + return &taskUpdate, nil +} + +// UpstreamBufferDiscard discards the queued task update with the given row ID. +func (db *DB) UpstreamBufferDiscard(ctx context.Context, queuedTaskUpdate *TaskUpdate) error { + tx := db.gormDB.WithContext(ctx).Delete(queuedTaskUpdate) + return tx.Error +} diff --git a/internal/worker/upstream_buffer.go b/internal/worker/upstream_buffer.go index e5d6abdc..4ede5e96 100644 --- a/internal/worker/upstream_buffer.go +++ b/internal/worker/upstream_buffer.go @@ -4,8 +4,6 @@ package worker import ( "context" - "database/sql" - "encoding/json" "errors" "fmt" "net/http" @@ -13,8 +11,8 @@ import ( "time" "github.com/rs/zerolog/log" - _ "modernc.org/sqlite" + "git.blender.org/flamenco/internal/worker/persistence" "git.blender.org/flamenco/pkg/api" ) @@ -30,7 +28,7 @@ import ( // UpstreamBufferDB implements the UpstreamBuffer interface using a database as backend. type UpstreamBufferDB struct { - db *sql.DB + db UpstreamBufferPersistence dbMutex *sync.Mutex // Protects from "database locked" errors client FlamencoClient @@ -41,6 +39,14 @@ type UpstreamBufferDB struct { wg *sync.WaitGroup } +type UpstreamBufferPersistence interface { + UpstreamBufferQueueSize(ctx context.Context) (int, error) + UpstreamBufferQueue(ctx context.Context, taskID string, apiTaskUpdate api.TaskUpdateJSONRequestBody) error + UpstreamBufferFrontItem(ctx context.Context) (*persistence.TaskUpdate, error) + UpstreamBufferDiscard(ctx context.Context, queuedTaskUpdate *persistence.TaskUpdate) error + Close() error +} + const defaultUpstreamFlushInterval = 30 * time.Second const databaseContextTimeout = 10 * time.Second const flushOnShutdownTimeout = 5 * time.Second @@ -68,25 +74,12 @@ func (ub *UpstreamBufferDB) OpenDB(dbCtx context.Context, databaseFilename strin return errors.New("upstream buffer database already opened") } - db, err := sql.Open("sqlite", databaseFilename) + db, err := persistence.OpenDB(dbCtx, databaseFilename) if err != nil { return fmt.Errorf("opening %s: %w", databaseFilename, err) } - - if err := db.PingContext(dbCtx); err != nil { - return fmt.Errorf("accessing %s: %w", databaseFilename, err) - } - - if _, err := db.ExecContext(dbCtx, "PRAGMA foreign_keys = 1"); err != nil { - return fmt.Errorf("enabling foreign keys: %w", err) - } - ub.db = db - if err := ub.prepareDatabase(dbCtx); err != nil { - return err - } - ub.wg.Add(1) go ub.periodicFlushLoop() @@ -97,7 +90,7 @@ func (ub *UpstreamBufferDB) SendTaskUpdate(ctx context.Context, taskID string, u ub.dbMutex.Lock() defer ub.dbMutex.Unlock() - queueSize, err := ub.queueSize() + queueSize, err := ub.queueSize(ctx) if err != nil { return fmt.Errorf("unable to determine upstream queue size: %w", err) } @@ -151,53 +144,15 @@ func (ub *UpstreamBufferDB) Close() error { return ub.db.Close() } -// prepareDatabase creates the database schema, if necessary. -func (ub *UpstreamBufferDB) prepareDatabase(dbCtx context.Context) error { - ub.dbMutex.Lock() - defer ub.dbMutex.Unlock() - - tx, err := ub.db.BeginTx(dbCtx, nil) - if err != nil { - return fmt.Errorf("beginning database transaction: %w", err) - } - defer rollbackTransaction(tx) - - stmt := `CREATE TABLE IF NOT EXISTS task_update_queue(task_id VARCHAR(36), payload BLOB)` - log.Debug().Str("sql", stmt).Msg("creating database table") - - if _, err := tx.ExecContext(dbCtx, stmt); err != nil { - return fmt.Errorf("creating database table: %w", err) - } - - if err = tx.Commit(); err != nil { - return fmt.Errorf("commiting creation of database table: %w", err) - } - - return nil -} - -func (ub *UpstreamBufferDB) queueSize() (int, error) { +func (ub *UpstreamBufferDB) queueSize(ctx context.Context) (int, error) { if ub.db == nil { log.Panic().Msg("no database opened, unable to inspect upstream queue") } - dbCtx, dbCtxCancel := context.WithTimeout(context.Background(), databaseContextTimeout) + dbCtx, dbCtxCancel := context.WithTimeout(ctx, databaseContextTimeout) defer dbCtxCancel() - var queueSize int - - err := ub.db. - QueryRowContext(dbCtx, "SELECT count(*) FROM task_update_queue"). - Scan(&queueSize) - - switch { - case err == sql.ErrNoRows: - return 0, nil - case err != nil: - return 0, err - default: - return queueSize, nil - } + return ub.db.UpstreamBufferQueueSize(dbCtx) } func (ub *UpstreamBufferDB) queueTaskUpdate(taskID string, update api.TaskUpdateJSONRequestBody) error { @@ -208,35 +163,13 @@ func (ub *UpstreamBufferDB) queueTaskUpdate(taskID string, update api.TaskUpdate dbCtx, dbCtxCancel := context.WithTimeout(context.Background(), databaseContextTimeout) defer dbCtxCancel() - tx, err := ub.db.BeginTx(dbCtx, nil) - if err != nil { - return fmt.Errorf("beginning database transaction: %w", err) - } - defer rollbackTransaction(tx) - - blob, err := json.Marshal(update) - if err != nil { - return fmt.Errorf("converting task update to JSON: %w", err) - } - - stmt := `INSERT INTO task_update_queue (task_id, payload) VALUES (?, ?)` - log.Debug().Str("sql", stmt).Str("task", taskID).Msg("inserting task update") - - if _, err := tx.ExecContext(dbCtx, stmt, taskID, blob); err != nil { - return fmt.Errorf("queueing task update: %w", err) - } - - if err = tx.Commit(); err != nil { - return fmt.Errorf("committing queued task update: %w", err) - } - - return nil + return ub.db.UpstreamBufferQueue(dbCtx, taskID, update) } func (ub *UpstreamBufferDB) QueueSize() (int, error) { ub.dbMutex.Lock() defer ub.dbMutex.Unlock() - return ub.queueSize() + return ub.queueSize(context.Background()) } func (ub *UpstreamBufferDB) Flush(ctx context.Context) error { @@ -246,7 +179,7 @@ func (ub *UpstreamBufferDB) Flush(ctx context.Context) error { // See if we need to flush at all. ub.dbMutex.Lock() - queueSize, err := ub.queueSize() + queueSize, err := ub.queueSize(ctx) ub.dbMutex.Unlock() switch { @@ -273,48 +206,31 @@ func (ub *UpstreamBufferDB) Flush(ctx context.Context) error { } func (ub *UpstreamBufferDB) flushFirstItem(ctx context.Context) (done bool, err error) { - dbCtx, dbCtxCancel := context.WithTimeout(context.Background(), databaseContextTimeout) + dbCtx, dbCtxCancel := context.WithTimeout(ctx, databaseContextTimeout) defer dbCtxCancel() - tx, err := ub.db.BeginTx(dbCtx, nil) + queued, err := ub.db.UpstreamBufferFrontItem(dbCtx) if err != nil { - return false, fmt.Errorf("beginning database transaction: %w", err) + return false, fmt.Errorf("finding first queued task update: %w", err) } - defer rollbackTransaction(tx) - - stmt := `SELECT rowid, task_id, payload FROM task_update_queue ORDER BY rowid LIMIT 1` - log.Trace().Str("sql", stmt).Msg("fetching queued task updates") - - var rowID int64 - var taskID string - var blob []byte - - err = tx.QueryRowContext(dbCtx, stmt).Scan(&rowID, &taskID, &blob) - switch { - case err == sql.ErrNoRows: - // Flush operation is done. - log.Debug().Msg("task update queue empty") + if queued == nil { + // Nothing is queued. return true, nil - case err != nil: - return false, fmt.Errorf("querying task update queue: %w", err) } - logger := log.With().Str("task", taskID).Logger() + logger := log.With().Str("task", queued.TaskID).Logger() - var update api.TaskUpdateJSONRequestBody - if err := json.Unmarshal(blob, &update); err != nil { + apiTaskUpdate, err := queued.Unmarshal() + if err != nil { // If we can't unmarshal the queued task update, there is little else to do // than to discard it and ignore it ever happened. logger.Warn().Err(err). Msg("unable to unmarshal queued task update, discarding") - if err := ub.discardRow(tx, rowID); err != nil { - return false, err - } - return false, tx.Commit() + return false, ub.db.UpstreamBufferDiscard(dbCtx, queued) } // actually attempt delivery. - resp, err := ub.client.TaskUpdateWithResponse(ctx, taskID, update) + resp, err := ub.client.TaskUpdateWithResponse(ctx, queued.TaskID, *apiTaskUpdate) if err != nil { logger.Info().Err(err).Msg("communication with Manager still problematic") return true, err @@ -334,24 +250,10 @@ func (ub *UpstreamBufferDB) flushFirstItem(ctx context.Context) (done bool, err Msg("queued task update discarded by Manager, unknown reason") } - if err := ub.discardRow(tx, rowID); err != nil { + if err := ub.db.UpstreamBufferDiscard(dbCtx, queued); err != nil { return false, err } - return false, tx.Commit() -} - -func (ub *UpstreamBufferDB) discardRow(tx *sql.Tx, rowID int64) error { - dbCtx, dbCtxCancel := context.WithTimeout(context.Background(), databaseContextTimeout) - defer dbCtxCancel() - - stmt := `DELETE FROM task_update_queue WHERE rowid = ?` - log.Trace().Str("sql", stmt).Int64("rowID", rowID).Msg("un-queueing task update") - - _, err := tx.ExecContext(dbCtx, stmt, rowID) - if err != nil { - return fmt.Errorf("un-queueing task update: %w", err) - } - return nil + return false, nil } func (ub *UpstreamBufferDB) periodicFlushLoop() { @@ -374,10 +276,3 @@ func (ub *UpstreamBufferDB) periodicFlushLoop() { } } } - -func rollbackTransaction(tx *sql.Tx) { - if err := tx.Rollback(); err != nil && err != sql.ErrTxDone { - // log.Error().Err(err).Msg("rolling back transaction") - log.Panic().Err(err).Msg("rolling back transaction") - } -} diff --git a/internal/worker/upstream_buffer_test.go b/internal/worker/upstream_buffer_test.go index ea21c9aa..79b27b22 100644 --- a/internal/worker/upstream_buffer_test.go +++ b/internal/worker/upstream_buffer_test.go @@ -6,13 +6,14 @@ import ( "context" "errors" "fmt" + "net/http" "sync" "testing" + "time" "github.com/benbjohnson/clock" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" - _ "modernc.org/sqlite" "git.blender.org/flamenco/internal/worker/mocks" "git.blender.org/flamenco/pkg/api" @@ -106,3 +107,58 @@ func TestUpstreamBufferManagerUnavailable(t *testing.T) { assert.NoError(t, ub.Close()) } + +func TestStressingBuffer(t *testing.T) { + if testing.Short() { + t.Skip("skipping potentially heavy test due to -short CLI arg") + return + } + + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + ctx := context.Background() + + ub, mocks := mockUpstreamBufferDB(t, mockCtrl) + assert.NoError(t, ub.OpenDB(ctx, sqliteTestDBName(t))) + + // Queue task updates much faster than the Manager can handle. + taskID := "3960dec4-978e-40ab-bede-bfa6428c6ebc" + update := api.TaskUpdateJSONRequestBody{ + Activity: ptr("Testing da ünits"), + Log: ptr("¿Unicode logging should work?"), + TaskStatus: ptr(api.TaskStatusActive), + } + + // Make the Manager slow to respond. + const managerResponseTime = 250 * time.Millisecond + mocks.client.EXPECT(). + TaskUpdateWithResponse(ctx, taskID, update). + DoAndReturn(func(ctx context.Context, taskID string, body api.TaskUpdateJSONRequestBody, editors ...api.RequestEditorFn) (*api.TaskUpdateResponse, error) { + time.Sleep(managerResponseTime) + return &api.TaskUpdateResponse{ + HTTPResponse: &http.Response{StatusCode: http.StatusNoContent}, + }, nil + }). + AnyTimes() + + // Send updates MUCH faster than the slowed-down Manager can handle. + wg := sync.WaitGroup{} + for i := 0; i < 100; i++ { + wg.Add(2) + go func() { + defer wg.Done() + err := ub.SendTaskUpdate(ctx, taskID, update) + assert.NoError(t, err) + }() + + // Also mix in a bunch of flushes. + go func() { + defer wg.Done() + _, err := ub.flushFirstItem(ctx) + assert.NoError(t, err) + }() + } + wg.Wait() + +}