package handler import ( "context" "fmt" "net/http" "strconv" "strings" "time" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "spider/internal/model" "spider/internal/store" "spider/internal/task" ) // TaskHandler handles task-related HTTP and WebSocket requests. type TaskHandler struct { store *store.Store taskMgr *task.Manager } // List handles GET /tasks func (h *TaskHandler) List(c *gin.Context) { page, pageSize, offset := parsePage(c) status := c.Query("status") plugin := c.Query("plugin_name") dateFrom := c.Query("date_from") dateTo := c.Query("date_to") query := h.store.DB.Model(&model.TaskLog{}).Order("created_at DESC") if status != "" { query = query.Where("status = ?", status) } if plugin != "" { query = query.Where("plugin_name = ?", plugin) } if dateFrom != "" { t, err := time.Parse("2006-01-02", dateFrom) if err == nil { query = query.Where("created_at >= ?", t) } } if dateTo != "" { t, err := time.Parse("2006-01-02", dateTo) if err == nil { query = query.Where("created_at < ?", t.AddDate(0, 0, 1)) } } var total int64 query.Count(&total) var tasks []model.TaskLog if err := query.Limit(pageSize).Offset(offset).Find(&tasks).Error; err != nil { Fail(c, 500, err.Error()) return } PageOK(c, tasks, total, page, pageSize) } // Start handles POST /tasks/start func (h *TaskHandler) Start(c *gin.Context) { var req task.StartRequest if err := c.ShouldBindJSON(&req); err != nil { Fail(c, 400, err.Error()) return } // Validate proxy_mode if req.ProxyMode != "" && req.ProxyMode != "single" && req.ProxyMode != "pool" { Fail(c, 400, "proxy_mode 必须是 single 或 pool") return } if req.ProxyMode == "single" && (req.ProxyID == nil || *req.ProxyID == 0) { Fail(c, 400, "固定代理模式需要指定 proxy_id") return } // Special case: clean task if req.PluginName == "clean" { taskLog, err := h.taskMgr.StartClean() if err != nil { Fail(c, 409, err.Error()) return } LogAudit(h.store, c, "create", "task", fmt.Sprintf("%d", taskLog.ID), gin.H{"type": "clean"}) c.JSON(http.StatusCreated, Response{Code: 0, Message: "ok", Data: taskLog}) return } taskLog, err := h.taskMgr.StartTask(req) if err != nil { Fail(c, 409, err.Error()) return } LogAudit(h.store, c, "create", "task", fmt.Sprintf("%d", taskLog.ID), gin.H{"plugin": req.PluginName}) c.JSON(http.StatusCreated, Response{Code: 0, Message: "ok", Data: taskLog}) } // Get handles GET /tasks/:id func (h *TaskHandler) Get(c *gin.Context) { id, err := strconv.ParseUint(c.Param("id"), 10, 64) if err != nil { Fail(c, 400, "invalid id") return } var taskLog model.TaskLog if err := h.store.DB.First(&taskLog, id).Error; err != nil { Fail(c, 404, "task not found") return } progress := h.taskMgr.GetProgress(uint(id)) OK(c, gin.H{ "task": taskLog, "progress": progress, }) } // Stop handles POST /tasks/:id/stop func (h *TaskHandler) Stop(c *gin.Context) { id, err := strconv.ParseUint(c.Param("id"), 10, 64) if err != nil { Fail(c, 400, "invalid id") return } if err := h.taskMgr.StopTask(uint(id)); err != nil { Fail(c, 500, err.Error()) return } LogAudit(h.store, c, "update", "task", fmt.Sprintf("%d", id), gin.H{"action": "stop"}) OK(c, gin.H{"message": "stop signal sent"}) } // Retry handles POST /tasks/:id/retry — restarts a failed/stopped task with same config. func (h *TaskHandler) Retry(c *gin.Context) { id, err := strconv.ParseUint(c.Param("id"), 10, 64) if err != nil { Fail(c, 400, "invalid id") return } var original model.TaskLog if err := h.store.DB.First(&original, id).Error; err != nil { Fail(c, 404, "task not found") return } if original.Status != "failed" && original.Status != "stopped" { Fail(c, 400, "只能重试失败或已停止的任务") return } var taskLog *model.TaskLog if original.TaskType == "clean" { taskLog, err = h.taskMgr.StartClean() } else { req := task.StartRequest{ PluginName: original.PluginName, ProxyMode: original.ProxyMode, } if original.ProxyID != nil { req.ProxyID = original.ProxyID } taskLog, err = h.taskMgr.StartTask(req) } if err != nil { Fail(c, 409, err.Error()) return } LogAudit(h.store, c, "create", "task", fmt.Sprintf("%d", taskLog.ID), gin.H{"retry_from": id}) c.JSON(http.StatusCreated, Response{Code: 0, Message: "ok", Data: taskLog}) } // Logs handles GET /tasks/:id/logs via WebSocket. // Uses Redis Pub/Sub for real-time updates instead of polling the database. func (h *TaskHandler) Logs(c *gin.Context) { id, err := strconv.ParseUint(c.Param("id"), 10, 64) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"}) return } upgrader := websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, } conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { return } defer conn.Close() send := func(msg string) bool { err := conn.WriteMessage(websocket.TextMessage, []byte(msg)) return err == nil } // Fetch task record var taskLog model.TaskLog if err := h.store.DB.First(&taskLog, id).Error; err != nil { send(fmt.Sprintf("[错误] 任务 #%d 不存在", id)) return } // Send history logs logs := h.taskMgr.GetLogs(uint(id)) for _, line := range logs { if !send(line) { return } } // If already finished, close if taskLog.Status == "completed" || taskLog.Status == "failed" || taskLog.Status == "stopped" { send(fmt.Sprintf("[完成] 任务已结束,状态: %s", taskLog.Status)) return } // Subscribe to Redis Pub/Sub for real-time updates (no DB polling!) rdb := h.taskMgr.GetRedis() pubsub := rdb.Subscribe(context.Background(), task.TaskPubSubChannel(uint(id))) defer pubsub.Close() // Detect client disconnect clientGone := make(chan struct{}) go func() { for { if _, _, err := conn.ReadMessage(); err != nil { close(clientGone) return } } }() ch := pubsub.Channel() // Safety timeout: if no updates for 5 minutes, check DB and close timeout := time.NewTimer(5 * time.Minute) defer timeout.Stop() for { select { case <-clientGone: return case msg, ok := <-ch: if !ok { return } if !send(msg.Payload) { return } // Check if task completed if isFinishMessage(msg.Payload) { return } timeout.Reset(5 * time.Minute) case <-timeout.C: // Fallback: check DB if task is still running var t model.TaskLog if err := h.store.DB.First(&t, id).Error; err != nil { return } if t.Status == "completed" || t.Status == "failed" || t.Status == "stopped" { send(fmt.Sprintf("[完成] 任务已结束,状态: %s", t.Status)) return } timeout.Reset(5 * time.Minute) } } } func isFinishMessage(msg string) bool { return msg == "任务完成" || strings.Contains(msg, "任务已结束") || strings.Contains(msg, "任务已停止") || strings.Contains(msg, "采集失败") } // Details handles GET /tasks/:id/details — returns per-operation execution logs. // Query params: page, page_size, action (filter by action type), status (filter by status) func (h *TaskHandler) Details(c *gin.Context) { id, err := strconv.ParseUint(c.Param("id"), 10, 64) if err != nil { Fail(c, 400, "invalid id") return } page, pageSize, offset := parsePage(c) query := h.store.DB.Model(&model.TaskDetail{}).Where("task_id = ?", id) if action := c.Query("action"); action != "" { query = query.Where("action = ?", action) } if status := c.Query("status"); status != "" { query = query.Where("status = ?", status) } var total int64 query.Count(&total) var details []model.TaskDetail if err := query.Order("seq ASC").Limit(pageSize).Offset(offset).Find(&details).Error; err != nil { Fail(c, 500, err.Error()) return } // Also return action summary counts var actionCounts []struct { Action string Status string Cnt int64 } h.store.DB.Model(&model.TaskDetail{}). Where("task_id = ?", id). Select("action, status, count(*) as cnt"). Group("action, status"). Scan(&actionCounts) summary := map[string]map[string]int64{} for _, ac := range actionCounts { if summary[ac.Action] == nil { summary[ac.Action] = map[string]int64{} } summary[ac.Action][ac.Status] = ac.Cnt } OK(c, gin.H{ "items": details, "total": total, "page": page, "page_size": pageSize, "summary": summary, }) } // Plugins handles GET /plugins — returns list of available plugins. func (h *TaskHandler) Plugins(c *gin.Context) { names := h.taskMgr.ListPlugins() OK(c, names) }