feat: support client_id for multi-connection
This commit is contained in:
@@ -42,8 +42,10 @@ knowledge_ocean/
|
||||
|
||||
## MCP 端点
|
||||
|
||||
- **SSE**:`/mcp/sse?token=...`
|
||||
- **消息**:`/mcp/message?token=...`
|
||||
- **SSE**:`/mcp/sse?token=...&client=...`
|
||||
- **消息**:`/mcp/message?token=...&client=...`
|
||||
|
||||
> `client` 用于区分多端连接,避免同一 token 抢占 SSE 通道。
|
||||
|
||||
**特性:**
|
||||
- `POST /mcp/message` 只返回 `202 Accepted`
|
||||
|
||||
29
main.go
29
main.go
@@ -36,7 +36,7 @@ func sendMCPResponse(id any, result any) MCPResponse {
|
||||
return MCPResponse{JSONRPC: "2.0", ID: id, Result: result}
|
||||
}
|
||||
|
||||
func (s *TaoServer) dispatchMCP(token string, req MCPRequest) {
|
||||
func (s *TaoServer) dispatchMCP(token string, client string, req MCPRequest) {
|
||||
var resp MCPResponse
|
||||
resp.JSONRPC = "2.0"
|
||||
resp.ID = req.ID
|
||||
@@ -94,13 +94,17 @@ func (s *TaoServer) dispatchMCP(token string, req MCPRequest) {
|
||||
log.Printf("[MCP Response] missing token for method=%s", req.Method)
|
||||
return
|
||||
}
|
||||
if ch, ok := s.conns.Load(token); ok {
|
||||
connKey := token
|
||||
if client != "" {
|
||||
connKey = token + "_" + client
|
||||
}
|
||||
if ch, ok := s.conns.Load(connKey); ok {
|
||||
if b, err := json.Marshal(resp); err == nil {
|
||||
ch.(chan string) <- string(b)
|
||||
log.Printf("[MCP Response] sent via SSE method=%s", req.Method)
|
||||
}
|
||||
} else {
|
||||
log.Printf("[MCP Response] no SSE channel for token=%s method=%s", token, req.Method)
|
||||
log.Printf("[MCP Response] no SSE channel for token=%s client=%s method=%s", token, client, req.Method)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -210,6 +214,7 @@ func (s *TaoServer) SSEHandler(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
// 若通过 query token 访问,也把 token 拼到 endpoint(便于客户端无 Header)
|
||||
token := r.URL.Query().Get("token")
|
||||
client := r.URL.Query().Get("client")
|
||||
if token != "" {
|
||||
if strings.Contains(endpoint, "?") {
|
||||
endpoint = endpoint + "&token=" + token
|
||||
@@ -217,14 +222,25 @@ func (s *TaoServer) SSEHandler(w http.ResponseWriter, r *http.Request) {
|
||||
endpoint = endpoint + "?token=" + token
|
||||
}
|
||||
}
|
||||
if client != "" {
|
||||
if strings.Contains(endpoint, "?") {
|
||||
endpoint = endpoint + "&client=" + client
|
||||
} else {
|
||||
endpoint = endpoint + "?client=" + client
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(w, "event: endpoint\ndata: %s\n\n", endpoint)
|
||||
flusher.Flush()
|
||||
|
||||
var msgChan chan string
|
||||
if token != "" {
|
||||
msgChan = make(chan string, 50)
|
||||
s.conns.Store(token, msgChan)
|
||||
defer s.conns.Delete(token)
|
||||
connKey := token
|
||||
if client != "" {
|
||||
connKey = token + "_" + client
|
||||
}
|
||||
s.conns.Store(connKey, msgChan)
|
||||
defer s.conns.Delete(connKey)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
@@ -266,9 +282,10 @@ func (s *TaoServer) MessageHandler(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
token := r.URL.Query().Get("token")
|
||||
client := r.URL.Query().Get("client")
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
|
||||
go s.dispatchMCP(token, req)
|
||||
go s.dispatchMCP(token, client, req)
|
||||
}
|
||||
|
||||
// --- 主程序 (Main) ---
|
||||
|
||||
Reference in New Issue
Block a user