Manager: convert worker tag queries to sqlc

Ref: #104305
This commit is contained in:
Sybren A. Stüvel 2024-09-18 15:48:42 +02:00
parent 35313477a0
commit 4bd6dc64b0
5 changed files with 382 additions and 95 deletions

View File

@ -179,14 +179,14 @@ func (db *DB) StoreAuthoredJob(ctx context.Context, authoredJob job_compilers.Au
}
if authoredJob.WorkerTagUUID != "" {
dbTag, err := qtx.queries.FetchWorkerTagByUUID(ctx, authoredJob.WorkerTagUUID)
workerTag, err := qtx.queries.FetchWorkerTagByUUID(ctx, authoredJob.WorkerTagUUID)
switch {
case errors.Is(err, sql.ErrNoRows):
return fmt.Errorf("no worker tag %q found", authoredJob.WorkerTagUUID)
case err != nil:
return fmt.Errorf("could not find worker tag %q: %w", authoredJob.WorkerTagUUID, err)
}
params.WorkerTagID = sql.NullInt64{Int64: dbTag.WorkerTag.ID, Valid: true}
params.WorkerTagID = sql.NullInt64{Int64: workerTag.ID, Valid: true}
}
log.Debug().
@ -358,7 +358,7 @@ func (db *DB) FetchJob(ctx context.Context, jobUUID string) (*Job, error) {
}
if sqlcJob.WorkerTagID.Valid {
workerTag, err := fetchWorkerTagByID(db.gormDB, uint(sqlcJob.WorkerTagID.Int64))
workerTag, err := fetchWorkerTagByID(ctx, queries, sqlcJob.WorkerTagID.Int64)
switch {
case errors.Is(err, sql.ErrNoRows):
return nil, ErrWorkerTagNotFound
@ -387,7 +387,7 @@ func (db *DB) FetchJobs(ctx context.Context) ([]*Job, error) {
}
if sqlcJob.WorkerTagID.Valid {
workerTag, err := fetchWorkerTagByID(db.gormDB, uint(sqlcJob.WorkerTagID.Int64))
workerTag, err := fetchWorkerTagByID(ctx, queries, sqlcJob.WorkerTagID.Int64)
switch {
case errors.Is(err, sql.ErrNoRows):
return nil, ErrWorkerTagNotFound

View File

@ -33,10 +33,15 @@ INSERT INTO workers (
)
RETURNING id;
-- name: AddWorkerTagMembership :exec
-- name: WorkerAddTagMembership :exec
INSERT INTO worker_tag_membership (worker_tag_id, worker_id)
VALUES (@worker_tag_id, @worker_id);
-- name: WorkerRemoveTagMemberships :exec
DELETE
FROM worker_tag_membership
WHERE worker_id=@worker_id;
-- name: FetchWorkers :many
SELECT sqlc.embed(workers) FROM workers
WHERE deleted_at IS NULL;
@ -53,18 +58,61 @@ SELECT * FROM workers WHERE workers.uuid = @uuid;
-- FetchWorkerUnconditional ignores soft-deletion status and just returns the worker.
SELECT * FROM workers WHERE workers.id = @worker_id;
-- name: FetchWorkerTags :many
-- name: FetchTagsOfWorker :many
SELECT worker_tags.*
FROM worker_tags
LEFT JOIN worker_tag_membership m ON (m.worker_tag_id = worker_tags.id)
LEFT JOIN workers on (m.worker_id = workers.id)
WHERE workers.uuid = @uuid;
-- name: FetchWorkerTags :many
SELECT *
FROM worker_tags;
-- name: FetchWorkerTagByUUID :one
SELECT sqlc.embed(worker_tags)
SELECT *
FROM worker_tags
WHERE worker_tags.uuid = @uuid;
-- name: FetchWorkerTagsByUUIDs :many
SELECT *
FROM worker_tags
WHERE uuid in (sqlc.slice('uuids'));
-- name: FetchWorkerTagByID :one
SELECT *
FROM worker_tags
WHERE id=@worker_tag_id;
-- name: SaveWorkerTag :exec
UPDATE worker_tags
SET
updated_at=@updated_at,
uuid=@uuid,
name=@name,
description=@description
WHERE id=@worker_tag_id;
-- name: DeleteWorkerTag :execrows
DELETE FROM worker_tags
WHERE uuid=@uuid;
-- name: CreateWorkerTag :execlastid
INSERT INTO worker_tags (
created_at,
uuid,
name,
description
) VALUES (
@created_at,
@uuid,
@name,
@description
);
-- name: CountWorkerTags :one
SELECT count(id) as count FROM worker_tags;
-- name: SoftDeleteWorker :execrows
UPDATE workers SET deleted_at=@deleted_at
WHERE uuid=@uuid;

View File

@ -8,22 +8,19 @@ package sqlc
import (
"context"
"database/sql"
"strings"
"time"
)
const addWorkerTagMembership = `-- name: AddWorkerTagMembership :exec
INSERT INTO worker_tag_membership (worker_tag_id, worker_id)
VALUES (?1, ?2)
const countWorkerTags = `-- name: CountWorkerTags :one
SELECT count(id) as count FROM worker_tags
`
type AddWorkerTagMembershipParams struct {
WorkerTagID int64
WorkerID int64
}
func (q *Queries) AddWorkerTagMembership(ctx context.Context, arg AddWorkerTagMembershipParams) error {
_, err := q.db.ExecContext(ctx, addWorkerTagMembership, arg.WorkerTagID, arg.WorkerID)
return err
func (q *Queries) CountWorkerTags(ctx context.Context) (int64, error) {
row := q.db.QueryRowContext(ctx, countWorkerTags)
var count int64
err := row.Scan(&count)
return count, err
}
const createWorker = `-- name: CreateWorker :one
@ -100,6 +97,91 @@ func (q *Queries) CreateWorker(ctx context.Context, arg CreateWorkerParams) (int
return id, err
}
const createWorkerTag = `-- name: CreateWorkerTag :execlastid
INSERT INTO worker_tags (
created_at,
uuid,
name,
description
) VALUES (
?1,
?2,
?3,
?4
)
`
type CreateWorkerTagParams struct {
CreatedAt time.Time
UUID string
Name string
Description string
}
func (q *Queries) CreateWorkerTag(ctx context.Context, arg CreateWorkerTagParams) (int64, error) {
result, err := q.db.ExecContext(ctx, createWorkerTag,
arg.CreatedAt,
arg.UUID,
arg.Name,
arg.Description,
)
if err != nil {
return 0, err
}
return result.LastInsertId()
}
const deleteWorkerTag = `-- name: DeleteWorkerTag :execrows
DELETE FROM worker_tags
WHERE uuid=?1
`
func (q *Queries) DeleteWorkerTag(ctx context.Context, uuid string) (int64, error) {
result, err := q.db.ExecContext(ctx, deleteWorkerTag, uuid)
if err != nil {
return 0, err
}
return result.RowsAffected()
}
const fetchTagsOfWorker = `-- name: FetchTagsOfWorker :many
SELECT worker_tags.id, worker_tags.created_at, worker_tags.updated_at, worker_tags.uuid, worker_tags.name, worker_tags.description
FROM worker_tags
LEFT JOIN worker_tag_membership m ON (m.worker_tag_id = worker_tags.id)
LEFT JOIN workers on (m.worker_id = workers.id)
WHERE workers.uuid = ?1
`
func (q *Queries) FetchTagsOfWorker(ctx context.Context, uuid string) ([]WorkerTag, error) {
rows, err := q.db.QueryContext(ctx, fetchTagsOfWorker, uuid)
if err != nil {
return nil, err
}
defer rows.Close()
var items []WorkerTag
for rows.Next() {
var i WorkerTag
if err := rows.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.UUID,
&i.Name,
&i.Description,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const fetchWorker = `-- name: FetchWorker :one
SELECT id, created_at, updated_at, uuid, secret, name, address, platform, software, status, last_seen_at, status_requested, lazy_status_request, supported_task_types, deleted_at, can_restart FROM workers WHERE workers.uuid = ?1 and deleted_at is NULL
`
@ -129,40 +211,99 @@ func (q *Queries) FetchWorker(ctx context.Context, uuid string) (Worker, error)
return i, err
}
const fetchWorkerTagByID = `-- name: FetchWorkerTagByID :one
SELECT id, created_at, updated_at, uuid, name, description
FROM worker_tags
WHERE id=?1
`
func (q *Queries) FetchWorkerTagByID(ctx context.Context, workerTagID int64) (WorkerTag, error) {
row := q.db.QueryRowContext(ctx, fetchWorkerTagByID, workerTagID)
var i WorkerTag
err := row.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.UUID,
&i.Name,
&i.Description,
)
return i, err
}
const fetchWorkerTagByUUID = `-- name: FetchWorkerTagByUUID :one
SELECT worker_tags.id, worker_tags.created_at, worker_tags.updated_at, worker_tags.uuid, worker_tags.name, worker_tags.description
SELECT id, created_at, updated_at, uuid, name, description
FROM worker_tags
WHERE worker_tags.uuid = ?1
`
type FetchWorkerTagByUUIDRow struct {
WorkerTag WorkerTag
}
func (q *Queries) FetchWorkerTagByUUID(ctx context.Context, uuid string) (FetchWorkerTagByUUIDRow, error) {
func (q *Queries) FetchWorkerTagByUUID(ctx context.Context, uuid string) (WorkerTag, error) {
row := q.db.QueryRowContext(ctx, fetchWorkerTagByUUID, uuid)
var i FetchWorkerTagByUUIDRow
var i WorkerTag
err := row.Scan(
&i.WorkerTag.ID,
&i.WorkerTag.CreatedAt,
&i.WorkerTag.UpdatedAt,
&i.WorkerTag.UUID,
&i.WorkerTag.Name,
&i.WorkerTag.Description,
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.UUID,
&i.Name,
&i.Description,
)
return i, err
}
const fetchWorkerTags = `-- name: FetchWorkerTags :many
SELECT worker_tags.id, worker_tags.created_at, worker_tags.updated_at, worker_tags.uuid, worker_tags.name, worker_tags.description
SELECT id, created_at, updated_at, uuid, name, description
FROM worker_tags
LEFT JOIN worker_tag_membership m ON (m.worker_tag_id = worker_tags.id)
LEFT JOIN workers on (m.worker_id = workers.id)
WHERE workers.uuid = ?1
`
func (q *Queries) FetchWorkerTags(ctx context.Context, uuid string) ([]WorkerTag, error) {
rows, err := q.db.QueryContext(ctx, fetchWorkerTags, uuid)
func (q *Queries) FetchWorkerTags(ctx context.Context) ([]WorkerTag, error) {
rows, err := q.db.QueryContext(ctx, fetchWorkerTags)
if err != nil {
return nil, err
}
defer rows.Close()
var items []WorkerTag
for rows.Next() {
var i WorkerTag
if err := rows.Scan(
&i.ID,
&i.CreatedAt,
&i.UpdatedAt,
&i.UUID,
&i.Name,
&i.Description,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const fetchWorkerTagsByUUIDs = `-- name: FetchWorkerTagsByUUIDs :many
SELECT id, created_at, updated_at, uuid, name, description
FROM worker_tags
WHERE uuid in (/*SLICE:uuids*/?)
`
func (q *Queries) FetchWorkerTagsByUUIDs(ctx context.Context, uuids []string) ([]WorkerTag, error) {
query := fetchWorkerTagsByUUIDs
var queryParams []interface{}
if len(uuids) > 0 {
for _, v := range uuids {
queryParams = append(queryParams, v)
}
query = strings.Replace(query, "/*SLICE:uuids*/?", strings.Repeat(",?", len(uuids))[1:], 1)
} else {
query = strings.Replace(query, "/*SLICE:uuids*/?", "NULL", 1)
}
rows, err := q.db.QueryContext(ctx, query, queryParams...)
if err != nil {
return nil, err
}
@ -381,6 +522,35 @@ func (q *Queries) SaveWorkerStatus(ctx context.Context, arg SaveWorkerStatusPara
return err
}
const saveWorkerTag = `-- name: SaveWorkerTag :exec
UPDATE worker_tags
SET
updated_at=?1,
uuid=?2,
name=?3,
description=?4
WHERE id=?5
`
type SaveWorkerTagParams struct {
UpdatedAt sql.NullTime
UUID string
Name string
Description string
WorkerTagID int64
}
func (q *Queries) SaveWorkerTag(ctx context.Context, arg SaveWorkerTagParams) error {
_, err := q.db.ExecContext(ctx, saveWorkerTag,
arg.UpdatedAt,
arg.UUID,
arg.Name,
arg.Description,
arg.WorkerTagID,
)
return err
}
const softDeleteWorker = `-- name: SoftDeleteWorker :execrows
UPDATE workers SET deleted_at=?1
WHERE uuid=?2
@ -433,6 +603,32 @@ func (q *Queries) SummarizeWorkerStatuses(ctx context.Context) ([]SummarizeWorke
return items, nil
}
const workerAddTagMembership = `-- name: WorkerAddTagMembership :exec
INSERT INTO worker_tag_membership (worker_tag_id, worker_id)
VALUES (?1, ?2)
`
type WorkerAddTagMembershipParams struct {
WorkerTagID int64
WorkerID int64
}
func (q *Queries) WorkerAddTagMembership(ctx context.Context, arg WorkerAddTagMembershipParams) error {
_, err := q.db.ExecContext(ctx, workerAddTagMembership, arg.WorkerTagID, arg.WorkerID)
return err
}
const workerRemoveTagMemberships = `-- name: WorkerRemoveTagMemberships :exec
DELETE
FROM worker_tag_membership
WHERE worker_id=?1
`
func (q *Queries) WorkerRemoveTagMemberships(ctx context.Context, workerID int64) error {
_, err := q.db.ExecContext(ctx, workerRemoveTagMemberships, workerID)
return err
}
const workerSeen = `-- name: WorkerSeen :exec
UPDATE workers SET
updated_at=?1,

View File

@ -6,7 +6,7 @@ import (
"context"
"fmt"
"gorm.io/gorm"
"projects.blender.org/studio/flamenco/internal/manager/persistence/sqlc"
)
type WorkerTag struct {
@ -20,51 +20,69 @@ type WorkerTag struct {
}
func (db *DB) CreateWorkerTag(ctx context.Context, wc *WorkerTag) error {
if err := db.gormDB.WithContext(ctx).Create(wc).Error; err != nil {
queries := db.queries()
now := db.gormDB.NowFunc()
dbID, err := queries.CreateWorkerTag(ctx, sqlc.CreateWorkerTagParams{
CreatedAt: now,
UUID: wc.UUID,
Name: wc.Name,
Description: wc.Description,
})
if err != nil {
return fmt.Errorf("creating new worker tag: %w", err)
}
wc.ID = uint(dbID)
wc.CreatedAt = now
return nil
}
// HasWorkerTags returns whether there are any tags defined at all.
func (db *DB) HasWorkerTags(ctx context.Context) (bool, error) {
var count int64
tx := db.gormDB.WithContext(ctx).
Model(&WorkerTag{}).
Count(&count)
if err := tx.Error; err != nil {
queries := db.queries()
count, err := queries.CountWorkerTags(ctx)
if err != nil {
return false, workerTagError(err, "counting worker tags")
}
return count > 0, nil
}
func (db *DB) FetchWorkerTag(ctx context.Context, uuid string) (*WorkerTag, error) {
tx := db.gormDB.WithContext(ctx)
return fetchWorkerTag(tx, uuid)
}
queries := db.queries()
// fetchWorkerTag fetches the worker tag using the given database instance.
func fetchWorkerTag(gormDB *gorm.DB, uuid string) (*WorkerTag, error) {
w := WorkerTag{}
tx := gormDB.First(&w, "uuid = ?", uuid)
if tx.Error != nil {
return nil, workerTagError(tx.Error, "fetching worker tag")
workerTag, err := queries.FetchWorkerTagByUUID(ctx, uuid)
if err != nil {
return nil, workerTagError(err, "fetching worker tag")
}
return &w, nil
return convertSqlcWorkerTag(workerTag), nil
}
// fetchWorkerTagByID fetches the worker tag using the given database instance.
func fetchWorkerTagByID(gormDB *gorm.DB, id uint) (*WorkerTag, error) {
w := WorkerTag{}
tx := gormDB.First(&w, "id = ?", id)
if tx.Error != nil {
return nil, workerTagError(tx.Error, "fetching worker tag")
func fetchWorkerTagByID(ctx context.Context, queries *sqlc.Queries, id int64) (*WorkerTag, error) {
workerTag, err := queries.FetchWorkerTagByID(ctx, id)
if err != nil {
return nil, workerTagError(err, "fetching worker tag")
}
return &w, nil
return convertSqlcWorkerTag(workerTag), nil
}
func (db *DB) SaveWorkerTag(ctx context.Context, tag *WorkerTag) error {
if err := db.gormDB.WithContext(ctx).Save(tag).Error; err != nil {
queries := db.queries()
err := queries.SaveWorkerTag(ctx, sqlc.SaveWorkerTagParams{
UpdatedAt: db.now(),
UUID: tag.UUID,
Name: tag.Name,
Description: tag.Description,
WorkerTagID: int64(tag.ID),
})
if err != nil {
return workerTagError(err, "saving worker tag")
}
return nil
@ -81,51 +99,77 @@ func (db *DB) DeleteWorkerTag(ctx context.Context, uuid string) error {
return ErrDeletingWithoutFK
}
tx := db.gormDB.WithContext(ctx).
Where("uuid = ?", uuid).
Delete(&WorkerTag{})
if tx.Error != nil {
return workerTagError(tx.Error, "deleting worker tag")
}
if tx.RowsAffected == 0 {
queries := db.queries()
rowsUpdated, err := queries.DeleteWorkerTag(ctx, uuid)
switch {
case err != nil:
return workerTagError(err, "deleting worker tag")
case rowsUpdated == 0:
return ErrWorkerTagNotFound
}
return nil
}
func (db *DB) FetchWorkerTags(ctx context.Context) ([]*WorkerTag, error) {
tags := make([]*WorkerTag, 0)
tx := db.gormDB.WithContext(ctx).Model(&WorkerTag{}).Scan(&tags)
if tx.Error != nil {
return nil, workerTagError(tx.Error, "fetching all worker tags")
queries := db.queries()
tags, err := queries.FetchWorkerTags(ctx)
if err != nil {
return nil, workerTagError(err, "fetching all worker tags")
}
return tags, nil
gormTags := make([]*WorkerTag, len(tags))
for index, tag := range tags {
gormTags[index] = convertSqlcWorkerTag(tag)
}
return gormTags, nil
}
func (db *DB) fetchWorkerTagsWithUUID(ctx context.Context, tagUUIDs []string) ([]*WorkerTag, error) {
tags := make([]*WorkerTag, 0)
tx := db.gormDB.WithContext(ctx).
Model(&WorkerTag{}).
Where("uuid in ?", tagUUIDs).
Scan(&tags)
if tx.Error != nil {
return nil, workerTagError(tx.Error, "fetching all worker tags")
func (db *DB) fetchWorkerTagsWithUUID(
ctx context.Context,
queries *sqlc.Queries,
tagUUIDs []string,
) ([]*WorkerTag, error) {
tags, err := queries.FetchWorkerTagsByUUIDs(ctx, tagUUIDs)
if err != nil {
return nil, workerTagError(err, "fetching all worker tags")
}
return tags, nil
gormTags := make([]*WorkerTag, len(tags))
for index, tag := range tags {
gormTags[index] = convertSqlcWorkerTag(tag)
}
return gormTags, nil
}
func (db *DB) WorkerSetTags(ctx context.Context, worker *Worker, tagUUIDs []string) error {
tags, err := db.fetchWorkerTagsWithUUID(ctx, tagUUIDs)
qtx, err := db.queriesWithTX()
if err != nil {
return err
}
defer qtx.rollback()
tags, err := db.fetchWorkerTagsWithUUID(ctx, qtx.queries, tagUUIDs)
if err != nil {
return workerTagError(err, "fetching worker tags")
}
err = db.gormDB.WithContext(ctx).
Model(worker).
Association("Tags").
Replace(tags)
err = qtx.queries.WorkerRemoveTagMemberships(ctx, int64(worker.ID))
if err != nil {
return workerTagError(err, "updating worker tags")
return workerTagError(err, "un-assigning existing worker tags")
}
return nil
for _, tag := range tags {
err = qtx.queries.WorkerAddTagMembership(ctx, sqlc.WorkerAddTagMembershipParams{
WorkerID: int64(worker.ID),
WorkerTagID: int64(tag.ID),
})
if err != nil {
return workerTagError(err, "assigning worker tags")
}
}
return qtx.commit()
}

View File

@ -101,7 +101,7 @@ func (db *DB) CreateWorker(ctx context.Context, w *Worker) error {
// TODO: remove the create-with-tags functionality to a higher-level function.
// This code is just here to make this function work like the GORM code did.
for _, tag := range w.Tags {
err := queries.AddWorkerTagMembership(ctx, sqlc.AddWorkerTagMembershipParams{
err := queries.WorkerAddTagMembership(ctx, sqlc.WorkerAddTagMembershipParams{
WorkerTagID: int64(tag.ID),
WorkerID: workerID,
})
@ -122,7 +122,7 @@ func (db *DB) FetchWorker(ctx context.Context, uuid string) (*Worker, error) {
}
// TODO: remove this code, and let the caller fetch the tags when interested in them.
workerTags, err := queries.FetchWorkerTags(ctx, uuid)
workerTags, err := queries.FetchTagsOfWorker(ctx, uuid)
if err != nil {
return nil, workerTagError(err, "fetching tags of worker %s", uuid)
}
@ -130,8 +130,7 @@ func (db *DB) FetchWorker(ctx context.Context, uuid string) (*Worker, error) {
convertedWorker := convertSqlcWorker(worker)
convertedWorker.Tags = make([]*WorkerTag, len(workerTags))
for index := range workerTags {
convertedTag := convertSqlcWorkerTag(workerTags[index])
convertedWorker.Tags[index] = &convertedTag
convertedWorker.Tags[index] = convertSqlcWorkerTag(workerTags[index])
}
return &convertedWorker, nil
@ -338,8 +337,8 @@ func convertSqlcWorker(worker sqlc.Worker) Worker {
// the model expected by the rest of the code. This is mostly in place to aid in
// the GORM to SQLC migration. It is intended that eventually the rest of the
// code will use the same SQLC-generated model.
func convertSqlcWorkerTag(tag sqlc.WorkerTag) WorkerTag {
return WorkerTag{
func convertSqlcWorkerTag(tag sqlc.WorkerTag) *WorkerTag {
return &WorkerTag{
Model: Model{
ID: uint(tag.ID),
CreatedAt: tag.CreatedAt,