manager.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. package task
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "log"
  7. "sync"
  8. "time"
  9. "github.com/redis/go-redis/v9"
  10. "gorm.io/gorm"
  11. "spider/internal/model"
  12. "spider/internal/plugin"
  13. "spider/internal/processor"
  14. "spider/internal/store"
  15. )
  16. // StartRequest is the payload for starting a new task.
  17. type StartRequest struct {
  18. PluginName string `json:"plugin_name" binding:"required"`
  19. AutoClean bool `json:"auto_clean"` // run processor after collection (default true)
  20. }
  21. // Manager manages plugin task lifecycle using goroutines.
  22. // Replaces the asynq-based worker.
  23. type Manager struct {
  24. db *gorm.DB
  25. redis *redis.Client
  26. registry *plugin.Registry
  27. store *store.Store
  28. processor *processor.Processor
  29. mu sync.Mutex
  30. running map[uint]context.CancelFunc // taskID -> cancel
  31. }
  32. // NewManager creates a new task manager.
  33. func NewManager(db *gorm.DB, rdb *redis.Client, reg *plugin.Registry, s *store.Store, proc *processor.Processor) *Manager {
  34. return &Manager{
  35. db: db,
  36. redis: rdb,
  37. registry: reg,
  38. store: s,
  39. processor: proc,
  40. running: make(map[uint]context.CancelFunc),
  41. }
  42. }
  43. // StartTask validates, creates a TaskLog record, and runs the plugin in a goroutine.
  44. func (m *Manager) StartTask(req StartRequest) (*model.TaskLog, error) {
  45. // Validate plugin exists
  46. collector, err := m.registry.Get(req.PluginName)
  47. if err != nil {
  48. return nil, fmt.Errorf("unknown plugin: %s", req.PluginName)
  49. }
  50. // Check if same plugin is already running
  51. var count int64
  52. m.db.Model(&model.TaskLog{}).
  53. Where("plugin_name = ? AND status = ?", req.PluginName, "running").
  54. Count(&count)
  55. if count > 0 {
  56. return nil, fmt.Errorf("plugin %s is already running", req.PluginName)
  57. }
  58. // Create task log record
  59. now := time.Now()
  60. taskLog := &model.TaskLog{
  61. TaskType: "collect",
  62. PluginName: req.PluginName,
  63. Status: "running",
  64. StartedAt: &now,
  65. }
  66. if err := m.db.Create(taskLog).Error; err != nil {
  67. return nil, fmt.Errorf("create task log: %w", err)
  68. }
  69. // Build config for the plugin
  70. cfg, err := m.buildPluginConfig(req.PluginName)
  71. if err != nil {
  72. m.failTask(taskLog, err)
  73. return nil, err
  74. }
  75. // Start in goroutine
  76. ctx, cancel := context.WithCancel(context.Background())
  77. m.mu.Lock()
  78. m.running[taskLog.ID] = cancel
  79. m.mu.Unlock()
  80. autoClean := req.AutoClean
  81. // Default to true if not explicitly set
  82. if !req.AutoClean {
  83. autoClean = true
  84. }
  85. go m.runTask(ctx, taskLog, collector, cfg, autoClean)
  86. return taskLog, nil
  87. }
  88. func (m *Manager) runTask(ctx context.Context, taskLog *model.TaskLog, collector plugin.Collector, cfg map[string]any, autoClean bool) {
  89. defer func() {
  90. m.mu.Lock()
  91. delete(m.running, taskLog.ID)
  92. m.mu.Unlock()
  93. }()
  94. m.writeLog(ctx, taskLog.ID, fmt.Sprintf("开始采集: %s", collector.Name()))
  95. merchantCount := 0
  96. errCount := 0
  97. // Callback: for each merchant found, save to raw table
  98. callback := func(data plugin.MerchantData) {
  99. inserted, err := m.store.SaveRaw(data)
  100. if err != nil {
  101. errCount++
  102. log.Printf("[task] save raw error: %v", err)
  103. return
  104. }
  105. if inserted {
  106. merchantCount++
  107. if merchantCount%10 == 0 {
  108. m.writeProgress(ctx, taskLog.ID, collector.Name(), merchantCount, 0,
  109. fmt.Sprintf("已采集 %d 个商户", merchantCount))
  110. }
  111. }
  112. }
  113. // Run the collector
  114. runErr := collector.Run(ctx, cfg, callback)
  115. // Check if stopped
  116. if m.isStopRequested(ctx, taskLog.ID) || ctx.Err() != nil {
  117. m.writeLog(ctx, taskLog.ID, "任务已停止")
  118. finishedAt := time.Now()
  119. m.db.Model(taskLog).Updates(map[string]any{
  120. "status": "stopped",
  121. "finished_at": &finishedAt,
  122. "merchants_added": merchantCount,
  123. "errors_count": errCount,
  124. })
  125. return
  126. }
  127. if runErr != nil {
  128. m.failTask(taskLog, runErr)
  129. m.writeLog(ctx, taskLog.ID, "采集失败: "+runErr.Error())
  130. return
  131. }
  132. m.writeLog(ctx, taskLog.ID, fmt.Sprintf("采集完成: 新增 %d 个商户", merchantCount))
  133. // Auto-clean: run processor on new raw records
  134. if autoClean && merchantCount > 0 {
  135. m.writeLog(ctx, taskLog.ID, "开始清洗流程...")
  136. m.writeProgress(ctx, taskLog.ID, "clean", 0, 0, "清洗中...")
  137. m.processor.SetProgressFn(func(step string, current, total int, msg string) {
  138. m.writeProgress(ctx, taskLog.ID, step, current, total, msg)
  139. m.writeLog(ctx, taskLog.ID, fmt.Sprintf("[%s] %d/%d %s", step, current, total, msg))
  140. })
  141. procResult, procErr := m.processor.Process(ctx)
  142. if procErr != nil {
  143. m.writeLog(ctx, taskLog.ID, "清洗失败: "+procErr.Error())
  144. } else {
  145. m.writeLog(ctx, taskLog.ID, fmt.Sprintf("清洗完成: Hot=%d, Warm=%d, Cold=%d",
  146. procResult.HotCount, procResult.WarmCount, procResult.ColdCount))
  147. }
  148. }
  149. // Complete
  150. finishedAt := time.Now()
  151. detail := fmt.Sprintf("采集 %d 个商户, 错误 %d 次", merchantCount, errCount)
  152. m.db.Model(taskLog).Updates(map[string]any{
  153. "status": "completed",
  154. "finished_at": &finishedAt,
  155. "merchants_added": merchantCount,
  156. "errors_count": errCount,
  157. "detail": detail,
  158. })
  159. m.writeProgress(ctx, taskLog.ID, "done", 100, 100, "任务完成")
  160. m.writeLog(ctx, taskLog.ID, "任务完成")
  161. log.Printf("[task] task %d completed: %s", taskLog.ID, detail)
  162. }
  163. // StartClean runs the processor independently (not tied to a plugin).
  164. func (m *Manager) StartClean() (*model.TaskLog, error) {
  165. var count int64
  166. m.db.Model(&model.TaskLog{}).
  167. Where("task_type = ? AND status = ?", "clean", "running").
  168. Count(&count)
  169. if count > 0 {
  170. return nil, fmt.Errorf("clean task is already running")
  171. }
  172. now := time.Now()
  173. taskLog := &model.TaskLog{
  174. TaskType: "clean",
  175. PluginName: "",
  176. Status: "running",
  177. StartedAt: &now,
  178. }
  179. if err := m.db.Create(taskLog).Error; err != nil {
  180. return nil, err
  181. }
  182. ctx, cancel := context.WithCancel(context.Background())
  183. m.mu.Lock()
  184. m.running[taskLog.ID] = cancel
  185. m.mu.Unlock()
  186. go func() {
  187. defer func() {
  188. m.mu.Lock()
  189. delete(m.running, taskLog.ID)
  190. m.mu.Unlock()
  191. }()
  192. m.writeLog(ctx, taskLog.ID, "开始独立清洗任务")
  193. m.processor.SetProgressFn(func(step string, current, total int, msg string) {
  194. m.writeProgress(ctx, taskLog.ID, step, current, total, msg)
  195. })
  196. result, err := m.processor.Process(ctx)
  197. finishedAt := time.Now()
  198. if err != nil {
  199. m.db.Model(taskLog).Updates(map[string]any{
  200. "status": "failed",
  201. "finished_at": &finishedAt,
  202. "detail": err.Error(),
  203. })
  204. return
  205. }
  206. detail := fmt.Sprintf("输入 %d, Hot=%d, Warm=%d, Cold=%d",
  207. result.InputCount, result.HotCount, result.WarmCount, result.ColdCount)
  208. m.db.Model(taskLog).Updates(map[string]any{
  209. "status": "completed",
  210. "finished_at": &finishedAt,
  211. "items_processed": result.InputCount,
  212. "merchants_added": result.OutputCount,
  213. "detail": detail,
  214. })
  215. m.writeLog(ctx, taskLog.ID, "清洗完成: "+detail)
  216. }()
  217. return taskLog, nil
  218. }
  219. // StopTask cancels a running task.
  220. func (m *Manager) StopTask(taskID uint) error {
  221. // Set Redis stop signal
  222. key := fmt.Sprintf("spider:task:stop:%d", taskID)
  223. m.redis.Set(context.Background(), key, "1", time.Hour)
  224. // Cancel the goroutine context
  225. m.mu.Lock()
  226. cancel, ok := m.running[taskID]
  227. m.mu.Unlock()
  228. if ok {
  229. cancel()
  230. }
  231. // Also try to stop the collector
  232. var taskLog model.TaskLog
  233. if err := m.db.First(&taskLog, taskID).Error; err == nil && taskLog.PluginName != "" {
  234. if collector, err := m.registry.Get(taskLog.PluginName); err == nil {
  235. collector.Stop()
  236. }
  237. }
  238. return nil
  239. }
  240. // GetProgress reads live progress from Redis.
  241. func (m *Manager) GetProgress(taskID uint) map[string]any {
  242. key := fmt.Sprintf("spider:task:progress:%d", taskID)
  243. vals, err := m.redis.HGetAll(context.Background(), key).Result()
  244. if err != nil {
  245. return nil
  246. }
  247. result := make(map[string]any)
  248. for k, v := range vals {
  249. result[k] = v
  250. }
  251. return result
  252. }
  253. // GetLogs reads task logs from Redis.
  254. func (m *Manager) GetLogs(taskID uint) []string {
  255. key := fmt.Sprintf("spider:task:logs:%d", taskID)
  256. logs, err := m.redis.LRange(context.Background(), key, 0, -1).Result()
  257. if err != nil {
  258. return nil
  259. }
  260. return logs
  261. }
  262. // buildPluginConfig builds the config map for a plugin from the DB.
  263. func (m *Manager) buildPluginConfig(pluginName string) (map[string]any, error) {
  264. cfg := make(map[string]any)
  265. switch pluginName {
  266. case "web_collector":
  267. keywords, err := m.store.ListEnabledKeywords()
  268. if err != nil {
  269. return nil, err
  270. }
  271. kws := make([]string, 0, len(keywords))
  272. for _, k := range keywords {
  273. kws = append(kws, k.Keyword)
  274. }
  275. cfg["keywords"] = kws
  276. case "tg_collector":
  277. seeds, err := m.store.ListSeeds()
  278. if err != nil {
  279. return nil, err
  280. }
  281. seedNames := make([]string, 0, len(seeds))
  282. for _, s := range seeds {
  283. seedNames = append(seedNames, s.Keyword)
  284. }
  285. cfg["seeds"] = seedNames
  286. cfg["max_depth"] = 3
  287. cfg["max_channels"] = 500
  288. cfg["message_limit"] = 500
  289. case "github_collector":
  290. keywords, err := m.store.ListEnabledKeywords()
  291. if err != nil {
  292. return nil, err
  293. }
  294. kws := make([]string, 0, len(keywords))
  295. for _, k := range keywords {
  296. kws = append(kws, k.Keyword)
  297. }
  298. cfg["keywords"] = kws
  299. cfg["repos_limit"] = 50
  300. }
  301. return cfg, nil
  302. }
  303. func (m *Manager) failTask(taskLog *model.TaskLog, err error) {
  304. finishedAt := time.Now()
  305. m.db.Model(taskLog).Updates(map[string]any{
  306. "status": "failed",
  307. "finished_at": &finishedAt,
  308. "detail": err.Error(),
  309. })
  310. }
  311. func (m *Manager) isStopRequested(ctx context.Context, taskID uint) bool {
  312. key := fmt.Sprintf("spider:task:stop:%d", taskID)
  313. val, _ := m.redis.Get(ctx, key).Result()
  314. return val == "1"
  315. }
  316. func (m *Manager) writeLog(ctx context.Context, taskID uint, msg string) {
  317. key := fmt.Sprintf("spider:task:logs:%d", taskID)
  318. ts := time.Now().Format("15:04:05")
  319. line := fmt.Sprintf("[%s] %s", ts, msg)
  320. m.redis.RPush(ctx, key, line)
  321. m.redis.LTrim(ctx, key, -500, -1)
  322. m.redis.Expire(ctx, key, 24*time.Hour)
  323. }
  324. func (m *Manager) writeProgress(ctx context.Context, taskID uint, phase string, current, total int, message string) {
  325. key := fmt.Sprintf("spider:task:progress:%d", taskID)
  326. now := time.Now().UTC().Format(time.RFC3339)
  327. fields := map[string]any{
  328. "phase": phase,
  329. "current": current,
  330. "total": total,
  331. "message": message,
  332. "updated_at": now,
  333. }
  334. b, _ := json.Marshal(fields)
  335. m.redis.Set(ctx, key, string(b), 24*time.Hour)
  336. }