fix: harden auth, sse, search, and docs

This commit is contained in:
OpenClaw Agent
2026-03-15 04:47:03 +08:00
parent f4574e6190
commit 80440077cf
4 changed files with 275 additions and 52 deletions

View File

@@ -50,7 +50,7 @@ knowledge_ocean/
**特性:** **特性:**
- `POST /mcp/message` 只返回 `202 Accepted` - `POST /mcp/message` 只返回 `202 Accepted`
- 所有 JSONRPC 响应通过 SSE `event: message` 返回 - 所有 JSONRPC 响应通过 SSE `event: message` 返回
- CORS 已放行(含 `OPTIONS` - CORS 可配置(默认不放行;设置 `TAO_CORS_ORIGINS=*` 才全放开
--- ---
@@ -130,13 +130,27 @@ knowledge_ocean/
- `repo_path` (可选,默认 `/root/.openclaw/workspace/tao_mcp_go`) - `repo_path` (可选,默认 `/root/.openclaw/workspace/tao_mcp_go`)
行为: 行为:
- `git pull`(失败不阻断) - `git pull``TAO_ALLOW_GIT_PULL=true` 才会执行,失败不阻断)
- `repo_path` 必须在 `TAO_ALLOWED_REPOS` 白名单内
- 扫描 `Inspirations``#Todo/#Fix` - 扫描 `Inspirations``#Todo/#Fix`
- 生成 `_Proposals/proposal_<timestamp>.md` - 生成 `_Proposals/proposal_<timestamp>.md`
- 当日日记记录摘要 - 当日日记记录摘要
--- ---
## 配置项(安全建议)
- **TAO_AUTH_TOKEN**:必填;不设置则启动失败
- **TAO_ALLOW_ANON**:是否允许匿名访问(默认 false
- **TAO_CORS_ORIGINS**:允许的来源列表,逗号分隔;`*` 为全放开
- **TAO_DEBUG**:是否输出请求体日志(默认 false
- **TAO_SEARCH_ROOT**:检索根目录(必须在 MEMORY_ROOT 下)
- **TAO_SEARCH_MAX_FILES**:检索文件上限(默认 2000
- **TAO_ALLOWED_REPOS**inspect_and_propose 允许的仓库白名单(逗号分隔)
- **TAO_ALLOW_GIT_PULL**:是否允许 inspect_and_propose 执行 git pull默认 false
---
## OpenClaw 接入(示例) ## OpenClaw 接入(示例)
**Base URL** **Base URL**
@@ -146,7 +160,7 @@ https://mcp.good.xx.kg
**Auth Token** **Auth Token**
``` ```
a3c60a86ed2a7d317b8855faa94a05d1 YOUR_TOKEN_HERE
``` ```
**Instructions粘贴** **Instructions粘贴**

187
main.go
View File

@@ -8,6 +8,8 @@ import (
"log" "log"
"net/http" "net/http"
"os" "os"
"path/filepath"
"strconv"
"strings" "strings"
"time" "time"
) )
@@ -90,14 +92,11 @@ func (s *TaoServer) dispatchMCP(token string, client string, req MCPRequest) {
return return
} }
if token == "" { if token == "" && !getEnvBool("TAO_ALLOW_ANON", false) {
log.Printf("[MCP Response] missing token for method=%s", req.Method) log.Printf("[MCP Response] missing token for method=%s", req.Method)
return return
} }
connKey := token connKey := buildConnKey(token, client)
if client != "" {
connKey = token + "_" + client
}
if ch, ok := s.conns.Load(connKey); ok { if ch, ok := s.conns.Load(connKey); ok {
if b, err := json.Marshal(resp); err == nil { if b, err := json.Marshal(resp); err == nil {
ch.(chan string) <- string(b) ch.(chan string) <- string(b)
@@ -115,19 +114,120 @@ func getEnv(key, def string) string {
return def return def
} }
func getEnvBool(key string, def bool) bool {
v := strings.ToLower(strings.TrimSpace(os.Getenv(key)))
if v == "" {
return def
}
switch v {
case "1", "true", "yes", "on":
return true
case "0", "false", "no", "off":
return false
default:
return def
}
}
func getEnvInt(key string, def int) int {
if v := os.Getenv(key); v != "" {
if n, err := strconv.Atoi(v); err == nil {
return n
}
}
return def
}
func extractToken(r *http.Request) (string, bool) {
if q := r.URL.Query().Get("token"); q != "" {
return q, true
}
h := r.Header.Get("Authorization")
if strings.HasPrefix(h, "Bearer ") {
return strings.TrimSpace(strings.TrimPrefix(h, "Bearer ")), false
}
return "", false
}
func buildConnKey(token string, client string) string {
if token == "" {
token = "anon"
}
if client != "" {
return token + "_" + client
}
return token
}
func generateClientID() string {
return fmt.Sprintf("c%d", time.Now().UnixNano())
}
func parseCORSOrigins() (bool, []string) {
raw := strings.TrimSpace(os.Getenv("TAO_CORS_ORIGINS"))
if raw == "" {
return false, nil
}
if raw == "*" {
return true, nil
}
parts := strings.Split(raw, ",")
var origins []string
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
origins = append(origins, p)
}
}
return false, origins
}
func setCORSHeaders(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
if origin == "" {
return
}
allowAll, origins := parseCORSOrigins()
if allowAll {
w.Header().Set("Access-Control-Allow-Origin", "*")
} else {
allowed := false
for _, o := range origins {
if o == origin {
allowed = true
break
}
}
if !allowed {
return
}
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Vary", "Origin")
}
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
}
func isSubpath(path string, base string) bool {
absPath, err1 := filepath.Abs(path)
absBase, err2 := filepath.Abs(base)
if err1 != nil || err2 != nil {
return false
}
if absPath == absBase {
return true
}
return strings.HasPrefix(absPath, absBase+string(filepath.Separator))
}
// --- 以简御繁:鉴权 --- // --- 以简御繁:鉴权 ---
func (s *TaoServer) checkAuth(r *http.Request) bool { func (s *TaoServer) checkAuth(r *http.Request) bool {
token := getEnv("TAO_AUTH_TOKEN", "") token := getEnv("TAO_AUTH_TOKEN", "")
if token == "" { if token == "" {
return true // 未配置则不启用鉴权 return getEnvBool("TAO_ALLOW_ANON", false)
} }
// Header Bearer reqToken, _ := extractToken(r)
h := r.Header.Get("Authorization") if reqToken == token {
if h == "Bearer "+token {
return true
}
// Query token
if q := r.URL.Query().Get("token"); q != "" && q == token {
return true return true
} }
return false return false
@@ -136,9 +236,7 @@ func (s *TaoServer) checkAuth(r *http.Request) bool {
func (s *TaoServer) requireAuth(next http.HandlerFunc) http.HandlerFunc { func (s *TaoServer) requireAuth(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if r.Method == "OPTIONS" { if r.Method == "OPTIONS" {
w.Header().Set("Access-Control-Allow-Origin", "*") setCORSHeaders(w, r)
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
return return
} }
@@ -197,8 +295,8 @@ func (s *TaoServer) SSEHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive") w.Header().Set("Connection", "keep-alive")
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("X-Accel-Buffering", "no") w.Header().Set("X-Accel-Buffering", "no")
setCORSHeaders(w, r)
flusher, ok := w.(http.Flusher) flusher, ok := w.(http.Flusher)
if !ok { if !ok {
@@ -212,14 +310,24 @@ func (s *TaoServer) SSEHandler(w http.ResponseWriter, r *http.Request) {
if style == "message" { if style == "message" {
endpoint = "message" endpoint = "message"
} }
// 若通过 query token 访问,也把 token 拼到 endpoint便于客户端无 Header
token := r.URL.Query().Get("token") queryToken := r.URL.Query().Get("token")
token, _ := extractToken(r)
if token == "" && !getEnvBool("TAO_ALLOW_ANON", false) {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
client := r.URL.Query().Get("client") client := r.URL.Query().Get("client")
if token != "" { if client == "" {
client = generateClientID()
}
if queryToken != "" {
if strings.Contains(endpoint, "?") { if strings.Contains(endpoint, "?") {
endpoint = endpoint + "&token=" + token endpoint = endpoint + "&token=" + queryToken
} else { } else {
endpoint = endpoint + "?token=" + token endpoint = endpoint + "?token=" + queryToken
} }
} }
if client != "" { if client != "" {
@@ -232,16 +340,10 @@ func (s *TaoServer) SSEHandler(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "event: endpoint\ndata: %s\n\n", endpoint) fmt.Fprintf(w, "event: endpoint\ndata: %s\n\n", endpoint)
flusher.Flush() flusher.Flush()
var msgChan chan string msgChan := make(chan string, 50)
if token != "" { connKey := buildConnKey(token, client)
msgChan = make(chan string, 50)
connKey := token
if client != "" {
connKey = token + "_" + client
}
s.conns.Store(connKey, msgChan) s.conns.Store(connKey, msgChan)
defer s.conns.Delete(connKey) defer s.conns.Delete(connKey)
}
ticker := time.NewTicker(5 * time.Second) ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop() defer ticker.Stop()
@@ -262,9 +364,7 @@ func (s *TaoServer) SSEHandler(w http.ResponseWriter, r *http.Request) {
// --- MCP Message --- // --- MCP Message ---
func (s *TaoServer) MessageHandler(w http.ResponseWriter, r *http.Request) { func (s *TaoServer) MessageHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*") setCORSHeaders(w, r)
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
if r.Method == "OPTIONS" { if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
@@ -273,7 +373,11 @@ func (s *TaoServer) MessageHandler(w http.ResponseWriter, r *http.Request) {
bodyBytes, _ := io.ReadAll(r.Body) bodyBytes, _ := io.ReadAll(r.Body)
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
if getEnvBool("TAO_DEBUG", false) {
log.Printf("[MCP POST] From=%s URL=%s Body=%s", r.RemoteAddr, r.URL.String(), string(bodyBytes)) log.Printf("[MCP POST] From=%s URL=%s Body=%s", r.RemoteAddr, r.URL.String(), string(bodyBytes))
} else {
log.Printf("[MCP POST] From=%s URL=%s", r.RemoteAddr, r.URL.String())
}
var req MCPRequest var req MCPRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
@@ -281,7 +385,7 @@ func (s *TaoServer) MessageHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
token := r.URL.Query().Get("token") token, _ := extractToken(r)
client := r.URL.Query().Get("client") client := r.URL.Query().Get("client")
w.WriteHeader(http.StatusAccepted) w.WriteHeader(http.StatusAccepted)
@@ -290,10 +394,23 @@ func (s *TaoServer) MessageHandler(w http.ResponseWriter, r *http.Request) {
// --- 主程序 (Main) --- // --- 主程序 (Main) ---
func main() { func main() {
if getEnv("TAO_AUTH_TOKEN", "") == "" && !getEnvBool("TAO_ALLOW_ANON", false) {
log.Fatal("TAO_AUTH_TOKEN is required unless TAO_ALLOW_ANON=true")
}
memoryRoot := getEnv("MEMORY_ROOT", "./knowledge_ocean")
searchRoot := getEnv("TAO_SEARCH_ROOT", memoryRoot)
if !isSubpath(searchRoot, memoryRoot) {
log.Printf("TAO_SEARCH_ROOT must be under MEMORY_ROOT, fallback to MEMORY_ROOT")
searchRoot = memoryRoot
}
server := &TaoServer{ server := &TaoServer{
config: Config{ config: Config{
MemoryRoot: getEnv("MEMORY_ROOT", "./knowledge_ocean"), MemoryRoot: memoryRoot,
Port: getEnv("PORT", "5001"), Port: getEnv("PORT", "5001"),
SearchRoot: searchRoot,
MaxSearchFiles: getEnvInt("TAO_SEARCH_MAX_FILES", 2000),
}, },
} }

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"sort"
"strconv" "strconv"
) )
@@ -446,7 +447,6 @@ func (s *TaoServer) RegisterTools() {
} }
} }
func buildToolList() []map[string]interface{} { func buildToolList() []map[string]interface{} {
var toolList []map[string]interface{} var toolList []map[string]interface{}
for _, t := range ToolRegistry { for _, t := range ToolRegistry {
@@ -456,6 +456,9 @@ func buildToolList() []map[string]interface{} {
"inputSchema": t.InputSchema, "inputSchema": t.InputSchema,
}) })
} }
sort.Slice(toolList, func(i, j int) bool {
return toolList[i]["name"].(string) < toolList[j]["name"].(string)
})
return toolList return toolList
} }

View File

@@ -2,6 +2,7 @@ package main
import ( import (
"fmt" "fmt"
"io"
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
@@ -15,6 +16,8 @@ import (
type Config struct { type Config struct {
MemoryRoot string MemoryRoot string
Port string Port string
SearchRoot string
MaxSearchFiles int
} }
type TaoServer struct { type TaoServer struct {
@@ -415,14 +418,14 @@ func (s *TaoServer) HousekeepMemory(targetMonth string) (string, error) {
if entry.IsDir() && strings.HasPrefix(entry.Name(), "W") { if entry.IsDir() && strings.HasPrefix(entry.Name(), "W") {
src := filepath.Join(monthDir, entry.Name()) src := filepath.Join(monthDir, entry.Name())
dst := filepath.Join(archiveRoot, entry.Name()) dst := filepath.Join(archiveRoot, entry.Name())
if err := os.Rename(src, dst); err != nil { if err := movePath(src, dst); err != nil {
return "", err return "", err
} }
} }
if !entry.IsDir() && entry.Name() != "Month_Summary.md" { if !entry.IsDir() && entry.Name() != "Month_Summary.md" {
src := filepath.Join(monthDir, entry.Name()) src := filepath.Join(monthDir, entry.Name())
dst := filepath.Join(archiveRoot, entry.Name()) dst := filepath.Join(archiveRoot, entry.Name())
if err := os.Rename(src, dst); err != nil { if err := movePath(src, dst); err != nil {
return "", err return "", err
} }
} }
@@ -434,6 +437,72 @@ func (s *TaoServer) HousekeepMemory(targetMonth string) (string, error) {
return fmt.Sprintf("归档完成: %s", archiveRoot), nil return fmt.Sprintf("归档完成: %s", archiveRoot), nil
} }
func movePath(src string, dst string) error {
if err := os.Rename(src, dst); err == nil {
return nil
}
info, err := os.Stat(src)
if err != nil {
return err
}
if info.IsDir() {
return copyDir(src, dst)
}
if err := copyFile(src, dst); err != nil {
return err
}
return os.RemoveAll(src)
}
func copyDir(src string, dst string) error {
if err := os.MkdirAll(dst, 0755); err != nil {
return err
}
entries, err := os.ReadDir(src)
if err != nil {
return err
}
for _, entry := range entries {
sPath := filepath.Join(src, entry.Name())
dPath := filepath.Join(dst, entry.Name())
if entry.IsDir() {
if err := copyDir(sPath, dPath); err != nil {
return err
}
} else {
if err := copyFile(sPath, dPath); err != nil {
return err
}
}
}
return nil
}
func copyFile(src string, dst string) error {
in, err := os.Open(src)
if err != nil {
return err
}
defer in.Close()
if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil {
return err
}
out, err := os.Create(dst)
if err != nil {
return err
}
defer out.Close()
if _, err := io.Copy(out, in); err != nil {
return err
}
if err := out.Sync(); err != nil {
return err
}
return nil
}
func mustAtoi(s string) int { func mustAtoi(s string) int {
n, _ := strconv.Atoi(s) n, _ := strconv.Atoi(s)
return n return n
@@ -447,8 +516,23 @@ func (s *TaoServer) InspectAndPropose(repoPath string) (string, error) {
repoPath = "/root/.openclaw/workspace/tao_mcp_go" repoPath = "/root/.openclaw/workspace/tao_mcp_go"
} }
// 1) 拉取最新代码(若失败则继续) allowed := getEnv("TAO_ALLOWED_REPOS", repoPath)
allowList := strings.Split(allowed, ",")
permitted := false
for _, item := range allowList {
item = strings.TrimSpace(item)
if item != "" && repoPath == item {
permitted = true
break
}
}
if !permitted {
return "repo_path not allowed", fmt.Errorf("repo_path not allowed: %s", repoPath)
}
if getEnvBool("TAO_ALLOW_GIT_PULL", false) {
_ = exec.Command("git", "-C", repoPath, "pull").Run() _ = exec.Command("git", "-C", repoPath, "pull").Run()
}
// 2) 收集灵感(包含 #Todo/#Fix // 2) 收集灵感(包含 #Todo/#Fix
inspDir := filepath.Join(s.config.MemoryRoot, "Inspirations") inspDir := filepath.Join(s.config.MemoryRoot, "Inspirations")
@@ -568,7 +652,7 @@ func (s *TaoServer) RecordSummary(content string, weekOffset int) (string, error
return summaryPath, nil return summaryPath, nil
} }
// SearchMemory 遍历所有 Markdown 文件,寻找包含关键词的内容 // SearchMemoryAdvanced 遍历所有 Markdown 文件,寻找包含关键词的内容
func (s *TaoServer) SearchMemoryAdvanced(keyword string, related []string, causal bool, includeArchive bool) ([]string, error) { func (s *TaoServer) SearchMemoryAdvanced(keyword string, related []string, causal bool, includeArchive bool) ([]string, error) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@@ -589,14 +673,19 @@ func (s *TaoServer) SearchMemoryAdvanced(keyword string, related []string, causa
} }
} }
err := filepath.Walk(s.config.MemoryRoot, func(path string, info os.FileInfo, err error) error { scanned := 0
err := filepath.Walk(s.config.SearchRoot, func(path string, info os.FileInfo, err error) error {
if err != nil { if err != nil {
return err return err
} }
if s.config.MaxSearchFiles > 0 && scanned >= s.config.MaxSearchFiles {
return filepath.SkipDir
}
if !includeArchive && info.IsDir() && strings.Contains(path, string(filepath.Separator)+"_Archive"+string(filepath.Separator)) { if !includeArchive && info.IsDir() && strings.Contains(path, string(filepath.Separator)+"_Archive"+string(filepath.Separator)) {
return filepath.SkipDir return filepath.SkipDir
} }
if !info.IsDir() && filepath.Ext(path) == ".md" { if !info.IsDir() && filepath.Ext(path) == ".md" {
scanned++
if !includeArchive && strings.Contains(path, string(filepath.Separator)+"_Archive"+string(filepath.Separator)) { if !includeArchive && strings.Contains(path, string(filepath.Separator)+"_Archive"+string(filepath.Separator)) {
return nil return nil
} }
@@ -607,7 +696,7 @@ func (s *TaoServer) SearchMemoryAdvanced(keyword string, related []string, causa
text := string(content) text := string(content)
for _, term := range terms { for _, term := range terms {
if term != "" && strings.Contains(text, term) { if term != "" && strings.Contains(text, term) {
rel, _ := filepath.Rel(s.config.MemoryRoot, path) rel, _ := filepath.Rel(s.config.SearchRoot, path)
label := "命中" label := "命中"
isCausal := term != keyword isCausal := term != keyword
if isCausal { if isCausal {