task.go 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. package handler
  2. import (
  3. "context"
  4. "fmt"
  5. "net/http"
  6. "strconv"
  7. "strings"
  8. "time"
  9. "github.com/gin-gonic/gin"
  10. "github.com/gorilla/websocket"
  11. "spider/internal/model"
  12. "spider/internal/store"
  13. "spider/internal/task"
  14. )
  15. // TaskHandler handles task-related HTTP and WebSocket requests.
  16. type TaskHandler struct {
  17. store *store.Store
  18. taskMgr *task.Manager
  19. }
  20. // List handles GET /tasks
  21. func (h *TaskHandler) List(c *gin.Context) {
  22. page, pageSize, offset := parsePage(c)
  23. status := c.Query("status")
  24. plugin := c.Query("plugin_name")
  25. dateFrom := c.Query("date_from")
  26. dateTo := c.Query("date_to")
  27. query := h.store.DB.Model(&model.TaskLog{}).Order("created_at DESC")
  28. if status != "" {
  29. query = query.Where("status = ?", status)
  30. }
  31. if plugin != "" {
  32. query = query.Where("plugin_name = ?", plugin)
  33. }
  34. if dateFrom != "" {
  35. t, err := time.Parse("2006-01-02", dateFrom)
  36. if err == nil {
  37. query = query.Where("created_at >= ?", t)
  38. }
  39. }
  40. if dateTo != "" {
  41. t, err := time.Parse("2006-01-02", dateTo)
  42. if err == nil {
  43. query = query.Where("created_at < ?", t.AddDate(0, 0, 1))
  44. }
  45. }
  46. var total int64
  47. query.Count(&total)
  48. var tasks []model.TaskLog
  49. if err := query.Limit(pageSize).Offset(offset).Find(&tasks).Error; err != nil {
  50. Fail(c, 500, err.Error())
  51. return
  52. }
  53. PageOK(c, tasks, total, page, pageSize)
  54. }
  55. // Start handles POST /tasks/start
  56. func (h *TaskHandler) Start(c *gin.Context) {
  57. var req task.StartRequest
  58. if err := c.ShouldBindJSON(&req); err != nil {
  59. Fail(c, 400, err.Error())
  60. return
  61. }
  62. // Validate proxy_mode
  63. if req.ProxyMode != "" && req.ProxyMode != "single" && req.ProxyMode != "pool" {
  64. Fail(c, 400, "proxy_mode 必须是 single 或 pool")
  65. return
  66. }
  67. if req.ProxyMode == "single" && (req.ProxyID == nil || *req.ProxyID == 0) {
  68. Fail(c, 400, "固定代理模式需要指定 proxy_id")
  69. return
  70. }
  71. // Special case: clean task
  72. if req.PluginName == "clean" {
  73. taskLog, err := h.taskMgr.StartClean()
  74. if err != nil {
  75. Fail(c, 409, err.Error())
  76. return
  77. }
  78. LogAudit(h.store, c, "create", "task", fmt.Sprintf("%d", taskLog.ID), gin.H{"type": "clean"})
  79. c.JSON(http.StatusCreated, Response{Code: 0, Message: "ok", Data: taskLog})
  80. return
  81. }
  82. taskLog, err := h.taskMgr.StartTask(req)
  83. if err != nil {
  84. Fail(c, 409, err.Error())
  85. return
  86. }
  87. LogAudit(h.store, c, "create", "task", fmt.Sprintf("%d", taskLog.ID), gin.H{"plugin": req.PluginName})
  88. c.JSON(http.StatusCreated, Response{Code: 0, Message: "ok", Data: taskLog})
  89. }
  90. // Get handles GET /tasks/:id
  91. func (h *TaskHandler) Get(c *gin.Context) {
  92. id, err := strconv.ParseUint(c.Param("id"), 10, 64)
  93. if err != nil {
  94. Fail(c, 400, "invalid id")
  95. return
  96. }
  97. var taskLog model.TaskLog
  98. if err := h.store.DB.First(&taskLog, id).Error; err != nil {
  99. Fail(c, 404, "task not found")
  100. return
  101. }
  102. progress := h.taskMgr.GetProgress(uint(id))
  103. OK(c, gin.H{
  104. "task": taskLog,
  105. "progress": progress,
  106. })
  107. }
  108. // Stop handles POST /tasks/:id/stop
  109. func (h *TaskHandler) Stop(c *gin.Context) {
  110. id, err := strconv.ParseUint(c.Param("id"), 10, 64)
  111. if err != nil {
  112. Fail(c, 400, "invalid id")
  113. return
  114. }
  115. if err := h.taskMgr.StopTask(uint(id)); err != nil {
  116. Fail(c, 500, err.Error())
  117. return
  118. }
  119. LogAudit(h.store, c, "update", "task", fmt.Sprintf("%d", id), gin.H{"action": "stop"})
  120. OK(c, gin.H{"message": "stop signal sent"})
  121. }
  122. // Retry handles POST /tasks/:id/retry — restarts a failed/stopped task with same config.
  123. func (h *TaskHandler) Retry(c *gin.Context) {
  124. id, err := strconv.ParseUint(c.Param("id"), 10, 64)
  125. if err != nil {
  126. Fail(c, 400, "invalid id")
  127. return
  128. }
  129. var original model.TaskLog
  130. if err := h.store.DB.First(&original, id).Error; err != nil {
  131. Fail(c, 404, "task not found")
  132. return
  133. }
  134. if original.Status != "failed" && original.Status != "stopped" {
  135. Fail(c, 400, "只能重试失败或已停止的任务")
  136. return
  137. }
  138. var taskLog *model.TaskLog
  139. if original.TaskType == "clean" {
  140. taskLog, err = h.taskMgr.StartClean()
  141. } else {
  142. req := task.StartRequest{
  143. PluginName: original.PluginName,
  144. ProxyMode: original.ProxyMode,
  145. }
  146. if original.ProxyID != nil {
  147. req.ProxyID = original.ProxyID
  148. }
  149. taskLog, err = h.taskMgr.StartTask(req)
  150. }
  151. if err != nil {
  152. Fail(c, 409, err.Error())
  153. return
  154. }
  155. LogAudit(h.store, c, "create", "task", fmt.Sprintf("%d", taskLog.ID), gin.H{"retry_from": id})
  156. c.JSON(http.StatusCreated, Response{Code: 0, Message: "ok", Data: taskLog})
  157. }
  158. // Logs handles GET /tasks/:id/logs via WebSocket.
  159. // Uses Redis Pub/Sub for real-time updates instead of polling the database.
  160. func (h *TaskHandler) Logs(c *gin.Context) {
  161. id, err := strconv.ParseUint(c.Param("id"), 10, 64)
  162. if err != nil {
  163. c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"})
  164. return
  165. }
  166. upgrader := websocket.Upgrader{
  167. CheckOrigin: func(r *http.Request) bool { return true },
  168. }
  169. conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
  170. if err != nil {
  171. return
  172. }
  173. defer conn.Close()
  174. send := func(msg string) bool {
  175. err := conn.WriteMessage(websocket.TextMessage, []byte(msg))
  176. return err == nil
  177. }
  178. // Fetch task record
  179. var taskLog model.TaskLog
  180. if err := h.store.DB.First(&taskLog, id).Error; err != nil {
  181. send(fmt.Sprintf("[错误] 任务 #%d 不存在", id))
  182. return
  183. }
  184. // Send history logs
  185. logs := h.taskMgr.GetLogs(uint(id))
  186. for _, line := range logs {
  187. if !send(line) {
  188. return
  189. }
  190. }
  191. // If already finished, close
  192. if taskLog.Status == "completed" || taskLog.Status == "failed" || taskLog.Status == "stopped" {
  193. send(fmt.Sprintf("[完成] 任务已结束,状态: %s", taskLog.Status))
  194. return
  195. }
  196. // Subscribe to Redis Pub/Sub for real-time updates (no DB polling!)
  197. rdb := h.taskMgr.GetRedis()
  198. pubsub := rdb.Subscribe(context.Background(), task.TaskPubSubChannel(uint(id)))
  199. defer pubsub.Close()
  200. // Detect client disconnect
  201. clientGone := make(chan struct{})
  202. go func() {
  203. for {
  204. if _, _, err := conn.ReadMessage(); err != nil {
  205. close(clientGone)
  206. return
  207. }
  208. }
  209. }()
  210. ch := pubsub.Channel()
  211. // Safety timeout: if no updates for 5 minutes, check DB and close
  212. timeout := time.NewTimer(5 * time.Minute)
  213. defer timeout.Stop()
  214. for {
  215. select {
  216. case <-clientGone:
  217. return
  218. case msg, ok := <-ch:
  219. if !ok {
  220. return
  221. }
  222. if !send(msg.Payload) {
  223. return
  224. }
  225. // Check if task completed
  226. if isFinishMessage(msg.Payload) {
  227. return
  228. }
  229. timeout.Reset(5 * time.Minute)
  230. case <-timeout.C:
  231. // Fallback: check DB if task is still running
  232. var t model.TaskLog
  233. if err := h.store.DB.First(&t, id).Error; err != nil {
  234. return
  235. }
  236. if t.Status == "completed" || t.Status == "failed" || t.Status == "stopped" {
  237. send(fmt.Sprintf("[完成] 任务已结束,状态: %s", t.Status))
  238. return
  239. }
  240. timeout.Reset(5 * time.Minute)
  241. }
  242. }
  243. }
  244. func isFinishMessage(msg string) bool {
  245. return msg == "任务完成" ||
  246. strings.Contains(msg, "任务已结束") ||
  247. strings.Contains(msg, "任务已停止") ||
  248. strings.Contains(msg, "采集失败")
  249. }
  250. // Details handles GET /tasks/:id/details — returns per-operation execution logs.
  251. // Query params: page, page_size, action (filter by action type), status (filter by status)
  252. func (h *TaskHandler) Details(c *gin.Context) {
  253. id, err := strconv.ParseUint(c.Param("id"), 10, 64)
  254. if err != nil {
  255. Fail(c, 400, "invalid id")
  256. return
  257. }
  258. page, pageSize, offset := parsePage(c)
  259. query := h.store.DB.Model(&model.TaskDetail{}).Where("task_id = ?", id)
  260. if action := c.Query("action"); action != "" {
  261. query = query.Where("action = ?", action)
  262. }
  263. if status := c.Query("status"); status != "" {
  264. query = query.Where("status = ?", status)
  265. }
  266. var total int64
  267. query.Count(&total)
  268. var details []model.TaskDetail
  269. if err := query.Order("seq ASC").Limit(pageSize).Offset(offset).Find(&details).Error; err != nil {
  270. Fail(c, 500, err.Error())
  271. return
  272. }
  273. // Also return action summary counts
  274. var actionCounts []struct {
  275. Action string
  276. Status string
  277. Cnt int64
  278. }
  279. h.store.DB.Model(&model.TaskDetail{}).
  280. Where("task_id = ?", id).
  281. Select("action, status, count(*) as cnt").
  282. Group("action, status").
  283. Scan(&actionCounts)
  284. summary := map[string]map[string]int64{}
  285. for _, ac := range actionCounts {
  286. if summary[ac.Action] == nil {
  287. summary[ac.Action] = map[string]int64{}
  288. }
  289. summary[ac.Action][ac.Status] = ac.Cnt
  290. }
  291. OK(c, gin.H{
  292. "items": details,
  293. "total": total,
  294. "page": page,
  295. "page_size": pageSize,
  296. "summary": summary,
  297. })
  298. }
  299. // Plugins handles GET /plugins — returns list of available plugins.
  300. func (h *TaskHandler) Plugins(c *gin.Context) {
  301. names := h.taskMgr.ListPlugins()
  302. OK(c, names)
  303. }