worker.go 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. package worker
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "log"
  7. "time"
  8. "github.com/hibiken/asynq"
  9. "github.com/redis/go-redis/v9"
  10. "gorm.io/gorm"
  11. "spider/internal/llm"
  12. "spider/internal/model"
  13. "spider/internal/pipeline"
  14. "spider/internal/search"
  15. "spider/internal/telegram"
  16. )
  17. const (
  18. QueueDefault = "default"
  19. TypeFullPipeline = "task:full"
  20. TypeDiscover = "task:discover"
  21. TypeSearch = "task:search"
  22. TypeGithub = "task:github"
  23. TypeScrape = "task:scrape"
  24. TypeCrawl = "task:crawl"
  25. TypeClean = "task:clean"
  26. TypeScore = "task:score"
  27. )
  28. // lockKeyForType returns the Redis lock key for a given task type.
  29. func lockKeyForType(taskType string) string {
  30. if taskType == "full" {
  31. return "spider:task:lock:global"
  32. }
  33. return fmt.Sprintf("spider:task:lock:%s", taskType)
  34. }
  35. // progressKey returns the Redis hash key for task progress.
  36. func progressKey(taskID uint) string {
  37. return fmt.Sprintf("spider:task:progress:%d", taskID)
  38. }
  39. // stopKey returns the Redis key used to signal stop for a task.
  40. func stopKey(taskID uint) string {
  41. return fmt.Sprintf("spider:task:stop:%d", taskID)
  42. }
  43. // TaskPayload is the asynq task payload.
  44. type TaskPayload struct {
  45. TaskID uint `json:"task_id"`
  46. Target string `json:"target,omitempty"`
  47. TestRun *TestRun `json:"test_run,omitempty"`
  48. SkipPhases []string `json:"skip_phases,omitempty"`
  49. }
  50. // TestRun limits items processed during a test run.
  51. type TestRun struct {
  52. ItemLimit int `json:"item_limit"`
  53. MessageLimit int `json:"message_limit"`
  54. }
  55. // Worker wraps the asynq server.
  56. type Worker struct {
  57. server *asynq.Server
  58. mux *asynq.ServeMux
  59. db *gorm.DB
  60. redis *redis.Client
  61. tgManager *telegram.AccountManager
  62. llmClient *llm.Client
  63. settings pipeline.Settings
  64. serperClient *search.SerperClient
  65. githubToken string
  66. pipeline *pipeline.Runner
  67. }
  68. // New creates and configures a new Worker.
  69. func New(redisAddr, redisPassword string, redisDB int, db *gorm.DB, rdb *redis.Client, tgManager *telegram.AccountManager, llmClient *llm.Client, settings pipeline.Settings, serperClient *search.SerperClient, githubToken string) *Worker {
  70. srv := asynq.NewServer(
  71. asynq.RedisClientOpt{
  72. Addr: redisAddr,
  73. Password: redisPassword,
  74. DB: redisDB,
  75. },
  76. asynq.Config{
  77. Concurrency: 4,
  78. Queues: map[string]int{
  79. QueueDefault: 10,
  80. },
  81. },
  82. )
  83. runner := pipeline.NewRunner(db, rdb)
  84. runner.RegisterPhase(pipeline.NewDiscoverPhase(db, tgManager, settings))
  85. runner.RegisterPhase(pipeline.NewSearchPhase(db, serperClient, settings))
  86. runner.RegisterPhase(pipeline.NewGithubPhase(db, githubToken, settings))
  87. runner.RegisterPhase(pipeline.NewScrapePhase(db, tgManager, llmClient, settings, rdb))
  88. runner.RegisterPhase(pipeline.NewCrawlPhase(db, llmClient, settings))
  89. runner.RegisterPhase(pipeline.NewCleanPhase(db, tgManager, settings))
  90. runner.RegisterPhase(pipeline.NewScorePhase(db))
  91. w := &Worker{
  92. server: srv,
  93. mux: asynq.NewServeMux(),
  94. db: db,
  95. redis: rdb,
  96. tgManager: tgManager,
  97. llmClient: llmClient,
  98. settings: settings,
  99. serperClient: serperClient,
  100. githubToken: githubToken,
  101. pipeline: runner,
  102. }
  103. // Register all task types to the same generic handler.
  104. w.mux.HandleFunc(TypeFullPipeline, w.processTask)
  105. w.mux.HandleFunc(TypeDiscover, w.processTask)
  106. w.mux.HandleFunc(TypeSearch, w.processTask)
  107. w.mux.HandleFunc(TypeGithub, w.processTask)
  108. w.mux.HandleFunc(TypeScrape, w.processTask)
  109. w.mux.HandleFunc(TypeCrawl, w.processTask)
  110. w.mux.HandleFunc(TypeClean, w.processTask)
  111. w.mux.HandleFunc(TypeScore, w.processTask)
  112. return w
  113. }
  114. // acquireLock tries to acquire a Redis SET NX EX lock. Returns true on success.
  115. func (w *Worker) acquireLock(ctx context.Context, lockKey string) bool {
  116. ok, err := w.redis.SetNX(ctx, lockKey, "1", 24*time.Hour).Result()
  117. if err != nil {
  118. log.Printf("[worker] acquireLock error key=%s: %v", lockKey, err)
  119. return false
  120. }
  121. return ok
  122. }
  123. // releaseLock deletes the Redis lock key.
  124. func (w *Worker) releaseLock(ctx context.Context, lockKey string) {
  125. if err := w.redis.Del(ctx, lockKey).Err(); err != nil {
  126. log.Printf("[worker] releaseLock error key=%s: %v", lockKey, err)
  127. }
  128. }
  129. // writeLog appends a timestamped log line to the Redis list for this task.
  130. // Keeps only the last 500 entries and sets a 24-hour TTL.
  131. func (w *Worker) writeLog(ctx context.Context, taskID uint, msg string) {
  132. key := fmt.Sprintf("spider:task:logs:%d", taskID)
  133. ts := time.Now().Format("15:04:05")
  134. line := fmt.Sprintf("[%s] %s", ts, msg)
  135. w.redis.RPush(ctx, key, line)
  136. w.redis.LTrim(ctx, key, -500, -1)
  137. w.redis.Expire(ctx, key, 24*time.Hour)
  138. }
  139. // writeProgress writes task progress fields to Redis.
  140. func (w *Worker) writeProgress(ctx context.Context, taskID uint, phase string, current, total int, message string) {
  141. key := progressKey(taskID)
  142. now := time.Now().UTC().Format(time.RFC3339)
  143. err := w.redis.HSet(ctx, key,
  144. "phase", phase,
  145. "current", current,
  146. "total", total,
  147. "message", message,
  148. "updated_at", now,
  149. ).Err()
  150. if err != nil {
  151. log.Printf("[worker] writeProgress error task=%d: %v", taskID, err)
  152. return
  153. }
  154. w.redis.Expire(ctx, key, 24*time.Hour)
  155. }
  156. // isStopRequested checks whether a stop signal has been set for this task.
  157. func (w *Worker) isStopRequested(ctx context.Context, taskID uint) bool {
  158. val, err := w.redis.Get(ctx, stopKey(taskID)).Result()
  159. if err != nil {
  160. return false
  161. }
  162. return val == "1"
  163. }
  164. // taskTypeFromAsynqType converts an asynq type string to the model task_type value.
  165. func taskTypeFromAsynqType(asynqType string) string {
  166. switch asynqType {
  167. case TypeFullPipeline:
  168. return "full"
  169. case TypeDiscover:
  170. return "discover"
  171. case TypeSearch:
  172. return "search"
  173. case TypeGithub:
  174. return "github"
  175. case TypeScrape:
  176. return "scrape"
  177. case TypeCrawl:
  178. return "crawl"
  179. case TypeClean:
  180. return "clean"
  181. case TypeScore:
  182. return "score"
  183. default:
  184. return asynqType
  185. }
  186. }
  187. // processTask is the core handler invoked for every registered task type.
  188. func (w *Worker) processTask(ctx context.Context, t *asynq.Task) error {
  189. var payload TaskPayload
  190. if err := json.Unmarshal(t.Payload(), &payload); err != nil {
  191. return fmt.Errorf("unmarshal payload: %w", err)
  192. }
  193. taskID := payload.TaskID
  194. taskType := taskTypeFromAsynqType(t.Type())
  195. lockKey := lockKeyForType(taskType)
  196. log.Printf("[worker] processing task id=%d type=%s", taskID, taskType)
  197. // Acquire distributed lock.
  198. if !w.acquireLock(ctx, lockKey) {
  199. return fmt.Errorf("another %s task is already running, skipping", taskType)
  200. }
  201. defer w.releaseLock(ctx, lockKey)
  202. // 1. Update task status → running.
  203. now := time.Now()
  204. if err := w.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", taskID).Updates(map[string]interface{}{
  205. "status": "running",
  206. "started_at": &now,
  207. }).Error; err != nil {
  208. return fmt.Errorf("update task running: %w", err)
  209. }
  210. // 2. Write initial progress to Redis and log task start.
  211. w.writeProgress(ctx, taskID, taskType, 0, 0, "任务启动中...")
  212. w.writeLog(ctx, taskID, fmt.Sprintf("任务开始: %s (id=%d)", taskType, taskID))
  213. // 3. Fetch the full task record for the pipeline.
  214. var task model.Task
  215. if err := w.db.WithContext(ctx).First(&task, taskID).Error; err != nil {
  216. return fmt.Errorf("fetch task record: %w", err)
  217. }
  218. // 4. Build pipeline options from payload.
  219. opts := &pipeline.Options{
  220. Target: payload.Target,
  221. SkipPhases: payload.SkipPhases,
  222. }
  223. if payload.TestRun != nil {
  224. opts.TestRun = &pipeline.TestRun{
  225. ItemLimit: payload.TestRun.ItemLimit,
  226. MessageLimit: payload.TestRun.MessageLimit,
  227. }
  228. }
  229. // Wire progress reporter so pipeline phases report through writeProgress and writeLog.
  230. w.pipeline.SetProgressReporter(func(phase string, current, total int, message string) {
  231. // Also check for stop signal on each progress report.
  232. if w.isStopRequested(ctx, taskID) {
  233. log.Printf("[worker] task %d stop requested during phase=%s", taskID, phase)
  234. }
  235. w.writeProgress(ctx, taskID, phase, current, total, message)
  236. if message != "" {
  237. w.writeLog(ctx, taskID, fmt.Sprintf("[%s] %d/%d %s", phase, current, total, message))
  238. }
  239. })
  240. // 5. Run the pipeline. For full tasks, phase failures are logged but non-fatal.
  241. if pipelineErr := w.pipeline.Run(ctx, &task, opts); pipelineErr != nil {
  242. // Single-phase tasks propagate errors; full-pipeline errors are already handled inside Run.
  243. log.Printf("[worker] pipeline error task=%d: %v", taskID, pipelineErr)
  244. errorTime := time.Now()
  245. w.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", taskID).Updates(map[string]interface{}{
  246. "status": "failed",
  247. "finished_at": &errorTime,
  248. "error_msg": pipelineErr.Error(),
  249. })
  250. w.writeProgress(ctx, taskID, taskType, 0, 0, "任务失败: "+pipelineErr.Error())
  251. w.writeLog(ctx, taskID, "任务失败: "+pipelineErr.Error())
  252. return pipelineErr
  253. }
  254. // Check for stop request after pipeline finishes.
  255. if w.isStopRequested(ctx, taskID) {
  256. log.Printf("[worker] task %d stop requested", taskID)
  257. stopTime := time.Now()
  258. w.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", taskID).Updates(map[string]interface{}{
  259. "status": "stopped",
  260. "finished_at": &stopTime,
  261. })
  262. w.writeProgress(ctx, taskID, taskType, 0, 0, "任务已停止")
  263. w.writeLog(ctx, taskID, "任务已停止")
  264. return nil
  265. }
  266. // 6. Mark task as completed.
  267. finishedAt := time.Now()
  268. resultJSON, _ := json.Marshal(map[string]interface{}{"message": "task completed successfully"})
  269. if err := w.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", taskID).Updates(map[string]interface{}{
  270. "status": "completed",
  271. "finished_at": &finishedAt,
  272. "result": resultJSON,
  273. }).Error; err != nil {
  274. return fmt.Errorf("update task completed: %w", err)
  275. }
  276. w.writeProgress(ctx, taskID, taskType, 100, 100, "任务完成")
  277. w.writeLog(ctx, taskID, "任务完成")
  278. log.Printf("[worker] task %d completed", taskID)
  279. return nil
  280. }
  281. // Start runs the asynq server (blocking).
  282. func (w *Worker) Start() error {
  283. return w.server.Run(w.mux)
  284. }
  285. // Stop gracefully shuts down the asynq server.
  286. func (w *Worker) Stop() {
  287. w.server.Shutdown()
  288. }