From f4574e619066f82221fdbbe93b1798b8696b2fa5 Mon Sep 17 00:00:00 2001 From: OpenClaw Agent Date: Sun, 15 Mar 2026 01:19:36 +0800 Subject: [PATCH] feat: support client_id for multi-connection --- README.md | 6 ++++-- main.go | 29 +++++++++++++++++++++++------ 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 41c06e2..ac79b9d 100644 --- a/README.md +++ b/README.md @@ -42,8 +42,10 @@ knowledge_ocean/ ## MCP 端点 -- **SSE**:`/mcp/sse?token=...` -- **消息**:`/mcp/message?token=...` +- **SSE**:`/mcp/sse?token=...&client=...` +- **消息**:`/mcp/message?token=...&client=...` + +> `client` 用于区分多端连接,避免同一 token 抢占 SSE 通道。 **特性:** - `POST /mcp/message` 只返回 `202 Accepted` diff --git a/main.go b/main.go index 4bbf4bd..f960898 100644 --- a/main.go +++ b/main.go @@ -36,7 +36,7 @@ func sendMCPResponse(id any, result any) MCPResponse { return MCPResponse{JSONRPC: "2.0", ID: id, Result: result} } -func (s *TaoServer) dispatchMCP(token string, req MCPRequest) { +func (s *TaoServer) dispatchMCP(token string, client string, req MCPRequest) { var resp MCPResponse resp.JSONRPC = "2.0" resp.ID = req.ID @@ -94,13 +94,17 @@ func (s *TaoServer) dispatchMCP(token string, req MCPRequest) { log.Printf("[MCP Response] missing token for method=%s", req.Method) return } - if ch, ok := s.conns.Load(token); ok { + connKey := token + if client != "" { + connKey = token + "_" + client + } + if ch, ok := s.conns.Load(connKey); ok { if b, err := json.Marshal(resp); err == nil { ch.(chan string) <- string(b) log.Printf("[MCP Response] sent via SSE method=%s", req.Method) } } else { - log.Printf("[MCP Response] no SSE channel for token=%s method=%s", token, req.Method) + log.Printf("[MCP Response] no SSE channel for token=%s client=%s method=%s", token, client, req.Method) } } @@ -210,6 +214,7 @@ func (s *TaoServer) SSEHandler(w http.ResponseWriter, r *http.Request) { } // 若通过 query token 访问,也把 token 拼到 endpoint(便于客户端无 Header) token := r.URL.Query().Get("token") + client := r.URL.Query().Get("client") if token != "" { if strings.Contains(endpoint, "?") { endpoint = endpoint + "&token=" + token @@ -217,14 +222,25 @@ func (s *TaoServer) SSEHandler(w http.ResponseWriter, r *http.Request) { endpoint = endpoint + "?token=" + token } } + if client != "" { + if strings.Contains(endpoint, "?") { + endpoint = endpoint + "&client=" + client + } else { + endpoint = endpoint + "?client=" + client + } + } fmt.Fprintf(w, "event: endpoint\ndata: %s\n\n", endpoint) flusher.Flush() var msgChan chan string if token != "" { msgChan = make(chan string, 50) - s.conns.Store(token, msgChan) - defer s.conns.Delete(token) + connKey := token + if client != "" { + connKey = token + "_" + client + } + s.conns.Store(connKey, msgChan) + defer s.conns.Delete(connKey) } ticker := time.NewTicker(5 * time.Second) @@ -266,9 +282,10 @@ func (s *TaoServer) MessageHandler(w http.ResponseWriter, r *http.Request) { } token := r.URL.Query().Get("token") + client := r.URL.Query().Get("client") w.WriteHeader(http.StatusAccepted) - go s.dispatchMCP(token, req) + go s.dispatchMCP(token, client, req) } // --- 主程序 (Main) ---