package task import ( "context" "encoding/json" "fmt" "log" "sync" "time" "github.com/redis/go-redis/v9" "gorm.io/gorm" "spider/internal/model" "spider/internal/plugin" "spider/internal/processor" "spider/internal/store" ) // StartRequest is the payload for starting a new task. type StartRequest struct { PluginName string `json:"plugin_name" binding:"required"` AutoClean bool `json:"auto_clean"` // run processor after collection (default true) } // Manager manages plugin task lifecycle using goroutines. // Replaces the asynq-based worker. type Manager struct { db *gorm.DB redis *redis.Client registry *plugin.Registry store *store.Store processor *processor.Processor mu sync.Mutex running map[uint]context.CancelFunc // taskID -> cancel } // NewManager creates a new task manager. func NewManager(db *gorm.DB, rdb *redis.Client, reg *plugin.Registry, s *store.Store, proc *processor.Processor) *Manager { return &Manager{ db: db, redis: rdb, registry: reg, store: s, processor: proc, running: make(map[uint]context.CancelFunc), } } // StartTask validates, creates a TaskLog record, and runs the plugin in a goroutine. func (m *Manager) StartTask(req StartRequest) (*model.TaskLog, error) { // Validate plugin exists collector, err := m.registry.Get(req.PluginName) if err != nil { return nil, fmt.Errorf("unknown plugin: %s", req.PluginName) } // Check if same plugin is already running var count int64 m.db.Model(&model.TaskLog{}). Where("plugin_name = ? AND status = ?", req.PluginName, "running"). Count(&count) if count > 0 { return nil, fmt.Errorf("plugin %s is already running", req.PluginName) } // Create task log record now := time.Now() taskLog := &model.TaskLog{ TaskType: "collect", PluginName: req.PluginName, Status: "running", StartedAt: &now, } if err := m.db.Create(taskLog).Error; err != nil { return nil, fmt.Errorf("create task log: %w", err) } // Build config for the plugin cfg, err := m.buildPluginConfig(req.PluginName) if err != nil { m.failTask(taskLog, err) return nil, err } // Start in goroutine ctx, cancel := context.WithCancel(context.Background()) m.mu.Lock() m.running[taskLog.ID] = cancel m.mu.Unlock() autoClean := req.AutoClean // Default to true if not explicitly set if !req.AutoClean { autoClean = true } go m.runTask(ctx, taskLog, collector, cfg, autoClean) return taskLog, nil } func (m *Manager) runTask(ctx context.Context, taskLog *model.TaskLog, collector plugin.Collector, cfg map[string]any, autoClean bool) { defer func() { m.mu.Lock() delete(m.running, taskLog.ID) m.mu.Unlock() }() m.writeLog(ctx, taskLog.ID, fmt.Sprintf("开始采集: %s", collector.Name())) merchantCount := 0 errCount := 0 // Callback: for each merchant found, save to raw table callback := func(data plugin.MerchantData) { inserted, err := m.store.SaveRaw(data) if err != nil { errCount++ log.Printf("[task] save raw error: %v", err) return } if inserted { merchantCount++ if merchantCount%10 == 0 { m.writeProgress(ctx, taskLog.ID, collector.Name(), merchantCount, 0, fmt.Sprintf("已采集 %d 个商户", merchantCount)) } } } // Run the collector runErr := collector.Run(ctx, cfg, callback) // Check if stopped if m.isStopRequested(ctx, taskLog.ID) || ctx.Err() != nil { m.writeLog(ctx, taskLog.ID, "任务已停止") finishedAt := time.Now() m.db.Model(taskLog).Updates(map[string]any{ "status": "stopped", "finished_at": &finishedAt, "merchants_added": merchantCount, "errors_count": errCount, }) return } if runErr != nil { m.failTask(taskLog, runErr) m.writeLog(ctx, taskLog.ID, "采集失败: "+runErr.Error()) return } m.writeLog(ctx, taskLog.ID, fmt.Sprintf("采集完成: 新增 %d 个商户", merchantCount)) // Auto-clean: run processor on new raw records if autoClean && merchantCount > 0 { m.writeLog(ctx, taskLog.ID, "开始清洗流程...") m.writeProgress(ctx, taskLog.ID, "clean", 0, 0, "清洗中...") m.processor.SetProgressFn(func(step string, current, total int, msg string) { m.writeProgress(ctx, taskLog.ID, step, current, total, msg) m.writeLog(ctx, taskLog.ID, fmt.Sprintf("[%s] %d/%d %s", step, current, total, msg)) }) procResult, procErr := m.processor.Process(ctx) if procErr != nil { m.writeLog(ctx, taskLog.ID, "清洗失败: "+procErr.Error()) } else { m.writeLog(ctx, taskLog.ID, fmt.Sprintf("清洗完成: Hot=%d, Warm=%d, Cold=%d", procResult.HotCount, procResult.WarmCount, procResult.ColdCount)) } } // Complete finishedAt := time.Now() detail := fmt.Sprintf("采集 %d 个商户, 错误 %d 次", merchantCount, errCount) m.db.Model(taskLog).Updates(map[string]any{ "status": "completed", "finished_at": &finishedAt, "merchants_added": merchantCount, "errors_count": errCount, "detail": detail, }) m.writeProgress(ctx, taskLog.ID, "done", 100, 100, "任务完成") m.writeLog(ctx, taskLog.ID, "任务完成") log.Printf("[task] task %d completed: %s", taskLog.ID, detail) } // StartClean runs the processor independently (not tied to a plugin). func (m *Manager) StartClean() (*model.TaskLog, error) { var count int64 m.db.Model(&model.TaskLog{}). Where("task_type = ? AND status = ?", "clean", "running"). Count(&count) if count > 0 { return nil, fmt.Errorf("clean task is already running") } now := time.Now() taskLog := &model.TaskLog{ TaskType: "clean", PluginName: "", Status: "running", StartedAt: &now, } if err := m.db.Create(taskLog).Error; err != nil { return nil, err } ctx, cancel := context.WithCancel(context.Background()) m.mu.Lock() m.running[taskLog.ID] = cancel m.mu.Unlock() go func() { defer func() { m.mu.Lock() delete(m.running, taskLog.ID) m.mu.Unlock() }() m.writeLog(ctx, taskLog.ID, "开始独立清洗任务") m.processor.SetProgressFn(func(step string, current, total int, msg string) { m.writeProgress(ctx, taskLog.ID, step, current, total, msg) }) result, err := m.processor.Process(ctx) finishedAt := time.Now() if err != nil { m.db.Model(taskLog).Updates(map[string]any{ "status": "failed", "finished_at": &finishedAt, "detail": err.Error(), }) return } detail := fmt.Sprintf("输入 %d, Hot=%d, Warm=%d, Cold=%d", result.InputCount, result.HotCount, result.WarmCount, result.ColdCount) m.db.Model(taskLog).Updates(map[string]any{ "status": "completed", "finished_at": &finishedAt, "items_processed": result.InputCount, "merchants_added": result.OutputCount, "detail": detail, }) m.writeLog(ctx, taskLog.ID, "清洗完成: "+detail) }() return taskLog, nil } // StopTask cancels a running task. func (m *Manager) StopTask(taskID uint) error { // Set Redis stop signal key := fmt.Sprintf("spider:task:stop:%d", taskID) m.redis.Set(context.Background(), key, "1", time.Hour) // Cancel the goroutine context m.mu.Lock() cancel, ok := m.running[taskID] m.mu.Unlock() if ok { cancel() } // Also try to stop the collector var taskLog model.TaskLog if err := m.db.First(&taskLog, taskID).Error; err == nil && taskLog.PluginName != "" { if collector, err := m.registry.Get(taskLog.PluginName); err == nil { collector.Stop() } } return nil } // GetProgress reads live progress from Redis. func (m *Manager) GetProgress(taskID uint) map[string]any { key := fmt.Sprintf("spider:task:progress:%d", taskID) vals, err := m.redis.HGetAll(context.Background(), key).Result() if err != nil { return nil } result := make(map[string]any) for k, v := range vals { result[k] = v } return result } // GetLogs reads task logs from Redis. func (m *Manager) GetLogs(taskID uint) []string { key := fmt.Sprintf("spider:task:logs:%d", taskID) logs, err := m.redis.LRange(context.Background(), key, 0, -1).Result() if err != nil { return nil } return logs } // buildPluginConfig builds the config map for a plugin from the DB. func (m *Manager) buildPluginConfig(pluginName string) (map[string]any, error) { cfg := make(map[string]any) switch pluginName { case "web_collector": keywords, err := m.store.ListEnabledKeywords() if err != nil { return nil, err } kws := make([]string, 0, len(keywords)) for _, k := range keywords { kws = append(kws, k.Keyword) } cfg["keywords"] = kws case "tg_collector": seeds, err := m.store.ListSeeds() if err != nil { return nil, err } seedNames := make([]string, 0, len(seeds)) for _, s := range seeds { seedNames = append(seedNames, s.Keyword) } cfg["seeds"] = seedNames cfg["max_depth"] = 3 cfg["max_channels"] = 500 cfg["message_limit"] = 500 case "github_collector": keywords, err := m.store.ListEnabledKeywords() if err != nil { return nil, err } kws := make([]string, 0, len(keywords)) for _, k := range keywords { kws = append(kws, k.Keyword) } cfg["keywords"] = kws cfg["repos_limit"] = 50 } return cfg, nil } func (m *Manager) failTask(taskLog *model.TaskLog, err error) { finishedAt := time.Now() m.db.Model(taskLog).Updates(map[string]any{ "status": "failed", "finished_at": &finishedAt, "detail": err.Error(), }) } func (m *Manager) isStopRequested(ctx context.Context, taskID uint) bool { key := fmt.Sprintf("spider:task:stop:%d", taskID) val, _ := m.redis.Get(ctx, key).Result() return val == "1" } func (m *Manager) writeLog(ctx context.Context, taskID uint, msg string) { key := fmt.Sprintf("spider:task:logs:%d", taskID) ts := time.Now().Format("15:04:05") line := fmt.Sprintf("[%s] %s", ts, msg) m.redis.RPush(ctx, key, line) m.redis.LTrim(ctx, key, -500, -1) m.redis.Expire(ctx, key, 24*time.Hour) } func (m *Manager) writeProgress(ctx context.Context, taskID uint, phase string, current, total int, message string) { key := fmt.Sprintf("spider:task:progress:%d", taskID) now := time.Now().UTC().Format(time.RFC3339) fields := map[string]any{ "phase": phase, "current": current, "total": total, "message": message, "updated_at": now, } b, _ := json.Marshal(fields) m.redis.Set(ctx, key, string(b), 24*time.Hour) }