| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392 |
- package task
- import (
- "context"
- "encoding/json"
- "fmt"
- "log"
- "sync"
- "time"
- "github.com/redis/go-redis/v9"
- "gorm.io/gorm"
- "spider/internal/model"
- "spider/internal/plugin"
- "spider/internal/processor"
- "spider/internal/store"
- )
- // StartRequest is the payload for starting a new task.
- type StartRequest struct {
- PluginName string `json:"plugin_name" binding:"required"`
- AutoClean bool `json:"auto_clean"` // run processor after collection (default true)
- }
- // Manager manages plugin task lifecycle using goroutines.
- // Replaces the asynq-based worker.
- type Manager struct {
- db *gorm.DB
- redis *redis.Client
- registry *plugin.Registry
- store *store.Store
- processor *processor.Processor
- mu sync.Mutex
- running map[uint]context.CancelFunc // taskID -> cancel
- }
- // NewManager creates a new task manager.
- func NewManager(db *gorm.DB, rdb *redis.Client, reg *plugin.Registry, s *store.Store, proc *processor.Processor) *Manager {
- return &Manager{
- db: db,
- redis: rdb,
- registry: reg,
- store: s,
- processor: proc,
- running: make(map[uint]context.CancelFunc),
- }
- }
- // StartTask validates, creates a TaskLog record, and runs the plugin in a goroutine.
- func (m *Manager) StartTask(req StartRequest) (*model.TaskLog, error) {
- // Validate plugin exists
- collector, err := m.registry.Get(req.PluginName)
- if err != nil {
- return nil, fmt.Errorf("unknown plugin: %s", req.PluginName)
- }
- // Check if same plugin is already running
- var count int64
- m.db.Model(&model.TaskLog{}).
- Where("plugin_name = ? AND status = ?", req.PluginName, "running").
- Count(&count)
- if count > 0 {
- return nil, fmt.Errorf("plugin %s is already running", req.PluginName)
- }
- // Create task log record
- now := time.Now()
- taskLog := &model.TaskLog{
- TaskType: "collect",
- PluginName: req.PluginName,
- Status: "running",
- StartedAt: &now,
- }
- if err := m.db.Create(taskLog).Error; err != nil {
- return nil, fmt.Errorf("create task log: %w", err)
- }
- // Build config for the plugin
- cfg, err := m.buildPluginConfig(req.PluginName)
- if err != nil {
- m.failTask(taskLog, err)
- return nil, err
- }
- // Start in goroutine
- ctx, cancel := context.WithCancel(context.Background())
- m.mu.Lock()
- m.running[taskLog.ID] = cancel
- m.mu.Unlock()
- autoClean := req.AutoClean
- // Default to true if not explicitly set
- if !req.AutoClean {
- autoClean = true
- }
- go m.runTask(ctx, taskLog, collector, cfg, autoClean)
- return taskLog, nil
- }
- func (m *Manager) runTask(ctx context.Context, taskLog *model.TaskLog, collector plugin.Collector, cfg map[string]any, autoClean bool) {
- defer func() {
- m.mu.Lock()
- delete(m.running, taskLog.ID)
- m.mu.Unlock()
- }()
- m.writeLog(ctx, taskLog.ID, fmt.Sprintf("开始采集: %s", collector.Name()))
- merchantCount := 0
- errCount := 0
- // Callback: for each merchant found, save to raw table
- callback := func(data plugin.MerchantData) {
- inserted, err := m.store.SaveRaw(data)
- if err != nil {
- errCount++
- log.Printf("[task] save raw error: %v", err)
- return
- }
- if inserted {
- merchantCount++
- if merchantCount%10 == 0 {
- m.writeProgress(ctx, taskLog.ID, collector.Name(), merchantCount, 0,
- fmt.Sprintf("已采集 %d 个商户", merchantCount))
- }
- }
- }
- // Run the collector
- runErr := collector.Run(ctx, cfg, callback)
- // Check if stopped
- if m.isStopRequested(ctx, taskLog.ID) || ctx.Err() != nil {
- m.writeLog(ctx, taskLog.ID, "任务已停止")
- finishedAt := time.Now()
- m.db.Model(taskLog).Updates(map[string]any{
- "status": "stopped",
- "finished_at": &finishedAt,
- "merchants_added": merchantCount,
- "errors_count": errCount,
- })
- return
- }
- if runErr != nil {
- m.failTask(taskLog, runErr)
- m.writeLog(ctx, taskLog.ID, "采集失败: "+runErr.Error())
- return
- }
- m.writeLog(ctx, taskLog.ID, fmt.Sprintf("采集完成: 新增 %d 个商户", merchantCount))
- // Auto-clean: run processor on new raw records
- if autoClean && merchantCount > 0 {
- m.writeLog(ctx, taskLog.ID, "开始清洗流程...")
- m.writeProgress(ctx, taskLog.ID, "clean", 0, 0, "清洗中...")
- m.processor.SetProgressFn(func(step string, current, total int, msg string) {
- m.writeProgress(ctx, taskLog.ID, step, current, total, msg)
- m.writeLog(ctx, taskLog.ID, fmt.Sprintf("[%s] %d/%d %s", step, current, total, msg))
- })
- procResult, procErr := m.processor.Process(ctx)
- if procErr != nil {
- m.writeLog(ctx, taskLog.ID, "清洗失败: "+procErr.Error())
- } else {
- m.writeLog(ctx, taskLog.ID, fmt.Sprintf("清洗完成: Hot=%d, Warm=%d, Cold=%d",
- procResult.HotCount, procResult.WarmCount, procResult.ColdCount))
- }
- }
- // Complete
- finishedAt := time.Now()
- detail := fmt.Sprintf("采集 %d 个商户, 错误 %d 次", merchantCount, errCount)
- m.db.Model(taskLog).Updates(map[string]any{
- "status": "completed",
- "finished_at": &finishedAt,
- "merchants_added": merchantCount,
- "errors_count": errCount,
- "detail": detail,
- })
- m.writeProgress(ctx, taskLog.ID, "done", 100, 100, "任务完成")
- m.writeLog(ctx, taskLog.ID, "任务完成")
- log.Printf("[task] task %d completed: %s", taskLog.ID, detail)
- }
- // StartClean runs the processor independently (not tied to a plugin).
- func (m *Manager) StartClean() (*model.TaskLog, error) {
- var count int64
- m.db.Model(&model.TaskLog{}).
- Where("task_type = ? AND status = ?", "clean", "running").
- Count(&count)
- if count > 0 {
- return nil, fmt.Errorf("clean task is already running")
- }
- now := time.Now()
- taskLog := &model.TaskLog{
- TaskType: "clean",
- PluginName: "",
- Status: "running",
- StartedAt: &now,
- }
- if err := m.db.Create(taskLog).Error; err != nil {
- return nil, err
- }
- ctx, cancel := context.WithCancel(context.Background())
- m.mu.Lock()
- m.running[taskLog.ID] = cancel
- m.mu.Unlock()
- go func() {
- defer func() {
- m.mu.Lock()
- delete(m.running, taskLog.ID)
- m.mu.Unlock()
- }()
- m.writeLog(ctx, taskLog.ID, "开始独立清洗任务")
- m.processor.SetProgressFn(func(step string, current, total int, msg string) {
- m.writeProgress(ctx, taskLog.ID, step, current, total, msg)
- })
- result, err := m.processor.Process(ctx)
- finishedAt := time.Now()
- if err != nil {
- m.db.Model(taskLog).Updates(map[string]any{
- "status": "failed",
- "finished_at": &finishedAt,
- "detail": err.Error(),
- })
- return
- }
- detail := fmt.Sprintf("输入 %d, Hot=%d, Warm=%d, Cold=%d",
- result.InputCount, result.HotCount, result.WarmCount, result.ColdCount)
- m.db.Model(taskLog).Updates(map[string]any{
- "status": "completed",
- "finished_at": &finishedAt,
- "items_processed": result.InputCount,
- "merchants_added": result.OutputCount,
- "detail": detail,
- })
- m.writeLog(ctx, taskLog.ID, "清洗完成: "+detail)
- }()
- return taskLog, nil
- }
- // StopTask cancels a running task.
- func (m *Manager) StopTask(taskID uint) error {
- // Set Redis stop signal
- key := fmt.Sprintf("spider:task:stop:%d", taskID)
- m.redis.Set(context.Background(), key, "1", time.Hour)
- // Cancel the goroutine context
- m.mu.Lock()
- cancel, ok := m.running[taskID]
- m.mu.Unlock()
- if ok {
- cancel()
- }
- // Also try to stop the collector
- var taskLog model.TaskLog
- if err := m.db.First(&taskLog, taskID).Error; err == nil && taskLog.PluginName != "" {
- if collector, err := m.registry.Get(taskLog.PluginName); err == nil {
- collector.Stop()
- }
- }
- return nil
- }
- // GetProgress reads live progress from Redis.
- func (m *Manager) GetProgress(taskID uint) map[string]any {
- key := fmt.Sprintf("spider:task:progress:%d", taskID)
- vals, err := m.redis.HGetAll(context.Background(), key).Result()
- if err != nil {
- return nil
- }
- result := make(map[string]any)
- for k, v := range vals {
- result[k] = v
- }
- return result
- }
- // GetLogs reads task logs from Redis.
- func (m *Manager) GetLogs(taskID uint) []string {
- key := fmt.Sprintf("spider:task:logs:%d", taskID)
- logs, err := m.redis.LRange(context.Background(), key, 0, -1).Result()
- if err != nil {
- return nil
- }
- return logs
- }
- // buildPluginConfig builds the config map for a plugin from the DB.
- func (m *Manager) buildPluginConfig(pluginName string) (map[string]any, error) {
- cfg := make(map[string]any)
- switch pluginName {
- case "web_collector":
- keywords, err := m.store.ListEnabledKeywords()
- if err != nil {
- return nil, err
- }
- kws := make([]string, 0, len(keywords))
- for _, k := range keywords {
- kws = append(kws, k.Keyword)
- }
- cfg["keywords"] = kws
- case "tg_collector":
- seeds, err := m.store.ListSeeds()
- if err != nil {
- return nil, err
- }
- seedNames := make([]string, 0, len(seeds))
- for _, s := range seeds {
- seedNames = append(seedNames, s.Keyword)
- }
- cfg["seeds"] = seedNames
- cfg["max_depth"] = 3
- cfg["max_channels"] = 500
- cfg["message_limit"] = 500
- case "github_collector":
- keywords, err := m.store.ListEnabledKeywords()
- if err != nil {
- return nil, err
- }
- kws := make([]string, 0, len(keywords))
- for _, k := range keywords {
- kws = append(kws, k.Keyword)
- }
- cfg["keywords"] = kws
- cfg["repos_limit"] = 50
- }
- return cfg, nil
- }
- func (m *Manager) failTask(taskLog *model.TaskLog, err error) {
- finishedAt := time.Now()
- m.db.Model(taskLog).Updates(map[string]any{
- "status": "failed",
- "finished_at": &finishedAt,
- "detail": err.Error(),
- })
- }
- func (m *Manager) isStopRequested(ctx context.Context, taskID uint) bool {
- key := fmt.Sprintf("spider:task:stop:%d", taskID)
- val, _ := m.redis.Get(ctx, key).Result()
- return val == "1"
- }
- func (m *Manager) writeLog(ctx context.Context, taskID uint, msg string) {
- key := fmt.Sprintf("spider:task:logs:%d", taskID)
- ts := time.Now().Format("15:04:05")
- line := fmt.Sprintf("[%s] %s", ts, msg)
- m.redis.RPush(ctx, key, line)
- m.redis.LTrim(ctx, key, -500, -1)
- m.redis.Expire(ctx, key, 24*time.Hour)
- }
- func (m *Manager) writeProgress(ctx context.Context, taskID uint, phase string, current, total int, message string) {
- key := fmt.Sprintf("spider:task:progress:%d", taskID)
- now := time.Now().UTC().Format(time.RFC3339)
- fields := map[string]any{
- "phase": phase,
- "current": current,
- "total": total,
- "message": message,
- "updated_at": now,
- }
- b, _ := json.Marshal(fields)
- m.redis.Set(ctx, key, string(b), 24*time.Hour)
- }
|