package config import ( "os" "strconv" "time" "github.com/spf13/viper" ) type Config struct { Server ServerConfig `mapstructure:"server"` Database DatabaseConfig `mapstructure:"database"` Speed SpeedConfig `mapstructure:"speed"` Log LogConfig `mapstructure:"log"` } type ServerConfig struct { Port int `mapstructure:"port"` Host string `mapstructure:"host"` } type DatabaseConfig struct { Host string `mapstructure:"host"` Port int `mapstructure:"port"` Username string `mapstructure:"username"` Password string `mapstructure:"password"` Database string `mapstructure:"database"` } type SpeedConfig struct { TestURLs []string `mapstructure:"test_urls"` Timeout time.Duration `mapstructure:"timeout"` Concurrency int `mapstructure:"concurrency"` Interval time.Duration `mapstructure:"interval"` TestOnStart bool `mapstructure:"test_on_start"` } type LogConfig struct { Level string `mapstructure:"level"` File string `mapstructure:"file"` } func Load() (*Config, error) { // 设置默认值 setDefaults() // 读取环境变量 viper.AutomaticEnv() // 读取配置文件 viper.SetConfigName("config") viper.SetConfigType("yaml") viper.AddConfigPath(".") viper.AddConfigPath("./config") if err := viper.ReadInConfig(); err != nil { // 如果配置文件不存在,使用默认值 if _, ok := err.(viper.ConfigFileNotFoundError); !ok { return nil, err } } var config Config if err := viper.Unmarshal(&config); err != nil { return nil, err } return &config, nil } func setDefaults() { // 服务器配置 viper.SetDefault("server.port", 3000) viper.SetDefault("server.host", "0.0.0.0") // 数据库配置 viper.SetDefault("database.host", "localhost") viper.SetDefault("database.port", 3306) viper.SetDefault("database.username", "root") viper.SetDefault("database.password", "") viper.SetDefault("database.database", "clash_speed_test") // 测速配置 viper.SetDefault("speed.test_urls", []string{ "https://www.google.com", "https://www.youtube.com", "https://www.github.com", }) viper.SetDefault("speed.timeout", 10*time.Second) viper.SetDefault("speed.concurrency", 5) viper.SetDefault("speed.interval", 5*time.Minute) viper.SetDefault("speed.test_on_start", true) // 日志配置 viper.SetDefault("log.level", "info") viper.SetDefault("log.file", "") // 环境变量覆盖 if port := os.Getenv("PORT"); port != "" { if p, err := strconv.Atoi(port); err == nil { viper.Set("server.port", p) } } if dbHost := os.Getenv("DB_HOST"); dbHost != "" { viper.Set("database.host", dbHost) } if dbPort := os.Getenv("DB_PORT"); dbPort != "" { if p, err := strconv.Atoi(dbPort); err == nil { viper.Set("database.port", p) } } if dbUser := os.Getenv("DB_USER"); dbUser != "" { viper.Set("database.username", dbUser) } if dbPass := os.Getenv("DB_PASS"); dbPass != "" { viper.Set("database.password", dbPass) } if dbName := os.Getenv("DB_NAME"); dbName != "" { viper.Set("database.database", dbName) } if interval := os.Getenv("SPEED_TEST_INTERVAL"); interval != "" { if d, err := time.ParseDuration(interval); err == nil { viper.Set("speed.interval", d) } } if timeout := os.Getenv("SPEED_TEST_TIMEOUT"); timeout != "" { if d, err := time.ParseDuration(timeout); err == nil { viper.Set("speed.timeout", d) } } }