| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325 |
- package worker
- import (
- "context"
- "encoding/json"
- "fmt"
- "log"
- "time"
- "github.com/hibiken/asynq"
- "github.com/redis/go-redis/v9"
- "gorm.io/gorm"
- "spider/internal/llm"
- "spider/internal/model"
- "spider/internal/pipeline"
- "spider/internal/search"
- "spider/internal/telegram"
- )
- const (
- QueueDefault = "default"
- TypeFullPipeline = "task:full"
- TypeDiscover = "task:discover"
- TypeSearch = "task:search"
- TypeGithub = "task:github"
- TypeScrape = "task:scrape"
- TypeCrawl = "task:crawl"
- TypeClean = "task:clean"
- TypeScore = "task:score"
- )
- // lockKeyForType returns the Redis lock key for a given task type.
- func lockKeyForType(taskType string) string {
- if taskType == "full" {
- return "spider:task:lock:global"
- }
- return fmt.Sprintf("spider:task:lock:%s", taskType)
- }
- // progressKey returns the Redis hash key for task progress.
- func progressKey(taskID uint) string {
- return fmt.Sprintf("spider:task:progress:%d", taskID)
- }
- // stopKey returns the Redis key used to signal stop for a task.
- func stopKey(taskID uint) string {
- return fmt.Sprintf("spider:task:stop:%d", taskID)
- }
- // TaskPayload is the asynq task payload.
- type TaskPayload struct {
- TaskID uint `json:"task_id"`
- Target string `json:"target,omitempty"`
- TestRun *TestRun `json:"test_run,omitempty"`
- SkipPhases []string `json:"skip_phases,omitempty"`
- }
- // TestRun limits items processed during a test run.
- type TestRun struct {
- ItemLimit int `json:"item_limit"`
- MessageLimit int `json:"message_limit"`
- }
- // Worker wraps the asynq server.
- type Worker struct {
- server *asynq.Server
- mux *asynq.ServeMux
- db *gorm.DB
- redis *redis.Client
- tgManager *telegram.AccountManager
- llmClient *llm.Client
- settings pipeline.Settings
- serperClient *search.SerperClient
- githubToken string
- pipeline *pipeline.Runner
- }
- // New creates and configures a new Worker.
- func New(redisAddr, redisPassword string, redisDB int, db *gorm.DB, rdb *redis.Client, tgManager *telegram.AccountManager, llmClient *llm.Client, settings pipeline.Settings, serperClient *search.SerperClient, githubToken string) *Worker {
- srv := asynq.NewServer(
- asynq.RedisClientOpt{
- Addr: redisAddr,
- Password: redisPassword,
- DB: redisDB,
- },
- asynq.Config{
- Concurrency: 4,
- Queues: map[string]int{
- QueueDefault: 10,
- },
- },
- )
- runner := pipeline.NewRunner(db, rdb)
- runner.RegisterPhase(pipeline.NewDiscoverPhase(db, tgManager, settings))
- runner.RegisterPhase(pipeline.NewSearchPhase(db, serperClient, settings))
- runner.RegisterPhase(pipeline.NewGithubPhase(db, githubToken, settings))
- runner.RegisterPhase(pipeline.NewScrapePhase(db, tgManager, llmClient, settings, rdb))
- runner.RegisterPhase(pipeline.NewCrawlPhase(db, llmClient, settings))
- runner.RegisterPhase(pipeline.NewCleanPhase(db, tgManager, settings))
- runner.RegisterPhase(pipeline.NewScorePhase(db))
- w := &Worker{
- server: srv,
- mux: asynq.NewServeMux(),
- db: db,
- redis: rdb,
- tgManager: tgManager,
- llmClient: llmClient,
- settings: settings,
- serperClient: serperClient,
- githubToken: githubToken,
- pipeline: runner,
- }
- // Register all task types to the same generic handler.
- w.mux.HandleFunc(TypeFullPipeline, w.processTask)
- w.mux.HandleFunc(TypeDiscover, w.processTask)
- w.mux.HandleFunc(TypeSearch, w.processTask)
- w.mux.HandleFunc(TypeGithub, w.processTask)
- w.mux.HandleFunc(TypeScrape, w.processTask)
- w.mux.HandleFunc(TypeCrawl, w.processTask)
- w.mux.HandleFunc(TypeClean, w.processTask)
- w.mux.HandleFunc(TypeScore, w.processTask)
- return w
- }
- // acquireLock tries to acquire a Redis SET NX EX lock. Returns true on success.
- func (w *Worker) acquireLock(ctx context.Context, lockKey string) bool {
- ok, err := w.redis.SetNX(ctx, lockKey, "1", 24*time.Hour).Result()
- if err != nil {
- log.Printf("[worker] acquireLock error key=%s: %v", lockKey, err)
- return false
- }
- return ok
- }
- // releaseLock deletes the Redis lock key.
- func (w *Worker) releaseLock(ctx context.Context, lockKey string) {
- if err := w.redis.Del(ctx, lockKey).Err(); err != nil {
- log.Printf("[worker] releaseLock error key=%s: %v", lockKey, err)
- }
- }
- // writeLog appends a timestamped log line to the Redis list for this task.
- // Keeps only the last 500 entries and sets a 24-hour TTL.
- func (w *Worker) writeLog(ctx context.Context, taskID uint, msg string) {
- key := fmt.Sprintf("spider:task:logs:%d", taskID)
- ts := time.Now().Format("15:04:05")
- line := fmt.Sprintf("[%s] %s", ts, msg)
- w.redis.RPush(ctx, key, line)
- w.redis.LTrim(ctx, key, -500, -1)
- w.redis.Expire(ctx, key, 24*time.Hour)
- }
- // writeProgress writes task progress fields to Redis.
- func (w *Worker) writeProgress(ctx context.Context, taskID uint, phase string, current, total int, message string) {
- key := progressKey(taskID)
- now := time.Now().UTC().Format(time.RFC3339)
- err := w.redis.HSet(ctx, key,
- "phase", phase,
- "current", current,
- "total", total,
- "message", message,
- "updated_at", now,
- ).Err()
- if err != nil {
- log.Printf("[worker] writeProgress error task=%d: %v", taskID, err)
- return
- }
- w.redis.Expire(ctx, key, 24*time.Hour)
- }
- // isStopRequested checks whether a stop signal has been set for this task.
- func (w *Worker) isStopRequested(ctx context.Context, taskID uint) bool {
- val, err := w.redis.Get(ctx, stopKey(taskID)).Result()
- if err != nil {
- return false
- }
- return val == "1"
- }
- // taskTypeFromAsynqType converts an asynq type string to the model task_type value.
- func taskTypeFromAsynqType(asynqType string) string {
- switch asynqType {
- case TypeFullPipeline:
- return "full"
- case TypeDiscover:
- return "discover"
- case TypeSearch:
- return "search"
- case TypeGithub:
- return "github"
- case TypeScrape:
- return "scrape"
- case TypeCrawl:
- return "crawl"
- case TypeClean:
- return "clean"
- case TypeScore:
- return "score"
- default:
- return asynqType
- }
- }
- // processTask is the core handler invoked for every registered task type.
- func (w *Worker) processTask(ctx context.Context, t *asynq.Task) error {
- var payload TaskPayload
- if err := json.Unmarshal(t.Payload(), &payload); err != nil {
- return fmt.Errorf("unmarshal payload: %w", err)
- }
- taskID := payload.TaskID
- taskType := taskTypeFromAsynqType(t.Type())
- lockKey := lockKeyForType(taskType)
- log.Printf("[worker] processing task id=%d type=%s", taskID, taskType)
- // Acquire distributed lock.
- if !w.acquireLock(ctx, lockKey) {
- return fmt.Errorf("another %s task is already running, skipping", taskType)
- }
- defer w.releaseLock(ctx, lockKey)
- // 1. Update task status → running.
- now := time.Now()
- if err := w.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", taskID).Updates(map[string]interface{}{
- "status": "running",
- "started_at": &now,
- }).Error; err != nil {
- return fmt.Errorf("update task running: %w", err)
- }
- // 2. Write initial progress to Redis and log task start.
- w.writeProgress(ctx, taskID, taskType, 0, 0, "任务启动中...")
- w.writeLog(ctx, taskID, fmt.Sprintf("任务开始: %s (id=%d)", taskType, taskID))
- // 3. Fetch the full task record for the pipeline.
- var task model.Task
- if err := w.db.WithContext(ctx).First(&task, taskID).Error; err != nil {
- return fmt.Errorf("fetch task record: %w", err)
- }
- // 4. Build pipeline options from payload.
- opts := &pipeline.Options{
- Target: payload.Target,
- SkipPhases: payload.SkipPhases,
- }
- if payload.TestRun != nil {
- opts.TestRun = &pipeline.TestRun{
- ItemLimit: payload.TestRun.ItemLimit,
- MessageLimit: payload.TestRun.MessageLimit,
- }
- }
- // Wire progress reporter so pipeline phases report through writeProgress and writeLog.
- w.pipeline.SetProgressReporter(func(phase string, current, total int, message string) {
- // Also check for stop signal on each progress report.
- if w.isStopRequested(ctx, taskID) {
- log.Printf("[worker] task %d stop requested during phase=%s", taskID, phase)
- }
- w.writeProgress(ctx, taskID, phase, current, total, message)
- if message != "" {
- w.writeLog(ctx, taskID, fmt.Sprintf("[%s] %d/%d %s", phase, current, total, message))
- }
- })
- // 5. Run the pipeline. For full tasks, phase failures are logged but non-fatal.
- if pipelineErr := w.pipeline.Run(ctx, &task, opts); pipelineErr != nil {
- // Single-phase tasks propagate errors; full-pipeline errors are already handled inside Run.
- log.Printf("[worker] pipeline error task=%d: %v", taskID, pipelineErr)
- errorTime := time.Now()
- w.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", taskID).Updates(map[string]interface{}{
- "status": "failed",
- "finished_at": &errorTime,
- "error_msg": pipelineErr.Error(),
- })
- w.writeProgress(ctx, taskID, taskType, 0, 0, "任务失败: "+pipelineErr.Error())
- w.writeLog(ctx, taskID, "任务失败: "+pipelineErr.Error())
- return pipelineErr
- }
- // Check for stop request after pipeline finishes.
- if w.isStopRequested(ctx, taskID) {
- log.Printf("[worker] task %d stop requested", taskID)
- stopTime := time.Now()
- w.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", taskID).Updates(map[string]interface{}{
- "status": "stopped",
- "finished_at": &stopTime,
- })
- w.writeProgress(ctx, taskID, taskType, 0, 0, "任务已停止")
- w.writeLog(ctx, taskID, "任务已停止")
- return nil
- }
- // 6. Mark task as completed.
- finishedAt := time.Now()
- resultJSON, _ := json.Marshal(map[string]interface{}{"message": "task completed successfully"})
- if err := w.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", taskID).Updates(map[string]interface{}{
- "status": "completed",
- "finished_at": &finishedAt,
- "result": resultJSON,
- }).Error; err != nil {
- return fmt.Errorf("update task completed: %w", err)
- }
- w.writeProgress(ctx, taskID, taskType, 100, 100, "任务完成")
- w.writeLog(ctx, taskID, "任务完成")
- log.Printf("[worker] task %d completed", taskID)
- return nil
- }
- // Start runs the asynq server (blocking).
- func (w *Worker) Start() error {
- return w.server.Run(w.mux)
- }
- // Stop gracefully shuts down the asynq server.
- func (w *Worker) Stop() {
- w.server.Shutdown()
- }
|