diff --git a/internal/manager/api_impl/interfaces.go b/internal/manager/api_impl/interfaces.go index b34cbe4d..17fd7c91 100644 --- a/internal/manager/api_impl/interfaces.go +++ b/internal/manager/api_impl/interfaces.go @@ -65,6 +65,14 @@ type PersistenceService interface { RemoveFromJobBlocklist(ctx context.Context, jobUUID, workerUUID, taskType string) error ClearJobBlocklist(ctx context.Context, job *persistence.Job) error + // Worker cluster management. + WorkerSetClusters(ctx context.Context, worker *persistence.Worker, clusterUUIDs []string) error + CreateWorkerCluster(ctx context.Context, cluster *persistence.WorkerCluster) error + FetchWorkerCluster(ctx context.Context, uuid string) (*persistence.WorkerCluster, error) + FetchWorkerClusters(ctx context.Context) ([]*persistence.WorkerCluster, error) + DeleteWorkerCluster(ctx context.Context, uuid string) error + SaveWorkerCluster(ctx context.Context, cluster *persistence.WorkerCluster) error + // WorkersLeftToRun returns a set of worker UUIDs that can run tasks of the given type on the given job. WorkersLeftToRun(ctx context.Context, job *persistence.Job, taskType string) (map[string]bool, error) // CountTaskFailuresOfWorker returns the number of task failures of this worker, on this particular job and task type. diff --git a/internal/manager/api_impl/jobs.go b/internal/manager/api_impl/jobs.go index cf78eb78..fcd84a7e 100644 --- a/internal/manager/api_impl/jobs.go +++ b/internal/manager/api_impl/jobs.go @@ -618,6 +618,9 @@ func jobDBtoAPI(dbJob *persistence.Job) api.Job { if dbJob.DeleteRequestedAt.Valid { apiJob.DeleteRequestedAt = &dbJob.DeleteRequestedAt.Time } + if dbJob.WorkerCluster != nil { + apiJob.WorkerCluster = &dbJob.WorkerCluster.UUID + } return apiJob } diff --git a/internal/manager/api_impl/jobs_test.go b/internal/manager/api_impl/jobs_test.go index 125e985e..5c64babf 100644 --- a/internal/manager/api_impl/jobs_test.go +++ b/internal/manager/api_impl/jobs_test.go @@ -17,6 +17,7 @@ import ( "git.blender.org/flamenco/pkg/moremock" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func ptr[T any](value T) *T { @@ -319,6 +320,103 @@ func TestSubmitJobWithShamanCheckoutID(t *testing.T) { assert.NoError(t, err) } +func TestSubmitJobWithWorkerCluster(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + mf := newMockedFlamenco(mockCtrl) + worker := testWorker() + + workerClusterUUID := "04435762-9dc8-4f13-80b7-643a6fa5b6fd" + cluster := persistence.WorkerCluster{ + Model: persistence.Model{ID: 47}, + UUID: workerClusterUUID, + Name: "first cluster", + Description: "my first cluster", + } + + submittedJob := api.SubmittedJob{ + Name: "поднео посао", + Type: "test", + Priority: 50, + SubmitterPlatform: worker.Platform, + WorkerCluster: &workerClusterUUID, + } + + mf.expectConvertTwoWayVariables(t, + config.VariableAudienceWorkers, + config.VariablePlatform(worker.Platform), + map[string]string{}, + ) + + // Expect the job compiler to be called. + authoredJob := job_compilers.AuthoredJob{ + JobID: "afc47568-bd9d-4368-8016-e91d945db36d", + WorkerClusterUUID: workerClusterUUID, + + Name: submittedJob.Name, + JobType: submittedJob.Type, + Priority: submittedJob.Priority, + Status: api.JobStatusUnderConstruction, + Created: mf.clock.Now(), + } + mf.jobCompiler.EXPECT().Compile(gomock.Any(), submittedJob).Return(&authoredJob, nil) + + // Expect the job to be saved with 'queued' status: + queuedJob := authoredJob + queuedJob.Status = api.JobStatusQueued + mf.persistence.EXPECT().StoreAuthoredJob(gomock.Any(), queuedJob).Return(nil) + + // Expect the job to be fetched from the database again: + dbJob := persistence.Job{ + Model: persistence.Model{ + ID: 47, + CreatedAt: mf.clock.Now(), + UpdatedAt: mf.clock.Now(), + }, + UUID: queuedJob.JobID, + Name: queuedJob.Name, + JobType: queuedJob.JobType, + Priority: queuedJob.Priority, + Status: queuedJob.Status, + Settings: persistence.StringInterfaceMap{}, + Metadata: persistence.StringStringMap{}, + + WorkerClusterID: &cluster.ID, + WorkerCluster: &cluster, + } + mf.persistence.EXPECT().FetchJob(gomock.Any(), queuedJob.JobID).Return(&dbJob, nil) + + // Expect the new job to be broadcast. + jobUpdate := api.SocketIOJobUpdate{ + Id: dbJob.UUID, + Name: &dbJob.Name, + Priority: dbJob.Priority, + Status: dbJob.Status, + Type: dbJob.JobType, + Updated: dbJob.UpdatedAt, + } + mf.broadcaster.EXPECT().BroadcastNewJob(jobUpdate) + + // Do the call. + echoCtx := mf.prepareMockedJSONRequest(submittedJob) + requestWorkerStore(echoCtx, &worker) + require.NoError(t, mf.flamenco.SubmitJob(echoCtx)) + + submittedJob.Metadata = new(api.JobMetadata) + submittedJob.Settings = new(api.JobSettings) + submittedJob.SubmitterPlatform = "" // Not persisted in the database. + assertResponseJSON(t, echoCtx, http.StatusOK, api.Job{ + SubmittedJob: submittedJob, + Id: dbJob.UUID, + Created: dbJob.CreatedAt, + Updated: dbJob.UpdatedAt, + DeleteRequestedAt: nil, + Activity: "", + Status: api.JobStatusQueued, + }) +} + func TestGetJobTypeHappy(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() diff --git a/internal/manager/api_impl/mocks/api_impl_mock.gen.go b/internal/manager/api_impl/mocks/api_impl_mock.gen.go index 0d261bcf..31a8aafa 100644 --- a/internal/manager/api_impl/mocks/api_impl_mock.gen.go +++ b/internal/manager/api_impl/mocks/api_impl_mock.gen.go @@ -141,6 +141,20 @@ func (mr *MockPersistenceServiceMockRecorder) CreateWorker(arg0, arg1 interface{ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateWorker", reflect.TypeOf((*MockPersistenceService)(nil).CreateWorker), arg0, arg1) } +// CreateWorkerCluster mocks base method. +func (m *MockPersistenceService) CreateWorkerCluster(arg0 context.Context, arg1 *persistence.WorkerCluster) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateWorkerCluster", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// CreateWorkerCluster indicates an expected call of CreateWorkerCluster. +func (mr *MockPersistenceServiceMockRecorder) CreateWorkerCluster(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateWorkerCluster", reflect.TypeOf((*MockPersistenceService)(nil).CreateWorkerCluster), arg0, arg1) +} + // DeleteWorker mocks base method. func (m *MockPersistenceService) DeleteWorker(arg0 context.Context, arg1 string) error { m.ctrl.T.Helper() @@ -155,6 +169,20 @@ func (mr *MockPersistenceServiceMockRecorder) DeleteWorker(arg0, arg1 interface{ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteWorker", reflect.TypeOf((*MockPersistenceService)(nil).DeleteWorker), arg0, arg1) } +// DeleteWorkerCluster mocks base method. +func (m *MockPersistenceService) DeleteWorkerCluster(arg0 context.Context, arg1 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteWorkerCluster", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteWorkerCluster indicates an expected call of DeleteWorkerCluster. +func (mr *MockPersistenceServiceMockRecorder) DeleteWorkerCluster(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteWorkerCluster", reflect.TypeOf((*MockPersistenceService)(nil).DeleteWorkerCluster), arg0, arg1) +} + // FetchJob mocks base method. func (m *MockPersistenceService) FetchJob(arg0 context.Context, arg1 string) (*persistence.Job, error) { m.ctrl.T.Helper() @@ -230,6 +258,36 @@ func (mr *MockPersistenceServiceMockRecorder) FetchWorker(arg0, arg1 interface{} return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchWorker", reflect.TypeOf((*MockPersistenceService)(nil).FetchWorker), arg0, arg1) } +// FetchWorkerCluster mocks base method. +func (m *MockPersistenceService) FetchWorkerCluster(arg0 context.Context, arg1 string) (*persistence.WorkerCluster, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FetchWorkerCluster", arg0, arg1) + ret0, _ := ret[0].(*persistence.WorkerCluster) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FetchWorkerCluster indicates an expected call of FetchWorkerCluster. +func (mr *MockPersistenceServiceMockRecorder) FetchWorkerCluster(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchWorkerCluster", reflect.TypeOf((*MockPersistenceService)(nil).FetchWorkerCluster), arg0, arg1) +} + +// FetchWorkerClusters mocks base method. +func (m *MockPersistenceService) FetchWorkerClusters(arg0 context.Context) ([]*persistence.WorkerCluster, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FetchWorkerClusters", arg0) + ret0, _ := ret[0].([]*persistence.WorkerCluster) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FetchWorkerClusters indicates an expected call of FetchWorkerClusters. +func (mr *MockPersistenceServiceMockRecorder) FetchWorkerClusters(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FetchWorkerClusters", reflect.TypeOf((*MockPersistenceService)(nil).FetchWorkerClusters), arg0) +} + // FetchWorkerTask mocks base method. func (m *MockPersistenceService) FetchWorkerTask(arg0 context.Context, arg1 *persistence.Worker) (*persistence.Task, error) { m.ctrl.T.Helper() @@ -375,6 +433,20 @@ func (mr *MockPersistenceServiceMockRecorder) SaveWorker(arg0, arg1 interface{}) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveWorker", reflect.TypeOf((*MockPersistenceService)(nil).SaveWorker), arg0, arg1) } +// SaveWorkerCluster mocks base method. +func (m *MockPersistenceService) SaveWorkerCluster(arg0 context.Context, arg1 *persistence.WorkerCluster) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveWorkerCluster", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// SaveWorkerCluster indicates an expected call of SaveWorkerCluster. +func (mr *MockPersistenceServiceMockRecorder) SaveWorkerCluster(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveWorkerCluster", reflect.TypeOf((*MockPersistenceService)(nil).SaveWorkerCluster), arg0, arg1) +} + // SaveWorkerStatus mocks base method. func (m *MockPersistenceService) SaveWorkerStatus(arg0 context.Context, arg1 *persistence.Worker) error { m.ctrl.T.Helper() @@ -460,6 +532,20 @@ func (mr *MockPersistenceServiceMockRecorder) WorkerSeen(arg0, arg1 interface{}) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WorkerSeen", reflect.TypeOf((*MockPersistenceService)(nil).WorkerSeen), arg0, arg1) } +// WorkerSetClusters mocks base method. +func (m *MockPersistenceService) WorkerSetClusters(arg0 context.Context, arg1 *persistence.Worker, arg2 []string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WorkerSetClusters", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// WorkerSetClusters indicates an expected call of WorkerSetClusters. +func (mr *MockPersistenceServiceMockRecorder) WorkerSetClusters(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WorkerSetClusters", reflect.TypeOf((*MockPersistenceService)(nil).WorkerSetClusters), arg0, arg1, arg2) +} + // WorkersLeftToRun mocks base method. func (m *MockPersistenceService) WorkersLeftToRun(arg0 context.Context, arg1 *persistence.Job, arg2 string) (map[string]bool, error) { m.ctrl.T.Helper() diff --git a/internal/manager/api_impl/worker_mgt.go b/internal/manager/api_impl/worker_mgt.go index 94cf3e56..6b305317 100644 --- a/internal/manager/api_impl/worker_mgt.go +++ b/internal/manager/api_impl/worker_mgt.go @@ -182,6 +182,195 @@ func (f *Flamenco) RequestWorkerStatusChange(e echo.Context, workerUUID string) return e.NoContent(http.StatusNoContent) } +func (f *Flamenco) SetWorkerClusters(e echo.Context, workerUUID string) error { + ctx := e.Request().Context() + logger := requestLogger(e) + logger = logger.With().Str("worker", workerUUID).Logger() + + if !uuid.IsValid(workerUUID) { + return sendAPIError(e, http.StatusBadRequest, "not a valid UUID") + } + + // Decode the request body. + var change api.WorkerClusterChangeRequest + if err := e.Bind(&change); err != nil { + logger.Warn().Err(err).Msg("bad request received") + return sendAPIError(e, http.StatusBadRequest, "invalid format") + } + + // Fetch the worker. + dbWorker, err := f.persist.FetchWorker(ctx, workerUUID) + if errors.Is(err, persistence.ErrWorkerNotFound) { + logger.Debug().Msg("non-existent worker requested") + return sendAPIError(e, http.StatusNotFound, "worker %q not found", workerUUID) + } + if err != nil { + logger.Error().Err(err).Msg("fetching worker") + return sendAPIError(e, http.StatusInternalServerError, "error fetching worker: %v", err) + } + + logger = logger.With(). + Strs("clusters", change.ClusterIds). + Logger() + logger.Info().Msg("worker cluster change requested") + + // Store the new cluster assignment. + if err := f.persist.WorkerSetClusters(ctx, dbWorker, change.ClusterIds); err != nil { + logger.Error().Err(err).Msg("saving worker after cluster change request") + return sendAPIError(e, http.StatusInternalServerError, "error saving worker: %v", err) + } + + // Broadcast the change. + update := webupdates.NewWorkerUpdate(dbWorker) + f.broadcaster.BroadcastWorkerUpdate(update) + + return e.NoContent(http.StatusNoContent) +} + +func (f *Flamenco) DeleteWorkerCluster(e echo.Context, clusterUUID string) error { + ctx := e.Request().Context() + logger := requestLogger(e) + logger = logger.With().Str("cluster", clusterUUID).Logger() + + if !uuid.IsValid(clusterUUID) { + return sendAPIError(e, http.StatusBadRequest, "not a valid UUID") + } + + err := f.persist.DeleteWorkerCluster(ctx, clusterUUID) + switch { + case errors.Is(err, persistence.ErrWorkerClusterNotFound): + logger.Debug().Msg("non-existent worker cluster requested") + return sendAPIError(e, http.StatusNotFound, "worker cluster %q not found", clusterUUID) + case err != nil: + logger.Error().Err(err).Msg("deleting worker cluster") + return sendAPIError(e, http.StatusInternalServerError, "error deleting worker cluster: %v", err) + } + + // TODO: SocketIO broadcast of cluster deletion. + + logger.Info().Msg("worker cluster deleted") + return e.NoContent(http.StatusNoContent) +} + +func (f *Flamenco) FetchWorkerCluster(e echo.Context, clusterUUID string) error { + ctx := e.Request().Context() + logger := requestLogger(e) + logger = logger.With().Str("cluster", clusterUUID).Logger() + + if !uuid.IsValid(clusterUUID) { + return sendAPIError(e, http.StatusBadRequest, "not a valid UUID") + } + + cluster, err := f.persist.FetchWorkerCluster(ctx, clusterUUID) + switch { + case errors.Is(err, persistence.ErrWorkerClusterNotFound): + logger.Debug().Msg("non-existent worker cluster requested") + return sendAPIError(e, http.StatusNotFound, "worker cluster %q not found", clusterUUID) + case err != nil: + logger.Error().Err(err).Msg("fetching worker cluster") + return sendAPIError(e, http.StatusInternalServerError, "error fetching worker cluster: %v", err) + } + + return e.JSON(http.StatusOK, workerClusterDBtoAPI(*cluster)) +} + +func (f *Flamenco) UpdateWorkerCluster(e echo.Context, clusterUUID string) error { + ctx := e.Request().Context() + logger := requestLogger(e) + logger = logger.With().Str("cluster", clusterUUID).Logger() + + if !uuid.IsValid(clusterUUID) { + return sendAPIError(e, http.StatusBadRequest, "not a valid UUID") + } + + // Decode the request body. + var update api.UpdateWorkerClusterJSONBody + if err := e.Bind(&update); err != nil { + logger.Warn().Err(err).Msg("bad request received") + return sendAPIError(e, http.StatusBadRequest, "invalid format") + } + + dbCluster, err := f.persist.FetchWorkerCluster(ctx, clusterUUID) + switch { + case errors.Is(err, persistence.ErrWorkerClusterNotFound): + logger.Debug().Msg("non-existent worker cluster requested") + return sendAPIError(e, http.StatusNotFound, "worker cluster %q not found", clusterUUID) + case err != nil: + logger.Error().Err(err).Msg("fetching worker cluster") + return sendAPIError(e, http.StatusInternalServerError, "error fetching worker cluster: %v", err) + } + + // Update the cluster. + dbCluster.Name = update.Name + if update.Description == nil { + dbCluster.Description = "" + } else { + dbCluster.Description = *update.Description + } + + if err := f.persist.SaveWorkerCluster(ctx, dbCluster); err != nil { + logger.Error().Err(err).Msg("saving worker cluster") + return sendAPIError(e, http.StatusInternalServerError, "error saving worker cluster") + } + + // TODO: SocketIO broadcast of cluster update. + + return e.NoContent(http.StatusNoContent) +} + +func (f *Flamenco) FetchWorkerClusters(e echo.Context) error { + ctx := e.Request().Context() + logger := requestLogger(e) + + dbClusters, err := f.persist.FetchWorkerClusters(ctx) + if err != nil { + logger.Error().Err(err).Msg("fetching worker clusters") + return sendAPIError(e, http.StatusInternalServerError, "error saving worker cluster") + } + + apiClusters := []api.WorkerCluster{} + for _, dbCluster := range dbClusters { + apiCluster := workerClusterDBtoAPI(*dbCluster) + apiClusters = append(apiClusters, apiCluster) + } + + clusterList := api.WorkerClusterList{ + Clusters: &apiClusters, + } + return e.JSON(http.StatusOK, &clusterList) +} + +func (f *Flamenco) CreateWorkerCluster(e echo.Context) error { + ctx := e.Request().Context() + logger := requestLogger(e) + + // Decode the request body. + var apiCluster api.CreateWorkerClusterJSONBody + if err := e.Bind(&apiCluster); err != nil { + logger.Warn().Err(err).Msg("bad request received") + return sendAPIError(e, http.StatusBadRequest, "invalid format") + } + + // Convert to persistence layer model. + dbCluster := persistence.WorkerCluster{ + UUID: apiCluster.Id, + Name: apiCluster.Name, + } + if apiCluster.Description != nil { + dbCluster.Description = *apiCluster.Description + } + + // Store in the database. + if err := f.persist.CreateWorkerCluster(ctx, &dbCluster); err != nil { + logger.Error().Err(err).Msg("creating worker cluster") + return sendAPIError(e, http.StatusInternalServerError, "error creating worker cluster") + } + + // TODO: SocketIO broadcast of cluster creation. + + return e.NoContent(http.StatusNoContent) +} + func workerSummary(w persistence.Worker) api.WorkerSummary { summary := api.WorkerSummary{ Id: w.UUID, @@ -211,5 +400,24 @@ func workerDBtoAPI(w persistence.Worker) api.Worker { SupportedTaskTypes: w.TaskTypes(), } + if len(w.Clusters) > 0 { + clusters := []api.WorkerCluster{} + for i := range w.Clusters { + clusters = append(clusters, workerClusterDBtoAPI(*w.Clusters[i])) + } + apiWorker.Clusters = &clusters + } + return apiWorker } + +func workerClusterDBtoAPI(wc persistence.WorkerCluster) api.WorkerCluster { + apiCluster := api.WorkerCluster{ + Id: wc.UUID, + Name: wc.Name, + } + if len(wc.Description) > 0 { + apiCluster.Description = &wc.Description + } + return apiCluster +} diff --git a/internal/manager/api_impl/worker_mgt_test.go b/internal/manager/api_impl/worker_mgt_test.go index 23586b9b..60c1521c 100644 --- a/internal/manager/api_impl/worker_mgt_test.go +++ b/internal/manager/api_impl/worker_mgt_test.go @@ -10,6 +10,7 @@ import ( "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "git.blender.org/flamenco/internal/manager/persistence" "git.blender.org/flamenco/pkg/api" @@ -260,3 +261,59 @@ func TestRequestWorkerStatusChangeRevert(t *testing.T) { assert.NoError(t, err) assertResponseNoContent(t, echo) } + +func TestWorkerClusterCRUDHappyFlow(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + mf := newMockedFlamenco(mockCtrl) + + // Create a cluster. + UUID := "18d9234e-5135-458f-a1ba-a350c3d4e837" + apiCluster := api.WorkerCluster{ + Id: UUID, + Name: "ʻO nā manu ʻino", + Description: ptr("Ke aloha"), + } + expectDBCluster := persistence.WorkerCluster{ + UUID: UUID, + Name: apiCluster.Name, + Description: *apiCluster.Description, + } + mf.persistence.EXPECT().CreateWorkerCluster(gomock.Any(), &expectDBCluster) + // TODO: expect SocketIO broadcast of the cluster creation. + echo := mf.prepareMockedJSONRequest(apiCluster) + require.NoError(t, mf.flamenco.CreateWorkerCluster(echo)) + assertResponseNoContent(t, echo) + + // Fetch the cluster + mf.persistence.EXPECT().FetchWorkerCluster(gomock.Any(), UUID).Return(&expectDBCluster, nil) + echo = mf.prepareMockedRequest(nil) + require.NoError(t, mf.flamenco.FetchWorkerCluster(echo, UUID)) + assertResponseJSON(t, echo, http.StatusOK, &apiCluster) + + // Update & save. + newUUID := "60442762-83d3-4fc3-bf75-6ab5799cdbaa" + newAPICluster := api.WorkerCluster{ + Id: newUUID, // Intentionally change the UUID. This should just be ignored. + Name: "updated name", + } + expectNewDBCluster := persistence.WorkerCluster{ + UUID: UUID, + Name: newAPICluster.Name, + Description: "", + } + // TODO: expect SocketIO broadcast of the cluster update. + mf.persistence.EXPECT().FetchWorkerCluster(gomock.Any(), UUID).Return(&expectDBCluster, nil) + mf.persistence.EXPECT().SaveWorkerCluster(gomock.Any(), &expectNewDBCluster) + echo = mf.prepareMockedJSONRequest(newAPICluster) + require.NoError(t, mf.flamenco.UpdateWorkerCluster(echo, UUID)) + assertResponseNoContent(t, echo) + + // Delete. + mf.persistence.EXPECT().DeleteWorkerCluster(gomock.Any(), UUID) + // TODO: expect SocketIO broadcast of the cluster deletion. + echo = mf.prepareMockedJSONRequest(newAPICluster) + require.NoError(t, mf.flamenco.DeleteWorkerCluster(echo, UUID)) + assertResponseNoContent(t, echo) +} diff --git a/internal/manager/job_compilers/author.go b/internal/manager/job_compilers/author.go index 95ec1ff3..7a457feb 100644 --- a/internal/manager/job_compilers/author.go +++ b/internal/manager/job_compilers/author.go @@ -20,7 +20,9 @@ type Author struct { } type AuthoredJob struct { - JobID string + JobID string + WorkerClusterUUID string + Name string JobType string Priority int diff --git a/internal/manager/job_compilers/job_compilers.go b/internal/manager/job_compilers/job_compilers.go index 9d9f3452..3b4cae4c 100644 --- a/internal/manager/job_compilers/job_compilers.go +++ b/internal/manager/job_compilers/job_compilers.go @@ -127,6 +127,10 @@ func (s *Service) Compile(ctx context.Context, sj api.SubmittedJob) (*AuthoredJo aj.Storage.ShamanCheckoutID = *sj.Storage.ShamanCheckoutId } + if sj.WorkerCluster != nil { + aj.WorkerClusterUUID = *sj.WorkerCluster + } + compiler, err := vm.getCompileJob() if err != nil { return nil, err @@ -139,12 +143,13 @@ func (s *Service) Compile(ctx context.Context, sj api.SubmittedJob) (*AuthoredJo Int("num_tasks", len(aj.Tasks)). Str("name", aj.Name). Str("jobtype", aj.JobType). + Str("job", aj.JobID). Msg("job compiled") return &aj, nil } -// ListJobTypes returns the list of available job types. +// ListJobTypes returns the list of available job types. func (s *Service) ListJobTypes() api.AvailableJobTypes { jobTypes := make([]api.AvailableJobType, 0) diff --git a/internal/manager/job_compilers/job_compilers_test.go b/internal/manager/job_compilers/job_compilers_test.go index debc62df..ad746b51 100644 --- a/internal/manager/job_compilers/job_compilers_test.go +++ b/internal/manager/job_compilers/job_compilers_test.go @@ -45,11 +45,12 @@ func exampleSubmittedJob() api.SubmittedJob { "user.name": "Sybren Stüvel", }} sj := api.SubmittedJob{ - Name: "3Д рендеринг", - Priority: 50, - Type: "simple-blender-render", - Settings: &settings, - Metadata: &metadata, + Name: "3Д рендеринг", + Priority: 50, + Type: "simple-blender-render", + Settings: &settings, + Metadata: &metadata, + WorkerCluster: ptr("acce9983-e663-4210-b3cc-f7bfa629cb21"), } return sj } @@ -79,6 +80,7 @@ func TestSimpleBlenderRenderHappy(t *testing.T) { // Properties should be copied as-is. assert.Equal(t, sj.Name, aj.Name) + assert.Equal(t, *sj.WorkerCluster, aj.WorkerClusterUUID) assert.Equal(t, sj.Type, aj.JobType) assert.Equal(t, sj.Priority, aj.Priority) assert.EqualValues(t, sj.Settings.AdditionalProperties, aj.Settings) @@ -137,6 +139,35 @@ func TestSimpleBlenderRenderHappy(t *testing.T) { assert.Equal(t, expectDeps, tVideo.Dependencies) } +func TestJobWithoutCluster(t *testing.T) { + c := mockedClock(t) + + s, err := Load(c) + require.NoError(t, err) + + // Compiling a job should be really fast. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + sj := exampleSubmittedJob() + + // Try with nil WorkerCluster. + { + sj.WorkerCluster = nil + aj, err := s.Compile(ctx, sj) + require.NoError(t, err) + assert.Zero(t, aj.WorkerClusterUUID) + } + + // Try with empty WorkerCluster. + { + sj.WorkerCluster = ptr("") + aj, err := s.Compile(ctx, sj) + require.NoError(t, err) + assert.Zero(t, aj.WorkerClusterUUID) + } +} + func TestSimpleBlenderRenderWindowsPaths(t *testing.T) { c := mockedClock(t) diff --git a/internal/manager/persistence/db_migration.go b/internal/manager/persistence/db_migration.go index 998c265a..b60d9230 100644 --- a/internal/manager/persistence/db_migration.go +++ b/internal/manager/persistence/db_migration.go @@ -16,6 +16,7 @@ func (db *DB) migrate() error { &Task{}, &TaskFailure{}, &Worker{}, + &WorkerCluster{}, ) if err != nil { return fmt.Errorf("failed to automigrate database: %v", err) diff --git a/internal/manager/persistence/errors.go b/internal/manager/persistence/errors.go index b8e6379c..1316f481 100644 --- a/internal/manager/persistence/errors.go +++ b/internal/manager/persistence/errors.go @@ -9,9 +9,10 @@ import ( ) var ( - ErrJobNotFound = PersistenceError{Message: "job not found", Err: gorm.ErrRecordNotFound} - ErrTaskNotFound = PersistenceError{Message: "task not found", Err: gorm.ErrRecordNotFound} - ErrWorkerNotFound = PersistenceError{Message: "worker not found", Err: gorm.ErrRecordNotFound} + ErrJobNotFound = PersistenceError{Message: "job not found", Err: gorm.ErrRecordNotFound} + ErrTaskNotFound = PersistenceError{Message: "task not found", Err: gorm.ErrRecordNotFound} + ErrWorkerNotFound = PersistenceError{Message: "worker not found", Err: gorm.ErrRecordNotFound} + ErrWorkerClusterNotFound = PersistenceError{Message: "worker cluster not found", Err: gorm.ErrRecordNotFound} ) type PersistenceError struct { @@ -39,6 +40,10 @@ func workerError(errorToWrap error, message string, msgArgs ...interface{}) erro return wrapError(translateGormWorkerError(errorToWrap), message, msgArgs...) } +func workerClusterError(errorToWrap error, message string, msgArgs ...interface{}) error { + return wrapError(translateGormWorkerClusterError(errorToWrap), message, msgArgs...) +} + func wrapError(errorToWrap error, message string, format ...interface{}) error { // Only format if there are arguments for formatting. var formattedMsg string @@ -80,3 +85,12 @@ func translateGormWorkerError(gormError error) error { } return gormError } + +// translateGormWorkerClusterError translates a Gorm error to a persistence layer error. +// This helps to keep Gorm as "implementation detail" of the persistence layer. +func translateGormWorkerClusterError(gormError error) error { + if errors.Is(gormError, gorm.ErrRecordNotFound) { + return ErrWorkerClusterNotFound + } + return gormError +} diff --git a/internal/manager/persistence/jobs.go b/internal/manager/persistence/jobs.go index 690c53a6..fa893f8d 100644 --- a/internal/manager/persistence/jobs.go +++ b/internal/manager/persistence/jobs.go @@ -35,6 +35,9 @@ type Job struct { DeleteRequestedAt sql.NullTime Storage JobStorageInfo `gorm:"embedded;embeddedPrefix:storage_"` + + WorkerClusterID *uint + WorkerCluster *WorkerCluster `gorm:"foreignkey:WorkerClusterID;references:ID;constraint:OnDelete:SET NULL"` } type StringInterfaceMap map[string]interface{} @@ -145,6 +148,16 @@ func (db *DB) StoreAuthoredJob(ctx context.Context, authoredJob job_compilers.Au }, } + // Find and assign the worker cluster. + if authoredJob.WorkerClusterUUID != "" { + dbCluster, err := fetchWorkerCluster(tx, authoredJob.WorkerClusterUUID) + if err != nil { + return err + } + dbJob.WorkerClusterID = &dbCluster.ID + dbJob.WorkerCluster = dbCluster + } + if err := tx.Create(&dbJob).Error; err != nil { return jobError(err, "storing job") } @@ -212,6 +225,7 @@ func (db *DB) FetchJob(ctx context.Context, jobUUID string) (*Job, error) { dbJob := Job{} findResult := db.gormDB.WithContext(ctx). Limit(1). + Preload("WorkerCluster"). Find(&dbJob, "uuid = ?", jobUUID) if findResult.Error != nil { return nil, jobError(findResult.Error, "fetching job") diff --git a/internal/manager/persistence/jobs_blocklist.go b/internal/manager/persistence/jobs_blocklist.go index 91fd1454..12dc5103 100644 --- a/internal/manager/persistence/jobs_blocklist.go +++ b/internal/manager/persistence/jobs_blocklist.go @@ -103,13 +103,27 @@ func (db *DB) WorkersLeftToRun(ctx context.Context, job *Job, taskType string) ( Where("JB.job_id = ?", job.ID). Where("JB.task_type = ?", taskType) - // Find the workers NOT blocked. - workers := []*Worker{} - tx := db.gormDB.WithContext(ctx). + query := db.gormDB.WithContext(ctx). Model(&Worker{}). Select("uuid"). - Where("id not in (?)", blockedWorkers). - Scan(&workers) + Where("id not in (?)", blockedWorkers) + + if job.WorkerClusterID != nil { + // Count workers not in any cluster + workers in the job's cluster. + clusterless := db.gormDB. + Table("worker_cluster_membership"). + Select("worker_id") + jobCluster := db.gormDB. + Table("worker_cluster_membership"). + Select("worker_id"). + Where("worker_cluster_id = ?", *job.WorkerClusterID) + query = query. + Where("id not in (?) or id in (?)", clusterless, jobCluster) + } + + // Find the workers NOT blocked. + workers := []*Worker{} + tx := query.Scan(&workers) if tx.Error != nil { return nil, tx.Error } diff --git a/internal/manager/persistence/jobs_blocklist_test.go b/internal/manager/persistence/jobs_blocklist_test.go index a87cb5d1..97e28e39 100644 --- a/internal/manager/persistence/jobs_blocklist_test.go +++ b/internal/manager/persistence/jobs_blocklist_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // SPDX-License-Identifier: GPL-3.0-or-later @@ -160,6 +161,71 @@ func TestWorkersLeftToRun(t *testing.T) { } } +func TestWorkersLeftToRunWithClusters(t *testing.T) { + ctx, cancel, db := persistenceTestFixtures(t, schedulerTestTimeout) + defer cancel() + + // Create clusters. + cluster1 := WorkerCluster{UUID: "11157623-4b14-4801-bee2-271dddab6309", Name: "Cluster 1"} + cluster2 := WorkerCluster{UUID: "22257623-4b14-4801-bee2-271dddab6309", Name: "Cluster 2"} + cluster3 := WorkerCluster{UUID: "33357623-4b14-4801-bee2-271dddab6309", Name: "Cluster 3"} + require.NoError(t, db.CreateWorkerCluster(ctx, &cluster1)) + require.NoError(t, db.CreateWorkerCluster(ctx, &cluster2)) + require.NoError(t, db.CreateWorkerCluster(ctx, &cluster3)) + + // Create a job in cluster1. + authoredJob := createTestAuthoredJobWithTasks() + authoredJob.WorkerClusterUUID = cluster1.UUID + job := persistAuthoredJob(t, ctx, db, authoredJob) + + // Clusters 1 + 3 + workerC13 := createWorker(ctx, t, db, func(w *Worker) { + w.UUID = "c13c1313-0000-1111-2222-333333333333" + w.Clusters = []*WorkerCluster{&cluster1, &cluster3} + }) + // Cluster 1 + workerC1 := createWorker(ctx, t, db, func(w *Worker) { + w.UUID = "c1c1c1c1-0000-1111-2222-333333333333" + w.Clusters = []*WorkerCluster{&cluster1} + }) + // Cluster 2 worker, this one should never appear. + createWorker(ctx, t, db, func(w *Worker) { + w.UUID = "c2c2c2c2-0000-1111-2222-333333333333" + w.Clusters = []*WorkerCluster{&cluster2} + }) + // No clusters, so should be able to run all. + workerCNone := createWorker(ctx, t, db, func(w *Worker) { + w.UUID = "00000000-0000-1111-2222-333333333333" + w.Clusters = nil + }) + + uuidMap := func(workers ...*Worker) map[string]bool { + theMap := map[string]bool{} + for _, worker := range workers { + theMap[worker.UUID] = true + } + return theMap + } + + // All Cluster 1 workers + clusterless worker, no blocklist. + left, err := db.WorkersLeftToRun(ctx, job, "blender") + require.NoError(t, err) + assert.Equal(t, uuidMap(workerC13, workerC1, workerCNone), left) + + // One worker blocked, two workers remain. + _ = db.AddWorkerToJobBlocklist(ctx, job, workerC1, "blender") + left, err = db.WorkersLeftToRun(ctx, job, "blender") + require.NoError(t, err) + assert.Equal(t, uuidMap(workerC13, workerCNone), left) + + // All workers blocked. + _ = db.AddWorkerToJobBlocklist(ctx, job, workerC13, "blender") + _ = db.AddWorkerToJobBlocklist(ctx, job, workerCNone, "blender") + left, err = db.WorkersLeftToRun(ctx, job, "blender") + assert.NoError(t, err) + assert.Empty(t, left) +} + func TestCountTaskFailuresOfWorker(t *testing.T) { ctx, close, db, dbJob, authoredJob := jobTasksTestFixtures(t) defer close() diff --git a/internal/manager/persistence/jobs_query.go b/internal/manager/persistence/jobs_query.go index 8607b689..63773b4a 100644 --- a/internal/manager/persistence/jobs_query.go +++ b/internal/manager/persistence/jobs_query.go @@ -64,6 +64,8 @@ func (db *DB) QueryJobs(ctx context.Context, apiQ api.JobsQuery) ([]*Job, error) } } + q.Preload("Cluster") + result := []*Job{} tx := q.Scan(&result) return result, tx.Error diff --git a/internal/manager/persistence/jobs_test.go b/internal/manager/persistence/jobs_test.go index b0f340cf..e482fdec 100644 --- a/internal/manager/persistence/jobs_test.go +++ b/internal/manager/persistence/jobs_test.go @@ -676,7 +676,7 @@ func jobTasksTestFixtures(t *testing.T) (context.Context, context.CancelFunc, *D return ctx, cancel, db, dbJob, authoredJob } -func createWorker(ctx context.Context, t *testing.T, db *DB) *Worker { +func createWorker(ctx context.Context, t *testing.T, db *DB, updaters ...func(*Worker)) *Worker { w := Worker{ UUID: "f0a123a9-ab05-4ce2-8577-94802cfe74a4", Name: "дрон", @@ -685,6 +685,11 @@ func createWorker(ctx context.Context, t *testing.T, db *DB) *Worker { Software: "3.0", Status: api.WorkerStatusAwake, SupportedTaskTypes: "blender,ffmpeg,file-management", + Clusters: nil, + } + + for _, updater := range updaters { + updater(&w) } err := db.CreateWorker(ctx, &w) diff --git a/internal/manager/persistence/task_scheduler.go b/internal/manager/persistence/task_scheduler.go index c1d603cc..66b632b1 100644 --- a/internal/manager/persistence/task_scheduler.go +++ b/internal/manager/persistence/task_scheduler.go @@ -114,18 +114,30 @@ func findTaskForWorker(tx *gorm.DB, w *Worker) (*Task, error) { // a 'schedulable' status might have been assigned to a worker, representing // the last worker to touch it -- it's not meant to indicate "ownership" of // the task. - findTaskResult := tx. - Model(&task). + findTaskQuery := tx.Model(&task). Joins("left join jobs on tasks.job_id = jobs.id"). Joins("left join task_failures TF on tasks.id = TF.task_id and TF.worker_id=?", w.ID). - Where("tasks.status in ?", schedulableTaskStatuses). // Schedulable task statuses - Where("jobs.status in ?", schedulableJobStatuses). // Schedulable job statuses - Where("tasks.type in ?", w.TaskTypes()). // Supported task types - Where("tasks.id not in (?)", incompleteDepsQuery). // Dependencies completed - Where("TF.worker_id is NULL"). // Not failed before - Where("tasks.type not in (?)", blockedTaskTypesQuery). // Non-blocklisted - Order("jobs.priority desc"). // Highest job priority - Order("tasks.priority desc"). // Highest task priority + Where("tasks.status in ?", schedulableTaskStatuses). // Schedulable task statuses + Where("jobs.status in ?", schedulableJobStatuses). // Schedulable job statuses + Where("tasks.type in ?", w.TaskTypes()). // Supported task types + Where("tasks.id not in (?)", incompleteDepsQuery). // Dependencies completed + Where("TF.worker_id is NULL"). // Not failed before + Where("tasks.type not in (?)", blockedTaskTypesQuery) // Non-blocklisted + + if len(w.Clusters) > 0 { + // Worker is assigned to one or more clusters, so limit the available jobs + // to those that have no cluster, or overlap with the Worker's clusters. + clusterIDs := []uint{} + for _, cluster := range w.Clusters { + clusterIDs = append(clusterIDs, cluster.ID) + } + findTaskQuery = findTaskQuery. + Where("jobs.worker_cluster_id is NULL or worker_cluster_id in ?", clusterIDs) + } + + findTaskResult := findTaskQuery. + Order("jobs.priority desc"). // Highest job priority + Order("tasks.priority desc"). // Highest task priority Limit(1). Preload("Job"). Find(&task) diff --git a/internal/manager/persistence/task_scheduler_test.go b/internal/manager/persistence/task_scheduler_test.go index c9fc1de9..bc858cbe 100644 --- a/internal/manager/persistence/task_scheduler_test.go +++ b/internal/manager/persistence/task_scheduler_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "git.blender.org/flamenco/internal/manager/job_compilers" "git.blender.org/flamenco/internal/uuid" @@ -289,6 +290,69 @@ func TestPreviouslyFailed(t *testing.T) { assert.Equal(t, att2.Name, task.Name, "the second task should have been chosen") } +func TestWorkerClusterJobWithCluster(t *testing.T) { + ctx, cancel, db := persistenceTestFixtures(t, schedulerTestTimeout) + defer cancel() + + // Create worker clusters: + cluster1 := WorkerCluster{UUID: "f0157623-4b14-4801-bee2-271dddab6309", Name: "Cluster 1"} + cluster2 := WorkerCluster{UUID: "2f71dba1-cf92-4752-8386-f5926affabd5", Name: "Cluster 2"} + require.NoError(t, db.CreateWorkerCluster(ctx, &cluster1)) + require.NoError(t, db.CreateWorkerCluster(ctx, &cluster2)) + + // Create a worker in cluster1: + w := linuxWorker(t, db, func(w *Worker) { + w.Clusters = []*WorkerCluster{&cluster1} + }) + + { // Test job with different cluster: + authTask := authorTestTask("the task", "blender") + job := authorTestJob("499cf0f8-e83d-4cb1-837a-df94789d07db", "simple-blender-render", authTask) + job.WorkerClusterUUID = cluster2.UUID + constructTestJob(ctx, t, db, job) + + task, err := db.ScheduleTask(ctx, &w) + require.NoError(t, err) + assert.Nil(t, task, "job with different cluster should not be scheduled") + } + + { // Test job with matching cluster: + authTask := authorTestTask("the task", "blender") + job := authorTestJob("5d4c2321-0bb7-4c13-a9dd-32a2c0cd156e", "simple-blender-render", authTask) + job.WorkerClusterUUID = cluster1.UUID + constructTestJob(ctx, t, db, job) + + task, err := db.ScheduleTask(ctx, &w) + require.NoError(t, err) + require.NotNil(t, task, "job with matching cluster should be scheduled") + assert.Equal(t, authTask.UUID, task.UUID) + } +} + +func TestWorkerClusterJobWithoutCluster(t *testing.T) { + ctx, cancel, db := persistenceTestFixtures(t, schedulerTestTimeout) + defer cancel() + + // Create worker cluster: + cluster1 := WorkerCluster{UUID: "f0157623-4b14-4801-bee2-271dddab6309", Name: "Cluster 1"} + require.NoError(t, db.CreateWorkerCluster(ctx, &cluster1)) + + // Create a worker in cluster1: + w := linuxWorker(t, db, func(w *Worker) { + w.Clusters = []*WorkerCluster{&cluster1} + }) + + // Test cluster-less job: + authTask := authorTestTask("the task", "blender") + job := authorTestJob("b6a1d859-122f-4791-8b78-b943329a9989", "simple-blender-render", authTask) + constructTestJob(ctx, t, db, job) + + task, err := db.ScheduleTask(ctx, &w) + require.NoError(t, err) + require.NotNil(t, task, "job without cluster should always be scheduled") + assert.Equal(t, authTask.UUID, task.UUID) +} + func TestBlocklisted(t *testing.T) { ctx, cancel, db := persistenceTestFixtures(t, schedulerTestTimeout) defer cancel() @@ -383,7 +447,7 @@ func setTaskStatus(t *testing.T, db *DB, taskUUID string, status api.TaskStatus) } } -func linuxWorker(t *testing.T, db *DB) Worker { +func linuxWorker(t *testing.T, db *DB, updaters ...func(worker *Worker)) Worker { w := Worker{ UUID: "b13b8322-3e96-41c3-940a-3d581008a5f8", Name: "Linux", @@ -392,6 +456,10 @@ func linuxWorker(t *testing.T, db *DB) Worker { SupportedTaskTypes: "blender,ffmpeg,file-management,misc", } + for _, updater := range updaters { + updater(&w) + } + err := db.gormDB.Save(&w).Error if err != nil { t.Logf("cannot save Linux worker: %v", err) diff --git a/internal/manager/persistence/test_support.go b/internal/manager/persistence/test_support.go index 6dd5d2c8..80832819 100644 --- a/internal/manager/persistence/test_support.go +++ b/internal/manager/persistence/test_support.go @@ -10,9 +10,12 @@ import ( "testing" "time" + "git.blender.org/flamenco/internal/uuid" + "git.blender.org/flamenco/pkg/api" "github.com/glebarez/sqlite" "github.com/rs/zerolog" "github.com/rs/zerolog/log" + "github.com/stretchr/testify/require" "gorm.io/gorm" ) @@ -87,3 +90,44 @@ func persistenceTestFixtures(t *testing.T, testContextTimeout time.Duration) (co return ctx, cancel, db } + +type WorkerTestFixture struct { + db *DB + ctx context.Context + done func() + + worker *Worker + cluster *WorkerCluster +} + +func workerTestFixtures(t *testing.T, testContextTimeout time.Duration) WorkerTestFixture { + ctx, cancel, db := persistenceTestFixtures(t, testContextTimeout) + + w := Worker{ + UUID: uuid.New(), + Name: "дрон", + Address: "fe80::5054:ff:fede:2ad7", + Platform: "linux", + Software: "3.0", + Status: api.WorkerStatusAwake, + SupportedTaskTypes: "blender,ffmpeg,file-management", + } + + wc := WorkerCluster{ + UUID: uuid.New(), + Name: "arbejdsklynge", + Description: "Worker cluster in Danish", + } + + require.NoError(t, db.CreateWorker(ctx, &w)) + require.NoError(t, db.CreateWorkerCluster(ctx, &wc)) + + return WorkerTestFixture{ + db: db, + ctx: ctx, + done: cancel, + + worker: &w, + cluster: &wc, + } +} diff --git a/internal/manager/persistence/timeout_test.go b/internal/manager/persistence/timeout_test.go index 729caf63..f26b1ef2 100644 --- a/internal/manager/persistence/timeout_test.go +++ b/internal/manager/persistence/timeout_test.go @@ -47,7 +47,7 @@ func TestFetchTimedOutTasks(t *testing.T) { // tests that the expected task is returned. assert.Equal(t, task.UUID, timedout[0].UUID) assert.Equal(t, job, timedout[0].Job, "the job should be included in the result as well") - assert.Equal(t, w, timedout[0].Worker, "the worker should be included in the result as well") + assert.Equal(t, w.UUID, timedout[0].Worker.UUID, "the worker should be included in the result as well") } } diff --git a/internal/manager/persistence/worker_cluster.go b/internal/manager/persistence/worker_cluster.go new file mode 100644 index 00000000..6d1c6ba9 --- /dev/null +++ b/internal/manager/persistence/worker_cluster.go @@ -0,0 +1,100 @@ +package persistence + +// SPDX-License-Identifier: GPL-3.0-or-later + +import ( + "context" + "fmt" + + "gorm.io/gorm" +) + +type WorkerCluster struct { + Model + + UUID string `gorm:"type:char(36);default:'';unique;index"` + Name string `gorm:"type:varchar(64);default:'';unique"` + Description string `gorm:"type:varchar(255);default:''"` + + Workers []*Worker `gorm:"many2many:worker_cluster_membership;constraint:OnDelete:CASCADE"` +} + +func (db *DB) CreateWorkerCluster(ctx context.Context, wc *WorkerCluster) error { + if err := db.gormDB.WithContext(ctx).Create(wc).Error; err != nil { + return fmt.Errorf("creating new worker cluster: %w", err) + } + return nil +} + +func (db *DB) FetchWorkerCluster(ctx context.Context, uuid string) (*WorkerCluster, error) { + tx := db.gormDB.WithContext(ctx) + return fetchWorkerCluster(tx, uuid) +} + +// fetchWorkerCluster fetches the worker cluster using the given database instance. +func fetchWorkerCluster(gormDB *gorm.DB, uuid string) (*WorkerCluster, error) { + w := WorkerCluster{} + tx := gormDB.First(&w, "uuid = ?", uuid) + if tx.Error != nil { + return nil, workerClusterError(tx.Error, "fetching worker cluster") + } + return &w, nil +} + +func (db *DB) SaveWorkerCluster(ctx context.Context, cluster *WorkerCluster) error { + if err := db.gormDB.WithContext(ctx).Save(cluster).Error; err != nil { + return workerClusterError(err, "saving worker cluster") + } + return nil +} + +// DeleteWorkerCluster deletes the given cluster, after unassigning all workers from it. +func (db *DB) DeleteWorkerCluster(ctx context.Context, uuid string) error { + tx := db.gormDB.WithContext(ctx). + Where("uuid = ?", uuid). + Delete(&WorkerCluster{}) + if tx.Error != nil { + return workerClusterError(tx.Error, "deleting worker cluster") + } + if tx.RowsAffected == 0 { + return ErrWorkerClusterNotFound + } + return nil +} + +func (db *DB) FetchWorkerClusters(ctx context.Context) ([]*WorkerCluster, error) { + clusters := make([]*WorkerCluster, 0) + tx := db.gormDB.WithContext(ctx).Model(&WorkerCluster{}).Scan(&clusters) + if tx.Error != nil { + return nil, workerClusterError(tx.Error, "fetching all worker clusters") + } + return clusters, nil +} + +func (db *DB) fetchWorkerClustersWithUUID(ctx context.Context, clusterUUIDs []string) ([]*WorkerCluster, error) { + clusters := make([]*WorkerCluster, 0) + tx := db.gormDB.WithContext(ctx). + Model(&WorkerCluster{}). + Where("uuid in ?", clusterUUIDs). + Scan(&clusters) + if tx.Error != nil { + return nil, workerClusterError(tx.Error, "fetching all worker clusters") + } + return clusters, nil +} + +func (db *DB) WorkerSetClusters(ctx context.Context, worker *Worker, clusterUUIDs []string) error { + clusters, err := db.fetchWorkerClustersWithUUID(ctx, clusterUUIDs) + if err != nil { + return workerClusterError(err, "fetching worker clusters") + } + + err = db.gormDB.WithContext(ctx). + Model(worker). + Association("Clusters"). + Replace(clusters) + if err != nil { + return workerClusterError(err, "updating worker clusters") + } + return nil +} diff --git a/internal/manager/persistence/worker_cluster_test.go b/internal/manager/persistence/worker_cluster_test.go new file mode 100644 index 00000000..e520fbc4 --- /dev/null +++ b/internal/manager/persistence/worker_cluster_test.go @@ -0,0 +1,150 @@ +package persistence + +// SPDX-License-Identifier: GPL-3.0-or-later + +import ( + "testing" + "time" + + "git.blender.org/flamenco/internal/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCreateFetchCluster(t *testing.T) { + f := workerTestFixtures(t, 1*time.Second) + defer f.done() + + // Test fetching non-existent cluster + fetchedCluster, err := f.db.FetchWorkerCluster(f.ctx, "7ee21bc8-ff1a-42d2-a6b6-cc4b529b189f") + assert.ErrorIs(t, err, ErrWorkerClusterNotFound) + assert.Nil(t, fetchedCluster) + + // New cluster creation is already done in the workerTestFixtures() call. + assert.NotNil(t, f.cluster) + + fetchedCluster, err = f.db.FetchWorkerCluster(f.ctx, f.cluster.UUID) + require.NoError(t, err) + assert.NotNil(t, fetchedCluster) + + // Test contents of fetched cluster. + assert.Equal(t, f.cluster.UUID, fetchedCluster.UUID) + assert.Equal(t, f.cluster.Name, fetchedCluster.Name) + assert.Equal(t, f.cluster.Description, fetchedCluster.Description) + assert.Zero(t, fetchedCluster.Workers) +} + +func TestFetchDeleteClusters(t *testing.T) { + f := workerTestFixtures(t, 1*time.Second) + defer f.done() + + secondCluster := WorkerCluster{ + UUID: uuid.New(), + Name: "arbeiderscluster", + Description: "Worker cluster in Dutch", + } + + require.NoError(t, f.db.CreateWorkerCluster(f.ctx, &secondCluster)) + + allClusters, err := f.db.FetchWorkerClusters(f.ctx) + require.NoError(t, err) + + require.Len(t, allClusters, 2) + var allClusterIDs [2]string + for idx := range allClusters { + allClusterIDs[idx] = allClusters[idx].UUID + } + assert.Contains(t, allClusterIDs, f.cluster.UUID) + assert.Contains(t, allClusterIDs, secondCluster.UUID) + + // Test deleting the 2nd cluster. + require.NoError(t, f.db.DeleteWorkerCluster(f.ctx, secondCluster.UUID)) + + allClusters, err = f.db.FetchWorkerClusters(f.ctx) + require.NoError(t, err) + require.Len(t, allClusters, 1) + assert.Equal(t, f.cluster.UUID, allClusters[0].UUID) +} + +func TestAssignUnassignWorkerClusters(t *testing.T) { + f := workerTestFixtures(t, 1*time.Second) + defer f.done() + + assertClusters := func(msgLabel string, clusterUUIDs ...string) { + w, err := f.db.FetchWorker(f.ctx, f.worker.UUID) + require.NoError(t, err) + + // Catch doubly-reported clusters, as the maps below would hide those cases. + assert.Len(t, w.Clusters, len(clusterUUIDs), msgLabel) + + expectClusters := make(map[string]bool) + for _, cid := range clusterUUIDs { + expectClusters[cid] = true + } + + actualClusters := make(map[string]bool) + for _, c := range w.Clusters { + actualClusters[c.UUID] = true + } + + assert.Equal(t, expectClusters, actualClusters, msgLabel) + } + + secondCluster := WorkerCluster{ + UUID: uuid.New(), + Name: "arbeiderscluster", + Description: "Worker cluster in Dutch", + } + + require.NoError(t, f.db.CreateWorkerCluster(f.ctx, &secondCluster)) + + // By default the Worker should not be part of a cluster. + assertClusters("default cluster assignment") + + require.NoError(t, f.db.WorkerSetClusters(f.ctx, f.worker, []string{f.cluster.UUID})) + assertClusters("setting one cluster", f.cluster.UUID) + + // Double assignments should also just work. + require.NoError(t, f.db.WorkerSetClusters(f.ctx, f.worker, []string{f.cluster.UUID, f.cluster.UUID})) + assertClusters("setting twice the same cluster", f.cluster.UUID) + + // Multiple cluster memberships. + require.NoError(t, f.db.WorkerSetClusters(f.ctx, f.worker, []string{f.cluster.UUID, secondCluster.UUID})) + assertClusters("setting two different clusters", f.cluster.UUID, secondCluster.UUID) + + // Remove memberships. + require.NoError(t, f.db.WorkerSetClusters(f.ctx, f.worker, []string{secondCluster.UUID})) + assertClusters("unassigning from first cluster", secondCluster.UUID) + require.NoError(t, f.db.WorkerSetClusters(f.ctx, f.worker, []string{})) + assertClusters("unassigning from second cluster") +} + +func TestSaveWorkerCluster(t *testing.T) { + f := workerTestFixtures(t, 1*time.Second) + defer f.done() + + f.cluster.Name = "übercluster" + f.cluster.Description = "ʻO kēlā hui ma laila" + require.NoError(t, f.db.SaveWorkerCluster(f.ctx, f.cluster)) + + fetched, err := f.db.FetchWorkerCluster(f.ctx, f.cluster.UUID) + require.NoError(t, err) + assert.Equal(t, f.cluster.Name, fetched.Name) + assert.Equal(t, f.cluster.Description, fetched.Description) +} + +func TestDeleteWorkerClusterWithWorkersAssigned(t *testing.T) { + f := workerTestFixtures(t, 1*time.Second) + defer f.done() + + // Assign the worker. + require.NoError(t, f.db.WorkerSetClusters(f.ctx, f.worker, []string{f.cluster.UUID})) + + // Delete the cluster. + require.NoError(t, f.db.DeleteWorkerCluster(f.ctx, f.cluster.UUID)) + + // Check the Worker has been unassigned from the cluster. + w, err := f.db.FetchWorker(f.ctx, f.worker.UUID) + require.NoError(t, err) + assert.Empty(t, w.Clusters) +} diff --git a/internal/manager/persistence/workers.go b/internal/manager/persistence/workers.go index b80b0b0e..b71996ab 100644 --- a/internal/manager/persistence/workers.go +++ b/internal/manager/persistence/workers.go @@ -30,6 +30,8 @@ type Worker struct { LazyStatusRequest bool `gorm:"type:smallint;default:0"` SupportedTaskTypes string `gorm:"type:varchar(255);default:''"` // comma-separated list of task types. + + Clusters []*WorkerCluster `gorm:"many2many:worker_cluster_membership;constraint:OnDelete:CASCADE"` } func (w *Worker) Identifier() string { @@ -71,6 +73,7 @@ func (db *DB) CreateWorker(ctx context.Context, w *Worker) error { func (db *DB) FetchWorker(ctx context.Context, uuid string) (*Worker, error) { w := Worker{} tx := db.gormDB.WithContext(ctx). + Preload("Clusters"). First(&w, "uuid = ?", uuid) if tx.Error != nil { return nil, workerError(tx.Error, "fetching worker") diff --git a/internal/manager/persistence/workers_test.go b/internal/manager/persistence/workers_test.go index aae65b71..38417f01 100644 --- a/internal/manager/persistence/workers_test.go +++ b/internal/manager/persistence/workers_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "git.blender.org/flamenco/internal/uuid" "git.blender.org/flamenco/pkg/api" @@ -317,3 +318,19 @@ func TestDeleteWorker(t *testing.T) { assert.True(t, fetchedTask.Worker.DeletedAt.Valid) } } + +func TestDeleteWorkerWithClusterAssigned(t *testing.T) { + f := workerTestFixtures(t, 1*time.Second) + defer f.done() + + // Assign the worker. + require.NoError(t, f.db.WorkerSetClusters(f.ctx, f.worker, []string{f.cluster.UUID})) + + // Delete the Worker. + require.NoError(t, f.db.DeleteWorker(f.ctx, f.worker.UUID)) + + // Check the Worker has been unassigned from the cluster. + cluster, err := f.db.FetchWorkerCluster(f.ctx, f.cluster.UUID) + require.NoError(t, err) + assert.Empty(t, cluster.Workers) +} diff --git a/internal/manager/webupdates/worker_updates.go b/internal/manager/webupdates/worker_updates.go index f546a777..a088cdb3 100644 --- a/internal/manager/webupdates/worker_updates.go +++ b/internal/manager/webupdates/worker_updates.go @@ -32,6 +32,8 @@ func NewWorkerUpdate(worker *persistence.Worker) api.SocketIOWorkerUpdate { workerUpdate.LastSeen = &worker.LastSeenAt } + // TODO: add cluster IDs. + return workerUpdate }