diff --git a/internal/manager/persistence/jobs_test.go b/internal/manager/persistence/jobs_test.go index ec9fbcdf..9dce1e69 100644 --- a/internal/manager/persistence/jobs_test.go +++ b/internal/manager/persistence/jobs_test.go @@ -33,7 +33,8 @@ import ( ) func TestStoreAuthoredJob(t *testing.T) { - db := CreateTestDB(t) + db, dbCloser := CreateTestDB(t) + defer dbCloser() ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() @@ -74,8 +75,8 @@ func TestStoreAuthoredJob(t *testing.T) { } func TestJobHasTasksInStatus(t *testing.T) { - ctx, ctxCancel, db, job, _ := jobTasksTestFixtures(t) - defer ctxCancel() + ctx, close, db, job, _ := jobTasksTestFixtures(t) + defer close() hasTasks, err := db.JobHasTasksInStatus(ctx, job, api.TaskStatusQueued) assert.NoError(t, err) @@ -87,8 +88,8 @@ func TestJobHasTasksInStatus(t *testing.T) { } func TestCountTasksOfJobInStatus(t *testing.T) { - ctx, ctxCancel, db, job, authoredJob := jobTasksTestFixtures(t) - defer ctxCancel() + ctx, close, db, job, authoredJob := jobTasksTestFixtures(t) + defer close() numQueued, numTotal, err := db.CountTasksOfJobInStatus(ctx, job, api.TaskStatusQueued) assert.NoError(t, err) @@ -118,8 +119,8 @@ func TestCountTasksOfJobInStatus(t *testing.T) { } func TestUpdateJobsTaskStatuses(t *testing.T) { - ctx, ctxCancel, db, job, authoredJob := jobTasksTestFixtures(t) - defer ctxCancel() + ctx, close, db, job, authoredJob := jobTasksTestFixtures(t) + defer close() err := db.UpdateJobsTaskStatuses(ctx, job, api.TaskStatusSoftFailed, "testing æctivity") assert.NoError(t, err) @@ -147,8 +148,8 @@ func TestUpdateJobsTaskStatuses(t *testing.T) { } func TestUpdateJobsTaskStatusesConditional(t *testing.T) { - ctx, ctxCancel, db, job, authoredJob := jobTasksTestFixtures(t) - defer ctxCancel() + ctx, close, db, job, authoredJob := jobTasksTestFixtures(t) + defer close() getTask := func(taskIndex int) *Task { task, err := db.FetchTask(ctx, authoredJob.Tasks[taskIndex].UUID) @@ -192,8 +193,8 @@ func TestUpdateJobsTaskStatusesConditional(t *testing.T) { } func TestTaskAssignToWorker(t *testing.T) { - ctx, ctxCancel, db, _, authoredJob := jobTasksTestFixtures(t) - defer ctxCancel() + ctx, close, db, _, authoredJob := jobTasksTestFixtures(t) + defer close() task, err := db.FetchTask(ctx, authoredJob.Tasks[1].UUID) assert.NoError(t, err) @@ -206,8 +207,8 @@ func TestTaskAssignToWorker(t *testing.T) { } func TestFetchTasksOfWorkerInStatus(t *testing.T) { - ctx, ctxCancel, db, _, authoredJob := jobTasksTestFixtures(t) - defer ctxCancel() + ctx, close, db, _, authoredJob := jobTasksTestFixtures(t) + defer close() task, err := db.FetchTask(ctx, authoredJob.Tasks[1].UUID) assert.NoError(t, err) @@ -288,9 +289,12 @@ func createTestAuthoredJobWithTasks() job_compilers.AuthoredJob { } func jobTasksTestFixtures(t *testing.T) (context.Context, context.CancelFunc, *DB, *Job, job_compilers.AuthoredJob) { - db := CreateTestDB(t) - - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + db, dbCloser := CreateTestDB(t) + ctx, ctxCancel := context.WithTimeout(context.Background(), 1*time.Second) + cancel := func() { + ctxCancel() + dbCloser() + } authoredJob := createTestAuthoredJobWithTasks() err := db.StoreAuthoredJob(ctx, authoredJob) diff --git a/internal/manager/persistence/task_scheduler_test.go b/internal/manager/persistence/task_scheduler_test.go index 0d294a4e..cc8eb770 100644 --- a/internal/manager/persistence/task_scheduler_test.go +++ b/internal/manager/persistence/task_scheduler_test.go @@ -33,7 +33,8 @@ import ( ) func TestNoTasks(t *testing.T) { - db := CreateTestDB(t) + db, dbCloser := CreateTestDB(t) + defer dbCloser() ctx, ctxCancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer ctxCancel() @@ -45,7 +46,8 @@ func TestNoTasks(t *testing.T) { } func TestOneJobOneTask(t *testing.T) { - db := CreateTestDB(t) + db, dbCloser := CreateTestDB(t) + defer dbCloser() ctx, ctxCancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer ctxCancel() @@ -81,7 +83,8 @@ func TestOneJobOneTask(t *testing.T) { } func TestOneJobThreeTasksByPrio(t *testing.T) { - db := CreateTestDB(t) + db, dbCloser := CreateTestDB(t) + defer dbCloser() ctx, ctxCancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer ctxCancel() @@ -113,7 +116,8 @@ func TestOneJobThreeTasksByPrio(t *testing.T) { } func TestOneJobThreeTasksByDependencies(t *testing.T) { - db := CreateTestDB(t) + db, dbCloser := CreateTestDB(t) + defer dbCloser() ctx, ctxCancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer ctxCancel() @@ -140,7 +144,8 @@ func TestOneJobThreeTasksByDependencies(t *testing.T) { } func TestTwoJobsThreeTasks(t *testing.T) { - db := CreateTestDB(t) + db, dbCloser := CreateTestDB(t) + defer dbCloser() ctx, ctxCancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer ctxCancel() diff --git a/internal/manager/persistence/test_support.go b/internal/manager/persistence/test_support.go index 1089c32a..85baead0 100644 --- a/internal/manager/persistence/test_support.go +++ b/internal/manager/persistence/test_support.go @@ -22,33 +22,56 @@ package persistence * ***** END GPL LICENSE BLOCK ***** */ import ( + "database/sql" "os" "testing" - "time" - "github.com/stretchr/testify/assert" - "golang.org/x/net/context" + "github.com/glebarez/sqlite" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "gorm.io/gorm" ) -const TestDSN = "flamenco-test.sqlite" - -func CreateTestDB(t *testing.T) *DB { - // Creating a new database should be fast. - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() +// Change this to a filename if you want to run a single test and inspect the +// resulting database. +const TestDSN = "file::memory:" +func CreateTestDB(t *testing.T) (db *DB, closer func()) { + // Delete the SQLite file if it exists on disk. if _, err := os.Stat(TestDSN); err == nil { - // File exists. if err := os.Remove(TestDSN); err != nil { t.Fatalf("unable to remove %s: %v", TestDSN, err) } } - db, err := openDB(ctx, TestDSN) - assert.NoError(t, err) + var err error + + dblogger := NewDBLogger(log.Level(zerolog.InfoLevel).Output(os.Stdout)) + sqliteConn, err := sql.Open(sqlite.DriverName, TestDSN) + if err != nil { + t.Fatalf("opening SQLite connection: %v", err) + } + + config := gorm.Config{ + Logger: dblogger, + ConnPool: sqliteConn, + } + + db, err = openDBWithConfig(TestDSN, &config) + if err != nil { + t.Fatalf("opening DB: %v", err) + } err = db.migrate() - assert.NoError(t, err) + if err != nil { + t.Fatalf("migrating DB: %v", err) + } - return db + closer = func() { + if err := sqliteConn.Close(); err != nil { + t.Fatalf("closing DB: %v", err) + } + } + + return db, closer } diff --git a/internal/manager/persistence/workers_test.go b/internal/manager/persistence/workers_test.go index 9e28de62..6d2da582 100644 --- a/internal/manager/persistence/workers_test.go +++ b/internal/manager/persistence/workers_test.go @@ -33,8 +33,8 @@ import ( ) func TestCreateFetchWorker(t *testing.T) { - db := CreateTestDB(t) - + db, dbCloser := CreateTestDB(t) + defer dbCloser() ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() @@ -69,8 +69,8 @@ func TestCreateFetchWorker(t *testing.T) { } func TestSaveWorker(t *testing.T) { - db := CreateTestDB(t) - + db, dbCloser := CreateTestDB(t) + defer dbCloser() ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel()