manager.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583
  1. package task
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "log"
  7. "sync"
  8. "time"
  9. "github.com/redis/go-redis/v9"
  10. "gorm.io/gorm"
  11. "spider/internal/model"
  12. "spider/internal/notification"
  13. "spider/internal/plugin"
  14. "spider/internal/processor"
  15. proxypool "spider/internal/proxy"
  16. "spider/internal/store"
  17. )
  18. // StartRequest is the payload for starting a new task.
  19. type StartRequest struct {
  20. PluginName string `json:"plugin_name" binding:"required"`
  21. AutoClean *bool `json:"auto_clean"` // run processor after collection (default true)
  22. TargetGroup string `json:"target_group,omitempty"` // target a specific TG group/channel for collection
  23. ProxyID *uint `json:"proxy_id,omitempty"` // optional single proxy for this task
  24. ProxyMode string `json:"proxy_mode,omitempty"` // "single" (default) or "pool"
  25. MaxMerchants int `json:"max_merchants,omitempty"` // stop after collecting this many merchants (0 = unlimited)
  26. MaxDurationMins int `json:"max_duration_mins,omitempty"` // stop after this many minutes (0 = unlimited)
  27. ResumeSnapshot bool `json:"resume_snapshot,omitempty"` // resume from last snapshot
  28. }
  29. // Manager manages plugin task lifecycle using goroutines.
  30. // Replaces the asynq-based worker.
  31. type Manager struct {
  32. db *gorm.DB
  33. redis *redis.Client
  34. registry *plugin.Registry
  35. store *store.Store
  36. processor *processor.Processor
  37. mu sync.Mutex
  38. running map[uint]context.CancelFunc // taskID -> cancel
  39. notifier *notification.Manager
  40. proxyPool *proxypool.Pool // current active proxy pool (nil when not using pool mode)
  41. }
  42. // NewManager creates a new task manager.
  43. func NewManager(db *gorm.DB, rdb *redis.Client, reg *plugin.Registry, s *store.Store, proc *processor.Processor) *Manager {
  44. return &Manager{
  45. db: db,
  46. redis: rdb,
  47. registry: reg,
  48. store: s,
  49. processor: proc,
  50. running: make(map[uint]context.CancelFunc),
  51. }
  52. }
  53. // StartTask validates, creates a TaskLog record, and runs the plugin in a goroutine.
  54. func (m *Manager) StartTask(req StartRequest) (*model.TaskLog, error) {
  55. // Validate plugin exists
  56. collector, err := m.registry.Get(req.PluginName)
  57. if err != nil {
  58. return nil, fmt.Errorf("unknown plugin: %s", req.PluginName)
  59. }
  60. // Check if same plugin is already running
  61. var count int64
  62. m.db.Model(&model.TaskLog{}).
  63. Where("plugin_name = ? AND status = ?", req.PluginName, "running").
  64. Count(&count)
  65. if count > 0 {
  66. return nil, fmt.Errorf("plugin %s is already running", req.PluginName)
  67. }
  68. // Resolve proxy configuration
  69. var proxyID *uint
  70. var proxyName string
  71. var proxyURL string
  72. var pool *proxypool.Pool
  73. if req.ProxyMode == "pool" {
  74. // Pool mode: load all enabled proxies
  75. var proxies []model.Proxy
  76. m.db.Where("enabled = ?", true).Find(&proxies)
  77. if len(proxies) == 0 {
  78. return nil, fmt.Errorf("代理池模式但没有可用的代理")
  79. }
  80. pool = proxypool.NewPool(3, 2*time.Minute)
  81. names := make([]string, 0, len(proxies))
  82. for _, p := range proxies {
  83. pool.Add(p.ID, p.Name, p.ProxyURL(), p.Region)
  84. names = append(names, p.Name)
  85. }
  86. proxyName = fmt.Sprintf("代理池(%d个)", len(proxies))
  87. m.mu.Lock()
  88. m.proxyPool = pool
  89. m.mu.Unlock()
  90. log.Printf("[task] using proxy pool with %d proxies: %v", len(proxies), names)
  91. } else if req.ProxyID != nil && *req.ProxyID > 0 {
  92. // Single proxy mode
  93. var proxy model.Proxy
  94. if err := m.db.First(&proxy, *req.ProxyID).Error; err == nil {
  95. proxyID = &proxy.ID
  96. proxyName = proxy.Name
  97. proxyURL = proxy.ProxyURL()
  98. }
  99. }
  100. // Create task log record
  101. now := time.Now()
  102. taskLog := &model.TaskLog{
  103. TaskType: "collect",
  104. PluginName: req.PluginName,
  105. Status: "running",
  106. StartedAt: &now,
  107. ProxyID: proxyID,
  108. ProxyName: proxyName,
  109. ProxyMode: req.ProxyMode,
  110. }
  111. if err := m.db.Create(taskLog).Error; err != nil {
  112. return nil, fmt.Errorf("create task log: %w", err)
  113. }
  114. // Build config for the plugin
  115. cfg, err := m.buildPluginConfig(req.PluginName)
  116. if err != nil {
  117. m.failTask(taskLog, err)
  118. return nil, err
  119. }
  120. // Set proxy in config
  121. if pool != nil {
  122. cfg["proxy_pool"] = pool
  123. // Also set an initial proxy_url for compatibility
  124. cfg["proxy_url"] = pool.Next()
  125. } else if proxyURL != "" {
  126. cfg["proxy_url"] = proxyURL
  127. }
  128. // Stop conditions
  129. if req.MaxMerchants > 0 {
  130. cfg["max_merchants"] = req.MaxMerchants
  131. }
  132. if req.MaxDurationMins > 0 {
  133. cfg["max_duration_mins"] = req.MaxDurationMins
  134. }
  135. cfg["resume_snapshot"] = req.ResumeSnapshot
  136. // If targeting a specific group, override seeds config
  137. if req.TargetGroup != "" {
  138. cfg["seeds"] = []string{req.TargetGroup}
  139. cfg["target_group"] = req.TargetGroup
  140. cfg["max_depth"] = 0 // don't BFS discover, just scrape this group
  141. cfg["max_channels"] = 1
  142. }
  143. // Start in goroutine
  144. ctx, cancel := context.WithCancel(context.Background())
  145. m.mu.Lock()
  146. m.running[taskLog.ID] = cancel
  147. m.mu.Unlock()
  148. // Default to true if not explicitly set
  149. autoClean := req.AutoClean == nil || *req.AutoClean
  150. go m.runTask(ctx, taskLog, collector, cfg, autoClean)
  151. return taskLog, nil
  152. }
  153. func (m *Manager) runTask(ctx context.Context, taskLog *model.TaskLog, collector plugin.Collector, cfg map[string]any, autoClean bool) {
  154. defer func() {
  155. m.mu.Lock()
  156. delete(m.running, taskLog.ID)
  157. m.proxyPool = nil // clear stale pool reference
  158. m.mu.Unlock()
  159. }()
  160. // Recover from panics to prevent crashing the entire server
  161. defer func() {
  162. if r := recover(); r != nil {
  163. log.Printf("[task] PANIC in task %d: %v", taskLog.ID, r)
  164. finishedAt := time.Now()
  165. m.db.Model(taskLog).Updates(map[string]any{
  166. "status": "failed",
  167. "finished_at": &finishedAt,
  168. "detail": fmt.Sprintf("panic: %v", r),
  169. })
  170. }
  171. }()
  172. // Create detail logger for this task
  173. dl := NewDetailLogger(m.db, taskLog.ID)
  174. defer dl.Close()
  175. collector.SetLogger(dl)
  176. m.writeLog(ctx, taskLog.ID, fmt.Sprintf("开始采集: %s", collector.Name()))
  177. merchantCount := 0
  178. errCount := 0
  179. // Callback: for each merchant found, save to raw table + group-member relationship
  180. callback := func(data plugin.MerchantData) {
  181. inserted, err := m.store.SaveRaw(data)
  182. if err != nil {
  183. errCount++
  184. log.Printf("[task] save raw error: %v", err)
  185. return
  186. }
  187. // Record group-member relationship if source is a TG group/channel
  188. if data.GroupUsername != "" && data.TgUsername != "" {
  189. m.store.SaveGroupMember(data.GroupUsername, data.TgUsername, data.GroupTitle, data.SourceType, taskLog.ID)
  190. }
  191. if inserted {
  192. merchantCount++
  193. if merchantCount%10 == 0 {
  194. m.writeProgress(ctx, taskLog.ID, collector.Name(), merchantCount, 0,
  195. fmt.Sprintf("已采集 %d 个商户", merchantCount))
  196. }
  197. }
  198. }
  199. // Run the collector
  200. runErr := collector.Run(ctx, cfg, callback)
  201. // Check if stopped
  202. if m.isStopRequested(ctx, taskLog.ID) || ctx.Err() != nil {
  203. m.writeLog(ctx, taskLog.ID, "任务已停止")
  204. finishedAt := time.Now()
  205. m.db.Model(taskLog).Updates(map[string]any{
  206. "status": "stopped",
  207. "finished_at": &finishedAt,
  208. "merchants_added": merchantCount,
  209. "errors_count": errCount,
  210. })
  211. return
  212. }
  213. if runErr != nil {
  214. m.failTask(taskLog, runErr)
  215. m.writeLog(ctx, taskLog.ID, "采集失败: "+runErr.Error())
  216. m.notify("task_failed", "任务失败", fmt.Sprintf("采集任务 #%d (%s) 失败: %s", taskLog.ID, collector.Name(), runErr.Error()))
  217. return
  218. }
  219. m.writeLog(ctx, taskLog.ID, fmt.Sprintf("采集完成: 新增 %d 个商户", merchantCount))
  220. // Auto-clean: run processor if there are any unprocessed raw records
  221. var pendingRaw int64
  222. m.db.Model(&model.MerchantRaw{}).Where("status = ?", "raw").Count(&pendingRaw)
  223. if autoClean && pendingRaw > 0 {
  224. m.writeLog(ctx, taskLog.ID, "开始清洗流程...")
  225. m.writeProgress(ctx, taskLog.ID, "clean", 0, 0, "清洗中...")
  226. m.processor.SetProgressFn(func(step string, current, total int, msg string) {
  227. m.writeProgress(ctx, taskLog.ID, step, current, total, msg)
  228. m.writeLog(ctx, taskLog.ID, fmt.Sprintf("[%s] %d/%d %s", step, current, total, msg))
  229. })
  230. m.processor.SetLogger(dl)
  231. procResult, procErr := m.processor.Process(ctx)
  232. if procErr != nil {
  233. m.writeLog(ctx, taskLog.ID, "清洗失败: "+procErr.Error())
  234. } else {
  235. m.writeLog(ctx, taskLog.ID, fmt.Sprintf("清洗完成: Hot=%d, Warm=%d, Cold=%d",
  236. procResult.HotCount, procResult.WarmCount, procResult.ColdCount))
  237. }
  238. }
  239. // Complete
  240. finishedAt := time.Now()
  241. detail := fmt.Sprintf("采集 %d 个商户, 错误 %d 次", merchantCount, errCount)
  242. if pool, ok := cfg["proxy_pool"].(*proxypool.Pool); ok && pool != nil {
  243. detail += " | " + pool.Summary()
  244. }
  245. m.db.Model(taskLog).Updates(map[string]any{
  246. "status": "completed",
  247. "finished_at": &finishedAt,
  248. "merchants_added": merchantCount,
  249. "errors_count": errCount,
  250. "detail": detail,
  251. })
  252. m.writeProgress(ctx, taskLog.ID, "done", 100, 100, "任务完成")
  253. m.writeLog(ctx, taskLog.ID, "任务完成")
  254. log.Printf("[task] task %d completed: %s", taskLog.ID, detail)
  255. m.notify("task_completed", "任务完成", fmt.Sprintf("采集任务 #%d (%s): %s", taskLog.ID, collector.Name(), detail))
  256. }
  257. // StartClean runs the processor independently (not tied to a plugin).
  258. func (m *Manager) StartClean() (*model.TaskLog, error) {
  259. var count int64
  260. m.db.Model(&model.TaskLog{}).
  261. Where("task_type = ? AND status = ?", "clean", "running").
  262. Count(&count)
  263. if count > 0 {
  264. return nil, fmt.Errorf("clean task is already running")
  265. }
  266. now := time.Now()
  267. taskLog := &model.TaskLog{
  268. TaskType: "clean",
  269. PluginName: "",
  270. Status: "running",
  271. StartedAt: &now,
  272. }
  273. if err := m.db.Create(taskLog).Error; err != nil {
  274. return nil, err
  275. }
  276. ctx, cancel := context.WithCancel(context.Background())
  277. m.mu.Lock()
  278. m.running[taskLog.ID] = cancel
  279. m.mu.Unlock()
  280. go func() {
  281. defer func() {
  282. m.mu.Lock()
  283. delete(m.running, taskLog.ID)
  284. m.mu.Unlock()
  285. }()
  286. defer func() {
  287. if r := recover(); r != nil {
  288. log.Printf("[task] PANIC in clean task %d: %v", taskLog.ID, r)
  289. finishedAt := time.Now()
  290. m.db.Model(taskLog).Updates(map[string]any{
  291. "status": "failed",
  292. "finished_at": &finishedAt,
  293. "detail": fmt.Sprintf("panic: %v", r),
  294. })
  295. }
  296. }()
  297. m.writeLog(ctx, taskLog.ID, "开始独立清洗任务")
  298. m.processor.SetProgressFn(func(step string, current, total int, msg string) {
  299. m.writeProgress(ctx, taskLog.ID, step, current, total, msg)
  300. })
  301. result, err := m.processor.Process(ctx)
  302. finishedAt := time.Now()
  303. if err != nil {
  304. m.db.Model(taskLog).Updates(map[string]any{
  305. "status": "failed",
  306. "finished_at": &finishedAt,
  307. "detail": err.Error(),
  308. })
  309. m.notify("task_failed", "清洗任务失败", fmt.Sprintf("清洗任务 #%d 失败: %s", taskLog.ID, err.Error()))
  310. return
  311. }
  312. detail := fmt.Sprintf("输入 %d, Hot=%d, Warm=%d, Cold=%d",
  313. result.InputCount, result.HotCount, result.WarmCount, result.ColdCount)
  314. m.db.Model(taskLog).Updates(map[string]any{
  315. "status": "completed",
  316. "finished_at": &finishedAt,
  317. "items_processed": result.InputCount,
  318. "merchants_added": result.OutputCount,
  319. "detail": detail,
  320. })
  321. m.writeLog(ctx, taskLog.ID, "清洗完成: "+detail)
  322. m.notify("task_completed", "清洗任务完成", fmt.Sprintf("清洗任务 #%d: %s", taskLog.ID, detail))
  323. // Notify for new hot merchants
  324. if result.HotCount > 0 {
  325. m.notify("new_hot_merchant", "发现优质商户", fmt.Sprintf("清洗任务 #%d 发现 %d 个优质商户", taskLog.ID, result.HotCount))
  326. }
  327. }()
  328. return taskLog, nil
  329. }
  330. // StopAll cancels all running tasks (used during graceful shutdown).
  331. func (m *Manager) StopAll() {
  332. m.mu.Lock()
  333. running := make(map[uint]context.CancelFunc, len(m.running))
  334. for id, cancel := range m.running {
  335. running[id] = cancel
  336. }
  337. m.mu.Unlock()
  338. for id, cancel := range running {
  339. log.Printf("[task] stopping task %d for shutdown", id)
  340. cancel()
  341. finishedAt := time.Now()
  342. m.db.Model(&model.TaskLog{}).Where("id = ? AND status = ?", id, "running").
  343. Updates(map[string]any{
  344. "status": "stopped",
  345. "finished_at": &finishedAt,
  346. "detail": "服务关闭,任务停止",
  347. })
  348. }
  349. // Clear proxy pool reference
  350. m.mu.Lock()
  351. m.proxyPool = nil
  352. m.mu.Unlock()
  353. }
  354. // StopTask cancels a running task.
  355. func (m *Manager) StopTask(taskID uint) error {
  356. // Set Redis stop signal
  357. key := fmt.Sprintf("spider:task:stop:%d", taskID)
  358. m.redis.Set(context.Background(), key, "1", time.Hour)
  359. // Cancel the goroutine context
  360. m.mu.Lock()
  361. cancel, ok := m.running[taskID]
  362. m.mu.Unlock()
  363. if ok {
  364. cancel()
  365. }
  366. // Also try to stop the collector
  367. var taskLog model.TaskLog
  368. if err := m.db.First(&taskLog, taskID).Error; err == nil && taskLog.PluginName != "" {
  369. if collector, err := m.registry.Get(taskLog.PluginName); err == nil {
  370. collector.Stop()
  371. }
  372. }
  373. return nil
  374. }
  375. // GetProgress reads live progress from Redis.
  376. func (m *Manager) GetProgress(taskID uint) map[string]any {
  377. key := fmt.Sprintf("spider:task:progress:%d", taskID)
  378. vals, err := m.redis.HGetAll(context.Background(), key).Result()
  379. if err != nil {
  380. return nil
  381. }
  382. result := make(map[string]any)
  383. for k, v := range vals {
  384. result[k] = v
  385. }
  386. return result
  387. }
  388. // GetLogs reads task logs from Redis.
  389. func (m *Manager) GetLogs(taskID uint) []string {
  390. key := fmt.Sprintf("spider:task:logs:%d", taskID)
  391. logs, err := m.redis.LRange(context.Background(), key, 0, -1).Result()
  392. if err != nil {
  393. return nil
  394. }
  395. return logs
  396. }
  397. // buildPluginConfig builds the config map for a plugin from the DB.
  398. func (m *Manager) buildPluginConfig(pluginName string) (map[string]any, error) {
  399. cfg := make(map[string]any)
  400. switch pluginName {
  401. case "web_collector":
  402. keywords, err := m.store.ListEnabledKeywords()
  403. if err != nil {
  404. return nil, err
  405. }
  406. kws := make([]string, 0, len(keywords))
  407. for _, k := range keywords {
  408. kws = append(kws, k.Keyword)
  409. }
  410. cfg["keywords"] = kws
  411. case "tg_collector":
  412. seeds, err := m.store.ListSeeds()
  413. if err != nil {
  414. return nil, err
  415. }
  416. seedNames := make([]string, 0, len(seeds))
  417. for _, s := range seeds {
  418. seedNames = append(seedNames, s.Keyword)
  419. }
  420. cfg["seeds"] = seedNames
  421. cfg["max_depth"] = 3
  422. cfg["max_channels"] = 500
  423. cfg["message_limit"] = 500
  424. case "github_collector":
  425. keywords, err := m.store.ListEnabledKeywords()
  426. if err != nil {
  427. return nil, err
  428. }
  429. kws := make([]string, 0, len(keywords))
  430. for _, k := range keywords {
  431. kws = append(kws, k.Keyword)
  432. }
  433. cfg["keywords"] = kws
  434. cfg["repos_limit"] = 50
  435. }
  436. return cfg, nil
  437. }
  438. func (m *Manager) failTask(taskLog *model.TaskLog, err error) {
  439. finishedAt := time.Now()
  440. m.db.Model(taskLog).Updates(map[string]any{
  441. "status": "failed",
  442. "finished_at": &finishedAt,
  443. "detail": err.Error(),
  444. })
  445. }
  446. func (m *Manager) isStopRequested(ctx context.Context, taskID uint) bool {
  447. key := fmt.Sprintf("spider:task:stop:%d", taskID)
  448. val, _ := m.redis.Get(ctx, key).Result()
  449. return val == "1"
  450. }
  451. // TaskPubSubChannel returns the Redis Pub/Sub channel name for a task.
  452. func TaskPubSubChannel(taskID uint) string {
  453. return fmt.Sprintf("spider:task:events:%d", taskID)
  454. }
  455. func (m *Manager) writeLog(ctx context.Context, taskID uint, msg string) {
  456. key := fmt.Sprintf("spider:task:logs:%d", taskID)
  457. ts := time.Now().Format("15:04:05")
  458. line := fmt.Sprintf("[%s] %s", ts, msg)
  459. m.redis.RPush(ctx, key, line)
  460. m.redis.LTrim(ctx, key, -500, -1)
  461. m.redis.Expire(ctx, key, 24*time.Hour)
  462. // Publish to Pub/Sub for real-time WebSocket delivery
  463. m.redis.Publish(ctx, TaskPubSubChannel(taskID), line)
  464. }
  465. func (m *Manager) writeProgress(ctx context.Context, taskID uint, phase string, current, total int, message string) {
  466. key := fmt.Sprintf("spider:task:progress:%d", taskID)
  467. now := time.Now().UTC().Format(time.RFC3339)
  468. fields := map[string]any{
  469. "phase": phase,
  470. "current": current,
  471. "total": total,
  472. "message": message,
  473. "updated_at": now,
  474. }
  475. b, _ := json.Marshal(fields)
  476. m.redis.Set(ctx, key, string(b), 24*time.Hour)
  477. // Publish progress to Pub/Sub
  478. m.redis.Publish(ctx, TaskPubSubChannel(taskID), fmt.Sprintf("[进度] %s", message))
  479. }
  480. // GetProxyPool returns the current active proxy pool (may be nil).
  481. func (m *Manager) GetProxyPool() *proxypool.Pool {
  482. m.mu.Lock()
  483. defer m.mu.Unlock()
  484. return m.proxyPool
  485. }
  486. // GetRedis returns the Redis client for WebSocket Pub/Sub subscription.
  487. func (m *Manager) GetRedis() *redis.Client {
  488. return m.redis
  489. }
  490. // SetNotifier sets the notification manager for event dispatching.
  491. func (m *Manager) SetNotifier(n *notification.Manager) {
  492. m.notifier = n
  493. }
  494. func (m *Manager) notify(eventType, title, msg string) {
  495. if m.notifier == nil {
  496. return
  497. }
  498. m.notifier.Send(notification.Event{
  499. Type: eventType,
  500. Title: title,
  501. Message: msg,
  502. })
  503. }
  504. // ListPlugins returns all registered plugin names.
  505. func (m *Manager) ListPlugins() []string {
  506. return m.registry.List()
  507. }