package handler import ( "fmt" "net/http" "strconv" "time" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "github.com/redis/go-redis/v9" "gorm.io/gorm" "spider/internal/model" "spider/internal/service" ) // validTaskTypes is the set of accepted task_type values. var validTaskTypes = map[string]bool{ "full": true, "discover": true, "search": true, "github": true, "scrape": true, "crawl": true, "clean": true, "score": true, } // TaskHandler handles task-related HTTP and WebSocket requests. type TaskHandler struct { db *gorm.DB taskService *service.TaskService redis *redis.Client upgrader websocket.Upgrader } // NewTaskHandler creates a TaskHandler. func NewTaskHandler(db *gorm.DB, svc *service.TaskService, rdb *redis.Client) *TaskHandler { return &TaskHandler{ db: db, taskService: svc, redis: rdb, upgrader: websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, }, } } // List handles GET /tasks // Query params: status, page, page_size func (h *TaskHandler) List(c *gin.Context) { status := c.Query("status") page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) if page < 1 { page = 1 } if pageSize < 1 || pageSize > 100 { pageSize = 20 } offset := (page - 1) * pageSize query := h.db.Model(&model.Task{}).Order("created_at DESC") if status != "" { query = query.Where("status = ?", status) } var total int64 query.Count(&total) var tasks []model.Task if err := query.Limit(pageSize).Offset(offset).Find(&tasks).Error; err != nil { Fail(c, 500, err.Error()) return } PageOK(c, tasks, total, page, pageSize) } // Start handles POST /tasks/start func (h *TaskHandler) Start(c *gin.Context) { var req service.StartTaskRequest if err := c.ShouldBindJSON(&req); err != nil { Fail(c, 400, err.Error()) return } if !validTaskTypes[req.TaskType] { Fail(c, 400, fmt.Sprintf("invalid task_type: %s", req.TaskType)) return } task, err := h.taskService.StartTask(req) if err != nil { Fail(c, 409, err.Error()) return } c.JSON(http.StatusCreated, Response{Code: 0, Message: "ok", Data: task}) } // Get handles GET /tasks/:id func (h *TaskHandler) Get(c *gin.Context) { id, err := strconv.ParseUint(c.Param("id"), 10, 64) if err != nil { Fail(c, 400, "invalid id") return } var task model.Task if err := h.db.First(&task, id).Error; err != nil { Fail(c, 404, "task not found") return } progress := h.taskService.GetProgress(&task) OK(c, gin.H{ "task": task, "progress": progress, }) } // Stop handles POST /tasks/:id/stop func (h *TaskHandler) Stop(c *gin.Context) { id, err := strconv.ParseUint(c.Param("id"), 10, 64) if err != nil { Fail(c, 400, "invalid id") return } var body struct { Force bool `json:"force"` } _ = c.ShouldBindJSON(&body) if err := h.taskService.StopTask(uint(id), body.Force); err != nil { Fail(c, 500, err.Error()) return } OK(c, gin.H{"message": "stop signal sent"}) } // Logs handles GET /tasks/:id/logs via WebSocket. // On connect it immediately sends history logs from Redis, then streams live progress // until the task finishes or the client disconnects. func (h *TaskHandler) Logs(c *gin.Context) { id, err := strconv.ParseUint(c.Param("id"), 10, 64) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"}) return } conn, err := h.upgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { return } defer conn.Close() ctx := c.Request.Context() ticker := time.NewTicker(time.Second) defer ticker.Stop() send := func(msg string) bool { err := conn.WriteMessage(websocket.TextMessage, []byte(msg)) return err == nil } // Fetch task record immediately. var task model.Task if err := h.db.First(&task, id).Error; err != nil { send(fmt.Sprintf("[错误] 任务 #%d 不存在", id)) return } // Send history logs from Redis list first. logKey := fmt.Sprintf("spider:task:logs:%d", id) historyLogs, _ := h.redis.LRange(ctx, logKey, 0, -1).Result() for _, line := range historyLogs { if !send(line) { return } } // If no history logs, send current task status summary. if len(historyLogs) == 0 { send(fmt.Sprintf("[信息] 任务 #%d (%s) 状态: %s", task.ID, task.TaskType, task.Status)) // Also send current Redis progress if available. progressKey := fmt.Sprintf("spider:task:progress:%d", id) vals, _ := h.redis.HGetAll(ctx, progressKey).Result() if len(vals) > 0 { msg := fmt.Sprintf("[进度] 阶段: %s | 进度: %s/%s | %s", vals["phase"], vals["current"], vals["total"], vals["message"]) send(msg) } } // If the task has already finished, send completion message and close. if task.Status == "completed" || task.Status == "failed" || task.Status == "stopped" { statusLabel := map[string]string{ "completed": "完成", "failed": "失败", "stopped": "停止", }[task.Status] send(fmt.Sprintf("[完成] 任务已%s", statusLabel)) return } // Task is still running — handle client close messages in the background. clientGone := make(chan struct{}) go func() { for { if _, _, err := conn.ReadMessage(); err != nil { close(clientGone) return } } }() progressKey := fmt.Sprintf("spider:task:progress:%d", id) for { select { case <-clientGone: return case <-ticker.C: var t model.Task if err := h.db.First(&t, id).Error; err != nil { return } vals, _ := h.redis.HGetAll(ctx, progressKey).Result() if len(vals) > 0 { msg := fmt.Sprintf("[进度] 阶段: %s | %s/%s | %s", vals["phase"], vals["current"], vals["total"], vals["message"]) if !send(msg) { return } } if t.Status == "completed" || t.Status == "failed" || t.Status == "stopped" { send(fmt.Sprintf("[完成] 任务已结束,状态: %s", t.Status)) return } } } }