feat: support client_id for multi-connection

This commit is contained in:
OpenClaw Agent
2026-03-15 01:19:36 +08:00
parent dfd24b341f
commit f4574e6190
2 changed files with 27 additions and 8 deletions

View File

@@ -42,8 +42,10 @@ knowledge_ocean/
## MCP 端点 ## MCP 端点
- **SSE**`/mcp/sse?token=...` - **SSE**`/mcp/sse?token=...&client=...`
- **消息**`/mcp/message?token=...` - **消息**`/mcp/message?token=...&client=...`
> `client` 用于区分多端连接,避免同一 token 抢占 SSE 通道。
**特性:** **特性:**
- `POST /mcp/message` 只返回 `202 Accepted` - `POST /mcp/message` 只返回 `202 Accepted`

29
main.go
View File

@@ -36,7 +36,7 @@ func sendMCPResponse(id any, result any) MCPResponse {
return MCPResponse{JSONRPC: "2.0", ID: id, Result: result} return MCPResponse{JSONRPC: "2.0", ID: id, Result: result}
} }
func (s *TaoServer) dispatchMCP(token string, req MCPRequest) { func (s *TaoServer) dispatchMCP(token string, client string, req MCPRequest) {
var resp MCPResponse var resp MCPResponse
resp.JSONRPC = "2.0" resp.JSONRPC = "2.0"
resp.ID = req.ID resp.ID = req.ID
@@ -94,13 +94,17 @@ func (s *TaoServer) dispatchMCP(token string, req MCPRequest) {
log.Printf("[MCP Response] missing token for method=%s", req.Method) log.Printf("[MCP Response] missing token for method=%s", req.Method)
return return
} }
if ch, ok := s.conns.Load(token); ok { connKey := token
if client != "" {
connKey = token + "_" + client
}
if ch, ok := s.conns.Load(connKey); ok {
if b, err := json.Marshal(resp); err == nil { if b, err := json.Marshal(resp); err == nil {
ch.(chan string) <- string(b) ch.(chan string) <- string(b)
log.Printf("[MCP Response] sent via SSE method=%s", req.Method) log.Printf("[MCP Response] sent via SSE method=%s", req.Method)
} }
} else { } else {
log.Printf("[MCP Response] no SSE channel for token=%s method=%s", token, req.Method) log.Printf("[MCP Response] no SSE channel for token=%s client=%s method=%s", token, client, req.Method)
} }
} }
@@ -210,6 +214,7 @@ func (s *TaoServer) SSEHandler(w http.ResponseWriter, r *http.Request) {
} }
// 若通过 query token 访问,也把 token 拼到 endpoint便于客户端无 Header // 若通过 query token 访问,也把 token 拼到 endpoint便于客户端无 Header
token := r.URL.Query().Get("token") token := r.URL.Query().Get("token")
client := r.URL.Query().Get("client")
if token != "" { if token != "" {
if strings.Contains(endpoint, "?") { if strings.Contains(endpoint, "?") {
endpoint = endpoint + "&token=" + token endpoint = endpoint + "&token=" + token
@@ -217,14 +222,25 @@ func (s *TaoServer) SSEHandler(w http.ResponseWriter, r *http.Request) {
endpoint = endpoint + "?token=" + token endpoint = endpoint + "?token=" + token
} }
} }
if client != "" {
if strings.Contains(endpoint, "?") {
endpoint = endpoint + "&client=" + client
} else {
endpoint = endpoint + "?client=" + client
}
}
fmt.Fprintf(w, "event: endpoint\ndata: %s\n\n", endpoint) fmt.Fprintf(w, "event: endpoint\ndata: %s\n\n", endpoint)
flusher.Flush() flusher.Flush()
var msgChan chan string var msgChan chan string
if token != "" { if token != "" {
msgChan = make(chan string, 50) msgChan = make(chan string, 50)
s.conns.Store(token, msgChan) connKey := token
defer s.conns.Delete(token) if client != "" {
connKey = token + "_" + client
}
s.conns.Store(connKey, msgChan)
defer s.conns.Delete(connKey)
} }
ticker := time.NewTicker(5 * time.Second) ticker := time.NewTicker(5 * time.Second)
@@ -266,9 +282,10 @@ func (s *TaoServer) MessageHandler(w http.ResponseWriter, r *http.Request) {
} }
token := r.URL.Query().Get("token") token := r.URL.Query().Get("token")
client := r.URL.Query().Get("client")
w.WriteHeader(http.StatusAccepted) w.WriteHeader(http.StatusAccepted)
go s.dispatchMCP(token, req) go s.dispatchMCP(token, client, req)
} }
// --- 主程序 (Main) --- // --- 主程序 (Main) ---