task.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. package handler
  2. import (
  3. "fmt"
  4. "net/http"
  5. "strconv"
  6. "time"
  7. "github.com/gin-gonic/gin"
  8. "github.com/gorilla/websocket"
  9. "github.com/redis/go-redis/v9"
  10. "gorm.io/gorm"
  11. "spider/internal/model"
  12. "spider/internal/service"
  13. )
  14. // validTaskTypes is the set of accepted task_type values.
  15. var validTaskTypes = map[string]bool{
  16. "full": true,
  17. "discover": true,
  18. "search": true,
  19. "github": true,
  20. "scrape": true,
  21. "crawl": true,
  22. "clean": true,
  23. "score": true,
  24. }
  25. // TaskHandler handles task-related HTTP and WebSocket requests.
  26. type TaskHandler struct {
  27. db *gorm.DB
  28. taskService *service.TaskService
  29. redis *redis.Client
  30. upgrader websocket.Upgrader
  31. }
  32. // NewTaskHandler creates a TaskHandler.
  33. func NewTaskHandler(db *gorm.DB, svc *service.TaskService, rdb *redis.Client) *TaskHandler {
  34. return &TaskHandler{
  35. db: db,
  36. taskService: svc,
  37. redis: rdb,
  38. upgrader: websocket.Upgrader{
  39. CheckOrigin: func(r *http.Request) bool { return true },
  40. },
  41. }
  42. }
  43. // List handles GET /tasks
  44. // Query params: status, page, page_size
  45. func (h *TaskHandler) List(c *gin.Context) {
  46. status := c.Query("status")
  47. page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
  48. pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
  49. if page < 1 {
  50. page = 1
  51. }
  52. if pageSize < 1 || pageSize > 100 {
  53. pageSize = 20
  54. }
  55. offset := (page - 1) * pageSize
  56. query := h.db.Model(&model.Task{}).Order("created_at DESC")
  57. if status != "" {
  58. query = query.Where("status = ?", status)
  59. }
  60. var total int64
  61. query.Count(&total)
  62. var tasks []model.Task
  63. if err := query.Limit(pageSize).Offset(offset).Find(&tasks).Error; err != nil {
  64. Fail(c, 500, err.Error())
  65. return
  66. }
  67. PageOK(c, tasks, total, page, pageSize)
  68. }
  69. // Start handles POST /tasks/start
  70. func (h *TaskHandler) Start(c *gin.Context) {
  71. var req service.StartTaskRequest
  72. if err := c.ShouldBindJSON(&req); err != nil {
  73. Fail(c, 400, err.Error())
  74. return
  75. }
  76. if !validTaskTypes[req.TaskType] {
  77. Fail(c, 400, fmt.Sprintf("invalid task_type: %s", req.TaskType))
  78. return
  79. }
  80. task, err := h.taskService.StartTask(req)
  81. if err != nil {
  82. Fail(c, 409, err.Error())
  83. return
  84. }
  85. c.JSON(http.StatusCreated, Response{Code: 0, Message: "ok", Data: task})
  86. }
  87. // Get handles GET /tasks/:id
  88. func (h *TaskHandler) Get(c *gin.Context) {
  89. id, err := strconv.ParseUint(c.Param("id"), 10, 64)
  90. if err != nil {
  91. Fail(c, 400, "invalid id")
  92. return
  93. }
  94. var task model.Task
  95. if err := h.db.First(&task, id).Error; err != nil {
  96. Fail(c, 404, "task not found")
  97. return
  98. }
  99. progress := h.taskService.GetProgress(&task)
  100. OK(c, gin.H{
  101. "task": task,
  102. "progress": progress,
  103. })
  104. }
  105. // Stop handles POST /tasks/:id/stop
  106. func (h *TaskHandler) Stop(c *gin.Context) {
  107. id, err := strconv.ParseUint(c.Param("id"), 10, 64)
  108. if err != nil {
  109. Fail(c, 400, "invalid id")
  110. return
  111. }
  112. var body struct {
  113. Force bool `json:"force"`
  114. }
  115. _ = c.ShouldBindJSON(&body)
  116. if err := h.taskService.StopTask(uint(id), body.Force); err != nil {
  117. Fail(c, 500, err.Error())
  118. return
  119. }
  120. OK(c, gin.H{"message": "stop signal sent"})
  121. }
  122. // Logs handles GET /tasks/:id/logs via WebSocket.
  123. // On connect it immediately sends history logs from Redis, then streams live progress
  124. // until the task finishes or the client disconnects.
  125. func (h *TaskHandler) Logs(c *gin.Context) {
  126. id, err := strconv.ParseUint(c.Param("id"), 10, 64)
  127. if err != nil {
  128. c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"})
  129. return
  130. }
  131. conn, err := h.upgrader.Upgrade(c.Writer, c.Request, nil)
  132. if err != nil {
  133. return
  134. }
  135. defer conn.Close()
  136. ctx := c.Request.Context()
  137. ticker := time.NewTicker(time.Second)
  138. defer ticker.Stop()
  139. send := func(msg string) bool {
  140. err := conn.WriteMessage(websocket.TextMessage, []byte(msg))
  141. return err == nil
  142. }
  143. // Fetch task record immediately.
  144. var task model.Task
  145. if err := h.db.First(&task, id).Error; err != nil {
  146. send(fmt.Sprintf("[错误] 任务 #%d 不存在", id))
  147. return
  148. }
  149. // Send history logs from Redis list first.
  150. logKey := fmt.Sprintf("spider:task:logs:%d", id)
  151. historyLogs, _ := h.redis.LRange(ctx, logKey, 0, -1).Result()
  152. for _, line := range historyLogs {
  153. if !send(line) {
  154. return
  155. }
  156. }
  157. // If no history logs, send current task status summary.
  158. if len(historyLogs) == 0 {
  159. send(fmt.Sprintf("[信息] 任务 #%d (%s) 状态: %s", task.ID, task.TaskType, task.Status))
  160. // Also send current Redis progress if available.
  161. progressKey := fmt.Sprintf("spider:task:progress:%d", id)
  162. vals, _ := h.redis.HGetAll(ctx, progressKey).Result()
  163. if len(vals) > 0 {
  164. msg := fmt.Sprintf("[进度] 阶段: %s | 进度: %s/%s | %s",
  165. vals["phase"], vals["current"], vals["total"], vals["message"])
  166. send(msg)
  167. }
  168. }
  169. // If the task has already finished, send completion message and close.
  170. if task.Status == "completed" || task.Status == "failed" || task.Status == "stopped" {
  171. statusLabel := map[string]string{
  172. "completed": "完成",
  173. "failed": "失败",
  174. "stopped": "停止",
  175. }[task.Status]
  176. send(fmt.Sprintf("[完成] 任务已%s", statusLabel))
  177. return
  178. }
  179. // Task is still running — handle client close messages in the background.
  180. clientGone := make(chan struct{})
  181. go func() {
  182. for {
  183. if _, _, err := conn.ReadMessage(); err != nil {
  184. close(clientGone)
  185. return
  186. }
  187. }
  188. }()
  189. progressKey := fmt.Sprintf("spider:task:progress:%d", id)
  190. for {
  191. select {
  192. case <-clientGone:
  193. return
  194. case <-ticker.C:
  195. var t model.Task
  196. if err := h.db.First(&t, id).Error; err != nil {
  197. return
  198. }
  199. vals, _ := h.redis.HGetAll(ctx, progressKey).Result()
  200. if len(vals) > 0 {
  201. msg := fmt.Sprintf("[进度] 阶段: %s | %s/%s | %s",
  202. vals["phase"], vals["current"], vals["total"], vals["message"])
  203. if !send(msg) {
  204. return
  205. }
  206. }
  207. if t.Status == "completed" || t.Status == "failed" || t.Status == "stopped" {
  208. send(fmt.Sprintf("[完成] 任务已结束,状态: %s", t.Status))
  209. return
  210. }
  211. }
  212. }
  213. }