From 80440077cf2ef43ae1a518aa18b46ea31b383043 Mon Sep 17 00:00:00 2001 From: OpenClaw Agent Date: Sun, 15 Mar 2026 04:47:03 +0800 Subject: [PATCH] fix: harden auth, sse, search, and docs --- README.md | 20 +++++- main.go | 195 ++++++++++++++++++++++++++++++++++++++++----------- mcp_tools.go | 5 +- tao_core.go | 107 +++++++++++++++++++++++++--- 4 files changed, 275 insertions(+), 52 deletions(-) diff --git a/README.md b/README.md index ac79b9d..d559868 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,7 @@ knowledge_ocean/ **特性:** - `POST /mcp/message` 只返回 `202 Accepted` - 所有 JSON‑RPC 响应通过 SSE `event: message` 返回 -- CORS 已放行(含 `OPTIONS`) +- CORS 可配置(默认不放行;设置 `TAO_CORS_ORIGINS=*` 才全放开) --- @@ -130,13 +130,27 @@ knowledge_ocean/ - `repo_path` (可选,默认 `/root/.openclaw/workspace/tao_mcp_go`) 行为: -- `git pull`(失败不阻断) +- `git pull`(需 `TAO_ALLOW_GIT_PULL=true` 才会执行,失败不阻断) +- `repo_path` 必须在 `TAO_ALLOWED_REPOS` 白名单内 - 扫描 `Inspirations` 中 `#Todo/#Fix` - 生成 `_Proposals/proposal_.md` - 当日日记记录摘要 --- +## 配置项(安全建议) + +- **TAO_AUTH_TOKEN**:必填;不设置则启动失败 +- **TAO_ALLOW_ANON**:是否允许匿名访问(默认 false) +- **TAO_CORS_ORIGINS**:允许的来源列表,逗号分隔;`*` 为全放开 +- **TAO_DEBUG**:是否输出请求体日志(默认 false) +- **TAO_SEARCH_ROOT**:检索根目录(必须在 MEMORY_ROOT 下) +- **TAO_SEARCH_MAX_FILES**:检索文件上限(默认 2000) +- **TAO_ALLOWED_REPOS**:inspect_and_propose 允许的仓库白名单(逗号分隔) +- **TAO_ALLOW_GIT_PULL**:是否允许 inspect_and_propose 执行 git pull(默认 false) + +--- + ## OpenClaw 接入(示例) **Base URL** @@ -146,7 +160,7 @@ https://mcp.good.xx.kg **Auth Token** ``` -a3c60a86ed2a7d317b8855faa94a05d1 +YOUR_TOKEN_HERE ``` **Instructions(粘贴)** diff --git a/main.go b/main.go index f960898..f45e035 100644 --- a/main.go +++ b/main.go @@ -8,6 +8,8 @@ import ( "log" "net/http" "os" + "path/filepath" + "strconv" "strings" "time" ) @@ -90,14 +92,11 @@ func (s *TaoServer) dispatchMCP(token string, client string, req MCPRequest) { return } - if token == "" { + if token == "" && !getEnvBool("TAO_ALLOW_ANON", false) { log.Printf("[MCP Response] missing token for method=%s", req.Method) return } - connKey := token - if client != "" { - connKey = token + "_" + client - } + connKey := buildConnKey(token, client) if ch, ok := s.conns.Load(connKey); ok { if b, err := json.Marshal(resp); err == nil { ch.(chan string) <- string(b) @@ -115,19 +114,120 @@ func getEnv(key, def string) string { return def } +func getEnvBool(key string, def bool) bool { + v := strings.ToLower(strings.TrimSpace(os.Getenv(key))) + if v == "" { + return def + } + switch v { + case "1", "true", "yes", "on": + return true + case "0", "false", "no", "off": + return false + default: + return def + } +} + +func getEnvInt(key string, def int) int { + if v := os.Getenv(key); v != "" { + if n, err := strconv.Atoi(v); err == nil { + return n + } + } + return def +} + +func extractToken(r *http.Request) (string, bool) { + if q := r.URL.Query().Get("token"); q != "" { + return q, true + } + h := r.Header.Get("Authorization") + if strings.HasPrefix(h, "Bearer ") { + return strings.TrimSpace(strings.TrimPrefix(h, "Bearer ")), false + } + return "", false +} + +func buildConnKey(token string, client string) string { + if token == "" { + token = "anon" + } + if client != "" { + return token + "_" + client + } + return token +} + +func generateClientID() string { + return fmt.Sprintf("c%d", time.Now().UnixNano()) +} + +func parseCORSOrigins() (bool, []string) { + raw := strings.TrimSpace(os.Getenv("TAO_CORS_ORIGINS")) + if raw == "" { + return false, nil + } + if raw == "*" { + return true, nil + } + parts := strings.Split(raw, ",") + var origins []string + for _, p := range parts { + p = strings.TrimSpace(p) + if p != "" { + origins = append(origins, p) + } + } + return false, origins +} + +func setCORSHeaders(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + if origin == "" { + return + } + allowAll, origins := parseCORSOrigins() + if allowAll { + w.Header().Set("Access-Control-Allow-Origin", "*") + } else { + allowed := false + for _, o := range origins { + if o == origin { + allowed = true + break + } + } + if !allowed { + return + } + w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Set("Vary", "Origin") + } + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") +} + +func isSubpath(path string, base string) bool { + absPath, err1 := filepath.Abs(path) + absBase, err2 := filepath.Abs(base) + if err1 != nil || err2 != nil { + return false + } + if absPath == absBase { + return true + } + return strings.HasPrefix(absPath, absBase+string(filepath.Separator)) +} + // --- 以简御繁:鉴权 --- func (s *TaoServer) checkAuth(r *http.Request) bool { token := getEnv("TAO_AUTH_TOKEN", "") if token == "" { - return true // 未配置则不启用鉴权 + return getEnvBool("TAO_ALLOW_ANON", false) } - // Header Bearer - h := r.Header.Get("Authorization") - if h == "Bearer "+token { - return true - } - // Query token - if q := r.URL.Query().Get("token"); q != "" && q == token { + reqToken, _ := extractToken(r) + if reqToken == token { return true } return false @@ -136,9 +236,7 @@ func (s *TaoServer) checkAuth(r *http.Request) bool { func (s *TaoServer) requireAuth(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if r.Method == "OPTIONS" { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") + setCORSHeaders(w, r) w.WriteHeader(http.StatusOK) return } @@ -197,8 +295,8 @@ func (s *TaoServer) SSEHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") - w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("X-Accel-Buffering", "no") + setCORSHeaders(w, r) flusher, ok := w.(http.Flusher) if !ok { @@ -212,14 +310,24 @@ func (s *TaoServer) SSEHandler(w http.ResponseWriter, r *http.Request) { if style == "message" { endpoint = "message" } - // 若通过 query token 访问,也把 token 拼到 endpoint(便于客户端无 Header) - token := r.URL.Query().Get("token") + + queryToken := r.URL.Query().Get("token") + token, _ := extractToken(r) + if token == "" && !getEnvBool("TAO_ALLOW_ANON", false) { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + client := r.URL.Query().Get("client") - if token != "" { + if client == "" { + client = generateClientID() + } + + if queryToken != "" { if strings.Contains(endpoint, "?") { - endpoint = endpoint + "&token=" + token + endpoint = endpoint + "&token=" + queryToken } else { - endpoint = endpoint + "?token=" + token + endpoint = endpoint + "?token=" + queryToken } } if client != "" { @@ -232,16 +340,10 @@ func (s *TaoServer) SSEHandler(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "event: endpoint\ndata: %s\n\n", endpoint) flusher.Flush() - var msgChan chan string - if token != "" { - msgChan = make(chan string, 50) - connKey := token - if client != "" { - connKey = token + "_" + client - } - s.conns.Store(connKey, msgChan) - defer s.conns.Delete(connKey) - } + msgChan := make(chan string, 50) + connKey := buildConnKey(token, client) + s.conns.Store(connKey, msgChan) + defer s.conns.Delete(connKey) ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() @@ -262,9 +364,7 @@ func (s *TaoServer) SSEHandler(w http.ResponseWriter, r *http.Request) { // --- MCP Message --- func (s *TaoServer) MessageHandler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") + setCORSHeaders(w, r) if r.Method == "OPTIONS" { w.WriteHeader(http.StatusOK) @@ -273,7 +373,11 @@ func (s *TaoServer) MessageHandler(w http.ResponseWriter, r *http.Request) { bodyBytes, _ := io.ReadAll(r.Body) r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) - log.Printf("[MCP POST] From=%s URL=%s Body=%s", r.RemoteAddr, r.URL.String(), string(bodyBytes)) + if getEnvBool("TAO_DEBUG", false) { + log.Printf("[MCP POST] From=%s URL=%s Body=%s", r.RemoteAddr, r.URL.String(), string(bodyBytes)) + } else { + log.Printf("[MCP POST] From=%s URL=%s", r.RemoteAddr, r.URL.String()) + } var req MCPRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { @@ -281,7 +385,7 @@ func (s *TaoServer) MessageHandler(w http.ResponseWriter, r *http.Request) { return } - token := r.URL.Query().Get("token") + token, _ := extractToken(r) client := r.URL.Query().Get("client") w.WriteHeader(http.StatusAccepted) @@ -290,10 +394,23 @@ func (s *TaoServer) MessageHandler(w http.ResponseWriter, r *http.Request) { // --- 主程序 (Main) --- func main() { + if getEnv("TAO_AUTH_TOKEN", "") == "" && !getEnvBool("TAO_ALLOW_ANON", false) { + log.Fatal("TAO_AUTH_TOKEN is required unless TAO_ALLOW_ANON=true") + } + + memoryRoot := getEnv("MEMORY_ROOT", "./knowledge_ocean") + searchRoot := getEnv("TAO_SEARCH_ROOT", memoryRoot) + if !isSubpath(searchRoot, memoryRoot) { + log.Printf("TAO_SEARCH_ROOT must be under MEMORY_ROOT, fallback to MEMORY_ROOT") + searchRoot = memoryRoot + } + server := &TaoServer{ config: Config{ - MemoryRoot: getEnv("MEMORY_ROOT", "./knowledge_ocean"), - Port: getEnv("PORT", "5001"), + MemoryRoot: memoryRoot, + Port: getEnv("PORT", "5001"), + SearchRoot: searchRoot, + MaxSearchFiles: getEnvInt("TAO_SEARCH_MAX_FILES", 2000), }, } diff --git a/mcp_tools.go b/mcp_tools.go index a8de10a..79951aa 100644 --- a/mcp_tools.go +++ b/mcp_tools.go @@ -1,6 +1,7 @@ package main import ( + "sort" "strconv" ) @@ -446,7 +447,6 @@ func (s *TaoServer) RegisterTools() { } } - func buildToolList() []map[string]interface{} { var toolList []map[string]interface{} for _, t := range ToolRegistry { @@ -456,6 +456,9 @@ func buildToolList() []map[string]interface{} { "inputSchema": t.InputSchema, }) } + sort.Slice(toolList, func(i, j int) bool { + return toolList[i]["name"].(string) < toolList[j]["name"].(string) + }) return toolList } diff --git a/tao_core.go b/tao_core.go index 189428e..2e746be 100644 --- a/tao_core.go +++ b/tao_core.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "io" "os" "os/exec" "path/filepath" @@ -13,8 +14,10 @@ import ( // --- 道 (Config & State) --- type Config struct { - MemoryRoot string - Port string + MemoryRoot string + Port string + SearchRoot string + MaxSearchFiles int } type TaoServer struct { @@ -415,14 +418,14 @@ func (s *TaoServer) HousekeepMemory(targetMonth string) (string, error) { if entry.IsDir() && strings.HasPrefix(entry.Name(), "W") { src := filepath.Join(monthDir, entry.Name()) dst := filepath.Join(archiveRoot, entry.Name()) - if err := os.Rename(src, dst); err != nil { + if err := movePath(src, dst); err != nil { return "", err } } if !entry.IsDir() && entry.Name() != "Month_Summary.md" { src := filepath.Join(monthDir, entry.Name()) dst := filepath.Join(archiveRoot, entry.Name()) - if err := os.Rename(src, dst); err != nil { + if err := movePath(src, dst); err != nil { return "", err } } @@ -434,6 +437,72 @@ func (s *TaoServer) HousekeepMemory(targetMonth string) (string, error) { return fmt.Sprintf("归档完成: %s", archiveRoot), nil } +func movePath(src string, dst string) error { + if err := os.Rename(src, dst); err == nil { + return nil + } + info, err := os.Stat(src) + if err != nil { + return err + } + if info.IsDir() { + return copyDir(src, dst) + } + if err := copyFile(src, dst); err != nil { + return err + } + return os.RemoveAll(src) +} + +func copyDir(src string, dst string) error { + if err := os.MkdirAll(dst, 0755); err != nil { + return err + } + entries, err := os.ReadDir(src) + if err != nil { + return err + } + for _, entry := range entries { + sPath := filepath.Join(src, entry.Name()) + dPath := filepath.Join(dst, entry.Name()) + if entry.IsDir() { + if err := copyDir(sPath, dPath); err != nil { + return err + } + } else { + if err := copyFile(sPath, dPath); err != nil { + return err + } + } + } + return nil +} + +func copyFile(src string, dst string) error { + in, err := os.Open(src) + if err != nil { + return err + } + defer in.Close() + + if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil { + return err + } + out, err := os.Create(dst) + if err != nil { + return err + } + defer out.Close() + + if _, err := io.Copy(out, in); err != nil { + return err + } + if err := out.Sync(); err != nil { + return err + } + return nil +} + func mustAtoi(s string) int { n, _ := strconv.Atoi(s) return n @@ -447,8 +516,23 @@ func (s *TaoServer) InspectAndPropose(repoPath string) (string, error) { repoPath = "/root/.openclaw/workspace/tao_mcp_go" } - // 1) 拉取最新代码(若失败则继续) - _ = exec.Command("git", "-C", repoPath, "pull").Run() + allowed := getEnv("TAO_ALLOWED_REPOS", repoPath) + allowList := strings.Split(allowed, ",") + permitted := false + for _, item := range allowList { + item = strings.TrimSpace(item) + if item != "" && repoPath == item { + permitted = true + break + } + } + if !permitted { + return "repo_path not allowed", fmt.Errorf("repo_path not allowed: %s", repoPath) + } + + if getEnvBool("TAO_ALLOW_GIT_PULL", false) { + _ = exec.Command("git", "-C", repoPath, "pull").Run() + } // 2) 收集灵感(包含 #Todo/#Fix) inspDir := filepath.Join(s.config.MemoryRoot, "Inspirations") @@ -568,7 +652,7 @@ func (s *TaoServer) RecordSummary(content string, weekOffset int) (string, error return summaryPath, nil } -// SearchMemory 遍历所有 Markdown 文件,寻找包含关键词的内容 +// SearchMemoryAdvanced 遍历所有 Markdown 文件,寻找包含关键词的内容 func (s *TaoServer) SearchMemoryAdvanced(keyword string, related []string, causal bool, includeArchive bool) ([]string, error) { s.mu.Lock() defer s.mu.Unlock() @@ -589,14 +673,19 @@ func (s *TaoServer) SearchMemoryAdvanced(keyword string, related []string, causa } } - err := filepath.Walk(s.config.MemoryRoot, func(path string, info os.FileInfo, err error) error { + scanned := 0 + err := filepath.Walk(s.config.SearchRoot, func(path string, info os.FileInfo, err error) error { if err != nil { return err } + if s.config.MaxSearchFiles > 0 && scanned >= s.config.MaxSearchFiles { + return filepath.SkipDir + } if !includeArchive && info.IsDir() && strings.Contains(path, string(filepath.Separator)+"_Archive"+string(filepath.Separator)) { return filepath.SkipDir } if !info.IsDir() && filepath.Ext(path) == ".md" { + scanned++ if !includeArchive && strings.Contains(path, string(filepath.Separator)+"_Archive"+string(filepath.Separator)) { return nil } @@ -607,7 +696,7 @@ func (s *TaoServer) SearchMemoryAdvanced(keyword string, related []string, causa text := string(content) for _, term := range terms { if term != "" && strings.Contains(text, term) { - rel, _ := filepath.Rel(s.config.MemoryRoot, path) + rel, _ := filepath.Rel(s.config.SearchRoot, path) label := "命中" isCausal := term != keyword if isCausal {