orbis/internal/workscheduler/workscheduler.go
kjuulh 1cf9d23491
All checks were successful
continuous-integration/drone/push Build is passing
feat: move schedules to registered workers
2025-01-18 13:18:13 +01:00

224 lines
4.9 KiB
Go

package workscheduler
import (
"context"
"errors"
"fmt"
"log/slog"
"git.front.kjuulh.io/kjuulh/orbis/internal/modelschedule"
"git.front.kjuulh.io/kjuulh/orbis/internal/worker"
"git.front.kjuulh.io/kjuulh/orbis/internal/workscheduler/repositories"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxpool"
)
//go:generate sqlc generate
type WorkScheduler struct {
db *pgxpool.Pool
logger *slog.Logger
}
func NewWorkScheduler(
db *pgxpool.Pool,
logger *slog.Logger,
) *WorkScheduler {
return &WorkScheduler{
db: db,
logger: logger,
}
}
type Worker struct {
Instance worker.WorkerInstance
RemainingCapacity uint
}
type Workers struct {
Workers []*Worker
}
func (w *Workers) IterateSlice(size uint) func(yield func([]Worker) bool) {
return func(yield func([]Worker) bool) {
if len(w.Workers) == 0 {
return
}
workers := make([]Worker, 0)
acc := uint(0)
for {
exit := true
for _, worker := range w.Workers {
if acc == size {
if !yield(workers) {
return
}
workers = make([]Worker, 0)
acc = uint(0)
}
if worker.RemainingCapacity <= 0 {
continue
}
worker.RemainingCapacity--
workers = append(workers, *worker)
acc++
exit = false
}
if exit {
if len(workers) > 0 {
if !yield(workers) {
return
}
}
return
}
}
}
}
func (w *WorkScheduler) GetWorkers(ctx context.Context, registeredWorkers *worker.Workers) (*Workers, error) {
w.logger.DebugContext(ctx, "found workers", "workers", len(registeredWorkers.Instances))
workers := make([]*Worker, 0, len(registeredWorkers.Instances))
for _, registeredWorker := range registeredWorkers.Instances {
remainingCapacity, err := w.GetWorker(ctx, &registeredWorker)
if err != nil {
return nil, fmt.Errorf("failed to find capacity for worker: %w", err)
}
if remainingCapacity == 0 {
w.logger.DebugContext(ctx, "skipping worker as no remaining capacity")
continue
}
workers = append(workers, &Worker{
Instance: registeredWorker,
RemainingCapacity: remainingCapacity,
})
}
return &Workers{Workers: workers}, nil
}
func (w *WorkScheduler) GetWorker(
ctx context.Context,
worker *worker.WorkerInstance,
) (uint, error) {
repo := repositories.New(w.db)
current_size, err := repo.GetCurrentQueueSize(ctx, worker.WorkerID)
if err != nil {
return 0, fmt.Errorf("failed to get current queue size: %s: %w", worker.WorkerID, err)
}
if int64(worker.Capacity)-current_size <= 0 {
return 0, nil
}
return worker.Capacity - uint(current_size), nil
}
func (w *WorkScheduler) InsertModelRun(
ctx context.Context,
worker Worker,
modelRun *modelschedule.ModelRunSchedule,
) error {
repo := repositories.New(w.db)
return repo.InsertQueueItem(ctx, &repositories.InsertQueueItemParams{
ScheduleID: uuid.New(),
WorkerID: worker.Instance.WorkerID,
StartRun: pgtype.Timestamptz{
Time: modelRun.Start,
Valid: true,
},
EndRun: pgtype.Timestamptz{
Time: modelRun.End,
Valid: true,
},
})
}
func (w *WorkScheduler) GetNext(ctx context.Context, workerID uuid.UUID) (*uuid.UUID, error) {
repo := repositories.New(w.db)
schedule, err := repo.GetNext(ctx, workerID)
if err != nil {
if !errors.Is(err, pgx.ErrNoRows) {
return nil, fmt.Errorf("failed to get next worker item: %w", err)
}
return nil, nil
}
return &schedule.ScheduleID, nil
}
func (w *WorkScheduler) StartProcessing(ctx context.Context, scheduleID uuid.UUID) error {
repo := repositories.New(w.db)
return repo.StartProcessing(ctx, scheduleID)
}
func (w *WorkScheduler) Archive(ctx context.Context, scheduleID uuid.UUID) error {
repo := repositories.New(w.db)
return repo.Archive(ctx, scheduleID)
}
func (w *WorkScheduler) GetUnattended(ctx context.Context, registeredWorkers *worker.Workers) error {
if len(registeredWorkers.Instances) == 0 {
return nil
}
repo := repositories.New(w.db)
schedules, err := repo.GetUnattended(ctx, &repositories.GetUnattendedParams{
WorkerIds: workerIDs(registeredWorkers),
Amount: 100,
})
if err != nil {
return fmt.Errorf("failed to get unattended workers: %w", err)
}
for i, schedule := range schedules {
worker := registeredWorkers.Instances[i%len(registeredWorkers.Instances)].WorkerID
w.logger.InfoContext(ctx, "dispatching schedule for worker", "worker", worker, "schedule", schedule.ScheduleID)
if err := repo.UpdateSchdule(
ctx,
&repositories.UpdateSchduleParams{
WorkerID: worker,
ScheduleID: schedule.ScheduleID,
},
); err != nil {
return fmt.Errorf("failed to update schedule: %w", err)
}
}
return nil
}
func workerIDs(registeredWorkers *worker.Workers) []uuid.UUID {
uuids := make([]uuid.UUID, 0, len(registeredWorkers.Instances))
for _, registeredWorker := range registeredWorkers.Instances {
uuids = append(uuids, registeredWorker.WorkerID)
}
return uuids
}