manager.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569
  1. package task
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "log"
  7. "sync"
  8. "time"
  9. "github.com/redis/go-redis/v9"
  10. "gorm.io/gorm"
  11. "spider/internal/model"
  12. "spider/internal/notification"
  13. "spider/internal/plugin"
  14. "spider/internal/processor"
  15. proxypool "spider/internal/proxy"
  16. "spider/internal/store"
  17. )
  18. // StartRequest is the payload for starting a new task.
  19. type StartRequest struct {
  20. PluginName string `json:"plugin_name" binding:"required"`
  21. AutoClean *bool `json:"auto_clean"` // run processor after collection (default true)
  22. TargetGroup string `json:"target_group,omitempty"` // target a specific TG group/channel for collection
  23. ProxyID *uint `json:"proxy_id,omitempty"` // optional single proxy for this task
  24. ProxyMode string `json:"proxy_mode,omitempty"` // "single" (default) or "pool"
  25. }
  26. // Manager manages plugin task lifecycle using goroutines.
  27. // Replaces the asynq-based worker.
  28. type Manager struct {
  29. db *gorm.DB
  30. redis *redis.Client
  31. registry *plugin.Registry
  32. store *store.Store
  33. processor *processor.Processor
  34. mu sync.Mutex
  35. running map[uint]context.CancelFunc // taskID -> cancel
  36. notifier *notification.Manager
  37. proxyPool *proxypool.Pool // current active proxy pool (nil when not using pool mode)
  38. }
  39. // NewManager creates a new task manager.
  40. func NewManager(db *gorm.DB, rdb *redis.Client, reg *plugin.Registry, s *store.Store, proc *processor.Processor) *Manager {
  41. return &Manager{
  42. db: db,
  43. redis: rdb,
  44. registry: reg,
  45. store: s,
  46. processor: proc,
  47. running: make(map[uint]context.CancelFunc),
  48. }
  49. }
  50. // StartTask validates, creates a TaskLog record, and runs the plugin in a goroutine.
  51. func (m *Manager) StartTask(req StartRequest) (*model.TaskLog, error) {
  52. // Validate plugin exists
  53. collector, err := m.registry.Get(req.PluginName)
  54. if err != nil {
  55. return nil, fmt.Errorf("unknown plugin: %s", req.PluginName)
  56. }
  57. // Check if same plugin is already running
  58. var count int64
  59. m.db.Model(&model.TaskLog{}).
  60. Where("plugin_name = ? AND status = ?", req.PluginName, "running").
  61. Count(&count)
  62. if count > 0 {
  63. return nil, fmt.Errorf("plugin %s is already running", req.PluginName)
  64. }
  65. // Resolve proxy configuration
  66. var proxyID *uint
  67. var proxyName string
  68. var proxyURL string
  69. var pool *proxypool.Pool
  70. if req.ProxyMode == "pool" {
  71. // Pool mode: load all enabled proxies
  72. var proxies []model.Proxy
  73. m.db.Where("enabled = ?", true).Find(&proxies)
  74. if len(proxies) == 0 {
  75. return nil, fmt.Errorf("代理池模式但没有可用的代理")
  76. }
  77. pool = proxypool.NewPool(3, 2*time.Minute)
  78. names := make([]string, 0, len(proxies))
  79. for _, p := range proxies {
  80. pool.Add(p.ID, p.Name, p.ProxyURL(), p.Region)
  81. names = append(names, p.Name)
  82. }
  83. proxyName = fmt.Sprintf("代理池(%d个)", len(proxies))
  84. m.mu.Lock()
  85. m.proxyPool = pool
  86. m.mu.Unlock()
  87. log.Printf("[task] using proxy pool with %d proxies: %v", len(proxies), names)
  88. } else if req.ProxyID != nil && *req.ProxyID > 0 {
  89. // Single proxy mode
  90. var proxy model.Proxy
  91. if err := m.db.First(&proxy, *req.ProxyID).Error; err == nil {
  92. proxyID = &proxy.ID
  93. proxyName = proxy.Name
  94. proxyURL = proxy.ProxyURL()
  95. }
  96. }
  97. // Create task log record
  98. now := time.Now()
  99. taskLog := &model.TaskLog{
  100. TaskType: "collect",
  101. PluginName: req.PluginName,
  102. Status: "running",
  103. StartedAt: &now,
  104. ProxyID: proxyID,
  105. ProxyName: proxyName,
  106. ProxyMode: req.ProxyMode,
  107. }
  108. if err := m.db.Create(taskLog).Error; err != nil {
  109. return nil, fmt.Errorf("create task log: %w", err)
  110. }
  111. // Build config for the plugin
  112. cfg, err := m.buildPluginConfig(req.PluginName)
  113. if err != nil {
  114. m.failTask(taskLog, err)
  115. return nil, err
  116. }
  117. // Set proxy in config
  118. if pool != nil {
  119. cfg["proxy_pool"] = pool
  120. // Also set an initial proxy_url for compatibility
  121. cfg["proxy_url"] = pool.Next()
  122. } else if proxyURL != "" {
  123. cfg["proxy_url"] = proxyURL
  124. }
  125. // If targeting a specific group, override seeds config
  126. if req.TargetGroup != "" {
  127. cfg["seeds"] = []string{req.TargetGroup}
  128. cfg["target_group"] = req.TargetGroup
  129. cfg["max_depth"] = 0 // don't BFS discover, just scrape this group
  130. cfg["max_channels"] = 1
  131. }
  132. // Start in goroutine
  133. ctx, cancel := context.WithCancel(context.Background())
  134. m.mu.Lock()
  135. m.running[taskLog.ID] = cancel
  136. m.mu.Unlock()
  137. // Default to true if not explicitly set
  138. autoClean := req.AutoClean == nil || *req.AutoClean
  139. go m.runTask(ctx, taskLog, collector, cfg, autoClean)
  140. return taskLog, nil
  141. }
  142. func (m *Manager) runTask(ctx context.Context, taskLog *model.TaskLog, collector plugin.Collector, cfg map[string]any, autoClean bool) {
  143. defer func() {
  144. m.mu.Lock()
  145. delete(m.running, taskLog.ID)
  146. m.proxyPool = nil // clear stale pool reference
  147. m.mu.Unlock()
  148. }()
  149. // Recover from panics to prevent crashing the entire server
  150. defer func() {
  151. if r := recover(); r != nil {
  152. log.Printf("[task] PANIC in task %d: %v", taskLog.ID, r)
  153. finishedAt := time.Now()
  154. m.db.Model(taskLog).Updates(map[string]any{
  155. "status": "failed",
  156. "finished_at": &finishedAt,
  157. "detail": fmt.Sprintf("panic: %v", r),
  158. })
  159. }
  160. }()
  161. // Create detail logger for this task
  162. dl := NewDetailLogger(m.db, taskLog.ID)
  163. defer dl.Close()
  164. collector.SetLogger(dl)
  165. m.writeLog(ctx, taskLog.ID, fmt.Sprintf("开始采集: %s", collector.Name()))
  166. merchantCount := 0
  167. errCount := 0
  168. // Callback: for each merchant found, save to raw table + group-member relationship
  169. callback := func(data plugin.MerchantData) {
  170. inserted, err := m.store.SaveRaw(data)
  171. if err != nil {
  172. errCount++
  173. log.Printf("[task] save raw error: %v", err)
  174. return
  175. }
  176. // Record group-member relationship if source is a TG group/channel
  177. if data.GroupUsername != "" && data.TgUsername != "" {
  178. m.store.SaveGroupMember(data.GroupUsername, data.TgUsername, data.GroupTitle, data.SourceType, taskLog.ID)
  179. }
  180. if inserted {
  181. merchantCount++
  182. if merchantCount%10 == 0 {
  183. m.writeProgress(ctx, taskLog.ID, collector.Name(), merchantCount, 0,
  184. fmt.Sprintf("已采集 %d 个商户", merchantCount))
  185. }
  186. }
  187. }
  188. // Run the collector
  189. runErr := collector.Run(ctx, cfg, callback)
  190. // Check if stopped
  191. if m.isStopRequested(ctx, taskLog.ID) || ctx.Err() != nil {
  192. m.writeLog(ctx, taskLog.ID, "任务已停止")
  193. finishedAt := time.Now()
  194. m.db.Model(taskLog).Updates(map[string]any{
  195. "status": "stopped",
  196. "finished_at": &finishedAt,
  197. "merchants_added": merchantCount,
  198. "errors_count": errCount,
  199. })
  200. return
  201. }
  202. if runErr != nil {
  203. m.failTask(taskLog, runErr)
  204. m.writeLog(ctx, taskLog.ID, "采集失败: "+runErr.Error())
  205. m.notify("task_failed", "任务失败", fmt.Sprintf("采集任务 #%d (%s) 失败: %s", taskLog.ID, collector.Name(), runErr.Error()))
  206. return
  207. }
  208. m.writeLog(ctx, taskLog.ID, fmt.Sprintf("采集完成: 新增 %d 个商户", merchantCount))
  209. // Auto-clean: run processor on new raw records
  210. if autoClean && merchantCount > 0 {
  211. m.writeLog(ctx, taskLog.ID, "开始清洗流程...")
  212. m.writeProgress(ctx, taskLog.ID, "clean", 0, 0, "清洗中...")
  213. m.processor.SetProgressFn(func(step string, current, total int, msg string) {
  214. m.writeProgress(ctx, taskLog.ID, step, current, total, msg)
  215. m.writeLog(ctx, taskLog.ID, fmt.Sprintf("[%s] %d/%d %s", step, current, total, msg))
  216. })
  217. m.processor.SetLogger(dl)
  218. procResult, procErr := m.processor.Process(ctx)
  219. if procErr != nil {
  220. m.writeLog(ctx, taskLog.ID, "清洗失败: "+procErr.Error())
  221. } else {
  222. m.writeLog(ctx, taskLog.ID, fmt.Sprintf("清洗完成: Hot=%d, Warm=%d, Cold=%d",
  223. procResult.HotCount, procResult.WarmCount, procResult.ColdCount))
  224. }
  225. }
  226. // Complete
  227. finishedAt := time.Now()
  228. detail := fmt.Sprintf("采集 %d 个商户, 错误 %d 次", merchantCount, errCount)
  229. if pool, ok := cfg["proxy_pool"].(*proxypool.Pool); ok && pool != nil {
  230. detail += " | " + pool.Summary()
  231. }
  232. m.db.Model(taskLog).Updates(map[string]any{
  233. "status": "completed",
  234. "finished_at": &finishedAt,
  235. "merchants_added": merchantCount,
  236. "errors_count": errCount,
  237. "detail": detail,
  238. })
  239. m.writeProgress(ctx, taskLog.ID, "done", 100, 100, "任务完成")
  240. m.writeLog(ctx, taskLog.ID, "任务完成")
  241. log.Printf("[task] task %d completed: %s", taskLog.ID, detail)
  242. m.notify("task_completed", "任务完成", fmt.Sprintf("采集任务 #%d (%s): %s", taskLog.ID, collector.Name(), detail))
  243. }
  244. // StartClean runs the processor independently (not tied to a plugin).
  245. func (m *Manager) StartClean() (*model.TaskLog, error) {
  246. var count int64
  247. m.db.Model(&model.TaskLog{}).
  248. Where("task_type = ? AND status = ?", "clean", "running").
  249. Count(&count)
  250. if count > 0 {
  251. return nil, fmt.Errorf("clean task is already running")
  252. }
  253. now := time.Now()
  254. taskLog := &model.TaskLog{
  255. TaskType: "clean",
  256. PluginName: "",
  257. Status: "running",
  258. StartedAt: &now,
  259. }
  260. if err := m.db.Create(taskLog).Error; err != nil {
  261. return nil, err
  262. }
  263. ctx, cancel := context.WithCancel(context.Background())
  264. m.mu.Lock()
  265. m.running[taskLog.ID] = cancel
  266. m.mu.Unlock()
  267. go func() {
  268. defer func() {
  269. m.mu.Lock()
  270. delete(m.running, taskLog.ID)
  271. m.mu.Unlock()
  272. }()
  273. defer func() {
  274. if r := recover(); r != nil {
  275. log.Printf("[task] PANIC in clean task %d: %v", taskLog.ID, r)
  276. finishedAt := time.Now()
  277. m.db.Model(taskLog).Updates(map[string]any{
  278. "status": "failed",
  279. "finished_at": &finishedAt,
  280. "detail": fmt.Sprintf("panic: %v", r),
  281. })
  282. }
  283. }()
  284. m.writeLog(ctx, taskLog.ID, "开始独立清洗任务")
  285. m.processor.SetProgressFn(func(step string, current, total int, msg string) {
  286. m.writeProgress(ctx, taskLog.ID, step, current, total, msg)
  287. })
  288. result, err := m.processor.Process(ctx)
  289. finishedAt := time.Now()
  290. if err != nil {
  291. m.db.Model(taskLog).Updates(map[string]any{
  292. "status": "failed",
  293. "finished_at": &finishedAt,
  294. "detail": err.Error(),
  295. })
  296. m.notify("task_failed", "清洗任务失败", fmt.Sprintf("清洗任务 #%d 失败: %s", taskLog.ID, err.Error()))
  297. return
  298. }
  299. detail := fmt.Sprintf("输入 %d, Hot=%d, Warm=%d, Cold=%d",
  300. result.InputCount, result.HotCount, result.WarmCount, result.ColdCount)
  301. m.db.Model(taskLog).Updates(map[string]any{
  302. "status": "completed",
  303. "finished_at": &finishedAt,
  304. "items_processed": result.InputCount,
  305. "merchants_added": result.OutputCount,
  306. "detail": detail,
  307. })
  308. m.writeLog(ctx, taskLog.ID, "清洗完成: "+detail)
  309. m.notify("task_completed", "清洗任务完成", fmt.Sprintf("清洗任务 #%d: %s", taskLog.ID, detail))
  310. // Notify for new hot merchants
  311. if result.HotCount > 0 {
  312. m.notify("new_hot_merchant", "发现优质商户", fmt.Sprintf("清洗任务 #%d 发现 %d 个优质商户", taskLog.ID, result.HotCount))
  313. }
  314. }()
  315. return taskLog, nil
  316. }
  317. // StopAll cancels all running tasks (used during graceful shutdown).
  318. func (m *Manager) StopAll() {
  319. m.mu.Lock()
  320. running := make(map[uint]context.CancelFunc, len(m.running))
  321. for id, cancel := range m.running {
  322. running[id] = cancel
  323. }
  324. m.mu.Unlock()
  325. for id, cancel := range running {
  326. log.Printf("[task] stopping task %d for shutdown", id)
  327. cancel()
  328. finishedAt := time.Now()
  329. m.db.Model(&model.TaskLog{}).Where("id = ? AND status = ?", id, "running").
  330. Updates(map[string]any{
  331. "status": "stopped",
  332. "finished_at": &finishedAt,
  333. "detail": "服务关闭,任务停止",
  334. })
  335. }
  336. // Clear proxy pool reference
  337. m.mu.Lock()
  338. m.proxyPool = nil
  339. m.mu.Unlock()
  340. }
  341. // StopTask cancels a running task.
  342. func (m *Manager) StopTask(taskID uint) error {
  343. // Set Redis stop signal
  344. key := fmt.Sprintf("spider:task:stop:%d", taskID)
  345. m.redis.Set(context.Background(), key, "1", time.Hour)
  346. // Cancel the goroutine context
  347. m.mu.Lock()
  348. cancel, ok := m.running[taskID]
  349. m.mu.Unlock()
  350. if ok {
  351. cancel()
  352. }
  353. // Also try to stop the collector
  354. var taskLog model.TaskLog
  355. if err := m.db.First(&taskLog, taskID).Error; err == nil && taskLog.PluginName != "" {
  356. if collector, err := m.registry.Get(taskLog.PluginName); err == nil {
  357. collector.Stop()
  358. }
  359. }
  360. return nil
  361. }
  362. // GetProgress reads live progress from Redis.
  363. func (m *Manager) GetProgress(taskID uint) map[string]any {
  364. key := fmt.Sprintf("spider:task:progress:%d", taskID)
  365. vals, err := m.redis.HGetAll(context.Background(), key).Result()
  366. if err != nil {
  367. return nil
  368. }
  369. result := make(map[string]any)
  370. for k, v := range vals {
  371. result[k] = v
  372. }
  373. return result
  374. }
  375. // GetLogs reads task logs from Redis.
  376. func (m *Manager) GetLogs(taskID uint) []string {
  377. key := fmt.Sprintf("spider:task:logs:%d", taskID)
  378. logs, err := m.redis.LRange(context.Background(), key, 0, -1).Result()
  379. if err != nil {
  380. return nil
  381. }
  382. return logs
  383. }
  384. // buildPluginConfig builds the config map for a plugin from the DB.
  385. func (m *Manager) buildPluginConfig(pluginName string) (map[string]any, error) {
  386. cfg := make(map[string]any)
  387. switch pluginName {
  388. case "web_collector":
  389. keywords, err := m.store.ListEnabledKeywords()
  390. if err != nil {
  391. return nil, err
  392. }
  393. kws := make([]string, 0, len(keywords))
  394. for _, k := range keywords {
  395. kws = append(kws, k.Keyword)
  396. }
  397. cfg["keywords"] = kws
  398. case "tg_collector":
  399. seeds, err := m.store.ListSeeds()
  400. if err != nil {
  401. return nil, err
  402. }
  403. seedNames := make([]string, 0, len(seeds))
  404. for _, s := range seeds {
  405. seedNames = append(seedNames, s.Keyword)
  406. }
  407. cfg["seeds"] = seedNames
  408. cfg["max_depth"] = 3
  409. cfg["max_channels"] = 500
  410. cfg["message_limit"] = 500
  411. case "github_collector":
  412. keywords, err := m.store.ListEnabledKeywords()
  413. if err != nil {
  414. return nil, err
  415. }
  416. kws := make([]string, 0, len(keywords))
  417. for _, k := range keywords {
  418. kws = append(kws, k.Keyword)
  419. }
  420. cfg["keywords"] = kws
  421. cfg["repos_limit"] = 50
  422. }
  423. return cfg, nil
  424. }
  425. func (m *Manager) failTask(taskLog *model.TaskLog, err error) {
  426. finishedAt := time.Now()
  427. m.db.Model(taskLog).Updates(map[string]any{
  428. "status": "failed",
  429. "finished_at": &finishedAt,
  430. "detail": err.Error(),
  431. })
  432. }
  433. func (m *Manager) isStopRequested(ctx context.Context, taskID uint) bool {
  434. key := fmt.Sprintf("spider:task:stop:%d", taskID)
  435. val, _ := m.redis.Get(ctx, key).Result()
  436. return val == "1"
  437. }
  438. // TaskPubSubChannel returns the Redis Pub/Sub channel name for a task.
  439. func TaskPubSubChannel(taskID uint) string {
  440. return fmt.Sprintf("spider:task:events:%d", taskID)
  441. }
  442. func (m *Manager) writeLog(ctx context.Context, taskID uint, msg string) {
  443. key := fmt.Sprintf("spider:task:logs:%d", taskID)
  444. ts := time.Now().Format("15:04:05")
  445. line := fmt.Sprintf("[%s] %s", ts, msg)
  446. m.redis.RPush(ctx, key, line)
  447. m.redis.LTrim(ctx, key, -500, -1)
  448. m.redis.Expire(ctx, key, 24*time.Hour)
  449. // Publish to Pub/Sub for real-time WebSocket delivery
  450. m.redis.Publish(ctx, TaskPubSubChannel(taskID), line)
  451. }
  452. func (m *Manager) writeProgress(ctx context.Context, taskID uint, phase string, current, total int, message string) {
  453. key := fmt.Sprintf("spider:task:progress:%d", taskID)
  454. now := time.Now().UTC().Format(time.RFC3339)
  455. fields := map[string]any{
  456. "phase": phase,
  457. "current": current,
  458. "total": total,
  459. "message": message,
  460. "updated_at": now,
  461. }
  462. b, _ := json.Marshal(fields)
  463. m.redis.Set(ctx, key, string(b), 24*time.Hour)
  464. // Publish progress to Pub/Sub
  465. m.redis.Publish(ctx, TaskPubSubChannel(taskID), fmt.Sprintf("[进度] %s", message))
  466. }
  467. // GetProxyPool returns the current active proxy pool (may be nil).
  468. func (m *Manager) GetProxyPool() *proxypool.Pool {
  469. m.mu.Lock()
  470. defer m.mu.Unlock()
  471. return m.proxyPool
  472. }
  473. // GetRedis returns the Redis client for WebSocket Pub/Sub subscription.
  474. func (m *Manager) GetRedis() *redis.Client {
  475. return m.redis
  476. }
  477. // SetNotifier sets the notification manager for event dispatching.
  478. func (m *Manager) SetNotifier(n *notification.Manager) {
  479. m.notifier = n
  480. }
  481. func (m *Manager) notify(eventType, title, msg string) {
  482. if m.notifier == nil {
  483. return
  484. }
  485. m.notifier.Send(notification.Event{
  486. Type: eventType,
  487. Title: title,
  488. Message: msg,
  489. })
  490. }
  491. // ListPlugins returns all registered plugin names.
  492. func (m *Manager) ListPlugins() []string {
  493. return m.registry.List()
  494. }