| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250 |
- package handler
- import (
- "fmt"
- "net/http"
- "strconv"
- "time"
- "github.com/gin-gonic/gin"
- "github.com/gorilla/websocket"
- "github.com/redis/go-redis/v9"
- "gorm.io/gorm"
- "spider/internal/model"
- "spider/internal/service"
- )
- // validTaskTypes is the set of accepted task_type values.
- var validTaskTypes = map[string]bool{
- "full": true,
- "discover": true,
- "search": true,
- "github": true,
- "scrape": true,
- "crawl": true,
- "clean": true,
- "score": true,
- }
- // TaskHandler handles task-related HTTP and WebSocket requests.
- type TaskHandler struct {
- db *gorm.DB
- taskService *service.TaskService
- redis *redis.Client
- upgrader websocket.Upgrader
- }
- // NewTaskHandler creates a TaskHandler.
- func NewTaskHandler(db *gorm.DB, svc *service.TaskService, rdb *redis.Client) *TaskHandler {
- return &TaskHandler{
- db: db,
- taskService: svc,
- redis: rdb,
- upgrader: websocket.Upgrader{
- CheckOrigin: func(r *http.Request) bool { return true },
- },
- }
- }
- // List handles GET /tasks
- // Query params: status, page, page_size
- func (h *TaskHandler) List(c *gin.Context) {
- status := c.Query("status")
- page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
- pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
- if page < 1 {
- page = 1
- }
- if pageSize < 1 || pageSize > 100 {
- pageSize = 20
- }
- offset := (page - 1) * pageSize
- query := h.db.Model(&model.Task{}).Order("created_at DESC")
- if status != "" {
- query = query.Where("status = ?", status)
- }
- var total int64
- query.Count(&total)
- var tasks []model.Task
- if err := query.Limit(pageSize).Offset(offset).Find(&tasks).Error; err != nil {
- Fail(c, 500, err.Error())
- return
- }
- PageOK(c, tasks, total, page, pageSize)
- }
- // Start handles POST /tasks/start
- func (h *TaskHandler) Start(c *gin.Context) {
- var req service.StartTaskRequest
- if err := c.ShouldBindJSON(&req); err != nil {
- Fail(c, 400, err.Error())
- return
- }
- if !validTaskTypes[req.TaskType] {
- Fail(c, 400, fmt.Sprintf("invalid task_type: %s", req.TaskType))
- return
- }
- task, err := h.taskService.StartTask(req)
- if err != nil {
- Fail(c, 409, err.Error())
- return
- }
- c.JSON(http.StatusCreated, Response{Code: 0, Message: "ok", Data: task})
- }
- // Get handles GET /tasks/:id
- func (h *TaskHandler) Get(c *gin.Context) {
- id, err := strconv.ParseUint(c.Param("id"), 10, 64)
- if err != nil {
- Fail(c, 400, "invalid id")
- return
- }
- var task model.Task
- if err := h.db.First(&task, id).Error; err != nil {
- Fail(c, 404, "task not found")
- return
- }
- progress := h.taskService.GetProgress(&task)
- OK(c, gin.H{
- "task": task,
- "progress": progress,
- })
- }
- // Stop handles POST /tasks/:id/stop
- func (h *TaskHandler) Stop(c *gin.Context) {
- id, err := strconv.ParseUint(c.Param("id"), 10, 64)
- if err != nil {
- Fail(c, 400, "invalid id")
- return
- }
- var body struct {
- Force bool `json:"force"`
- }
- _ = c.ShouldBindJSON(&body)
- if err := h.taskService.StopTask(uint(id), body.Force); err != nil {
- Fail(c, 500, err.Error())
- return
- }
- OK(c, gin.H{"message": "stop signal sent"})
- }
- // Logs handles GET /tasks/:id/logs via WebSocket.
- // On connect it immediately sends history logs from Redis, then streams live progress
- // until the task finishes or the client disconnects.
- func (h *TaskHandler) Logs(c *gin.Context) {
- id, err := strconv.ParseUint(c.Param("id"), 10, 64)
- if err != nil {
- c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"})
- return
- }
- conn, err := h.upgrader.Upgrade(c.Writer, c.Request, nil)
- if err != nil {
- return
- }
- defer conn.Close()
- ctx := c.Request.Context()
- ticker := time.NewTicker(time.Second)
- defer ticker.Stop()
- send := func(msg string) bool {
- err := conn.WriteMessage(websocket.TextMessage, []byte(msg))
- return err == nil
- }
- // Fetch task record immediately.
- var task model.Task
- if err := h.db.First(&task, id).Error; err != nil {
- send(fmt.Sprintf("[错误] 任务 #%d 不存在", id))
- return
- }
- // Send history logs from Redis list first.
- logKey := fmt.Sprintf("spider:task:logs:%d", id)
- historyLogs, _ := h.redis.LRange(ctx, logKey, 0, -1).Result()
- for _, line := range historyLogs {
- if !send(line) {
- return
- }
- }
- // If no history logs, send current task status summary.
- if len(historyLogs) == 0 {
- send(fmt.Sprintf("[信息] 任务 #%d (%s) 状态: %s", task.ID, task.TaskType, task.Status))
- // Also send current Redis progress if available.
- progressKey := fmt.Sprintf("spider:task:progress:%d", id)
- vals, _ := h.redis.HGetAll(ctx, progressKey).Result()
- if len(vals) > 0 {
- msg := fmt.Sprintf("[进度] 阶段: %s | 进度: %s/%s | %s",
- vals["phase"], vals["current"], vals["total"], vals["message"])
- send(msg)
- }
- }
- // If the task has already finished, send completion message and close.
- if task.Status == "completed" || task.Status == "failed" || task.Status == "stopped" {
- statusLabel := map[string]string{
- "completed": "完成",
- "failed": "失败",
- "stopped": "停止",
- }[task.Status]
- send(fmt.Sprintf("[完成] 任务已%s", statusLabel))
- return
- }
- // Task is still running — handle client close messages in the background.
- clientGone := make(chan struct{})
- go func() {
- for {
- if _, _, err := conn.ReadMessage(); err != nil {
- close(clientGone)
- return
- }
- }
- }()
- progressKey := fmt.Sprintf("spider:task:progress:%d", id)
- for {
- select {
- case <-clientGone:
- return
- case <-ticker.C:
- var t model.Task
- if err := h.db.First(&t, id).Error; err != nil {
- return
- }
- vals, _ := h.redis.HGetAll(ctx, progressKey).Result()
- if len(vals) > 0 {
- msg := fmt.Sprintf("[进度] 阶段: %s | %s/%s | %s",
- vals["phase"], vals["current"], vals["total"], vals["message"])
- if !send(msg) {
- return
- }
- }
- if t.Status == "completed" || t.Status == "failed" || t.Status == "stopped" {
- send(fmt.Sprintf("[完成] 任务已结束,状态: %s", t.Status))
- return
- }
- }
- }
- }
|