package handlers import ( "database/sql" "encoding/json" "fmt" "html/template" "log" "net/http" "strconv" "strings" "time" "sms-receiver-go/auth" "sms-receiver-go/config" "sms-receiver-go/database" "sms-receiver-go/models" "sms-receiver-go/sign" "github.com/gorilla/mux" ) var templates *template.Template // InitTemplates 初始化模板 func InitTemplates(templatesPath string) error { // 先创建带函数的模板 funcMap := template.FuncMap{ // 基本运算(支持 int 和 int64) "add": func(a, b interface{}) int64 { ai, _ := a.(int) ai64, _ := a.(int64) bi, _ := b.(int) bi64, _ := b.(int64) if ai64 == 0 && ai != 0 { ai64 = int64(ai) } if bi64 == 0 && bi != 0 { bi64 = int64(bi) } return ai64 + bi64 }, "sub": func(a, b int) int { return a - b }, "mul": func(a, b int) int { return a * b }, "div": func(a, b int) int { return a / b }, "ceilDiv": func(a, b int) int { return (a + b - 1) / b }, // 比较函数 "eq": func(a, b interface{}) bool { return a == b }, "ne": func(a, b interface{}) bool { return a != b }, "lt": func(a, b int) bool { return a < b }, "le": func(a, b int) bool { return a <= b }, "gt": func(a, b int) bool { return a > b }, "ge": func(a, b int) bool { return a >= b }, // 其他 "seq": createRange, "mulFloat": func(a, b int64) float64 { return float64(a) * float64(b) / 100 }, } var err error templates, err = template.New("root").Funcs(funcMap).ParseGlob(templatesPath + "/*.html") if err != nil { return fmt.Errorf("加载模板失败: %w", err) } // 调试: 打印加载的模板名称 log.Printf("已加载的模板:") for _, t := range templates.Templates() { log.Printf(" - %s", t.Name()) } return nil } // createRange 创建整数序列 func createRange(start, end int) []int { result := make([]int, end-start+1) for i := start; i <= end; i++ { result[i-start] = i } return result } // Index 首页 - 短信列表 func Index(w http.ResponseWriter, r *http.Request) { loggedIn, _ := auth.CheckLogin(w, r) if !loggedIn { http.Redirect(w, r, "/login", http.StatusSeeOther) return } // 获取查询参数 page, _ := strconv.Atoi(r.URL.Query().Get("page")) if page < 1 { page = 1 } limit := 20 from := r.URL.Query().Get("from") search := r.URL.Query().Get("search") messages, total, err := database.GetMessages(page, limit, from, search) if err != nil { log.Printf("获取消息失败: %v", err) http.Error(w, "获取消息失败", http.StatusInternalServerError) return } log.Printf("查询结果: 总数=%d, 本页=%d 条", total, len(messages)) // 获取统计数据 stats, err := database.GetStatistics() if err != nil { log.Printf("获取统计失败: %v", err) } log.Printf("统计: 总数=%d, 今日=%d, 本周=%d", stats.Total, stats.Today, stats.Week) // 获取所有发送方号码(用于筛选) fromNumbers, _ := getFromNumbers() // 计算总页数 totalPages := (total + int64(limit) - 1) / int64(limit) if totalPages == 0 { totalPages = 1 } // 格式化时间(转换为本地时区显示) cfg := config.Get() loc, _ := time.LoadLocation(cfg.Timezone) for i := range messages { // 优先显示短信时间戳(本地时间) localTime := time.UnixMilli(messages[i].Timestamp).In(loc) messages[i].LocalTimestampStr = localTime.Format("2006-01-02 15:04:05") // 同时保留 created_at 作为排序参考 messages[i].CreatedAt = messages[i].CreatedAt.In(loc) } data := map[string]interface{}{ "messages": messages, "stats": stats, "total": total, "totalPages": int(totalPages), "page": page, "limit": limit, "fromNumbers": fromNumbers, "selectedFrom": from, "search": search, } log.Printf("传递给模板的数据: messages=%d, total=%d, totalPages=%d", len(messages), total, totalPages) if len(messages) > 0 { log.Printf("第一条消息: ID=%d, From=%s, Content=%s", messages[0].ID, messages[0].FromNumber, messages[0].Content) } if err := templates.ExecuteTemplate(w, "index.html", data); err != nil { log.Printf("模板执行错误: %v", err) http.Error(w, "模板渲染失败", http.StatusInternalServerError) } } // Login 登录页面 func Login(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodGet { // 显示登录页面 errorMsg := r.URL.Query().Get("error") templates.ExecuteTemplate(w, "login.html", map[string]string{ "error": errorMsg, }) return } // 处理登录请求 username := r.FormValue("username") password := r.FormValue("password") cfg := config.Get() if cfg.Security.Enabled { if username == cfg.Security.Username && password == cfg.Security.Password { if err := auth.Login(w, r, username); err != nil { log.Printf("创建会话失败: %v", err) http.Error(w, "创建会话失败: "+err.Error(), http.StatusInternalServerError) return } http.Redirect(w, r, "/", http.StatusSeeOther) return } // 登录失败 templates.ExecuteTemplate(w, "login.html", map[string]string{ "error": "用户名或密码错误", }) return } // 未启用登录验证 auth.Login(w, r, username) http.Redirect(w, r, "/", http.StatusSeeOther) } // Logout 登出 func Logout(w http.ResponseWriter, r *http.Request) { auth.Logout(r, w) http.Redirect(w, r, "/login", http.StatusSeeOther) } // MessageDetail 短信详情页面 func MessageDetail(w http.ResponseWriter, r *http.Request) { loggedIn, _ := auth.CheckLogin(w, r) if !loggedIn { http.Redirect(w, r, "/login", http.StatusSeeOther) return } vars := mux.Vars(r) id, err := strconv.ParseInt(vars["id"], 10, 64) if err != nil { http.Error(w, "无效的消息 ID", http.StatusBadRequest) return } msg, err := database.GetMessageByID(id) if err != nil { http.Error(w, "获取消息失败", http.StatusInternalServerError) return } if msg == nil { http.Error(w, "消息不存在", http.StatusNotFound) return } // 格式化时间 cfg := config.Get() loc, _ := time.LoadLocation(cfg.Timezone) localTime := time.UnixMilli(msg.Timestamp).In(loc) msg.TimestampStr = localTime.Format("2006-01-02 15:04:05") msg.Content = strings.ReplaceAll(msg.Content, "\n", "
") templates.ExecuteTemplate(w, "message_detail.html", msg) } // Logs 接收日志页面 func Logs(w http.ResponseWriter, r *http.Request) { loggedIn, _ := auth.CheckLogin(w, r) if !loggedIn { http.Redirect(w, r, "/login", http.StatusSeeOther) return } page, _ := strconv.Atoi(r.URL.Query().Get("page")) if page < 1 { page = 1 } limit := 50 logs, total, err := database.GetLogs(page, limit) if err != nil { http.Error(w, "获取日志失败", http.StatusInternalServerError) return } // 计算总页数 totalPages := (total + int64(limit) - 1) / int64(limit) if totalPages == 0 { totalPages = 1 } data := map[string]interface{}{ "logs": logs, "total": total, "page": page, "limit": limit, "totalPages": int(totalPages), } templates.ExecuteTemplate(w, "logs.html", data) } // Statistics 统计信息页面 func Statistics(w http.ResponseWriter, r *http.Request) { loggedIn, _ := auth.CheckLogin(w, r) if !loggedIn { http.Redirect(w, r, "/login", http.StatusSeeOther) return } stats, err := database.GetStatistics() if err != nil { http.Error(w, "获取统计失败", http.StatusInternalServerError) return } data := map[string]interface{}{ "stats": stats, } templates.ExecuteTemplate(w, "statistics.html", data) } // ReceiveSMS API - 接收短信 func ReceiveSMS(w http.ResponseWriter, r *http.Request) { // 解析 multipart/form-data (优先) if err := r.ParseMultipartForm(32 << 20); err != nil { // 回退到 ParseForm if err := r.ParseForm(); err != nil { writeJSON(w, models.APIResponse{ Success: false, Error: "解析请求失败: " + err.Error(), }, http.StatusBadRequest) return } } // 获取参数 from := r.FormValue("from") content := r.FormValue("content") if from == "" || content == "" { writeJSON(w, models.APIResponse{ Success: false, Error: "缺少必填参数 (from: '" + from + "', content: '" + content + "')", }, http.StatusBadRequest) return } // 获取可选参数 timestampStr := r.FormValue("timestamp") timestamp := time.Now().UnixMilli() if timestampStr != "" { if t, err := strconv.ParseInt(timestampStr, 10, 64); err == nil { timestamp = t } } signStr := r.FormValue("sign") device := r.FormValue("device") sim := r.FormValue("sim") // 获取 Token(从 query string 或 form) token := r.URL.Query().Get("token") if token == "" { token = r.FormValue("token") } // 验证签名 cfg := config.Get() signValid := sql.NullBool{Bool: true, Valid: true} if token != "" && cfg.Security.SignVerify { valid, err := sign.VerifySign(token, timestamp, signStr, &cfg.Security) if err != nil { writeJSON(w, models.APIResponse{ Success: false, Error: "签名验证错误", }, http.StatusInternalServerError) return } signValid.Bool = valid signValid.Valid = true if !valid { signValid.Bool = false } } // 保存消息 msg := &models.SMSMessage{ FromNumber: from, Content: content, Timestamp: timestamp, DeviceInfo: sql.NullString{String: device, Valid: device != ""}, SIMInfo: sql.NullString{String: sim, Valid: sim != ""}, SignVerified: signValid, IPAddress: getClientIP(r), } messageID, err := database.InsertMessage(msg) if err != nil { // 记录失败日志 log := &models.ReceiveLog{ FromNumber: from, Content: content, Timestamp: timestamp, Sign: sql.NullString{String: signStr, Valid: signStr != ""}, SignValid: signValid, IPAddress: getClientIP(r), Status: "error", ErrorMessage: sql.NullString{String: err.Error(), Valid: true}, } database.InsertLog(log) writeJSON(w, models.APIResponse{ Success: false, Error: "保存消息失败", }, http.StatusInternalServerError) return } // 记录成功日志 log := &models.ReceiveLog{ FromNumber: from, Content: content, Timestamp: timestamp, Sign: sql.NullString{String: signStr, Valid: signStr != ""}, SignValid: signValid, IPAddress: getClientIP(r), Status: "success", } database.InsertLog(log) writeJSON(w, models.APIResponse{ Success: true, Message: "短信已接收", MessageID: messageID, }, http.StatusOK) } // APIGetMessages API - 获取消息列表 func APIGetMessages(w http.ResponseWriter, r *http.Request) { if !isAPIAuthenticated(r) { writeJSON(w, models.APIResponse{Success: false, Error: "未授权"}, http.StatusUnauthorized) return } page, _ := strconv.Atoi(r.URL.Query().Get("page")) if page < 1 { page = 1 } limit, _ := strconv.Atoi(r.URL.Query().Get("limit")) if limit <= 0 { limit = 20 } if limit > 100 { limit = 100 } from := r.URL.Query().Get("from") search := r.URL.Query().Get("search") messages, total, err := database.GetMessages(page, limit, from, search) if err != nil { writeJSON(w, models.APIResponse{Success: false, Error: "获取消息失败"}, http.StatusInternalServerError) return } // 格式化时间 cfg := config.Get() loc, _ := time.LoadLocation(cfg.Timezone) for i := range messages { localTime := time.UnixMilli(messages[i].Timestamp).In(loc) messages[i].LocalTimestampStr = localTime.Format("2006-01-02 15:04:05") } response := models.MessageListResponse{ Success: true, Data: messages, Total: total, Page: page, Limit: limit, } writeJSON(w, response, http.StatusOK) } // APIStatistics API - 获取统计信息 func APIStatistics(w http.ResponseWriter, r *http.Request) { if !isAPIAuthenticated(r) { writeJSON(w, models.APIResponse{Success: false, Error: "未授权"}, http.StatusUnauthorized) return } stats, err := database.GetStatistics() if err != nil { writeJSON(w, models.APIResponse{Success: false, Error: "获取统计失败"}, http.StatusInternalServerError) return } response := models.StatisticsResponse{ Success: true, Data: *stats, } writeJSON(w, response, http.StatusOK) } // StaticFile 处理静态文件 func StaticFile(w http.ResponseWriter, r *http.Request) { http.ServeFile(w, r, "static"+r.URL.Path) } // 辅助函数 func writeJSON(w http.ResponseWriter, data interface{}, status int) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) json.NewEncoder(w).Encode(data) } func getClientIP(r *http.Request) string { forwarded := r.Header.Get("X-Forwarded-For") if forwarded != "" { return strings.Split(forwarded, ",")[0] } return r.RemoteAddr } func isAPIAuthenticated(r *http.Request) bool { cfg := config.Get() if !cfg.Security.Enabled { return true } loggedIn, _ := auth.IsLoggedIn(r) return loggedIn } func getFromNumbers() ([]string, error) { rows, err := database.GetDB().Query("SELECT DISTINCT from_number FROM sms_messages ORDER BY from_number") if err != nil { return nil, err } defer rows.Close() var numbers []string for rows.Next() { var number string if err := rows.Scan(&number); err != nil { return nil, err } numbers = append(numbers, number) } return numbers, nil }