package handler import ( "fmt" "net/http" "strconv" "time" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "spider/internal/model" "spider/internal/store" "spider/internal/task" ) // TaskHandler handles task-related HTTP and WebSocket requests. type TaskHandler struct { store *store.Store taskMgr *task.Manager } // List handles GET /tasks func (h *TaskHandler) List(c *gin.Context) { page, pageSize, offset := parsePage(c) status := c.Query("status") query := h.store.DB.Model(&model.TaskLog{}).Order("created_at DESC") if status != "" { query = query.Where("status = ?", status) } var total int64 query.Count(&total) var tasks []model.TaskLog if err := query.Limit(pageSize).Offset(offset).Find(&tasks).Error; err != nil { Fail(c, 500, err.Error()) return } PageOK(c, tasks, total, page, pageSize) } // Start handles POST /tasks/start func (h *TaskHandler) Start(c *gin.Context) { var req task.StartRequest if err := c.ShouldBindJSON(&req); err != nil { Fail(c, 400, err.Error()) return } // Special case: clean task if req.PluginName == "clean" { taskLog, err := h.taskMgr.StartClean() if err != nil { Fail(c, 409, err.Error()) return } c.JSON(http.StatusCreated, Response{Code: 0, Message: "ok", Data: taskLog}) return } taskLog, err := h.taskMgr.StartTask(req) if err != nil { Fail(c, 409, err.Error()) return } c.JSON(http.StatusCreated, Response{Code: 0, Message: "ok", Data: taskLog}) } // Get handles GET /tasks/:id func (h *TaskHandler) Get(c *gin.Context) { id, err := strconv.ParseUint(c.Param("id"), 10, 64) if err != nil { Fail(c, 400, "invalid id") return } var taskLog model.TaskLog if err := h.store.DB.First(&taskLog, id).Error; err != nil { Fail(c, 404, "task not found") return } progress := h.taskMgr.GetProgress(uint(id)) OK(c, gin.H{ "task": taskLog, "progress": progress, }) } // Stop handles POST /tasks/:id/stop func (h *TaskHandler) Stop(c *gin.Context) { id, err := strconv.ParseUint(c.Param("id"), 10, 64) if err != nil { Fail(c, 400, "invalid id") return } if err := h.taskMgr.StopTask(uint(id)); err != nil { Fail(c, 500, err.Error()) return } OK(c, gin.H{"message": "stop signal sent"}) } // Logs handles GET /tasks/:id/logs via WebSocket. func (h *TaskHandler) Logs(c *gin.Context) { id, err := strconv.ParseUint(c.Param("id"), 10, 64) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"}) return } upgrader := websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, } conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { return } defer conn.Close() send := func(msg string) bool { err := conn.WriteMessage(websocket.TextMessage, []byte(msg)) return err == nil } // Fetch task record var taskLog model.TaskLog if err := h.store.DB.First(&taskLog, id).Error; err != nil { send(fmt.Sprintf("[错误] 任务 #%d 不存在", id)) return } // Send history logs logs := h.taskMgr.GetLogs(uint(id)) for _, line := range logs { if !send(line) { return } } // If finished, close if taskLog.Status == "completed" || taskLog.Status == "failed" || taskLog.Status == "stopped" { send(fmt.Sprintf("[完成] 任务已结束,状态: %s", taskLog.Status)) return } // Stream live updates clientGone := make(chan struct{}) go func() { for { if _, _, err := conn.ReadMessage(); err != nil { close(clientGone) return } } }() ticker := time.NewTicker(time.Second) defer ticker.Stop() for { select { case <-clientGone: return case <-ticker.C: var t model.TaskLog if err := h.store.DB.First(&t, id).Error; err != nil { return } progress := h.taskMgr.GetProgress(uint(id)) if progress != nil { msg := fmt.Sprintf("[进度] %v", progress["message"]) if !send(msg) { return } } if t.Status == "completed" || t.Status == "failed" || t.Status == "stopped" { send(fmt.Sprintf("[完成] 任务已结束,状态: %s", t.Status)) return } } } }