🔴 高优先级 (6项全部完成): - 数据库事务支持 (InsertMessageWithLog) - SQL注入修复 (参数化查询) - 配置验证 (Validate方法) - 会话密钥强化 (长度验证) - 签名验证增强 (SignVerificationResult) - 密码哈希支持 (bcrypt) 🟡 中优先级 (15项全部完成): - 连接池配置 (MaxOpenConns, MaxIdleConns) - 查询优化 (范围查询, 索引) - 健康检查增强 (/health 端点) - API版本控制 (/api/v1/*) - 认证中间件 (RequireAuth, RequireAPIAuth) - 定时任务优化 (robfig/cron) - 配置文件示例 (config.example.yaml) - 常量定义 (config/constants.go) - 开发文档 (DEVELOPMENT.md) 🟢 低优先级 (9项全部完成): - Docker支持 (Dockerfile, docker-compose.yml) - Makefile构建脚本 - 优化报告 (OPTIMIZATION_REPORT.md) - 密码哈希工具 (tools/password_hash.go) - 14个新文件 - 30项优化100%完成 版本: v2.0.0
214 lines
5.7 KiB
Go
214 lines
5.7 KiB
Go
package config
|
||
|
||
import (
|
||
"fmt"
|
||
"log"
|
||
"os"
|
||
"time"
|
||
|
||
"github.com/spf13/viper"
|
||
)
|
||
|
||
type Config struct {
|
||
App AppConfig `mapstructure:"app"`
|
||
Server ServerConfig `mapstructure:"server"`
|
||
Security SecurityConfig `mapstructure:"security"`
|
||
SMS SMSConfig `mapstructure:"sms"`
|
||
Database DatabaseConfig `mapstructure:"database"`
|
||
Timezone string `mapstructure:"timezone"`
|
||
APITokens []APIToken `mapstructure:"api_tokens"`
|
||
}
|
||
|
||
type AppConfig struct {
|
||
Name string `mapstructure:"name"`
|
||
Version string `mapstructure:"version"`
|
||
}
|
||
|
||
type ServerConfig struct {
|
||
Host string `mapstructure:"host"`
|
||
Port int `mapstructure:"port"`
|
||
Debug bool `mapstructure:"debug"`
|
||
}
|
||
|
||
type SecurityConfig struct {
|
||
Enabled bool `mapstructure:"enabled"`
|
||
Username string `mapstructure:"username"`
|
||
Password string `mapstructure:"password"`
|
||
PasswordHash string `mapstructure:"password_hash"` // bcrypt 哈希值(推荐使用)
|
||
SessionLifetime int `mapstructure:"session_lifetime"`
|
||
SecretKey string `mapstructure:"secret_key"`
|
||
SignVerify bool `mapstructure:"sign_verify"`
|
||
SignMaxAge int64 `mapstructure:"sign_max_age"`
|
||
}
|
||
|
||
// Validate 验证配置的有效性
|
||
func (c *Config) Validate() error {
|
||
// 验证数据库路径
|
||
if c.Database.Path == "" {
|
||
return fmt.Errorf("数据库路径不能为空")
|
||
}
|
||
|
||
// 验证安全密钥
|
||
if c.Security.SecretKey == "" {
|
||
return fmt.Errorf("安全密钥不能为空,请在配置文件中设置 secret_key")
|
||
}
|
||
|
||
// 验证密钥长度(至少16字节)
|
||
key := []byte(c.Security.SecretKey)
|
||
if len(key) < 16 {
|
||
return fmt.Errorf("安全密钥长度不足,建议至少16字节(当前: %d 字节)", len(key))
|
||
}
|
||
|
||
// 设置默认值
|
||
if c.Security.SessionLifetime == 0 {
|
||
c.Security.SessionLifetime = DefaultSessionLifetime
|
||
log.Printf("使用默认会话有效期: %d 秒", DefaultSessionLifetime)
|
||
}
|
||
if c.Security.SignMaxAge == 0 {
|
||
c.Security.SignMaxAge = DefaultSignMaxAge
|
||
log.Printf("使用默认签名有效期: %d 毫秒", DefaultSignMaxAge)
|
||
}
|
||
|
||
// 如果启用了登录验证,验证用户名和密码
|
||
if c.Security.Enabled {
|
||
if c.Security.Username == "" {
|
||
return fmt.Errorf("启用登录验证时,用户名不能为空")
|
||
}
|
||
if c.Security.Password == "" && c.Security.PasswordHash == "" {
|
||
return fmt.Errorf("启用登录验证时,必须设置 password 或 password_hash")
|
||
}
|
||
}
|
||
|
||
// 验证服务器端口
|
||
if c.Server.Port < 1 || c.Server.Port > 65535 {
|
||
return fmt.Errorf("服务器端口无效: %d", c.Server.Port)
|
||
}
|
||
|
||
// 验证时区
|
||
if c.Timezone == "" {
|
||
c.Timezone = "Asia/Shanghai"
|
||
log.Printf("使用默认时区: %s", c.Timezone)
|
||
}
|
||
// 检查时区是否有效
|
||
if _, err := time.LoadLocation(c.Timezone); err != nil {
|
||
return fmt.Errorf("无效的时区配置: %s", c.Timezone)
|
||
}
|
||
|
||
// 日志提示
|
||
if c.Security.Password != "" && c.Security.PasswordHash != "" {
|
||
log.Printf("警告: 同时设置了 password 和 password_hash,将优先使用 password_hash")
|
||
}
|
||
if c.Security.Password != "" {
|
||
log.Printf("警告: 使用明文密码不安全,建议使用 password_hash")
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
type SMSConfig struct {
|
||
MaxMessages int `mapstructure:"max_messages"`
|
||
AutoCleanup bool `mapstructure:"auto_cleanup"`
|
||
CleanupDays int `mapstructure:"cleanup_days"`
|
||
}
|
||
|
||
type DatabaseConfig struct {
|
||
Path string `mapstructure:"path"`
|
||
}
|
||
|
||
type APIToken struct {
|
||
Name string `mapstructure:"name"`
|
||
Token string `mapstructure:"token"`
|
||
Secret string `mapstructure:"secret"`
|
||
Enabled bool `mapstructure:"enabled"`
|
||
}
|
||
|
||
var cfg *Config
|
||
|
||
func Load(configPath string) (*Config, error) {
|
||
viper.SetConfigFile(configPath)
|
||
viper.SetConfigType("yaml")
|
||
|
||
// 允许环境变量覆盖
|
||
viper.AutomaticEnv()
|
||
|
||
if err := viper.ReadInConfig(); err != nil {
|
||
return nil, fmt.Errorf("读取配置文件失败: %w", err)
|
||
}
|
||
|
||
cfg = &Config{}
|
||
if err := viper.Unmarshal(cfg); err != nil {
|
||
return nil, fmt.Errorf("解析配置文件失败: %w", err)
|
||
}
|
||
|
||
// 验证配置
|
||
if err := cfg.Validate(); err != nil {
|
||
return nil, fmt.Errorf("配置验证失败: %w", err)
|
||
}
|
||
|
||
return cfg, nil
|
||
}
|
||
|
||
func Get() *Config {
|
||
return cfg
|
||
}
|
||
|
||
// GetSessionLifetimeDuration 返回会话 lifetime 为 duration
|
||
func (c *Config) GetSessionLifetimeDuration() time.Duration {
|
||
return time.Duration(c.Security.SessionLifetime) * time.Second
|
||
}
|
||
|
||
// GetSignMaxAgeDuration 返回签名最大有效期
|
||
func (c *Config) GetSignMaxAgeDuration() time.Duration {
|
||
return time.Duration(c.Security.SignMaxAge) * time.Millisecond
|
||
}
|
||
|
||
// GetServerAddress 返回服务器地址
|
||
func (c *Config) GetServerAddress() string {
|
||
return fmt.Sprintf("%s:%d", c.Server.Host, c.Server.Port)
|
||
}
|
||
|
||
// GetTokenByName 根据名称获取 Token 配置
|
||
func (c *Config) GetTokenByName(name string) *APIToken {
|
||
for i := range c.APITokens {
|
||
if c.APITokens[i].Name == name && c.APITokens[i].Enabled {
|
||
return &c.APITokens[i]
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// GetTokenByValue 根据 token 值获取配置
|
||
func (c *Config) GetTokenByValue(token string) *APIToken {
|
||
for i := range c.APITokens {
|
||
if c.APITokens[i].Token == token && c.APITokens[i].Enabled {
|
||
return &c.APITokens[i]
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// Save 保存配置到文件
|
||
func (c *Config) Save(path string) error {
|
||
viper.Set("app", c.App)
|
||
viper.Set("server", c.Server)
|
||
viper.Set("security", c.Security)
|
||
viper.Set("sms", c.SMS)
|
||
viper.Set("database", c.Database)
|
||
viper.Set("timezone", c.Timezone)
|
||
viper.Set("api_tokens", c.APITokens)
|
||
|
||
return viper.WriteConfigAs(path)
|
||
}
|
||
|
||
// LoadDefault 加载默认配置文件
|
||
func LoadDefault() (*Config, error) {
|
||
configPath := "config.yaml"
|
||
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||
// 尝试查找上层目录
|
||
if _, err := os.Stat("../config.yaml"); err == nil {
|
||
configPath = "../config.yaml"
|
||
}
|
||
}
|
||
return Load(configPath)
|
||
}
|