| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353 |
- package handler
- import (
- "context"
- "fmt"
- "net/http"
- "strconv"
- "strings"
- "time"
- "github.com/gin-gonic/gin"
- "github.com/gorilla/websocket"
- "spider/internal/model"
- "spider/internal/store"
- "spider/internal/task"
- )
- // TaskHandler handles task-related HTTP and WebSocket requests.
- type TaskHandler struct {
- store *store.Store
- taskMgr *task.Manager
- }
- // List handles GET /tasks
- func (h *TaskHandler) List(c *gin.Context) {
- page, pageSize, offset := parsePage(c)
- status := c.Query("status")
- plugin := c.Query("plugin_name")
- dateFrom := c.Query("date_from")
- dateTo := c.Query("date_to")
- query := h.store.DB.Model(&model.TaskLog{}).Order("created_at DESC")
- if status != "" {
- query = query.Where("status = ?", status)
- }
- if plugin != "" {
- query = query.Where("plugin_name = ?", plugin)
- }
- if dateFrom != "" {
- t, err := time.Parse("2006-01-02", dateFrom)
- if err == nil {
- query = query.Where("created_at >= ?", t)
- }
- }
- if dateTo != "" {
- t, err := time.Parse("2006-01-02", dateTo)
- if err == nil {
- query = query.Where("created_at < ?", t.AddDate(0, 0, 1))
- }
- }
- var total int64
- query.Count(&total)
- var tasks []model.TaskLog
- if err := query.Limit(pageSize).Offset(offset).Find(&tasks).Error; err != nil {
- Fail(c, 500, err.Error())
- return
- }
- PageOK(c, tasks, total, page, pageSize)
- }
- // Start handles POST /tasks/start
- func (h *TaskHandler) Start(c *gin.Context) {
- var req task.StartRequest
- if err := c.ShouldBindJSON(&req); err != nil {
- Fail(c, 400, err.Error())
- return
- }
- // Validate proxy_mode
- if req.ProxyMode != "" && req.ProxyMode != "single" && req.ProxyMode != "pool" {
- Fail(c, 400, "proxy_mode 必须是 single 或 pool")
- return
- }
- if req.ProxyMode == "single" && (req.ProxyID == nil || *req.ProxyID == 0) {
- Fail(c, 400, "固定代理模式需要指定 proxy_id")
- return
- }
- // Special case: clean task
- if req.PluginName == "clean" {
- taskLog, err := h.taskMgr.StartClean()
- if err != nil {
- Fail(c, 409, err.Error())
- return
- }
- LogAudit(h.store, c, "create", "task", fmt.Sprintf("%d", taskLog.ID), gin.H{"type": "clean"})
- c.JSON(http.StatusCreated, Response{Code: 0, Message: "ok", Data: taskLog})
- return
- }
- taskLog, err := h.taskMgr.StartTask(req)
- if err != nil {
- Fail(c, 409, err.Error())
- return
- }
- LogAudit(h.store, c, "create", "task", fmt.Sprintf("%d", taskLog.ID), gin.H{"plugin": req.PluginName})
- c.JSON(http.StatusCreated, Response{Code: 0, Message: "ok", Data: taskLog})
- }
- // Get handles GET /tasks/:id
- func (h *TaskHandler) Get(c *gin.Context) {
- id, err := strconv.ParseUint(c.Param("id"), 10, 64)
- if err != nil {
- Fail(c, 400, "invalid id")
- return
- }
- var taskLog model.TaskLog
- if err := h.store.DB.First(&taskLog, id).Error; err != nil {
- Fail(c, 404, "task not found")
- return
- }
- progress := h.taskMgr.GetProgress(uint(id))
- OK(c, gin.H{
- "task": taskLog,
- "progress": progress,
- })
- }
- // Stop handles POST /tasks/:id/stop
- func (h *TaskHandler) Stop(c *gin.Context) {
- id, err := strconv.ParseUint(c.Param("id"), 10, 64)
- if err != nil {
- Fail(c, 400, "invalid id")
- return
- }
- if err := h.taskMgr.StopTask(uint(id)); err != nil {
- Fail(c, 500, err.Error())
- return
- }
- LogAudit(h.store, c, "update", "task", fmt.Sprintf("%d", id), gin.H{"action": "stop"})
- OK(c, gin.H{"message": "stop signal sent"})
- }
- // Retry handles POST /tasks/:id/retry — restarts a failed/stopped task with same config.
- func (h *TaskHandler) Retry(c *gin.Context) {
- id, err := strconv.ParseUint(c.Param("id"), 10, 64)
- if err != nil {
- Fail(c, 400, "invalid id")
- return
- }
- var original model.TaskLog
- if err := h.store.DB.First(&original, id).Error; err != nil {
- Fail(c, 404, "task not found")
- return
- }
- if original.Status != "failed" && original.Status != "stopped" {
- Fail(c, 400, "只能重试失败或已停止的任务")
- return
- }
- var taskLog *model.TaskLog
- if original.TaskType == "clean" {
- taskLog, err = h.taskMgr.StartClean()
- } else {
- req := task.StartRequest{
- PluginName: original.PluginName,
- ProxyMode: original.ProxyMode,
- }
- if original.ProxyID != nil {
- req.ProxyID = original.ProxyID
- }
- taskLog, err = h.taskMgr.StartTask(req)
- }
- if err != nil {
- Fail(c, 409, err.Error())
- return
- }
- LogAudit(h.store, c, "create", "task", fmt.Sprintf("%d", taskLog.ID), gin.H{"retry_from": id})
- c.JSON(http.StatusCreated, Response{Code: 0, Message: "ok", Data: taskLog})
- }
- // Logs handles GET /tasks/:id/logs via WebSocket.
- // Uses Redis Pub/Sub for real-time updates instead of polling the database.
- func (h *TaskHandler) Logs(c *gin.Context) {
- id, err := strconv.ParseUint(c.Param("id"), 10, 64)
- if err != nil {
- c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"})
- return
- }
- upgrader := websocket.Upgrader{
- CheckOrigin: func(r *http.Request) bool { return true },
- }
- conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
- if err != nil {
- return
- }
- defer conn.Close()
- send := func(msg string) bool {
- err := conn.WriteMessage(websocket.TextMessage, []byte(msg))
- return err == nil
- }
- // Fetch task record
- var taskLog model.TaskLog
- if err := h.store.DB.First(&taskLog, id).Error; err != nil {
- send(fmt.Sprintf("[错误] 任务 #%d 不存在", id))
- return
- }
- // Send history logs
- logs := h.taskMgr.GetLogs(uint(id))
- for _, line := range logs {
- if !send(line) {
- return
- }
- }
- // If already finished, close
- if taskLog.Status == "completed" || taskLog.Status == "failed" || taskLog.Status == "stopped" {
- send(fmt.Sprintf("[完成] 任务已结束,状态: %s", taskLog.Status))
- return
- }
- // Subscribe to Redis Pub/Sub for real-time updates (no DB polling!)
- rdb := h.taskMgr.GetRedis()
- pubsub := rdb.Subscribe(context.Background(), task.TaskPubSubChannel(uint(id)))
- defer pubsub.Close()
- // Detect client disconnect
- clientGone := make(chan struct{})
- go func() {
- for {
- if _, _, err := conn.ReadMessage(); err != nil {
- close(clientGone)
- return
- }
- }
- }()
- ch := pubsub.Channel()
- // Safety timeout: if no updates for 5 minutes, check DB and close
- timeout := time.NewTimer(5 * time.Minute)
- defer timeout.Stop()
- for {
- select {
- case <-clientGone:
- return
- case msg, ok := <-ch:
- if !ok {
- return
- }
- if !send(msg.Payload) {
- return
- }
- // Check if task completed
- if isFinishMessage(msg.Payload) {
- return
- }
- timeout.Reset(5 * time.Minute)
- case <-timeout.C:
- // Fallback: check DB if task is still running
- var t model.TaskLog
- if err := h.store.DB.First(&t, id).Error; err != nil {
- return
- }
- if t.Status == "completed" || t.Status == "failed" || t.Status == "stopped" {
- send(fmt.Sprintf("[完成] 任务已结束,状态: %s", t.Status))
- return
- }
- timeout.Reset(5 * time.Minute)
- }
- }
- }
- func isFinishMessage(msg string) bool {
- return msg == "任务完成" ||
- strings.Contains(msg, "任务已结束") ||
- strings.Contains(msg, "任务已停止") ||
- strings.Contains(msg, "采集失败")
- }
- // Details handles GET /tasks/:id/details — returns per-operation execution logs.
- // Query params: page, page_size, action (filter by action type), status (filter by status)
- func (h *TaskHandler) Details(c *gin.Context) {
- id, err := strconv.ParseUint(c.Param("id"), 10, 64)
- if err != nil {
- Fail(c, 400, "invalid id")
- return
- }
- page, pageSize, offset := parsePage(c)
- query := h.store.DB.Model(&model.TaskDetail{}).Where("task_id = ?", id)
- if action := c.Query("action"); action != "" {
- query = query.Where("action = ?", action)
- }
- if status := c.Query("status"); status != "" {
- query = query.Where("status = ?", status)
- }
- var total int64
- query.Count(&total)
- var details []model.TaskDetail
- if err := query.Order("seq ASC").Limit(pageSize).Offset(offset).Find(&details).Error; err != nil {
- Fail(c, 500, err.Error())
- return
- }
- // Also return action summary counts
- var actionCounts []struct {
- Action string
- Status string
- Cnt int64
- }
- h.store.DB.Model(&model.TaskDetail{}).
- Where("task_id = ?", id).
- Select("action, status, count(*) as cnt").
- Group("action, status").
- Scan(&actionCounts)
- summary := map[string]map[string]int64{}
- for _, ac := range actionCounts {
- if summary[ac.Action] == nil {
- summary[ac.Action] = map[string]int64{}
- }
- summary[ac.Action][ac.Status] = ac.Cnt
- }
- OK(c, gin.H{
- "items": details,
- "total": total,
- "page": page,
- "page_size": pageSize,
- "summary": summary,
- })
- }
- // Plugins handles GET /plugins — returns list of available plugins.
- func (h *TaskHandler) Plugins(c *gin.Context) {
- names := h.taskMgr.ListPlugins()
- OK(c, names)
- }
|