| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193 |
- package handler
- import (
- "fmt"
- "net/http"
- "strconv"
- "time"
- "github.com/gin-gonic/gin"
- "github.com/gorilla/websocket"
- "spider/internal/model"
- "spider/internal/store"
- "spider/internal/task"
- )
- // TaskHandler handles task-related HTTP and WebSocket requests.
- type TaskHandler struct {
- store *store.Store
- taskMgr *task.Manager
- }
- // List handles GET /tasks
- func (h *TaskHandler) List(c *gin.Context) {
- page, pageSize, offset := parsePage(c)
- status := c.Query("status")
- query := h.store.DB.Model(&model.TaskLog{}).Order("created_at DESC")
- if status != "" {
- query = query.Where("status = ?", status)
- }
- var total int64
- query.Count(&total)
- var tasks []model.TaskLog
- if err := query.Limit(pageSize).Offset(offset).Find(&tasks).Error; err != nil {
- Fail(c, 500, err.Error())
- return
- }
- PageOK(c, tasks, total, page, pageSize)
- }
- // Start handles POST /tasks/start
- func (h *TaskHandler) Start(c *gin.Context) {
- var req task.StartRequest
- if err := c.ShouldBindJSON(&req); err != nil {
- Fail(c, 400, err.Error())
- return
- }
- // Special case: clean task
- if req.PluginName == "clean" {
- taskLog, err := h.taskMgr.StartClean()
- if err != nil {
- Fail(c, 409, err.Error())
- return
- }
- c.JSON(http.StatusCreated, Response{Code: 0, Message: "ok", Data: taskLog})
- return
- }
- taskLog, err := h.taskMgr.StartTask(req)
- if err != nil {
- Fail(c, 409, err.Error())
- return
- }
- c.JSON(http.StatusCreated, Response{Code: 0, Message: "ok", Data: taskLog})
- }
- // Get handles GET /tasks/:id
- func (h *TaskHandler) Get(c *gin.Context) {
- id, err := strconv.ParseUint(c.Param("id"), 10, 64)
- if err != nil {
- Fail(c, 400, "invalid id")
- return
- }
- var taskLog model.TaskLog
- if err := h.store.DB.First(&taskLog, id).Error; err != nil {
- Fail(c, 404, "task not found")
- return
- }
- progress := h.taskMgr.GetProgress(uint(id))
- OK(c, gin.H{
- "task": taskLog,
- "progress": progress,
- })
- }
- // Stop handles POST /tasks/:id/stop
- func (h *TaskHandler) Stop(c *gin.Context) {
- id, err := strconv.ParseUint(c.Param("id"), 10, 64)
- if err != nil {
- Fail(c, 400, "invalid id")
- return
- }
- if err := h.taskMgr.StopTask(uint(id)); err != nil {
- Fail(c, 500, err.Error())
- return
- }
- OK(c, gin.H{"message": "stop signal sent"})
- }
- // Logs handles GET /tasks/:id/logs via WebSocket.
- func (h *TaskHandler) Logs(c *gin.Context) {
- id, err := strconv.ParseUint(c.Param("id"), 10, 64)
- if err != nil {
- c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"})
- return
- }
- upgrader := websocket.Upgrader{
- CheckOrigin: func(r *http.Request) bool { return true },
- }
- conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
- if err != nil {
- return
- }
- defer conn.Close()
- send := func(msg string) bool {
- err := conn.WriteMessage(websocket.TextMessage, []byte(msg))
- return err == nil
- }
- // Fetch task record
- var taskLog model.TaskLog
- if err := h.store.DB.First(&taskLog, id).Error; err != nil {
- send(fmt.Sprintf("[错误] 任务 #%d 不存在", id))
- return
- }
- // Send history logs
- logs := h.taskMgr.GetLogs(uint(id))
- for _, line := range logs {
- if !send(line) {
- return
- }
- }
- // If finished, close
- if taskLog.Status == "completed" || taskLog.Status == "failed" || taskLog.Status == "stopped" {
- send(fmt.Sprintf("[完成] 任务已结束,状态: %s", taskLog.Status))
- return
- }
- // Stream live updates
- clientGone := make(chan struct{})
- go func() {
- for {
- if _, _, err := conn.ReadMessage(); err != nil {
- close(clientGone)
- return
- }
- }
- }()
- ticker := time.NewTicker(time.Second)
- defer ticker.Stop()
- for {
- select {
- case <-clientGone:
- return
- case <-ticker.C:
- var t model.TaskLog
- if err := h.store.DB.First(&t, id).Error; err != nil {
- return
- }
- progress := h.taskMgr.GetProgress(uint(id))
- if progress != nil {
- msg := fmt.Sprintf("[进度] %v", progress["message"])
- if !send(msg) {
- return
- }
- }
- if t.Status == "completed" || t.Status == "failed" || t.Status == "stopped" {
- send(fmt.Sprintf("[完成] 任务已结束,状态: %s", t.Status))
- return
- }
- }
- }
- }
|