package task import ( "context" "encoding/json" "fmt" "log" "sync" "time" "github.com/redis/go-redis/v9" "gorm.io/gorm" "spider/internal/model" "spider/internal/notification" "spider/internal/plugin" "spider/internal/processor" proxypool "spider/internal/proxy" "spider/internal/store" ) // StartRequest is the payload for starting a new task. type StartRequest struct { PluginName string `json:"plugin_name" binding:"required"` AutoClean *bool `json:"auto_clean"` // run processor after collection (default true) TargetGroup string `json:"target_group,omitempty"` // target a specific TG group/channel for collection ProxyID *uint `json:"proxy_id,omitempty"` // optional single proxy for this task ProxyMode string `json:"proxy_mode,omitempty"` // "single" (default) or "pool" MaxMerchants int `json:"max_merchants,omitempty"` // stop after collecting this many merchants (0 = unlimited) MaxDurationMins int `json:"max_duration_mins,omitempty"` // stop after this many minutes (0 = unlimited) ResumeSnapshot bool `json:"resume_snapshot,omitempty"` // resume from last snapshot } // Manager manages plugin task lifecycle using goroutines. // Replaces the asynq-based worker. type Manager struct { db *gorm.DB redis *redis.Client registry *plugin.Registry store *store.Store processor *processor.Processor mu sync.Mutex running map[uint]context.CancelFunc // taskID -> cancel notifier *notification.Manager proxyPool *proxypool.Pool // current active proxy pool (nil when not using pool mode) } // NewManager creates a new task manager. func NewManager(db *gorm.DB, rdb *redis.Client, reg *plugin.Registry, s *store.Store, proc *processor.Processor) *Manager { return &Manager{ db: db, redis: rdb, registry: reg, store: s, processor: proc, running: make(map[uint]context.CancelFunc), } } // StartTask validates, creates a TaskLog record, and runs the plugin in a goroutine. func (m *Manager) StartTask(req StartRequest) (*model.TaskLog, error) { // Validate plugin exists collector, err := m.registry.Get(req.PluginName) if err != nil { return nil, fmt.Errorf("unknown plugin: %s", req.PluginName) } // Check if same plugin is already running var count int64 m.db.Model(&model.TaskLog{}). Where("plugin_name = ? AND status = ?", req.PluginName, "running"). Count(&count) if count > 0 { return nil, fmt.Errorf("plugin %s is already running", req.PluginName) } // Resolve proxy configuration var proxyID *uint var proxyName string var proxyURL string var pool *proxypool.Pool if req.ProxyMode == "pool" { // Pool mode: load all enabled proxies var proxies []model.Proxy m.db.Where("enabled = ?", true).Find(&proxies) if len(proxies) == 0 { return nil, fmt.Errorf("代理池模式但没有可用的代理") } pool = proxypool.NewPool(3, 2*time.Minute) names := make([]string, 0, len(proxies)) for _, p := range proxies { pool.Add(p.ID, p.Name, p.ProxyURL(), p.Region) names = append(names, p.Name) } proxyName = fmt.Sprintf("代理池(%d个)", len(proxies)) m.mu.Lock() m.proxyPool = pool m.mu.Unlock() log.Printf("[task] using proxy pool with %d proxies: %v", len(proxies), names) } else if req.ProxyID != nil && *req.ProxyID > 0 { // Single proxy mode var proxy model.Proxy if err := m.db.First(&proxy, *req.ProxyID).Error; err == nil { proxyID = &proxy.ID proxyName = proxy.Name proxyURL = proxy.ProxyURL() } } // Create task log record now := time.Now() taskLog := &model.TaskLog{ TaskType: "collect", PluginName: req.PluginName, Status: "running", StartedAt: &now, ProxyID: proxyID, ProxyName: proxyName, ProxyMode: req.ProxyMode, } if err := m.db.Create(taskLog).Error; err != nil { return nil, fmt.Errorf("create task log: %w", err) } // Build config for the plugin cfg, err := m.buildPluginConfig(req.PluginName) if err != nil { m.failTask(taskLog, err) return nil, err } // Set proxy in config if pool != nil { cfg["proxy_pool"] = pool // Also set an initial proxy_url for compatibility cfg["proxy_url"] = pool.Next() } else if proxyURL != "" { cfg["proxy_url"] = proxyURL } // Stop conditions if req.MaxMerchants > 0 { cfg["max_merchants"] = req.MaxMerchants } if req.MaxDurationMins > 0 { cfg["max_duration_mins"] = req.MaxDurationMins } cfg["resume_snapshot"] = req.ResumeSnapshot // If targeting a specific group, override seeds config if req.TargetGroup != "" { cfg["seeds"] = []string{req.TargetGroup} cfg["target_group"] = req.TargetGroup cfg["max_depth"] = 0 // don't BFS discover, just scrape this group cfg["max_channels"] = 1 } // Start in goroutine ctx, cancel := context.WithCancel(context.Background()) m.mu.Lock() m.running[taskLog.ID] = cancel m.mu.Unlock() // Default to true if not explicitly set autoClean := req.AutoClean == nil || *req.AutoClean go m.runTask(ctx, taskLog, collector, cfg, autoClean) return taskLog, nil } func (m *Manager) runTask(ctx context.Context, taskLog *model.TaskLog, collector plugin.Collector, cfg map[string]any, autoClean bool) { defer func() { m.mu.Lock() delete(m.running, taskLog.ID) m.proxyPool = nil // clear stale pool reference m.mu.Unlock() }() // Recover from panics to prevent crashing the entire server defer func() { if r := recover(); r != nil { log.Printf("[task] PANIC in task %d: %v", taskLog.ID, r) finishedAt := time.Now() m.db.Model(taskLog).Updates(map[string]any{ "status": "failed", "finished_at": &finishedAt, "detail": fmt.Sprintf("panic: %v", r), }) } }() // Create detail logger for this task dl := NewDetailLogger(m.db, taskLog.ID) defer dl.Close() collector.SetLogger(dl) m.writeLog(ctx, taskLog.ID, fmt.Sprintf("开始采集: %s", collector.Name())) merchantCount := 0 errCount := 0 // Callback: for each merchant found, save to raw table + group-member relationship callback := func(data plugin.MerchantData) { inserted, err := m.store.SaveRaw(data) if err != nil { errCount++ log.Printf("[task] save raw error: %v", err) return } // Record group-member relationship if source is a TG group/channel if data.GroupUsername != "" && data.TgUsername != "" { m.store.SaveGroupMember(data.GroupUsername, data.TgUsername, data.GroupTitle, data.SourceType, taskLog.ID) } if inserted { merchantCount++ if merchantCount%10 == 0 { m.writeProgress(ctx, taskLog.ID, collector.Name(), merchantCount, 0, fmt.Sprintf("已采集 %d 个商户", merchantCount)) } } } // Run the collector runErr := collector.Run(ctx, cfg, callback) // Check if stopped if m.isStopRequested(ctx, taskLog.ID) || ctx.Err() != nil { m.writeLog(ctx, taskLog.ID, "任务已停止") finishedAt := time.Now() m.db.Model(taskLog).Updates(map[string]any{ "status": "stopped", "finished_at": &finishedAt, "merchants_added": merchantCount, "errors_count": errCount, }) return } if runErr != nil { m.failTask(taskLog, runErr) m.writeLog(ctx, taskLog.ID, "采集失败: "+runErr.Error()) m.notify("task_failed", "任务失败", fmt.Sprintf("采集任务 #%d (%s) 失败: %s", taskLog.ID, collector.Name(), runErr.Error())) return } m.writeLog(ctx, taskLog.ID, fmt.Sprintf("采集完成: 新增 %d 个商户", merchantCount)) // Auto-clean: run processor if there are any unprocessed raw records var pendingRaw int64 m.db.Model(&model.MerchantRaw{}).Where("status = ?", "raw").Count(&pendingRaw) if autoClean && pendingRaw > 0 { m.writeLog(ctx, taskLog.ID, "开始清洗流程...") m.writeProgress(ctx, taskLog.ID, "clean", 0, 0, "清洗中...") m.processor.SetProgressFn(func(step string, current, total int, msg string) { m.writeProgress(ctx, taskLog.ID, step, current, total, msg) m.writeLog(ctx, taskLog.ID, fmt.Sprintf("[%s] %d/%d %s", step, current, total, msg)) }) m.processor.SetLogger(dl) procResult, procErr := m.processor.Process(ctx) if procErr != nil { m.writeLog(ctx, taskLog.ID, "清洗失败: "+procErr.Error()) } else { m.writeLog(ctx, taskLog.ID, fmt.Sprintf("清洗完成: Hot=%d, Warm=%d, Cold=%d", procResult.HotCount, procResult.WarmCount, procResult.ColdCount)) } } // Complete finishedAt := time.Now() detail := fmt.Sprintf("采集 %d 个商户, 错误 %d 次", merchantCount, errCount) if pool, ok := cfg["proxy_pool"].(*proxypool.Pool); ok && pool != nil { detail += " | " + pool.Summary() } m.db.Model(taskLog).Updates(map[string]any{ "status": "completed", "finished_at": &finishedAt, "merchants_added": merchantCount, "errors_count": errCount, "detail": detail, }) m.writeProgress(ctx, taskLog.ID, "done", 100, 100, "任务完成") m.writeLog(ctx, taskLog.ID, "任务完成") log.Printf("[task] task %d completed: %s", taskLog.ID, detail) m.notify("task_completed", "任务完成", fmt.Sprintf("采集任务 #%d (%s): %s", taskLog.ID, collector.Name(), detail)) } // StartClean runs the processor independently (not tied to a plugin). func (m *Manager) StartClean() (*model.TaskLog, error) { var count int64 m.db.Model(&model.TaskLog{}). Where("task_type = ? AND status = ?", "clean", "running"). Count(&count) if count > 0 { return nil, fmt.Errorf("clean task is already running") } now := time.Now() taskLog := &model.TaskLog{ TaskType: "clean", PluginName: "", Status: "running", StartedAt: &now, } if err := m.db.Create(taskLog).Error; err != nil { return nil, err } ctx, cancel := context.WithCancel(context.Background()) m.mu.Lock() m.running[taskLog.ID] = cancel m.mu.Unlock() go func() { defer func() { m.mu.Lock() delete(m.running, taskLog.ID) m.mu.Unlock() }() defer func() { if r := recover(); r != nil { log.Printf("[task] PANIC in clean task %d: %v", taskLog.ID, r) finishedAt := time.Now() m.db.Model(taskLog).Updates(map[string]any{ "status": "failed", "finished_at": &finishedAt, "detail": fmt.Sprintf("panic: %v", r), }) } }() m.writeLog(ctx, taskLog.ID, "开始独立清洗任务") m.processor.SetProgressFn(func(step string, current, total int, msg string) { m.writeProgress(ctx, taskLog.ID, step, current, total, msg) }) result, err := m.processor.Process(ctx) finishedAt := time.Now() if err != nil { m.db.Model(taskLog).Updates(map[string]any{ "status": "failed", "finished_at": &finishedAt, "detail": err.Error(), }) m.notify("task_failed", "清洗任务失败", fmt.Sprintf("清洗任务 #%d 失败: %s", taskLog.ID, err.Error())) return } detail := fmt.Sprintf("输入 %d, Hot=%d, Warm=%d, Cold=%d", result.InputCount, result.HotCount, result.WarmCount, result.ColdCount) m.db.Model(taskLog).Updates(map[string]any{ "status": "completed", "finished_at": &finishedAt, "items_processed": result.InputCount, "merchants_added": result.OutputCount, "detail": detail, }) m.writeLog(ctx, taskLog.ID, "清洗完成: "+detail) m.notify("task_completed", "清洗任务完成", fmt.Sprintf("清洗任务 #%d: %s", taskLog.ID, detail)) // Notify for new hot merchants if result.HotCount > 0 { m.notify("new_hot_merchant", "发现优质商户", fmt.Sprintf("清洗任务 #%d 发现 %d 个优质商户", taskLog.ID, result.HotCount)) } }() return taskLog, nil } // StopAll cancels all running tasks (used during graceful shutdown). func (m *Manager) StopAll() { m.mu.Lock() running := make(map[uint]context.CancelFunc, len(m.running)) for id, cancel := range m.running { running[id] = cancel } m.mu.Unlock() for id, cancel := range running { log.Printf("[task] stopping task %d for shutdown", id) cancel() finishedAt := time.Now() m.db.Model(&model.TaskLog{}).Where("id = ? AND status = ?", id, "running"). Updates(map[string]any{ "status": "stopped", "finished_at": &finishedAt, "detail": "服务关闭,任务停止", }) } // Clear proxy pool reference m.mu.Lock() m.proxyPool = nil m.mu.Unlock() } // StopTask cancels a running task. func (m *Manager) StopTask(taskID uint) error { // Set Redis stop signal key := fmt.Sprintf("spider:task:stop:%d", taskID) m.redis.Set(context.Background(), key, "1", time.Hour) // Cancel the goroutine context m.mu.Lock() cancel, ok := m.running[taskID] m.mu.Unlock() if ok { cancel() } // Also try to stop the collector var taskLog model.TaskLog if err := m.db.First(&taskLog, taskID).Error; err == nil && taskLog.PluginName != "" { if collector, err := m.registry.Get(taskLog.PluginName); err == nil { collector.Stop() } } return nil } // GetProgress reads live progress from Redis. func (m *Manager) GetProgress(taskID uint) map[string]any { key := fmt.Sprintf("spider:task:progress:%d", taskID) vals, err := m.redis.HGetAll(context.Background(), key).Result() if err != nil { return nil } result := make(map[string]any) for k, v := range vals { result[k] = v } return result } // GetLogs reads task logs from Redis. func (m *Manager) GetLogs(taskID uint) []string { key := fmt.Sprintf("spider:task:logs:%d", taskID) logs, err := m.redis.LRange(context.Background(), key, 0, -1).Result() if err != nil { return nil } return logs } // buildPluginConfig builds the config map for a plugin from the DB. func (m *Manager) buildPluginConfig(pluginName string) (map[string]any, error) { cfg := make(map[string]any) switch pluginName { case "web_collector": keywords, err := m.store.ListEnabledKeywords() if err != nil { return nil, err } kws := make([]string, 0, len(keywords)) for _, k := range keywords { kws = append(kws, k.Keyword) } cfg["keywords"] = kws case "tg_collector": seeds, err := m.store.ListSeeds() if err != nil { return nil, err } seedNames := make([]string, 0, len(seeds)) for _, s := range seeds { seedNames = append(seedNames, s.Keyword) } cfg["seeds"] = seedNames cfg["max_depth"] = 3 cfg["max_channels"] = 500 cfg["message_limit"] = 500 case "github_collector": keywords, err := m.store.ListEnabledKeywords() if err != nil { return nil, err } kws := make([]string, 0, len(keywords)) for _, k := range keywords { kws = append(kws, k.Keyword) } cfg["keywords"] = kws cfg["repos_limit"] = 50 } return cfg, nil } func (m *Manager) failTask(taskLog *model.TaskLog, err error) { finishedAt := time.Now() m.db.Model(taskLog).Updates(map[string]any{ "status": "failed", "finished_at": &finishedAt, "detail": err.Error(), }) } func (m *Manager) isStopRequested(ctx context.Context, taskID uint) bool { key := fmt.Sprintf("spider:task:stop:%d", taskID) val, _ := m.redis.Get(ctx, key).Result() return val == "1" } // TaskPubSubChannel returns the Redis Pub/Sub channel name for a task. func TaskPubSubChannel(taskID uint) string { return fmt.Sprintf("spider:task:events:%d", taskID) } func (m *Manager) writeLog(ctx context.Context, taskID uint, msg string) { key := fmt.Sprintf("spider:task:logs:%d", taskID) ts := time.Now().Format("15:04:05") line := fmt.Sprintf("[%s] %s", ts, msg) m.redis.RPush(ctx, key, line) m.redis.LTrim(ctx, key, -500, -1) m.redis.Expire(ctx, key, 24*time.Hour) // Publish to Pub/Sub for real-time WebSocket delivery m.redis.Publish(ctx, TaskPubSubChannel(taskID), line) } func (m *Manager) writeProgress(ctx context.Context, taskID uint, phase string, current, total int, message string) { key := fmt.Sprintf("spider:task:progress:%d", taskID) now := time.Now().UTC().Format(time.RFC3339) fields := map[string]any{ "phase": phase, "current": current, "total": total, "message": message, "updated_at": now, } b, _ := json.Marshal(fields) m.redis.Set(ctx, key, string(b), 24*time.Hour) // Publish progress to Pub/Sub m.redis.Publish(ctx, TaskPubSubChannel(taskID), fmt.Sprintf("[进度] %s", message)) } // GetProxyPool returns the current active proxy pool (may be nil). func (m *Manager) GetProxyPool() *proxypool.Pool { m.mu.Lock() defer m.mu.Unlock() return m.proxyPool } // GetRedis returns the Redis client for WebSocket Pub/Sub subscription. func (m *Manager) GetRedis() *redis.Client { return m.redis } // SetNotifier sets the notification manager for event dispatching. func (m *Manager) SetNotifier(n *notification.Manager) { m.notifier = n } func (m *Manager) notify(eventType, title, msg string) { if m.notifier == nil { return } m.notifier.Send(notification.Event{ Type: eventType, Title: title, Message: msg, }) } // ListPlugins returns all registered plugin names. func (m *Manager) ListPlugins() []string { return m.registry.List() }