pipeline.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. package pipeline
  2. import (
  3. "context"
  4. "fmt"
  5. "log"
  6. "spider/internal/model"
  7. "github.com/redis/go-redis/v9"
  8. "gorm.io/gorm"
  9. )
  10. // fullPhaseOrder defines the sequential execution order for a full pipeline run.
  11. var fullPhaseOrder = []string{
  12. "discover",
  13. "search",
  14. "github",
  15. "scrape",
  16. "crawl",
  17. "clean",
  18. "score",
  19. }
  20. // Runner Pipeline 调度器
  21. type Runner struct {
  22. db *gorm.DB
  23. redis *redis.Client
  24. phases map[string]Phase // 注册的 phase,key 是 phase 名称
  25. reporter ProgressReporter
  26. }
  27. // NewRunner creates a new pipeline Runner.
  28. func NewRunner(db *gorm.DB, rdb *redis.Client) *Runner {
  29. return &Runner{
  30. db: db,
  31. redis: rdb,
  32. phases: make(map[string]Phase),
  33. }
  34. }
  35. // RegisterPhase 注册一个 phase 实现
  36. func (r *Runner) RegisterPhase(p Phase) {
  37. r.phases[p.Name()] = p
  38. }
  39. // SetProgressReporter 设置进度上报函数
  40. func (r *Runner) SetProgressReporter(fn ProgressReporter) {
  41. r.reporter = fn
  42. }
  43. // report calls the reporter if one is set; otherwise logs to stderr.
  44. func (r *Runner) report(phase string, current, total int, message string) {
  45. if r.reporter != nil {
  46. r.reporter(phase, current, total, message)
  47. }
  48. }
  49. // Run 执行 pipeline
  50. // task.TaskType: "full" | "discover" | "search" | "github" | "scrape" | "crawl" | "clean" | "score"
  51. // full 类型按顺序执行所有未跳过的 phase
  52. // 单阶段类型直接执行对应 phase
  53. func (r *Runner) Run(ctx context.Context, task *model.Task, opts *Options) error {
  54. if task.TaskType == "full" {
  55. for _, phaseName := range fullPhaseOrder {
  56. if isContextDone(ctx) {
  57. return fmt.Errorf("pipeline cancelled before phase %s", phaseName)
  58. }
  59. if ShouldSkip(phaseName, opts.SkipPhases) {
  60. log.Printf("[pipeline] skipping phase=%s (in SkipPhases)", phaseName)
  61. continue
  62. }
  63. r.report(phaseName, 0, 0, "开始 "+phaseName)
  64. if err := r.runSingle(ctx, task, phaseName, opts); err != nil {
  65. log.Printf("[pipeline] phase=%s error: %v (continuing)", phaseName, err)
  66. }
  67. r.report(phaseName, 100, 100, phaseName+" 完成")
  68. }
  69. return nil
  70. }
  71. // Single-phase task
  72. phaseName := task.TaskType
  73. if isContextDone(ctx) {
  74. return fmt.Errorf("pipeline cancelled before phase %s", phaseName)
  75. }
  76. r.report(phaseName, 0, 0, "开始 "+phaseName)
  77. if err := r.runSingle(ctx, task, phaseName, opts); err != nil {
  78. r.report(phaseName, 0, 0, phaseName+" 失败: "+err.Error())
  79. return err
  80. }
  81. r.report(phaseName, 100, 100, phaseName+" 完成")
  82. return nil
  83. }
  84. // runSingle 执行单个 phase
  85. func (r *Runner) runSingle(ctx context.Context, task *model.Task, phaseName string, opts *Options) error {
  86. p, ok := r.phases[phaseName]
  87. if !ok {
  88. return fmt.Errorf("phase %q not registered", phaseName)
  89. }
  90. return p.Run(ctx, task, opts)
  91. }
  92. // isContextDone 检查 context 是否已取消(用于各阶段检查停止信号)
  93. func isContextDone(ctx context.Context) bool {
  94. select {
  95. case <-ctx.Done():
  96. return true
  97. default:
  98. return false
  99. }
  100. }