| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181 |
- package service
- import (
- "encoding/json"
- "fmt"
- "time"
- "github.com/hibiken/asynq"
- "github.com/redis/go-redis/v9"
- "golang.org/x/net/context"
- "gorm.io/gorm"
- "spider/internal/model"
- "spider/internal/worker"
- )
- // StartTaskRequest is the payload for starting a new task.
- type StartTaskRequest struct {
- TaskType string `json:"task_type" binding:"required"`
- Target string `json:"target"`
- TestRun *worker.TestRun `json:"test_run"`
- SkipPhases []string `json:"skip_phases"`
- }
- // TaskService manages task lifecycle.
- type TaskService struct {
- db *gorm.DB
- redis *redis.Client
- client *asynq.Client
- }
- // NewTaskService creates a TaskService. The asynq.Client is constructed from the
- // same Redis client options used by the rest of the application.
- func NewTaskService(db *gorm.DB, rdb *redis.Client) *TaskService {
- opts := rdb.Options()
- client := asynq.NewClient(asynq.RedisClientOpt{
- Addr: opts.Addr,
- Password: opts.Password,
- DB: opts.DB,
- })
- return &TaskService{
- db: db,
- redis: rdb,
- client: client,
- }
- }
- // asynqTypeForTaskType maps model task_type to asynq task type constant.
- func asynqTypeForTaskType(taskType string) (string, error) {
- m := map[string]string{
- "full": worker.TypeFullPipeline,
- "discover": worker.TypeDiscover,
- "search": worker.TypeSearch,
- "github": worker.TypeGithub,
- "scrape": worker.TypeScrape,
- "crawl": worker.TypeCrawl,
- "clean": worker.TypeClean,
- "score": worker.TypeScore,
- }
- at, ok := m[taskType]
- if !ok {
- return "", fmt.Errorf("unknown task type: %s", taskType)
- }
- return at, nil
- }
- // StartTask validates, creates a Task record, and enqueues it via asynq.
- func (s *TaskService) StartTask(req StartTaskRequest) (*model.Task, error) {
- // Check if a task of the same type is already running.
- var count int64
- if err := s.db.Model(&model.Task{}).
- Where("task_type = ? AND status = ?", req.TaskType, "running").
- Count(&count).Error; err != nil {
- return nil, fmt.Errorf("check running tasks: %w", err)
- }
- if count > 0 {
- return nil, fmt.Errorf("a %s task is already running", req.TaskType)
- }
- // Validate and get asynq type.
- asynqType, err := asynqTypeForTaskType(req.TaskType)
- if err != nil {
- return nil, err
- }
- // Encode params.
- paramsJSON, err := json.Marshal(req)
- if err != nil {
- return nil, fmt.Errorf("marshal params: %w", err)
- }
- // Create Task record in DB.
- task := &model.Task{
- TaskType: req.TaskType,
- Status: "pending",
- Params: paramsJSON,
- }
- if err := s.db.Create(task).Error; err != nil {
- return nil, fmt.Errorf("create task record: %w", err)
- }
- // Build asynq payload.
- payload := worker.TaskPayload{
- TaskID: task.ID,
- Target: req.Target,
- TestRun: req.TestRun,
- SkipPhases: req.SkipPhases,
- }
- payloadBytes, err := json.Marshal(payload)
- if err != nil {
- return nil, fmt.Errorf("marshal payload: %w", err)
- }
- // Enqueue.
- asynqTask := asynq.NewTask(asynqType, payloadBytes, asynq.Queue(worker.QueueDefault))
- if _, err := s.client.Enqueue(asynqTask); err != nil {
- // Roll back the DB record to failed.
- s.db.Model(task).Updates(map[string]interface{}{"status": "failed", "error_msg": err.Error()})
- return nil, fmt.Errorf("enqueue task: %w", err)
- }
- return task, nil
- }
- // StopTask marks the task as stopped in the DB and sets a Redis stop signal.
- func (s *TaskService) StopTask(taskID uint, force bool) error {
- var task model.Task
- if err := s.db.First(&task, taskID).Error; err != nil {
- return fmt.Errorf("task not found: %w", err)
- }
- // Set the Redis stop signal so the worker can detect it.
- stopKey := fmt.Sprintf("spider:task:stop:%d", taskID)
- if err := s.redis.Set(context.Background(), stopKey, "1", time.Hour).Err(); err != nil {
- return fmt.Errorf("set stop signal: %w", err)
- }
- // If force, immediately update the DB status.
- if force {
- finishedAt := time.Now()
- if err := s.db.Model(&model.Task{}).Where("id = ?", taskID).Updates(map[string]interface{}{
- "status": "stopped",
- "finished_at": &finishedAt,
- }).Error; err != nil {
- return fmt.Errorf("update task stopped: %w", err)
- }
- }
- return nil
- }
- // GetProgress reads progress from Redis and returns a merged map.
- func (s *TaskService) GetProgress(task *model.Task) map[string]interface{} {
- result := make(map[string]interface{})
- // Start with DB-stored progress if any.
- if len(task.Progress) > 0 {
- _ = json.Unmarshal(task.Progress, &result)
- }
- // Overlay with live Redis progress.
- progressKey := fmt.Sprintf("spider:task:progress:%d", task.ID)
- vals, err := s.redis.HGetAll(context.Background(), progressKey).Result()
- if err == nil && len(vals) > 0 {
- for k, v := range vals {
- result[k] = v
- }
- }
- return result
- }
- // IsStopRequested checks whether a stop signal has been set for the given task.
- func (s *TaskService) IsStopRequested(taskID uint) bool {
- stopKey := fmt.Sprintf("spider:task:stop:%d", taskID)
- val, err := s.redis.Get(context.Background(), stopKey).Result()
- if err != nil {
- return false
- }
- return val == "1"
- }
|