package handler import ( "context" "fmt" "net/http" "strings" "time" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "github.com/redis/go-redis/v9" "golang.org/x/crypto/bcrypt" "spider/internal/config" "spider/internal/model" "spider/internal/store" "gorm.io/gorm" ) const jwtExpiry = 7 * 24 * time.Hour // 7 days func getJWTSecret() string { if cfg := config.Get(); cfg != nil && cfg.Security.JWTSecret != "" { return cfg.Security.JWTSecret } return "spider-jwt-secret-2026" } // rdb is the shared Redis client for token blacklist. Set via SetAuthRedis. var authRedis *redis.Client // SetAuthRedis sets the Redis client used for token blacklisting. func SetAuthRedis(r *redis.Client) { authRedis = r } // AuthHandler handles authentication. type AuthHandler struct { store *store.Store } // LoginRequest is the login payload. type LoginRequest struct { Username string `json:"username" binding:"required"` Password string `json:"password" binding:"required"` } // Login handles POST /auth/login func (h *AuthHandler) Login(c *gin.Context) { var req LoginRequest if err := c.ShouldBindJSON(&req); err != nil { Fail(c, 400, "请输入用户名和密码") return } var user model.User if err := h.store.DB.Where("username = ?", req.Username).First(&user).Error; err != nil { Fail(c, 401, "用户名或密码错误") return } if !user.Enabled { Fail(c, 403, "账号已禁用") return } if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(req.Password)); err != nil { Fail(c, 401, "用户名或密码错误") return } // Generate JWT with expiration token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ "user_id": user.ID, "username": user.Username, "role": user.Role, "exp": time.Now().Add(jwtExpiry).Unix(), }) tokenStr, err := token.SignedString([]byte(getJWTSecret())) if err != nil { Fail(c, 500, "生成令牌失败") return } // Update last login info & audit now := time.Now() ip := c.ClientIP() go func() { h.store.DB.Model(&user).Updates(map[string]any{ "last_login_at": now, "last_login_ip": ip, }) h.store.DB.Create(&model.AuditLog{ Username: user.Username, Action: "login", TargetType: "user", TargetID: fmt.Sprintf("%d", user.ID), IP: ip, }) }() OK(c, gin.H{ "token": tokenStr, "user": gin.H{ "id": user.ID, "username": user.Username, "nickname": user.Nickname, "role": user.Role, "must_change_password": user.MustChangePassword, }, }) } // ChangePassword handles PUT /auth/password func (h *AuthHandler) ChangePassword(c *gin.Context) { userID := c.GetUint("user_id") var req struct { OldPassword string `json:"old_password" binding:"required"` NewPassword string `json:"new_password" binding:"required,min=6"` } if err := c.ShouldBindJSON(&req); err != nil { Fail(c, 400, "新密码至少6位") return } var user model.User if err := h.store.DB.First(&user, userID).Error; err != nil { Fail(c, 404, "用户不存在") return } if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(req.OldPassword)); err != nil { Fail(c, 400, "旧密码错误") return } if err := ValidatePassword(req.NewPassword); err != nil { Fail(c, 400, err.Error()) return } hashed, _ := bcrypt.GenerateFromPassword([]byte(req.NewPassword), bcrypt.DefaultCost) h.store.DB.Model(&user).Updates(map[string]any{ "password": string(hashed), "must_change_password": false, }) LogAudit(h.store, c, "update", "user", fmt.Sprintf("%d", userID), gin.H{"action": "change_password"}) OK(c, gin.H{"message": "密码已修改"}) } // GetProfile handles GET /auth/profile func (h *AuthHandler) GetProfile(c *gin.Context) { userID := c.GetUint("user_id") var user model.User if err := h.store.DB.First(&user, userID).Error; err != nil { Fail(c, 404, "用户不存在") return } OK(c, gin.H{ "id": user.ID, "username": user.Username, "nickname": user.Nickname, "role": user.Role, }) } // UpdateProfile handles PUT /auth/profile — user updates their own nickname func (h *AuthHandler) UpdateProfile(c *gin.Context) { userID := c.GetUint("user_id") var req struct { Nickname *string `json:"nickname"` } if err := c.ShouldBindJSON(&req); err != nil { Fail(c, 400, err.Error()) return } var user model.User if err := h.store.DB.First(&user, userID).Error; err != nil { Fail(c, 404, "用户不存在") return } if req.Nickname != nil { h.store.DB.Model(&user).Update("nickname", *req.Nickname) } h.store.DB.First(&user, userID) OK(c, gin.H{ "id": user.ID, "username": user.Username, "nickname": user.Nickname, "role": user.Role, }) } // Logout handles POST /auth/logout — blacklists the current token func (h *AuthHandler) Logout(c *gin.Context) { auth := c.GetHeader("Authorization") if auth != "" && strings.HasPrefix(auth, "Bearer ") { tokenStr := strings.TrimPrefix(auth, "Bearer ") if authRedis != nil { // Blacklist token until its expiry authRedis.Set(context.Background(), "spider:token:blacklist:"+tokenStr, "1", jwtExpiry) } } LogAudit(h.store, c, "logout", "user", c.GetString("username"), nil) OK(c, gin.H{"message": "已退出"}) } // ── JWT Middleware ── // JWTAuth is the authentication middleware. func JWTAuth() gin.HandlerFunc { return func(c *gin.Context) { auth := c.GetHeader("Authorization") if auth == "" || !strings.HasPrefix(auth, "Bearer ") { c.AbortWithStatusJSON(http.StatusUnauthorized, Response{Code: 401, Message: "未登录"}) return } tokenStr := strings.TrimPrefix(auth, "Bearer ") // Check blacklist if authRedis != nil { blacklisted, _ := authRedis.Exists(context.Background(), "spider:token:blacklist:"+tokenStr).Result() if blacklisted > 0 { c.AbortWithStatusJSON(http.StatusUnauthorized, Response{Code: 401, Message: "令牌已失效"}) return } } token, err := jwt.Parse(tokenStr, func(t *jwt.Token) (interface{}, error) { return []byte(getJWTSecret()), nil }) if err != nil || !token.Valid { c.AbortWithStatusJSON(http.StatusUnauthorized, Response{Code: 401, Message: "令牌无效"}) return } claims, ok := token.Claims.(jwt.MapClaims) if !ok { c.AbortWithStatusJSON(http.StatusUnauthorized, Response{Code: 401, Message: "令牌解析失败"}) return } // Set user info in context if uid, ok := claims["user_id"].(float64); ok { c.Set("user_id", uint(uid)) } if username, ok := claims["username"].(string); ok { c.Set("username", username) } if role, ok := claims["role"].(string); ok { c.Set("role", role) } c.Next() } } // RequireRole returns middleware that checks the user's role. func RequireRole(roles ...string) gin.HandlerFunc { roleSet := make(map[string]bool) for _, r := range roles { roleSet[r] = true } return func(c *gin.Context) { role := c.GetString("role") if !roleSet[role] { c.AbortWithStatusJSON(http.StatusForbidden, Response{Code: 403, Message: "权限不足"}) return } c.Next() } } // RequireAction returns middleware that checks the user's role has the required action permission. // Falls back to default permissions if no DB record exists. func RequireAction(action string) gin.HandlerFunc { return func(c *gin.Context) { role := c.GetString("role") // Admin always has all permissions if role == "admin" { c.Next() return } // Check DB for role permissions var perm model.RolePermission if err := getPermissionDB().Where("role = ?", role).First(&perm).Error; err == nil { // Found in DB for _, a := range strings.Split(perm.Actions, ",") { if strings.TrimSpace(a) == action { c.Next() return } } } else { // Fallback to defaults defaults := model.DefaultPermissions() if d, ok := defaults[role]; ok { for _, a := range strings.Split(d.Actions, ",") { if strings.TrimSpace(a) == action { c.Next() return } } } } c.AbortWithStatusJSON(http.StatusForbidden, Response{Code: 403, Message: "无此操作权限"}) } } // permDB is cached reference to avoid import cycles var permDB interface{ Where(query interface{}, args ...interface{}) *gorm.DB } func setPermissionDB(db *gorm.DB) { permDB = db } func getPermissionDB() *gorm.DB { if permDB == nil { return nil } return permDB.(*gorm.DB) } // HashPassword hashes a password with bcrypt. func HashPassword(password string) string { hashed, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) return string(hashed) } // ValidatePassword checks password meets complexity requirements. func ValidatePassword(password string) error { minLen := 8 if cfg := config.Get(); cfg != nil && cfg.Security.PasswordMinLen > 0 { minLen = cfg.Security.PasswordMinLen } if len(password) < minLen { return fmt.Errorf("密码至少 %d 位", minLen) } var hasUpper, hasLower, hasDigit bool for _, c := range password { switch { case c >= 'A' && c <= 'Z': hasUpper = true case c >= 'a' && c <= 'z': hasLower = true case c >= '0' && c <= '9': hasDigit = true } } if !hasUpper || !hasLower || !hasDigit { return fmt.Errorf("密码必须包含大写字母、小写字母和数字") } return nil }