Manager: protect task log writing with mutex

A per-task mutex is used to protect the writing of task logs, so that
mutliple goroutines can safely write to the same task log.
This commit is contained in:
Sybren A. Stüvel 2022-06-09 14:44:54 +02:00
parent 92d6693871
commit 04dd479248
3 changed files with 95 additions and 2 deletions

View File

@ -29,6 +29,10 @@ type Worker struct {
SupportedTaskTypes string `gorm:"type:varchar(255);default:''"` // comma-separated list of task types.
}
func (w *Worker) Identifier() string {
return fmt.Sprintf("%s (%s)", w.Name, w.UUID)
}
// TaskTypes returns the worker's supported task types as list of strings.
func (w *Worker) TaskTypes() []string {
return strings.Split(w.SupportedTaskTypes, ",")

View File

@ -9,6 +9,7 @@ import (
"os"
"path"
"path/filepath"
"sync"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
@ -22,6 +23,10 @@ const (
// Storage can write data to task logs, rotate logs, etc.
type Storage struct {
BasePath string // Directory where task logs are stored.
// Locks to only allow one goroutine at a time to handle the logs of a certain task.
mutex *sync.Mutex
taskLocks map[string]*sync.Mutex
}
// NewStorage creates a new log storage rooted at `basePath`.
@ -39,7 +44,9 @@ func NewStorage(basePath string) *Storage {
Msg("task logs")
return &Storage{
BasePath: basePath,
BasePath: basePath,
mutex: new(sync.Mutex),
taskLocks: make(map[string]*sync.Mutex),
}
}
@ -50,6 +57,9 @@ func (s *Storage) Write(logger zerolog.Logger, jobID, taskID string, logText str
return nil
}
s.taskLock(taskID)
defer s.taskUnlock(taskID)
filepath := s.filepath(jobID, taskID)
logger = logger.With().Str("filepath", filepath).Logger()
@ -94,6 +104,9 @@ func (s *Storage) RotateFile(logger zerolog.Logger, jobID, taskID string) {
logpath := s.filepath(jobID, taskID)
logger = logger.With().Str("logpath", logpath).Logger()
s.taskLock(taskID)
defer s.taskUnlock(taskID)
err := rotateLogFile(logger, logpath)
if err != nil {
// rotateLogFile() has already logged something, so we can ignore `err`.
@ -117,6 +130,9 @@ func (s *Storage) filepath(jobID, taskID string) string {
func (s *Storage) Tail(jobID, taskID string) (string, error) {
filepath := s.filepath(jobID, taskID)
s.taskLock(taskID)
defer s.taskUnlock(taskID)
file, err := os.Open(filepath)
if err != nil {
return "", fmt.Errorf("unable to open log file for reading: %w", err)
@ -163,3 +179,25 @@ func (s *Storage) Tail(jobID, taskID string) (string, error) {
return string(buffer), nil
}
func (s *Storage) taskLock(taskID string) {
s.mutex.Lock()
defer s.mutex.Unlock()
mutex := s.taskLocks[taskID]
if mutex == nil {
mutex = new(sync.Mutex)
s.taskLocks[taskID] = mutex
}
mutex.Lock()
}
func (s *Storage) taskUnlock(taskID string) {
// This code doesn't modify s.taskLocks, and the task should have been locked
// already by now.
mutex := s.taskLocks[taskID]
if mutex == nil {
panic("trying to unlock task that is not yet locked")
}
mutex.Unlock()
}

View File

@ -7,9 +7,12 @@ import (
"io/ioutil"
"os"
"path/filepath"
"strings"
"sync"
"testing"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/stretchr/testify/assert"
)
@ -18,7 +21,7 @@ func tempStorage() *Storage {
if err != nil {
panic(err)
}
return &Storage{temppath}
return NewStorage(temppath)
}
func TestLogWriting(t *testing.T) {
@ -137,3 +140,51 @@ func TestLogTail(t *testing.T) {
string(contents),
)
}
func TestLogWritingParallel(t *testing.T) {
s := tempStorage()
defer os.RemoveAll(s.BasePath)
// defer t.Errorf("not removing %s", s.BasePath)
numGoroutines := 1000 // How many goroutines run in parallel.
runLength := 100 // How many characters are logged, per goroutine.
wg := sync.WaitGroup{}
wg.Add(numGoroutines)
jobID := "6d9a05a1-261e-4f6f-93b0-8c4f6b6d500d"
taskID := "d19888cc-c389-4a24-aebf-8458ababdb02"
for i := 0; i < numGoroutines; i++ {
// Write lines of 100 characters to the task log. Each goroutine writes a
// different character, starting at 'A'.
go func(i int32) {
defer wg.Done()
logger := log.With().Int32("goroutine", i).Logger()
letter := rune(int32('A') + (i % 26))
if len(string(letter)) > 1 {
panic("this test assumes only single-byte runes are used")
}
logText := strings.Repeat(string(letter), runLength)
assert.NoError(t, s.Write(logger, jobID, taskID, logText))
}(int32(i))
}
wg.Wait()
// Test that the final log contains 1000 lines of of 100 characters, without
// any run getting interrupted by another one.
contents, err := os.ReadFile(s.filepath(jobID, taskID))
assert.NoError(t, err)
lines := strings.Split(string(contents), "\n")
assert.Equal(t, numGoroutines+1, len(lines),
"each goroutine should have written a single line, and the file should have a newline at the end")
for lineIndex, line := range lines {
if lineIndex == numGoroutines {
assert.Empty(t, line, "the last line should be empty")
} else {
assert.Lenf(t, line, runLength, "each line should be %d runes long; line #%d is not", line, lineIndex)
}
}
}