package worker import ( "context" "encoding/json" "fmt" "log" "time" "github.com/hibiken/asynq" "github.com/redis/go-redis/v9" "gorm.io/gorm" "spider/internal/llm" "spider/internal/model" "spider/internal/pipeline" "spider/internal/search" "spider/internal/telegram" ) const ( QueueDefault = "default" TypeFullPipeline = "task:full" TypeDiscover = "task:discover" TypeSearch = "task:search" TypeGithub = "task:github" TypeScrape = "task:scrape" TypeCrawl = "task:crawl" TypeClean = "task:clean" TypeScore = "task:score" ) // lockKeyForType returns the Redis lock key for a given task type. func lockKeyForType(taskType string) string { if taskType == "full" { return "spider:task:lock:global" } return fmt.Sprintf("spider:task:lock:%s", taskType) } // progressKey returns the Redis hash key for task progress. func progressKey(taskID uint) string { return fmt.Sprintf("spider:task:progress:%d", taskID) } // stopKey returns the Redis key used to signal stop for a task. func stopKey(taskID uint) string { return fmt.Sprintf("spider:task:stop:%d", taskID) } // TaskPayload is the asynq task payload. type TaskPayload struct { TaskID uint `json:"task_id"` Target string `json:"target,omitempty"` TestRun *TestRun `json:"test_run,omitempty"` SkipPhases []string `json:"skip_phases,omitempty"` } // TestRun limits items processed during a test run. type TestRun struct { ItemLimit int `json:"item_limit"` MessageLimit int `json:"message_limit"` } // Worker wraps the asynq server. type Worker struct { server *asynq.Server mux *asynq.ServeMux db *gorm.DB redis *redis.Client tgManager *telegram.AccountManager llmClient *llm.Client settings pipeline.Settings serperClient *search.SerperClient githubToken string pipeline *pipeline.Runner } // New creates and configures a new Worker. func New(redisAddr, redisPassword string, redisDB int, db *gorm.DB, rdb *redis.Client, tgManager *telegram.AccountManager, llmClient *llm.Client, settings pipeline.Settings, serperClient *search.SerperClient, githubToken string) *Worker { srv := asynq.NewServer( asynq.RedisClientOpt{ Addr: redisAddr, Password: redisPassword, DB: redisDB, }, asynq.Config{ Concurrency: 4, Queues: map[string]int{ QueueDefault: 10, }, }, ) runner := pipeline.NewRunner(db, rdb) runner.RegisterPhase(pipeline.NewDiscoverPhase(db, tgManager, settings)) runner.RegisterPhase(pipeline.NewSearchPhase(db, serperClient, settings)) runner.RegisterPhase(pipeline.NewGithubPhase(db, githubToken, settings)) runner.RegisterPhase(pipeline.NewScrapePhase(db, tgManager, llmClient, settings, rdb)) runner.RegisterPhase(pipeline.NewCrawlPhase(db, llmClient, settings)) runner.RegisterPhase(pipeline.NewCleanPhase(db, tgManager, settings)) runner.RegisterPhase(pipeline.NewScorePhase(db)) w := &Worker{ server: srv, mux: asynq.NewServeMux(), db: db, redis: rdb, tgManager: tgManager, llmClient: llmClient, settings: settings, serperClient: serperClient, githubToken: githubToken, pipeline: runner, } // Register all task types to the same generic handler. w.mux.HandleFunc(TypeFullPipeline, w.processTask) w.mux.HandleFunc(TypeDiscover, w.processTask) w.mux.HandleFunc(TypeSearch, w.processTask) w.mux.HandleFunc(TypeGithub, w.processTask) w.mux.HandleFunc(TypeScrape, w.processTask) w.mux.HandleFunc(TypeCrawl, w.processTask) w.mux.HandleFunc(TypeClean, w.processTask) w.mux.HandleFunc(TypeScore, w.processTask) return w } // acquireLock tries to acquire a Redis SET NX EX lock. Returns true on success. func (w *Worker) acquireLock(ctx context.Context, lockKey string) bool { ok, err := w.redis.SetNX(ctx, lockKey, "1", 24*time.Hour).Result() if err != nil { log.Printf("[worker] acquireLock error key=%s: %v", lockKey, err) return false } return ok } // releaseLock deletes the Redis lock key. func (w *Worker) releaseLock(ctx context.Context, lockKey string) { if err := w.redis.Del(ctx, lockKey).Err(); err != nil { log.Printf("[worker] releaseLock error key=%s: %v", lockKey, err) } } // writeLog appends a timestamped log line to the Redis list for this task. // Keeps only the last 500 entries and sets a 24-hour TTL. func (w *Worker) writeLog(ctx context.Context, taskID uint, msg string) { key := fmt.Sprintf("spider:task:logs:%d", taskID) ts := time.Now().Format("15:04:05") line := fmt.Sprintf("[%s] %s", ts, msg) w.redis.RPush(ctx, key, line) w.redis.LTrim(ctx, key, -500, -1) w.redis.Expire(ctx, key, 24*time.Hour) } // writeProgress writes task progress fields to Redis. func (w *Worker) writeProgress(ctx context.Context, taskID uint, phase string, current, total int, message string) { key := progressKey(taskID) now := time.Now().UTC().Format(time.RFC3339) err := w.redis.HSet(ctx, key, "phase", phase, "current", current, "total", total, "message", message, "updated_at", now, ).Err() if err != nil { log.Printf("[worker] writeProgress error task=%d: %v", taskID, err) return } w.redis.Expire(ctx, key, 24*time.Hour) } // isStopRequested checks whether a stop signal has been set for this task. func (w *Worker) isStopRequested(ctx context.Context, taskID uint) bool { val, err := w.redis.Get(ctx, stopKey(taskID)).Result() if err != nil { return false } return val == "1" } // taskTypeFromAsynqType converts an asynq type string to the model task_type value. func taskTypeFromAsynqType(asynqType string) string { switch asynqType { case TypeFullPipeline: return "full" case TypeDiscover: return "discover" case TypeSearch: return "search" case TypeGithub: return "github" case TypeScrape: return "scrape" case TypeCrawl: return "crawl" case TypeClean: return "clean" case TypeScore: return "score" default: return asynqType } } // processTask is the core handler invoked for every registered task type. func (w *Worker) processTask(ctx context.Context, t *asynq.Task) error { var payload TaskPayload if err := json.Unmarshal(t.Payload(), &payload); err != nil { return fmt.Errorf("unmarshal payload: %w", err) } taskID := payload.TaskID taskType := taskTypeFromAsynqType(t.Type()) lockKey := lockKeyForType(taskType) log.Printf("[worker] processing task id=%d type=%s", taskID, taskType) // Acquire distributed lock. if !w.acquireLock(ctx, lockKey) { return fmt.Errorf("another %s task is already running, skipping", taskType) } defer w.releaseLock(ctx, lockKey) // 1. Update task status → running. now := time.Now() if err := w.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", taskID).Updates(map[string]interface{}{ "status": "running", "started_at": &now, }).Error; err != nil { return fmt.Errorf("update task running: %w", err) } // 2. Write initial progress to Redis and log task start. w.writeProgress(ctx, taskID, taskType, 0, 0, "任务启动中...") w.writeLog(ctx, taskID, fmt.Sprintf("任务开始: %s (id=%d)", taskType, taskID)) // 3. Fetch the full task record for the pipeline. var task model.Task if err := w.db.WithContext(ctx).First(&task, taskID).Error; err != nil { return fmt.Errorf("fetch task record: %w", err) } // 4. Build pipeline options from payload. opts := &pipeline.Options{ Target: payload.Target, SkipPhases: payload.SkipPhases, } if payload.TestRun != nil { opts.TestRun = &pipeline.TestRun{ ItemLimit: payload.TestRun.ItemLimit, MessageLimit: payload.TestRun.MessageLimit, } } // Wire progress reporter so pipeline phases report through writeProgress and writeLog. w.pipeline.SetProgressReporter(func(phase string, current, total int, message string) { // Also check for stop signal on each progress report. if w.isStopRequested(ctx, taskID) { log.Printf("[worker] task %d stop requested during phase=%s", taskID, phase) } w.writeProgress(ctx, taskID, phase, current, total, message) if message != "" { w.writeLog(ctx, taskID, fmt.Sprintf("[%s] %d/%d %s", phase, current, total, message)) } }) // 5. Run the pipeline. For full tasks, phase failures are logged but non-fatal. if pipelineErr := w.pipeline.Run(ctx, &task, opts); pipelineErr != nil { // Single-phase tasks propagate errors; full-pipeline errors are already handled inside Run. log.Printf("[worker] pipeline error task=%d: %v", taskID, pipelineErr) errorTime := time.Now() w.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", taskID).Updates(map[string]interface{}{ "status": "failed", "finished_at": &errorTime, "error_msg": pipelineErr.Error(), }) w.writeProgress(ctx, taskID, taskType, 0, 0, "任务失败: "+pipelineErr.Error()) w.writeLog(ctx, taskID, "任务失败: "+pipelineErr.Error()) return pipelineErr } // Check for stop request after pipeline finishes. if w.isStopRequested(ctx, taskID) { log.Printf("[worker] task %d stop requested", taskID) stopTime := time.Now() w.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", taskID).Updates(map[string]interface{}{ "status": "stopped", "finished_at": &stopTime, }) w.writeProgress(ctx, taskID, taskType, 0, 0, "任务已停止") w.writeLog(ctx, taskID, "任务已停止") return nil } // 6. Mark task as completed. finishedAt := time.Now() resultJSON, _ := json.Marshal(map[string]interface{}{"message": "task completed successfully"}) if err := w.db.WithContext(ctx).Model(&model.Task{}).Where("id = ?", taskID).Updates(map[string]interface{}{ "status": "completed", "finished_at": &finishedAt, "result": resultJSON, }).Error; err != nil { return fmt.Errorf("update task completed: %w", err) } w.writeProgress(ctx, taskID, taskType, 100, 100, "任务完成") w.writeLog(ctx, taskID, "任务完成") log.Printf("[worker] task %d completed", taskID) return nil } // Start runs the asynq server (blocking). func (w *Worker) Start() error { return w.server.Run(w.mux) } // Stop gracefully shuts down the asynq server. func (w *Worker) Stop() { w.server.Shutdown() }