| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569 |
- package task
- import (
- "context"
- "encoding/json"
- "fmt"
- "log"
- "sync"
- "time"
- "github.com/redis/go-redis/v9"
- "gorm.io/gorm"
- "spider/internal/model"
- "spider/internal/notification"
- "spider/internal/plugin"
- "spider/internal/processor"
- proxypool "spider/internal/proxy"
- "spider/internal/store"
- )
- // StartRequest is the payload for starting a new task.
- type StartRequest struct {
- PluginName string `json:"plugin_name" binding:"required"`
- AutoClean *bool `json:"auto_clean"` // run processor after collection (default true)
- TargetGroup string `json:"target_group,omitempty"` // target a specific TG group/channel for collection
- ProxyID *uint `json:"proxy_id,omitempty"` // optional single proxy for this task
- ProxyMode string `json:"proxy_mode,omitempty"` // "single" (default) or "pool"
- }
- // Manager manages plugin task lifecycle using goroutines.
- // Replaces the asynq-based worker.
- type Manager struct {
- db *gorm.DB
- redis *redis.Client
- registry *plugin.Registry
- store *store.Store
- processor *processor.Processor
- mu sync.Mutex
- running map[uint]context.CancelFunc // taskID -> cancel
- notifier *notification.Manager
- proxyPool *proxypool.Pool // current active proxy pool (nil when not using pool mode)
- }
- // NewManager creates a new task manager.
- func NewManager(db *gorm.DB, rdb *redis.Client, reg *plugin.Registry, s *store.Store, proc *processor.Processor) *Manager {
- return &Manager{
- db: db,
- redis: rdb,
- registry: reg,
- store: s,
- processor: proc,
- running: make(map[uint]context.CancelFunc),
- }
- }
- // StartTask validates, creates a TaskLog record, and runs the plugin in a goroutine.
- func (m *Manager) StartTask(req StartRequest) (*model.TaskLog, error) {
- // Validate plugin exists
- collector, err := m.registry.Get(req.PluginName)
- if err != nil {
- return nil, fmt.Errorf("unknown plugin: %s", req.PluginName)
- }
- // Check if same plugin is already running
- var count int64
- m.db.Model(&model.TaskLog{}).
- Where("plugin_name = ? AND status = ?", req.PluginName, "running").
- Count(&count)
- if count > 0 {
- return nil, fmt.Errorf("plugin %s is already running", req.PluginName)
- }
- // Resolve proxy configuration
- var proxyID *uint
- var proxyName string
- var proxyURL string
- var pool *proxypool.Pool
- if req.ProxyMode == "pool" {
- // Pool mode: load all enabled proxies
- var proxies []model.Proxy
- m.db.Where("enabled = ?", true).Find(&proxies)
- if len(proxies) == 0 {
- return nil, fmt.Errorf("代理池模式但没有可用的代理")
- }
- pool = proxypool.NewPool(3, 2*time.Minute)
- names := make([]string, 0, len(proxies))
- for _, p := range proxies {
- pool.Add(p.ID, p.Name, p.ProxyURL(), p.Region)
- names = append(names, p.Name)
- }
- proxyName = fmt.Sprintf("代理池(%d个)", len(proxies))
- m.mu.Lock()
- m.proxyPool = pool
- m.mu.Unlock()
- log.Printf("[task] using proxy pool with %d proxies: %v", len(proxies), names)
- } else if req.ProxyID != nil && *req.ProxyID > 0 {
- // Single proxy mode
- var proxy model.Proxy
- if err := m.db.First(&proxy, *req.ProxyID).Error; err == nil {
- proxyID = &proxy.ID
- proxyName = proxy.Name
- proxyURL = proxy.ProxyURL()
- }
- }
- // Create task log record
- now := time.Now()
- taskLog := &model.TaskLog{
- TaskType: "collect",
- PluginName: req.PluginName,
- Status: "running",
- StartedAt: &now,
- ProxyID: proxyID,
- ProxyName: proxyName,
- ProxyMode: req.ProxyMode,
- }
- if err := m.db.Create(taskLog).Error; err != nil {
- return nil, fmt.Errorf("create task log: %w", err)
- }
- // Build config for the plugin
- cfg, err := m.buildPluginConfig(req.PluginName)
- if err != nil {
- m.failTask(taskLog, err)
- return nil, err
- }
- // Set proxy in config
- if pool != nil {
- cfg["proxy_pool"] = pool
- // Also set an initial proxy_url for compatibility
- cfg["proxy_url"] = pool.Next()
- } else if proxyURL != "" {
- cfg["proxy_url"] = proxyURL
- }
- // If targeting a specific group, override seeds config
- if req.TargetGroup != "" {
- cfg["seeds"] = []string{req.TargetGroup}
- cfg["target_group"] = req.TargetGroup
- cfg["max_depth"] = 0 // don't BFS discover, just scrape this group
- cfg["max_channels"] = 1
- }
- // Start in goroutine
- ctx, cancel := context.WithCancel(context.Background())
- m.mu.Lock()
- m.running[taskLog.ID] = cancel
- m.mu.Unlock()
- // Default to true if not explicitly set
- autoClean := req.AutoClean == nil || *req.AutoClean
- go m.runTask(ctx, taskLog, collector, cfg, autoClean)
- return taskLog, nil
- }
- func (m *Manager) runTask(ctx context.Context, taskLog *model.TaskLog, collector plugin.Collector, cfg map[string]any, autoClean bool) {
- defer func() {
- m.mu.Lock()
- delete(m.running, taskLog.ID)
- m.proxyPool = nil // clear stale pool reference
- m.mu.Unlock()
- }()
- // Recover from panics to prevent crashing the entire server
- defer func() {
- if r := recover(); r != nil {
- log.Printf("[task] PANIC in task %d: %v", taskLog.ID, r)
- finishedAt := time.Now()
- m.db.Model(taskLog).Updates(map[string]any{
- "status": "failed",
- "finished_at": &finishedAt,
- "detail": fmt.Sprintf("panic: %v", r),
- })
- }
- }()
- // Create detail logger for this task
- dl := NewDetailLogger(m.db, taskLog.ID)
- defer dl.Close()
- collector.SetLogger(dl)
- m.writeLog(ctx, taskLog.ID, fmt.Sprintf("开始采集: %s", collector.Name()))
- merchantCount := 0
- errCount := 0
- // Callback: for each merchant found, save to raw table + group-member relationship
- callback := func(data plugin.MerchantData) {
- inserted, err := m.store.SaveRaw(data)
- if err != nil {
- errCount++
- log.Printf("[task] save raw error: %v", err)
- return
- }
- // Record group-member relationship if source is a TG group/channel
- if data.GroupUsername != "" && data.TgUsername != "" {
- m.store.SaveGroupMember(data.GroupUsername, data.TgUsername, data.GroupTitle, data.SourceType, taskLog.ID)
- }
- if inserted {
- merchantCount++
- if merchantCount%10 == 0 {
- m.writeProgress(ctx, taskLog.ID, collector.Name(), merchantCount, 0,
- fmt.Sprintf("已采集 %d 个商户", merchantCount))
- }
- }
- }
- // Run the collector
- runErr := collector.Run(ctx, cfg, callback)
- // Check if stopped
- if m.isStopRequested(ctx, taskLog.ID) || ctx.Err() != nil {
- m.writeLog(ctx, taskLog.ID, "任务已停止")
- finishedAt := time.Now()
- m.db.Model(taskLog).Updates(map[string]any{
- "status": "stopped",
- "finished_at": &finishedAt,
- "merchants_added": merchantCount,
- "errors_count": errCount,
- })
- return
- }
- if runErr != nil {
- m.failTask(taskLog, runErr)
- m.writeLog(ctx, taskLog.ID, "采集失败: "+runErr.Error())
- m.notify("task_failed", "任务失败", fmt.Sprintf("采集任务 #%d (%s) 失败: %s", taskLog.ID, collector.Name(), runErr.Error()))
- return
- }
- m.writeLog(ctx, taskLog.ID, fmt.Sprintf("采集完成: 新增 %d 个商户", merchantCount))
- // Auto-clean: run processor on new raw records
- if autoClean && merchantCount > 0 {
- m.writeLog(ctx, taskLog.ID, "开始清洗流程...")
- m.writeProgress(ctx, taskLog.ID, "clean", 0, 0, "清洗中...")
- m.processor.SetProgressFn(func(step string, current, total int, msg string) {
- m.writeProgress(ctx, taskLog.ID, step, current, total, msg)
- m.writeLog(ctx, taskLog.ID, fmt.Sprintf("[%s] %d/%d %s", step, current, total, msg))
- })
- m.processor.SetLogger(dl)
- procResult, procErr := m.processor.Process(ctx)
- if procErr != nil {
- m.writeLog(ctx, taskLog.ID, "清洗失败: "+procErr.Error())
- } else {
- m.writeLog(ctx, taskLog.ID, fmt.Sprintf("清洗完成: Hot=%d, Warm=%d, Cold=%d",
- procResult.HotCount, procResult.WarmCount, procResult.ColdCount))
- }
- }
- // Complete
- finishedAt := time.Now()
- detail := fmt.Sprintf("采集 %d 个商户, 错误 %d 次", merchantCount, errCount)
- if pool, ok := cfg["proxy_pool"].(*proxypool.Pool); ok && pool != nil {
- detail += " | " + pool.Summary()
- }
- m.db.Model(taskLog).Updates(map[string]any{
- "status": "completed",
- "finished_at": &finishedAt,
- "merchants_added": merchantCount,
- "errors_count": errCount,
- "detail": detail,
- })
- m.writeProgress(ctx, taskLog.ID, "done", 100, 100, "任务完成")
- m.writeLog(ctx, taskLog.ID, "任务完成")
- log.Printf("[task] task %d completed: %s", taskLog.ID, detail)
- m.notify("task_completed", "任务完成", fmt.Sprintf("采集任务 #%d (%s): %s", taskLog.ID, collector.Name(), detail))
- }
- // StartClean runs the processor independently (not tied to a plugin).
- func (m *Manager) StartClean() (*model.TaskLog, error) {
- var count int64
- m.db.Model(&model.TaskLog{}).
- Where("task_type = ? AND status = ?", "clean", "running").
- Count(&count)
- if count > 0 {
- return nil, fmt.Errorf("clean task is already running")
- }
- now := time.Now()
- taskLog := &model.TaskLog{
- TaskType: "clean",
- PluginName: "",
- Status: "running",
- StartedAt: &now,
- }
- if err := m.db.Create(taskLog).Error; err != nil {
- return nil, err
- }
- ctx, cancel := context.WithCancel(context.Background())
- m.mu.Lock()
- m.running[taskLog.ID] = cancel
- m.mu.Unlock()
- go func() {
- defer func() {
- m.mu.Lock()
- delete(m.running, taskLog.ID)
- m.mu.Unlock()
- }()
- defer func() {
- if r := recover(); r != nil {
- log.Printf("[task] PANIC in clean task %d: %v", taskLog.ID, r)
- finishedAt := time.Now()
- m.db.Model(taskLog).Updates(map[string]any{
- "status": "failed",
- "finished_at": &finishedAt,
- "detail": fmt.Sprintf("panic: %v", r),
- })
- }
- }()
- m.writeLog(ctx, taskLog.ID, "开始独立清洗任务")
- m.processor.SetProgressFn(func(step string, current, total int, msg string) {
- m.writeProgress(ctx, taskLog.ID, step, current, total, msg)
- })
- result, err := m.processor.Process(ctx)
- finishedAt := time.Now()
- if err != nil {
- m.db.Model(taskLog).Updates(map[string]any{
- "status": "failed",
- "finished_at": &finishedAt,
- "detail": err.Error(),
- })
- m.notify("task_failed", "清洗任务失败", fmt.Sprintf("清洗任务 #%d 失败: %s", taskLog.ID, err.Error()))
- return
- }
- detail := fmt.Sprintf("输入 %d, Hot=%d, Warm=%d, Cold=%d",
- result.InputCount, result.HotCount, result.WarmCount, result.ColdCount)
- m.db.Model(taskLog).Updates(map[string]any{
- "status": "completed",
- "finished_at": &finishedAt,
- "items_processed": result.InputCount,
- "merchants_added": result.OutputCount,
- "detail": detail,
- })
- m.writeLog(ctx, taskLog.ID, "清洗完成: "+detail)
- m.notify("task_completed", "清洗任务完成", fmt.Sprintf("清洗任务 #%d: %s", taskLog.ID, detail))
- // Notify for new hot merchants
- if result.HotCount > 0 {
- m.notify("new_hot_merchant", "发现优质商户", fmt.Sprintf("清洗任务 #%d 发现 %d 个优质商户", taskLog.ID, result.HotCount))
- }
- }()
- return taskLog, nil
- }
- // StopAll cancels all running tasks (used during graceful shutdown).
- func (m *Manager) StopAll() {
- m.mu.Lock()
- running := make(map[uint]context.CancelFunc, len(m.running))
- for id, cancel := range m.running {
- running[id] = cancel
- }
- m.mu.Unlock()
- for id, cancel := range running {
- log.Printf("[task] stopping task %d for shutdown", id)
- cancel()
- finishedAt := time.Now()
- m.db.Model(&model.TaskLog{}).Where("id = ? AND status = ?", id, "running").
- Updates(map[string]any{
- "status": "stopped",
- "finished_at": &finishedAt,
- "detail": "服务关闭,任务停止",
- })
- }
- // Clear proxy pool reference
- m.mu.Lock()
- m.proxyPool = nil
- m.mu.Unlock()
- }
- // StopTask cancels a running task.
- func (m *Manager) StopTask(taskID uint) error {
- // Set Redis stop signal
- key := fmt.Sprintf("spider:task:stop:%d", taskID)
- m.redis.Set(context.Background(), key, "1", time.Hour)
- // Cancel the goroutine context
- m.mu.Lock()
- cancel, ok := m.running[taskID]
- m.mu.Unlock()
- if ok {
- cancel()
- }
- // Also try to stop the collector
- var taskLog model.TaskLog
- if err := m.db.First(&taskLog, taskID).Error; err == nil && taskLog.PluginName != "" {
- if collector, err := m.registry.Get(taskLog.PluginName); err == nil {
- collector.Stop()
- }
- }
- return nil
- }
- // GetProgress reads live progress from Redis.
- func (m *Manager) GetProgress(taskID uint) map[string]any {
- key := fmt.Sprintf("spider:task:progress:%d", taskID)
- vals, err := m.redis.HGetAll(context.Background(), key).Result()
- if err != nil {
- return nil
- }
- result := make(map[string]any)
- for k, v := range vals {
- result[k] = v
- }
- return result
- }
- // GetLogs reads task logs from Redis.
- func (m *Manager) GetLogs(taskID uint) []string {
- key := fmt.Sprintf("spider:task:logs:%d", taskID)
- logs, err := m.redis.LRange(context.Background(), key, 0, -1).Result()
- if err != nil {
- return nil
- }
- return logs
- }
- // buildPluginConfig builds the config map for a plugin from the DB.
- func (m *Manager) buildPluginConfig(pluginName string) (map[string]any, error) {
- cfg := make(map[string]any)
- switch pluginName {
- case "web_collector":
- keywords, err := m.store.ListEnabledKeywords()
- if err != nil {
- return nil, err
- }
- kws := make([]string, 0, len(keywords))
- for _, k := range keywords {
- kws = append(kws, k.Keyword)
- }
- cfg["keywords"] = kws
- case "tg_collector":
- seeds, err := m.store.ListSeeds()
- if err != nil {
- return nil, err
- }
- seedNames := make([]string, 0, len(seeds))
- for _, s := range seeds {
- seedNames = append(seedNames, s.Keyword)
- }
- cfg["seeds"] = seedNames
- cfg["max_depth"] = 3
- cfg["max_channels"] = 500
- cfg["message_limit"] = 500
- case "github_collector":
- keywords, err := m.store.ListEnabledKeywords()
- if err != nil {
- return nil, err
- }
- kws := make([]string, 0, len(keywords))
- for _, k := range keywords {
- kws = append(kws, k.Keyword)
- }
- cfg["keywords"] = kws
- cfg["repos_limit"] = 50
- }
- return cfg, nil
- }
- func (m *Manager) failTask(taskLog *model.TaskLog, err error) {
- finishedAt := time.Now()
- m.db.Model(taskLog).Updates(map[string]any{
- "status": "failed",
- "finished_at": &finishedAt,
- "detail": err.Error(),
- })
- }
- func (m *Manager) isStopRequested(ctx context.Context, taskID uint) bool {
- key := fmt.Sprintf("spider:task:stop:%d", taskID)
- val, _ := m.redis.Get(ctx, key).Result()
- return val == "1"
- }
- // TaskPubSubChannel returns the Redis Pub/Sub channel name for a task.
- func TaskPubSubChannel(taskID uint) string {
- return fmt.Sprintf("spider:task:events:%d", taskID)
- }
- func (m *Manager) writeLog(ctx context.Context, taskID uint, msg string) {
- key := fmt.Sprintf("spider:task:logs:%d", taskID)
- ts := time.Now().Format("15:04:05")
- line := fmt.Sprintf("[%s] %s", ts, msg)
- m.redis.RPush(ctx, key, line)
- m.redis.LTrim(ctx, key, -500, -1)
- m.redis.Expire(ctx, key, 24*time.Hour)
- // Publish to Pub/Sub for real-time WebSocket delivery
- m.redis.Publish(ctx, TaskPubSubChannel(taskID), line)
- }
- func (m *Manager) writeProgress(ctx context.Context, taskID uint, phase string, current, total int, message string) {
- key := fmt.Sprintf("spider:task:progress:%d", taskID)
- now := time.Now().UTC().Format(time.RFC3339)
- fields := map[string]any{
- "phase": phase,
- "current": current,
- "total": total,
- "message": message,
- "updated_at": now,
- }
- b, _ := json.Marshal(fields)
- m.redis.Set(ctx, key, string(b), 24*time.Hour)
- // Publish progress to Pub/Sub
- m.redis.Publish(ctx, TaskPubSubChannel(taskID), fmt.Sprintf("[进度] %s", message))
- }
- // GetProxyPool returns the current active proxy pool (may be nil).
- func (m *Manager) GetProxyPool() *proxypool.Pool {
- m.mu.Lock()
- defer m.mu.Unlock()
- return m.proxyPool
- }
- // GetRedis returns the Redis client for WebSocket Pub/Sub subscription.
- func (m *Manager) GetRedis() *redis.Client {
- return m.redis
- }
- // SetNotifier sets the notification manager for event dispatching.
- func (m *Manager) SetNotifier(n *notification.Manager) {
- m.notifier = n
- }
- func (m *Manager) notify(eventType, title, msg string) {
- if m.notifier == nil {
- return
- }
- m.notifier.Send(notification.Event{
- Type: eventType,
- Title: title,
- Message: msg,
- })
- }
- // ListPlugins returns all registered plugin names.
- func (m *Manager) ListPlugins() []string {
- return m.registry.List()
- }
|