package service import ( "encoding/json" "fmt" "time" "github.com/hibiken/asynq" "github.com/redis/go-redis/v9" "golang.org/x/net/context" "gorm.io/gorm" "spider/internal/model" "spider/internal/worker" ) // StartTaskRequest is the payload for starting a new task. type StartTaskRequest struct { TaskType string `json:"task_type" binding:"required"` Target string `json:"target"` TestRun *worker.TestRun `json:"test_run"` SkipPhases []string `json:"skip_phases"` } // TaskService manages task lifecycle. type TaskService struct { db *gorm.DB redis *redis.Client client *asynq.Client } // NewTaskService creates a TaskService. The asynq.Client is constructed from the // same Redis client options used by the rest of the application. func NewTaskService(db *gorm.DB, rdb *redis.Client) *TaskService { opts := rdb.Options() client := asynq.NewClient(asynq.RedisClientOpt{ Addr: opts.Addr, Password: opts.Password, DB: opts.DB, }) return &TaskService{ db: db, redis: rdb, client: client, } } // asynqTypeForTaskType maps model task_type to asynq task type constant. func asynqTypeForTaskType(taskType string) (string, error) { m := map[string]string{ "full": worker.TypeFullPipeline, "discover": worker.TypeDiscover, "search": worker.TypeSearch, "github": worker.TypeGithub, "scrape": worker.TypeScrape, "crawl": worker.TypeCrawl, "clean": worker.TypeClean, "score": worker.TypeScore, } at, ok := m[taskType] if !ok { return "", fmt.Errorf("unknown task type: %s", taskType) } return at, nil } // StartTask validates, creates a Task record, and enqueues it via asynq. func (s *TaskService) StartTask(req StartTaskRequest) (*model.Task, error) { // Check if a task of the same type is already running. var count int64 if err := s.db.Model(&model.Task{}). Where("task_type = ? AND status = ?", req.TaskType, "running"). Count(&count).Error; err != nil { return nil, fmt.Errorf("check running tasks: %w", err) } if count > 0 { return nil, fmt.Errorf("a %s task is already running", req.TaskType) } // Validate and get asynq type. asynqType, err := asynqTypeForTaskType(req.TaskType) if err != nil { return nil, err } // Encode params. paramsJSON, err := json.Marshal(req) if err != nil { return nil, fmt.Errorf("marshal params: %w", err) } // Create Task record in DB. task := &model.Task{ TaskType: req.TaskType, Status: "pending", Params: paramsJSON, } if err := s.db.Create(task).Error; err != nil { return nil, fmt.Errorf("create task record: %w", err) } // Build asynq payload. payload := worker.TaskPayload{ TaskID: task.ID, Target: req.Target, TestRun: req.TestRun, SkipPhases: req.SkipPhases, } payloadBytes, err := json.Marshal(payload) if err != nil { return nil, fmt.Errorf("marshal payload: %w", err) } // Enqueue. asynqTask := asynq.NewTask(asynqType, payloadBytes, asynq.Queue(worker.QueueDefault)) if _, err := s.client.Enqueue(asynqTask); err != nil { // Roll back the DB record to failed. s.db.Model(task).Updates(map[string]interface{}{"status": "failed", "error_msg": err.Error()}) return nil, fmt.Errorf("enqueue task: %w", err) } return task, nil } // StopTask marks the task as stopped in the DB and sets a Redis stop signal. func (s *TaskService) StopTask(taskID uint, force bool) error { var task model.Task if err := s.db.First(&task, taskID).Error; err != nil { return fmt.Errorf("task not found: %w", err) } // Set the Redis stop signal so the worker can detect it. stopKey := fmt.Sprintf("spider:task:stop:%d", taskID) if err := s.redis.Set(context.Background(), stopKey, "1", time.Hour).Err(); err != nil { return fmt.Errorf("set stop signal: %w", err) } // If force, immediately update the DB status. if force { finishedAt := time.Now() if err := s.db.Model(&model.Task{}).Where("id = ?", taskID).Updates(map[string]interface{}{ "status": "stopped", "finished_at": &finishedAt, }).Error; err != nil { return fmt.Errorf("update task stopped: %w", err) } } return nil } // GetProgress reads progress from Redis and returns a merged map. func (s *TaskService) GetProgress(task *model.Task) map[string]interface{} { result := make(map[string]interface{}) // Start with DB-stored progress if any. if len(task.Progress) > 0 { _ = json.Unmarshal(task.Progress, &result) } // Overlay with live Redis progress. progressKey := fmt.Sprintf("spider:task:progress:%d", task.ID) vals, err := s.redis.HGetAll(context.Background(), progressKey).Result() if err == nil && len(vals) > 0 { for k, v := range vals { result[k] = v } } return result } // IsStopRequested checks whether a stop signal has been set for the given task. func (s *TaskService) IsStopRequested(taskID uint) bool { stopKey := fmt.Sprintf("spider:task:stop:%d", taskID) val, err := s.redis.Get(context.Background(), stopKey).Result() if err != nil { return false } return val == "1" }