fix: harden auth, sse, search, and docs
This commit is contained in:
20
README.md
20
README.md
@@ -50,7 +50,7 @@ knowledge_ocean/
|
||||
**特性:**
|
||||
- `POST /mcp/message` 只返回 `202 Accepted`
|
||||
- 所有 JSON‑RPC 响应通过 SSE `event: message` 返回
|
||||
- CORS 已放行(含 `OPTIONS`)
|
||||
- CORS 可配置(默认不放行;设置 `TAO_CORS_ORIGINS=*` 才全放开)
|
||||
|
||||
---
|
||||
|
||||
@@ -130,13 +130,27 @@ knowledge_ocean/
|
||||
- `repo_path` (可选,默认 `/root/.openclaw/workspace/tao_mcp_go`)
|
||||
|
||||
行为:
|
||||
- `git pull`(失败不阻断)
|
||||
- `git pull`(需 `TAO_ALLOW_GIT_PULL=true` 才会执行,失败不阻断)
|
||||
- `repo_path` 必须在 `TAO_ALLOWED_REPOS` 白名单内
|
||||
- 扫描 `Inspirations` 中 `#Todo/#Fix`
|
||||
- 生成 `_Proposals/proposal_<timestamp>.md`
|
||||
- 当日日记记录摘要
|
||||
|
||||
---
|
||||
|
||||
## 配置项(安全建议)
|
||||
|
||||
- **TAO_AUTH_TOKEN**:必填;不设置则启动失败
|
||||
- **TAO_ALLOW_ANON**:是否允许匿名访问(默认 false)
|
||||
- **TAO_CORS_ORIGINS**:允许的来源列表,逗号分隔;`*` 为全放开
|
||||
- **TAO_DEBUG**:是否输出请求体日志(默认 false)
|
||||
- **TAO_SEARCH_ROOT**:检索根目录(必须在 MEMORY_ROOT 下)
|
||||
- **TAO_SEARCH_MAX_FILES**:检索文件上限(默认 2000)
|
||||
- **TAO_ALLOWED_REPOS**:inspect_and_propose 允许的仓库白名单(逗号分隔)
|
||||
- **TAO_ALLOW_GIT_PULL**:是否允许 inspect_and_propose 执行 git pull(默认 false)
|
||||
|
||||
---
|
||||
|
||||
## OpenClaw 接入(示例)
|
||||
|
||||
**Base URL**
|
||||
@@ -146,7 +160,7 @@ https://mcp.good.xx.kg
|
||||
|
||||
**Auth Token**
|
||||
```
|
||||
a3c60a86ed2a7d317b8855faa94a05d1
|
||||
YOUR_TOKEN_HERE
|
||||
```
|
||||
|
||||
**Instructions(粘贴)**
|
||||
|
||||
195
main.go
195
main.go
@@ -8,6 +8,8 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -90,14 +92,11 @@ func (s *TaoServer) dispatchMCP(token string, client string, req MCPRequest) {
|
||||
return
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
if token == "" && !getEnvBool("TAO_ALLOW_ANON", false) {
|
||||
log.Printf("[MCP Response] missing token for method=%s", req.Method)
|
||||
return
|
||||
}
|
||||
connKey := token
|
||||
if client != "" {
|
||||
connKey = token + "_" + client
|
||||
}
|
||||
connKey := buildConnKey(token, client)
|
||||
if ch, ok := s.conns.Load(connKey); ok {
|
||||
if b, err := json.Marshal(resp); err == nil {
|
||||
ch.(chan string) <- string(b)
|
||||
@@ -115,19 +114,120 @@ func getEnv(key, def string) string {
|
||||
return def
|
||||
}
|
||||
|
||||
func getEnvBool(key string, def bool) bool {
|
||||
v := strings.ToLower(strings.TrimSpace(os.Getenv(key)))
|
||||
if v == "" {
|
||||
return def
|
||||
}
|
||||
switch v {
|
||||
case "1", "true", "yes", "on":
|
||||
return true
|
||||
case "0", "false", "no", "off":
|
||||
return false
|
||||
default:
|
||||
return def
|
||||
}
|
||||
}
|
||||
|
||||
func getEnvInt(key string, def int) int {
|
||||
if v := os.Getenv(key); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return def
|
||||
}
|
||||
|
||||
func extractToken(r *http.Request) (string, bool) {
|
||||
if q := r.URL.Query().Get("token"); q != "" {
|
||||
return q, true
|
||||
}
|
||||
h := r.Header.Get("Authorization")
|
||||
if strings.HasPrefix(h, "Bearer ") {
|
||||
return strings.TrimSpace(strings.TrimPrefix(h, "Bearer ")), false
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func buildConnKey(token string, client string) string {
|
||||
if token == "" {
|
||||
token = "anon"
|
||||
}
|
||||
if client != "" {
|
||||
return token + "_" + client
|
||||
}
|
||||
return token
|
||||
}
|
||||
|
||||
func generateClientID() string {
|
||||
return fmt.Sprintf("c%d", time.Now().UnixNano())
|
||||
}
|
||||
|
||||
func parseCORSOrigins() (bool, []string) {
|
||||
raw := strings.TrimSpace(os.Getenv("TAO_CORS_ORIGINS"))
|
||||
if raw == "" {
|
||||
return false, nil
|
||||
}
|
||||
if raw == "*" {
|
||||
return true, nil
|
||||
}
|
||||
parts := strings.Split(raw, ",")
|
||||
var origins []string
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p != "" {
|
||||
origins = append(origins, p)
|
||||
}
|
||||
}
|
||||
return false, origins
|
||||
}
|
||||
|
||||
func setCORSHeaders(w http.ResponseWriter, r *http.Request) {
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin == "" {
|
||||
return
|
||||
}
|
||||
allowAll, origins := parseCORSOrigins()
|
||||
if allowAll {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
} else {
|
||||
allowed := false
|
||||
for _, o := range origins {
|
||||
if o == origin {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !allowed {
|
||||
return
|
||||
}
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
w.Header().Set("Vary", "Origin")
|
||||
}
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||
}
|
||||
|
||||
func isSubpath(path string, base string) bool {
|
||||
absPath, err1 := filepath.Abs(path)
|
||||
absBase, err2 := filepath.Abs(base)
|
||||
if err1 != nil || err2 != nil {
|
||||
return false
|
||||
}
|
||||
if absPath == absBase {
|
||||
return true
|
||||
}
|
||||
return strings.HasPrefix(absPath, absBase+string(filepath.Separator))
|
||||
}
|
||||
|
||||
// --- 以简御繁:鉴权 ---
|
||||
func (s *TaoServer) checkAuth(r *http.Request) bool {
|
||||
token := getEnv("TAO_AUTH_TOKEN", "")
|
||||
if token == "" {
|
||||
return true // 未配置则不启用鉴权
|
||||
return getEnvBool("TAO_ALLOW_ANON", false)
|
||||
}
|
||||
// Header Bearer
|
||||
h := r.Header.Get("Authorization")
|
||||
if h == "Bearer "+token {
|
||||
return true
|
||||
}
|
||||
// Query token
|
||||
if q := r.URL.Query().Get("token"); q != "" && q == token {
|
||||
reqToken, _ := extractToken(r)
|
||||
if reqToken == token {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
@@ -136,9 +236,7 @@ func (s *TaoServer) checkAuth(r *http.Request) bool {
|
||||
func (s *TaoServer) requireAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "OPTIONS" {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||
setCORSHeaders(w, r)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
@@ -197,8 +295,8 @@ func (s *TaoServer) SSEHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("X-Accel-Buffering", "no")
|
||||
setCORSHeaders(w, r)
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
@@ -212,14 +310,24 @@ func (s *TaoServer) SSEHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if style == "message" {
|
||||
endpoint = "message"
|
||||
}
|
||||
// 若通过 query token 访问,也把 token 拼到 endpoint(便于客户端无 Header)
|
||||
token := r.URL.Query().Get("token")
|
||||
|
||||
queryToken := r.URL.Query().Get("token")
|
||||
token, _ := extractToken(r)
|
||||
if token == "" && !getEnvBool("TAO_ALLOW_ANON", false) {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
client := r.URL.Query().Get("client")
|
||||
if token != "" {
|
||||
if client == "" {
|
||||
client = generateClientID()
|
||||
}
|
||||
|
||||
if queryToken != "" {
|
||||
if strings.Contains(endpoint, "?") {
|
||||
endpoint = endpoint + "&token=" + token
|
||||
endpoint = endpoint + "&token=" + queryToken
|
||||
} else {
|
||||
endpoint = endpoint + "?token=" + token
|
||||
endpoint = endpoint + "?token=" + queryToken
|
||||
}
|
||||
}
|
||||
if client != "" {
|
||||
@@ -232,16 +340,10 @@ func (s *TaoServer) SSEHandler(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintf(w, "event: endpoint\ndata: %s\n\n", endpoint)
|
||||
flusher.Flush()
|
||||
|
||||
var msgChan chan string
|
||||
if token != "" {
|
||||
msgChan = make(chan string, 50)
|
||||
connKey := token
|
||||
if client != "" {
|
||||
connKey = token + "_" + client
|
||||
}
|
||||
s.conns.Store(connKey, msgChan)
|
||||
defer s.conns.Delete(connKey)
|
||||
}
|
||||
msgChan := make(chan string, 50)
|
||||
connKey := buildConnKey(token, client)
|
||||
s.conns.Store(connKey, msgChan)
|
||||
defer s.conns.Delete(connKey)
|
||||
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
@@ -262,9 +364,7 @@ func (s *TaoServer) SSEHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// --- MCP Message ---
|
||||
func (s *TaoServer) MessageHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||
setCORSHeaders(w, r)
|
||||
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
@@ -273,7 +373,11 @@ func (s *TaoServer) MessageHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
bodyBytes, _ := io.ReadAll(r.Body)
|
||||
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
log.Printf("[MCP POST] From=%s URL=%s Body=%s", r.RemoteAddr, r.URL.String(), string(bodyBytes))
|
||||
if getEnvBool("TAO_DEBUG", false) {
|
||||
log.Printf("[MCP POST] From=%s URL=%s Body=%s", r.RemoteAddr, r.URL.String(), string(bodyBytes))
|
||||
} else {
|
||||
log.Printf("[MCP POST] From=%s URL=%s", r.RemoteAddr, r.URL.String())
|
||||
}
|
||||
|
||||
var req MCPRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
@@ -281,7 +385,7 @@ func (s *TaoServer) MessageHandler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
token := r.URL.Query().Get("token")
|
||||
token, _ := extractToken(r)
|
||||
client := r.URL.Query().Get("client")
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
|
||||
@@ -290,10 +394,23 @@ func (s *TaoServer) MessageHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// --- 主程序 (Main) ---
|
||||
func main() {
|
||||
if getEnv("TAO_AUTH_TOKEN", "") == "" && !getEnvBool("TAO_ALLOW_ANON", false) {
|
||||
log.Fatal("TAO_AUTH_TOKEN is required unless TAO_ALLOW_ANON=true")
|
||||
}
|
||||
|
||||
memoryRoot := getEnv("MEMORY_ROOT", "./knowledge_ocean")
|
||||
searchRoot := getEnv("TAO_SEARCH_ROOT", memoryRoot)
|
||||
if !isSubpath(searchRoot, memoryRoot) {
|
||||
log.Printf("TAO_SEARCH_ROOT must be under MEMORY_ROOT, fallback to MEMORY_ROOT")
|
||||
searchRoot = memoryRoot
|
||||
}
|
||||
|
||||
server := &TaoServer{
|
||||
config: Config{
|
||||
MemoryRoot: getEnv("MEMORY_ROOT", "./knowledge_ocean"),
|
||||
Port: getEnv("PORT", "5001"),
|
||||
MemoryRoot: memoryRoot,
|
||||
Port: getEnv("PORT", "5001"),
|
||||
SearchRoot: searchRoot,
|
||||
MaxSearchFiles: getEnvInt("TAO_SEARCH_MAX_FILES", 2000),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
@@ -446,7 +447,6 @@ func (s *TaoServer) RegisterTools() {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func buildToolList() []map[string]interface{} {
|
||||
var toolList []map[string]interface{}
|
||||
for _, t := range ToolRegistry {
|
||||
@@ -456,6 +456,9 @@ func buildToolList() []map[string]interface{} {
|
||||
"inputSchema": t.InputSchema,
|
||||
})
|
||||
}
|
||||
sort.Slice(toolList, func(i, j int) bool {
|
||||
return toolList[i]["name"].(string) < toolList[j]["name"].(string)
|
||||
})
|
||||
return toolList
|
||||
}
|
||||
|
||||
|
||||
107
tao_core.go
107
tao_core.go
@@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
@@ -13,8 +14,10 @@ import (
|
||||
|
||||
// --- 道 (Config & State) ---
|
||||
type Config struct {
|
||||
MemoryRoot string
|
||||
Port string
|
||||
MemoryRoot string
|
||||
Port string
|
||||
SearchRoot string
|
||||
MaxSearchFiles int
|
||||
}
|
||||
|
||||
type TaoServer struct {
|
||||
@@ -415,14 +418,14 @@ func (s *TaoServer) HousekeepMemory(targetMonth string) (string, error) {
|
||||
if entry.IsDir() && strings.HasPrefix(entry.Name(), "W") {
|
||||
src := filepath.Join(monthDir, entry.Name())
|
||||
dst := filepath.Join(archiveRoot, entry.Name())
|
||||
if err := os.Rename(src, dst); err != nil {
|
||||
if err := movePath(src, dst); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
if !entry.IsDir() && entry.Name() != "Month_Summary.md" {
|
||||
src := filepath.Join(monthDir, entry.Name())
|
||||
dst := filepath.Join(archiveRoot, entry.Name())
|
||||
if err := os.Rename(src, dst); err != nil {
|
||||
if err := movePath(src, dst); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
@@ -434,6 +437,72 @@ func (s *TaoServer) HousekeepMemory(targetMonth string) (string, error) {
|
||||
return fmt.Sprintf("归档完成: %s", archiveRoot), nil
|
||||
}
|
||||
|
||||
func movePath(src string, dst string) error {
|
||||
if err := os.Rename(src, dst); err == nil {
|
||||
return nil
|
||||
}
|
||||
info, err := os.Stat(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if info.IsDir() {
|
||||
return copyDir(src, dst)
|
||||
}
|
||||
if err := copyFile(src, dst); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.RemoveAll(src)
|
||||
}
|
||||
|
||||
func copyDir(src string, dst string) error {
|
||||
if err := os.MkdirAll(dst, 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
entries, err := os.ReadDir(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, entry := range entries {
|
||||
sPath := filepath.Join(src, entry.Name())
|
||||
dPath := filepath.Join(dst, entry.Name())
|
||||
if entry.IsDir() {
|
||||
if err := copyDir(sPath, dPath); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := copyFile(sPath, dPath); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func copyFile(src string, dst string) error {
|
||||
in, err := os.Open(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer in.Close()
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
out, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
if _, err := io.Copy(out, in); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := out.Sync(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func mustAtoi(s string) int {
|
||||
n, _ := strconv.Atoi(s)
|
||||
return n
|
||||
@@ -447,8 +516,23 @@ func (s *TaoServer) InspectAndPropose(repoPath string) (string, error) {
|
||||
repoPath = "/root/.openclaw/workspace/tao_mcp_go"
|
||||
}
|
||||
|
||||
// 1) 拉取最新代码(若失败则继续)
|
||||
_ = exec.Command("git", "-C", repoPath, "pull").Run()
|
||||
allowed := getEnv("TAO_ALLOWED_REPOS", repoPath)
|
||||
allowList := strings.Split(allowed, ",")
|
||||
permitted := false
|
||||
for _, item := range allowList {
|
||||
item = strings.TrimSpace(item)
|
||||
if item != "" && repoPath == item {
|
||||
permitted = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !permitted {
|
||||
return "repo_path not allowed", fmt.Errorf("repo_path not allowed: %s", repoPath)
|
||||
}
|
||||
|
||||
if getEnvBool("TAO_ALLOW_GIT_PULL", false) {
|
||||
_ = exec.Command("git", "-C", repoPath, "pull").Run()
|
||||
}
|
||||
|
||||
// 2) 收集灵感(包含 #Todo/#Fix)
|
||||
inspDir := filepath.Join(s.config.MemoryRoot, "Inspirations")
|
||||
@@ -568,7 +652,7 @@ func (s *TaoServer) RecordSummary(content string, weekOffset int) (string, error
|
||||
return summaryPath, nil
|
||||
}
|
||||
|
||||
// SearchMemory 遍历所有 Markdown 文件,寻找包含关键词的内容
|
||||
// SearchMemoryAdvanced 遍历所有 Markdown 文件,寻找包含关键词的内容
|
||||
func (s *TaoServer) SearchMemoryAdvanced(keyword string, related []string, causal bool, includeArchive bool) ([]string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -589,14 +673,19 @@ func (s *TaoServer) SearchMemoryAdvanced(keyword string, related []string, causa
|
||||
}
|
||||
}
|
||||
|
||||
err := filepath.Walk(s.config.MemoryRoot, func(path string, info os.FileInfo, err error) error {
|
||||
scanned := 0
|
||||
err := filepath.Walk(s.config.SearchRoot, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if s.config.MaxSearchFiles > 0 && scanned >= s.config.MaxSearchFiles {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
if !includeArchive && info.IsDir() && strings.Contains(path, string(filepath.Separator)+"_Archive"+string(filepath.Separator)) {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
if !info.IsDir() && filepath.Ext(path) == ".md" {
|
||||
scanned++
|
||||
if !includeArchive && strings.Contains(path, string(filepath.Separator)+"_Archive"+string(filepath.Separator)) {
|
||||
return nil
|
||||
}
|
||||
@@ -607,7 +696,7 @@ func (s *TaoServer) SearchMemoryAdvanced(keyword string, related []string, causa
|
||||
text := string(content)
|
||||
for _, term := range terms {
|
||||
if term != "" && strings.Contains(text, term) {
|
||||
rel, _ := filepath.Rel(s.config.MemoryRoot, path)
|
||||
rel, _ := filepath.Rel(s.config.SearchRoot, path)
|
||||
label := "命中"
|
||||
isCausal := term != keyword
|
||||
if isCausal {
|
||||
|
||||
Reference in New Issue
Block a user