task_service.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. package service
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "time"
  6. "github.com/hibiken/asynq"
  7. "github.com/redis/go-redis/v9"
  8. "golang.org/x/net/context"
  9. "gorm.io/gorm"
  10. "spider/internal/model"
  11. "spider/internal/worker"
  12. )
  13. // StartTaskRequest is the payload for starting a new task.
  14. type StartTaskRequest struct {
  15. TaskType string `json:"task_type" binding:"required"`
  16. Target string `json:"target"`
  17. TestRun *worker.TestRun `json:"test_run"`
  18. SkipPhases []string `json:"skip_phases"`
  19. }
  20. // TaskService manages task lifecycle.
  21. type TaskService struct {
  22. db *gorm.DB
  23. redis *redis.Client
  24. client *asynq.Client
  25. }
  26. // NewTaskService creates a TaskService. The asynq.Client is constructed from the
  27. // same Redis client options used by the rest of the application.
  28. func NewTaskService(db *gorm.DB, rdb *redis.Client) *TaskService {
  29. opts := rdb.Options()
  30. client := asynq.NewClient(asynq.RedisClientOpt{
  31. Addr: opts.Addr,
  32. Password: opts.Password,
  33. DB: opts.DB,
  34. })
  35. return &TaskService{
  36. db: db,
  37. redis: rdb,
  38. client: client,
  39. }
  40. }
  41. // asynqTypeForTaskType maps model task_type to asynq task type constant.
  42. func asynqTypeForTaskType(taskType string) (string, error) {
  43. m := map[string]string{
  44. "full": worker.TypeFullPipeline,
  45. "discover": worker.TypeDiscover,
  46. "search": worker.TypeSearch,
  47. "github": worker.TypeGithub,
  48. "scrape": worker.TypeScrape,
  49. "crawl": worker.TypeCrawl,
  50. "clean": worker.TypeClean,
  51. "score": worker.TypeScore,
  52. }
  53. at, ok := m[taskType]
  54. if !ok {
  55. return "", fmt.Errorf("unknown task type: %s", taskType)
  56. }
  57. return at, nil
  58. }
  59. // StartTask validates, creates a Task record, and enqueues it via asynq.
  60. func (s *TaskService) StartTask(req StartTaskRequest) (*model.Task, error) {
  61. // Check if a task of the same type is already running.
  62. var count int64
  63. if err := s.db.Model(&model.Task{}).
  64. Where("task_type = ? AND status = ?", req.TaskType, "running").
  65. Count(&count).Error; err != nil {
  66. return nil, fmt.Errorf("check running tasks: %w", err)
  67. }
  68. if count > 0 {
  69. return nil, fmt.Errorf("a %s task is already running", req.TaskType)
  70. }
  71. // Validate and get asynq type.
  72. asynqType, err := asynqTypeForTaskType(req.TaskType)
  73. if err != nil {
  74. return nil, err
  75. }
  76. // Encode params.
  77. paramsJSON, err := json.Marshal(req)
  78. if err != nil {
  79. return nil, fmt.Errorf("marshal params: %w", err)
  80. }
  81. // Create Task record in DB.
  82. task := &model.Task{
  83. TaskType: req.TaskType,
  84. Status: "pending",
  85. Params: paramsJSON,
  86. }
  87. if err := s.db.Create(task).Error; err != nil {
  88. return nil, fmt.Errorf("create task record: %w", err)
  89. }
  90. // Build asynq payload.
  91. payload := worker.TaskPayload{
  92. TaskID: task.ID,
  93. Target: req.Target,
  94. TestRun: req.TestRun,
  95. SkipPhases: req.SkipPhases,
  96. }
  97. payloadBytes, err := json.Marshal(payload)
  98. if err != nil {
  99. return nil, fmt.Errorf("marshal payload: %w", err)
  100. }
  101. // Enqueue.
  102. asynqTask := asynq.NewTask(asynqType, payloadBytes, asynq.Queue(worker.QueueDefault))
  103. if _, err := s.client.Enqueue(asynqTask); err != nil {
  104. // Roll back the DB record to failed.
  105. s.db.Model(task).Updates(map[string]interface{}{"status": "failed", "error_msg": err.Error()})
  106. return nil, fmt.Errorf("enqueue task: %w", err)
  107. }
  108. return task, nil
  109. }
  110. // StopTask marks the task as stopped in the DB and sets a Redis stop signal.
  111. func (s *TaskService) StopTask(taskID uint, force bool) error {
  112. var task model.Task
  113. if err := s.db.First(&task, taskID).Error; err != nil {
  114. return fmt.Errorf("task not found: %w", err)
  115. }
  116. // Set the Redis stop signal so the worker can detect it.
  117. stopKey := fmt.Sprintf("spider:task:stop:%d", taskID)
  118. if err := s.redis.Set(context.Background(), stopKey, "1", time.Hour).Err(); err != nil {
  119. return fmt.Errorf("set stop signal: %w", err)
  120. }
  121. // If force, immediately update the DB status.
  122. if force {
  123. finishedAt := time.Now()
  124. if err := s.db.Model(&model.Task{}).Where("id = ?", taskID).Updates(map[string]interface{}{
  125. "status": "stopped",
  126. "finished_at": &finishedAt,
  127. }).Error; err != nil {
  128. return fmt.Errorf("update task stopped: %w", err)
  129. }
  130. }
  131. return nil
  132. }
  133. // GetProgress reads progress from Redis and returns a merged map.
  134. func (s *TaskService) GetProgress(task *model.Task) map[string]interface{} {
  135. result := make(map[string]interface{})
  136. // Start with DB-stored progress if any.
  137. if len(task.Progress) > 0 {
  138. _ = json.Unmarshal(task.Progress, &result)
  139. }
  140. // Overlay with live Redis progress.
  141. progressKey := fmt.Sprintf("spider:task:progress:%d", task.ID)
  142. vals, err := s.redis.HGetAll(context.Background(), progressKey).Result()
  143. if err == nil && len(vals) > 0 {
  144. for k, v := range vals {
  145. result[k] = v
  146. }
  147. }
  148. return result
  149. }
  150. // IsStopRequested checks whether a stop signal has been set for the given task.
  151. func (s *TaskService) IsStopRequested(taskID uint) bool {
  152. stopKey := fmt.Sprintf("spider:task:stop:%d", taskID)
  153. val, err := s.redis.Get(context.Background(), stopKey).Result()
  154. if err != nil {
  155. return false
  156. }
  157. return val == "1"
  158. }