123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139 |
- package config
- import (
- "os"
- "strconv"
- "time"
- "github.com/spf13/viper"
- )
- type Config struct {
- Server ServerConfig `mapstructure:"server"`
- Database DatabaseConfig `mapstructure:"database"`
- Speed SpeedConfig `mapstructure:"speed"`
- Log LogConfig `mapstructure:"log"`
- }
- type ServerConfig struct {
- Port int `mapstructure:"port"`
- Host string `mapstructure:"host"`
- }
- type DatabaseConfig struct {
- Host string `mapstructure:"host"`
- Port int `mapstructure:"port"`
- Username string `mapstructure:"username"`
- Password string `mapstructure:"password"`
- Database string `mapstructure:"database"`
- }
- type SpeedConfig struct {
- TestURLs []string `mapstructure:"test_urls"`
- Timeout time.Duration `mapstructure:"timeout"`
- Concurrency int `mapstructure:"concurrency"`
- Interval time.Duration `mapstructure:"interval"`
- TestOnStart bool `mapstructure:"test_on_start"`
- }
- type LogConfig struct {
- Level string `mapstructure:"level"`
- File string `mapstructure:"file"`
- }
- func Load() (*Config, error) {
- // 设置默认值
- setDefaults()
- // 读取环境变量
- viper.AutomaticEnv()
- // 读取配置文件
- viper.SetConfigName("config")
- viper.SetConfigType("yaml")
- viper.AddConfigPath(".")
- viper.AddConfigPath("./config")
- if err := viper.ReadInConfig(); err != nil {
- // 如果配置文件不存在,使用默认值
- if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
- return nil, err
- }
- }
- var config Config
- if err := viper.Unmarshal(&config); err != nil {
- return nil, err
- }
- return &config, nil
- }
- func setDefaults() {
- // 服务器配置
- viper.SetDefault("server.port", 3000)
- viper.SetDefault("server.host", "0.0.0.0")
- // 数据库配置
- viper.SetDefault("database.host", "localhost")
- viper.SetDefault("database.port", 3306)
- viper.SetDefault("database.username", "root")
- viper.SetDefault("database.password", "")
- viper.SetDefault("database.database", "clash_speed_test")
- // 测速配置
- viper.SetDefault("speed.test_urls", []string{
- "https://www.google.com",
- "https://www.youtube.com",
- "https://www.github.com",
- })
- viper.SetDefault("speed.timeout", 10*time.Second)
- viper.SetDefault("speed.concurrency", 5)
- viper.SetDefault("speed.interval", 5*time.Minute)
- viper.SetDefault("speed.test_on_start", true)
- // 日志配置
- viper.SetDefault("log.level", "info")
- viper.SetDefault("log.file", "")
- // 环境变量覆盖
- if port := os.Getenv("PORT"); port != "" {
- if p, err := strconv.Atoi(port); err == nil {
- viper.Set("server.port", p)
- }
- }
- if dbHost := os.Getenv("DB_HOST"); dbHost != "" {
- viper.Set("database.host", dbHost)
- }
- if dbPort := os.Getenv("DB_PORT"); dbPort != "" {
- if p, err := strconv.Atoi(dbPort); err == nil {
- viper.Set("database.port", p)
- }
- }
- if dbUser := os.Getenv("DB_USER"); dbUser != "" {
- viper.Set("database.username", dbUser)
- }
- if dbPass := os.Getenv("DB_PASS"); dbPass != "" {
- viper.Set("database.password", dbPass)
- }
- if dbName := os.Getenv("DB_NAME"); dbName != "" {
- viper.Set("database.database", dbName)
- }
- if interval := os.Getenv("SPEED_TEST_INTERVAL"); interval != "" {
- if d, err := time.ParseDuration(interval); err == nil {
- viper.Set("speed.interval", d)
- }
- }
- if timeout := os.Getenv("SPEED_TEST_TIMEOUT"); timeout != "" {
- if d, err := time.ParseDuration(timeout); err == nil {
- viper.Set("speed.timeout", d)
- }
- }
- }
|