diff --git a/internal/manager/persistence/jobs_blocklist.go b/internal/manager/persistence/jobs_blocklist.go index 12dc5103..c12b09b7 100644 --- a/internal/manager/persistence/jobs_blocklist.go +++ b/internal/manager/persistence/jobs_blocklist.go @@ -108,17 +108,16 @@ func (db *DB) WorkersLeftToRun(ctx context.Context, job *Job, taskType string) ( Select("uuid"). Where("id not in (?)", blockedWorkers) - if job.WorkerClusterID != nil { - // Count workers not in any cluster + workers in the job's cluster. - clusterless := db.gormDB. - Table("worker_cluster_membership"). - Select("worker_id") + if job.WorkerClusterID == nil { + // Count all workers, so no extra restrictions are necessary. + } else { + // Only count workers in the job's cluster. jobCluster := db.gormDB. Table("worker_cluster_membership"). Select("worker_id"). Where("worker_cluster_id = ?", *job.WorkerClusterID) query = query. - Where("id not in (?) or id in (?)", clusterless, jobCluster) + Where("id in (?)", jobCluster) } // Find the workers NOT blocked. diff --git a/internal/manager/persistence/jobs_blocklist_test.go b/internal/manager/persistence/jobs_blocklist_test.go index 97e28e39..97e3503d 100644 --- a/internal/manager/persistence/jobs_blocklist_test.go +++ b/internal/manager/persistence/jobs_blocklist_test.go @@ -126,6 +126,16 @@ func TestWorkersLeftToRun(t *testing.T) { worker1 := createWorker(ctx, t, db) worker2 := createWorkerFrom(ctx, t, db, *worker1) + // Create one worker cluster. It will not be used by this job, but one of the + // workers will be assigned to it. It can get this job's tasks, though. + // Because the job is clusterless, it can be run by all. + cluster1 := WorkerCluster{UUID: "11157623-4b14-4801-bee2-271dddab6309", Name: "Cluster 1"} + require.NoError(t, db.CreateWorkerCluster(ctx, &cluster1)) + workerC1 := createWorker(ctx, t, db, func(w *Worker) { + w.UUID = "c1c1c1c1-0000-1111-2222-333333333333" + w.Clusters = []*WorkerCluster{&cluster1} + }) + uuidMap := func(workers ...*Worker) map[string]bool { theMap := map[string]bool{} for _, worker := range workers { @@ -134,21 +144,22 @@ func TestWorkersLeftToRun(t *testing.T) { return theMap } - // Two workers, no blocklist. + // Three workers, no blocklist. left, err = db.WorkersLeftToRun(ctx, job, "blender") if assert.NoError(t, err) { - assert.Equal(t, uuidMap(worker1, worker2), left) + assert.Equal(t, uuidMap(worker1, worker2, workerC1), left) } // Two workers, one blocked. _ = db.AddWorkerToJobBlocklist(ctx, job, worker1, "blender") left, err = db.WorkersLeftToRun(ctx, job, "blender") if assert.NoError(t, err) { - assert.Equal(t, uuidMap(worker2), left) + assert.Equal(t, uuidMap(worker2, workerC1), left) } - // Two workers, both blocked. + // All workers blocked. _ = db.AddWorkerToJobBlocklist(ctx, job, worker2, "blender") + _ = db.AddWorkerToJobBlocklist(ctx, job, workerC1, "blender") left, err = db.WorkersLeftToRun(ctx, job, "blender") assert.NoError(t, err) assert.Empty(t, left) @@ -157,7 +168,7 @@ func TestWorkersLeftToRun(t *testing.T) { fakeJob := Job{Model: Model{ID: 327}} left, err = db.WorkersLeftToRun(ctx, &fakeJob, "blender") if assert.NoError(t, err) { - assert.Equal(t, uuidMap(worker1, worker2), left) + assert.Equal(t, uuidMap(worker1, worker2, workerC1), left) } } @@ -193,8 +204,9 @@ func TestWorkersLeftToRunWithClusters(t *testing.T) { w.UUID = "c2c2c2c2-0000-1111-2222-333333333333" w.Clusters = []*WorkerCluster{&cluster2} }) - // No clusters, so should be able to run all. - workerCNone := createWorker(ctx, t, db, func(w *Worker) { + // No clusters, so should be able to run only clusterless jobs. Which is none + // in this test. + createWorker(ctx, t, db, func(w *Worker) { w.UUID = "00000000-0000-1111-2222-333333333333" w.Clusters = nil }) @@ -207,20 +219,19 @@ func TestWorkersLeftToRunWithClusters(t *testing.T) { return theMap } - // All Cluster 1 workers + clusterless worker, no blocklist. + // All Cluster 1 workers, no blocklist. left, err := db.WorkersLeftToRun(ctx, job, "blender") require.NoError(t, err) - assert.Equal(t, uuidMap(workerC13, workerC1, workerCNone), left) + assert.Equal(t, uuidMap(workerC13, workerC1), left) - // One worker blocked, two workers remain. + // One worker blocked, one worker remain. _ = db.AddWorkerToJobBlocklist(ctx, job, workerC1, "blender") left, err = db.WorkersLeftToRun(ctx, job, "blender") require.NoError(t, err) - assert.Equal(t, uuidMap(workerC13, workerCNone), left) + assert.Equal(t, uuidMap(workerC13), left) - // All workers blocked. + // All clustered workers blocked. _ = db.AddWorkerToJobBlocklist(ctx, job, workerC13, "blender") - _ = db.AddWorkerToJobBlocklist(ctx, job, workerCNone, "blender") left, err = db.WorkersLeftToRun(ctx, job, "blender") assert.NoError(t, err) assert.Empty(t, left)