task.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. package handler
  2. import (
  3. "fmt"
  4. "net/http"
  5. "strconv"
  6. "time"
  7. "github.com/gin-gonic/gin"
  8. "github.com/gorilla/websocket"
  9. "spider/internal/model"
  10. "spider/internal/store"
  11. "spider/internal/task"
  12. )
  13. // TaskHandler handles task-related HTTP and WebSocket requests.
  14. type TaskHandler struct {
  15. store *store.Store
  16. taskMgr *task.Manager
  17. }
  18. // List handles GET /tasks
  19. func (h *TaskHandler) List(c *gin.Context) {
  20. page, pageSize, offset := parsePage(c)
  21. status := c.Query("status")
  22. query := h.store.DB.Model(&model.TaskLog{}).Order("created_at DESC")
  23. if status != "" {
  24. query = query.Where("status = ?", status)
  25. }
  26. var total int64
  27. query.Count(&total)
  28. var tasks []model.TaskLog
  29. if err := query.Limit(pageSize).Offset(offset).Find(&tasks).Error; err != nil {
  30. Fail(c, 500, err.Error())
  31. return
  32. }
  33. PageOK(c, tasks, total, page, pageSize)
  34. }
  35. // Start handles POST /tasks/start
  36. func (h *TaskHandler) Start(c *gin.Context) {
  37. var req task.StartRequest
  38. if err := c.ShouldBindJSON(&req); err != nil {
  39. Fail(c, 400, err.Error())
  40. return
  41. }
  42. // Special case: clean task
  43. if req.PluginName == "clean" {
  44. taskLog, err := h.taskMgr.StartClean()
  45. if err != nil {
  46. Fail(c, 409, err.Error())
  47. return
  48. }
  49. c.JSON(http.StatusCreated, Response{Code: 0, Message: "ok", Data: taskLog})
  50. return
  51. }
  52. taskLog, err := h.taskMgr.StartTask(req)
  53. if err != nil {
  54. Fail(c, 409, err.Error())
  55. return
  56. }
  57. c.JSON(http.StatusCreated, Response{Code: 0, Message: "ok", Data: taskLog})
  58. }
  59. // Get handles GET /tasks/:id
  60. func (h *TaskHandler) Get(c *gin.Context) {
  61. id, err := strconv.ParseUint(c.Param("id"), 10, 64)
  62. if err != nil {
  63. Fail(c, 400, "invalid id")
  64. return
  65. }
  66. var taskLog model.TaskLog
  67. if err := h.store.DB.First(&taskLog, id).Error; err != nil {
  68. Fail(c, 404, "task not found")
  69. return
  70. }
  71. progress := h.taskMgr.GetProgress(uint(id))
  72. OK(c, gin.H{
  73. "task": taskLog,
  74. "progress": progress,
  75. })
  76. }
  77. // Stop handles POST /tasks/:id/stop
  78. func (h *TaskHandler) Stop(c *gin.Context) {
  79. id, err := strconv.ParseUint(c.Param("id"), 10, 64)
  80. if err != nil {
  81. Fail(c, 400, "invalid id")
  82. return
  83. }
  84. if err := h.taskMgr.StopTask(uint(id)); err != nil {
  85. Fail(c, 500, err.Error())
  86. return
  87. }
  88. OK(c, gin.H{"message": "stop signal sent"})
  89. }
  90. // Logs handles GET /tasks/:id/logs via WebSocket.
  91. func (h *TaskHandler) Logs(c *gin.Context) {
  92. id, err := strconv.ParseUint(c.Param("id"), 10, 64)
  93. if err != nil {
  94. c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"})
  95. return
  96. }
  97. upgrader := websocket.Upgrader{
  98. CheckOrigin: func(r *http.Request) bool { return true },
  99. }
  100. conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
  101. if err != nil {
  102. return
  103. }
  104. defer conn.Close()
  105. send := func(msg string) bool {
  106. err := conn.WriteMessage(websocket.TextMessage, []byte(msg))
  107. return err == nil
  108. }
  109. // Fetch task record
  110. var taskLog model.TaskLog
  111. if err := h.store.DB.First(&taskLog, id).Error; err != nil {
  112. send(fmt.Sprintf("[错误] 任务 #%d 不存在", id))
  113. return
  114. }
  115. // Send history logs
  116. logs := h.taskMgr.GetLogs(uint(id))
  117. for _, line := range logs {
  118. if !send(line) {
  119. return
  120. }
  121. }
  122. // If finished, close
  123. if taskLog.Status == "completed" || taskLog.Status == "failed" || taskLog.Status == "stopped" {
  124. send(fmt.Sprintf("[完成] 任务已结束,状态: %s", taskLog.Status))
  125. return
  126. }
  127. // Stream live updates
  128. clientGone := make(chan struct{})
  129. go func() {
  130. for {
  131. if _, _, err := conn.ReadMessage(); err != nil {
  132. close(clientGone)
  133. return
  134. }
  135. }
  136. }()
  137. ticker := time.NewTicker(time.Second)
  138. defer ticker.Stop()
  139. for {
  140. select {
  141. case <-clientGone:
  142. return
  143. case <-ticker.C:
  144. var t model.TaskLog
  145. if err := h.store.DB.First(&t, id).Error; err != nil {
  146. return
  147. }
  148. progress := h.taskMgr.GetProgress(uint(id))
  149. if progress != nil {
  150. msg := fmt.Sprintf("[进度] %v", progress["message"])
  151. if !send(msg) {
  152. return
  153. }
  154. }
  155. if t.Status == "completed" || t.Status == "failed" || t.Status == "stopped" {
  156. send(fmt.Sprintf("[完成] 任务已结束,状态: %s", t.Status))
  157. return
  158. }
  159. }
  160. }
  161. }