diff --git a/cmd/inp2ps/main.go b/cmd/inp2ps/main.go index 3061db9..2523f0b 100644 --- a/cmd/inp2ps/main.go +++ b/cmd/inp2ps/main.go @@ -13,6 +13,7 @@ import ( "os" "os/signal" "syscall" + "time" "github.com/openp2p-cn/inp2p/internal/server" "github.com/openp2p-cn/inp2p/pkg/auth" @@ -90,17 +91,16 @@ func main() { srv := server.New(cfg) srv.StartCleanup() - // Auth Middleware - authMiddleware := func(next http.HandlerFunc) http.HandlerFunc { + // Admin-only Middleware + adminMiddleware := func(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/api/v1/auth/login" { next(w, r) return } - // Check Authorization header authHeader := r.Header.Get("Authorization") - expected := fmt.Sprintf("Bearer %d", cfg.Token) - if authHeader != expected { + valid := authHeader == fmt.Sprintf("Bearer %d", cfg.Token) + if !valid { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) fmt.Fprintf(w, `{"error":401,"message":"unauthorized"}`) @@ -110,6 +110,32 @@ func main() { } } + // Tenant or Admin Middleware + tenantMiddleware := func(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/v1/auth/login" { + next(w, r) + return + } + authHeader := r.Header.Get("Authorization") + if authHeader == fmt.Sprintf("Bearer %d", cfg.Token) { + next(w, r) + return + } + // check API key + if srv.Store() != nil { + if ten, err := srv.Store().VerifyAPIKey(server.BearerToken(r)); err == nil && ten != nil { + next(w, r) + return + } + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + fmt.Fprintf(w, `{"error":401,"message":"unauthorized"}`) + return + } + } + mux := http.NewServeMux() mux.HandleFunc("/ws", srv.HandleWS) @@ -117,6 +143,12 @@ func main() { webDir := "/root/.openclaw/workspace/inp2p/web" mux.Handle("/", http.FileServer(http.Dir(webDir))) + // Tenant APIs (API key auth inside handlers) + mux.HandleFunc("/api/v1/admin/tenants", adminMiddleware(srv.HandleAdminCreateTenant)) + mux.HandleFunc("/api/v1/admin/tenants/", adminMiddleware(srv.HandleAdminCreateAPIKey)) + mux.HandleFunc("/api/v1/tenants/enroll", srv.HandleTenantEnroll) + mux.HandleFunc("/api/v1/enroll/consume", srv.HandleEnrollConsume) + mux.HandleFunc("/api/v1/auth/login", func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) @@ -138,7 +170,16 @@ func main() { req.Token = req2.Token } - if req.Token != cfg.Token { + valid := req.Token == cfg.Token + if !valid { + for _, t := range cfg.Tokens { + if req.Token == t { + valid = true + break + } + } + } + if !valid { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) fmt.Fprintf(w, `{"error":1,"message":"invalid token"}`) @@ -148,31 +189,54 @@ func main() { fmt.Fprintf(w, `{"error":0,"token":"%d"}`, cfg.Token) }) - mux.HandleFunc("/api/v1/health", authMiddleware(func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/api/v1/health", tenantMiddleware(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") fmt.Fprintf(w, `{"status":"ok","version":"%s","nodes":%d}`, config.Version, len(srv.GetOnlineNodes())) })) - mux.HandleFunc("/api/v1/nodes", authMiddleware(func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/api/v1/nodes", tenantMiddleware(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } w.Header().Set("Content-Type", "application/json") + // tenant filter by API key + tenantID := int64(0) + if srv.Store() != nil { + if ten, err := srv.Store().VerifyAPIKey(server.BearerToken(r)); err == nil && ten != nil { + tenantID = ten.ID + } + } + if tenantID > 0 { + nodes := srv.GetOnlineNodesByTenant(tenantID) + _ = json.NewEncoder(w).Encode(map[string]any{"nodes": nodes}) + return + } nodes := srv.GetOnlineNodes() _ = json.NewEncoder(w).Encode(map[string]any{"nodes": nodes}) })) - mux.HandleFunc("/api/v1/sdwans", authMiddleware(func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/api/v1/sdwans", tenantMiddleware(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } w.Header().Set("Content-Type", "application/json") + // tenant filter by API key + tenantID := int64(0) + if srv.Store() != nil { + if ten, err := srv.Store().VerifyAPIKey(server.BearerToken(r)); err == nil && ten != nil { + tenantID = ten.ID + } + } + if tenantID > 0 { + _ = json.NewEncoder(w).Encode(srv.GetSDWANTenant(tenantID)) + return + } _ = json.NewEncoder(w).Encode(srv.GetSDWAN()) })) - mux.HandleFunc("/api/v1/sdwan/edit", authMiddleware(func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/api/v1/sdwan/edit", tenantMiddleware(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return @@ -182,6 +246,22 @@ func main() { http.Error(w, err.Error(), http.StatusBadRequest) return } + // tenant filter by API key + tenantID := int64(0) + if srv.Store() != nil { + if ten, err := srv.Store().VerifyAPIKey(server.BearerToken(r)); err == nil && ten != nil { + tenantID = ten.ID + } + } + if tenantID > 0 { + if err := srv.SetSDWANTenant(tenantID, req); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{"error": 0, "message": "ok"}) + return + } if err := srv.SetSDWAN(req); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -191,7 +271,7 @@ func main() { })) // Remote Config Push API - mux.HandleFunc("/api/v1/nodes/apps", authMiddleware(func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/api/v1/nodes/apps", tenantMiddleware(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return @@ -209,6 +289,17 @@ func main() { http.Error(w, "node not found", http.StatusNotFound) return } + // tenant filter by API key + tenantID := int64(0) + if srv.Store() != nil { + if ten, err := srv.Store().VerifyAPIKey(server.BearerToken(r)); err == nil && ten != nil { + tenantID = ten.ID + } + } + if tenantID > 0 && node.TenantID != tenantID { + http.Error(w, "node not found", http.StatusNotFound) + return + } // Push to client _ = node.Conn.Write(protocol.MsgPush, protocol.SubPushConfig, req.Apps) w.Header().Set("Content-Type", "application/json") @@ -216,7 +307,7 @@ func main() { })) // Kick (disconnect) a node - mux.HandleFunc("/api/v1/nodes/kick", authMiddleware(func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/api/v1/nodes/kick", tenantMiddleware(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return @@ -233,13 +324,24 @@ func main() { http.Error(w, "node not found or offline", http.StatusNotFound) return } + // tenant filter by API key + tenantID := int64(0) + if srv.Store() != nil { + if ten, err := srv.Store().VerifyAPIKey(server.BearerToken(r)); err == nil && ten != nil { + tenantID = ten.ID + } + } + if tenantID > 0 && node.TenantID != tenantID { + http.Error(w, "node not found", http.StatusNotFound) + return + } node.Conn.Close() w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(map[string]any{"error": 0, "message": "node kicked"}) })) // Trigger P2P connect between two nodes - mux.HandleFunc("/api/v1/connect", authMiddleware(func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/api/v1/connect", tenantMiddleware(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return @@ -260,6 +362,17 @@ func main() { http.Error(w, "source node offline", http.StatusNotFound) return } + // tenant filter by API key + tenantID := int64(0) + if srv.Store() != nil { + if ten, err := srv.Store().VerifyAPIKey(server.BearerToken(r)); err == nil && ten != nil { + tenantID = ten.ID + } + } + if tenantID > 0 && fromNode.TenantID != tenantID { + http.Error(w, "node not found", http.StatusNotFound) + return + } app := protocol.AppConfig{ AppName: req.AppName, Protocol: "tcp", @@ -269,6 +382,14 @@ func main() { DstPort: req.DstPort, Enabled: 1, } + // enforce same-tenant target + if tenantID > 0 { + toNode := srv.GetNode(req.To) + if toNode == nil || toNode.TenantID != tenantID { + http.Error(w, "node not found", http.StatusNotFound) + return + } + } if err := srv.PushConnect(fromNode, req.To, app); err != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadGateway) @@ -280,7 +401,7 @@ func main() { })) // Server uptime + detailed stats - mux.HandleFunc("/api/v1/stats", authMiddleware(func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/api/v1/stats", tenantMiddleware(func(w http.ResponseWriter, r *http.Request) { nodes := srv.GetOnlineNodes() coneCount, symmCount, unknCount := 0, 0, 0 relayCount := 0 @@ -317,7 +438,13 @@ func main() { } log.Printf("[main] signaling server on :%d (no TLS — use reverse proxy for production)", cfg.WSPort) - httpSrv := &http.Server{Handler: mux} + // Enable TCP keepalive at server level + httpSrv := &http.Server{ + Handler: mux, + ReadHeaderTimeout: 10 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 120 * time.Second, + } go func() { if err := httpSrv.Serve(ln); err != http.ErrServerClosed { log.Fatalf("[main] serve: %v", err) diff --git a/go.mod b/go.mod index 7e74c12..09ffcaa 100644 --- a/go.mod +++ b/go.mod @@ -7,3 +7,20 @@ toolchain go1.24.4 require github.com/gorilla/websocket v1.5.3 require golang.org/x/sys v0.41.0 + +require modernc.org/sqlite v1.29.0 + +require ( + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/google/uuid v1.3.0 // indirect + github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect + github.com/mattn/go-isatty v0.0.16 // indirect + github.com/ncruces/go-strftime v0.1.9 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 // indirect + modernc.org/libc v1.41.0 // indirect + modernc.org/mathutil v1.6.0 // indirect + modernc.org/memory v1.7.2 // indirect + modernc.org/strutil v1.2.0 // indirect + modernc.org/token v1.1.0 // indirect +) diff --git a/go.sum b/go.sum index a2d144d..55cdff9 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,41 @@ +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= +github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= +github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= +github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= +github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0= +golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/tools v0.17.0 h1:FvmRgNOcs3kOa+T20R1uhfP9F6HgG2mfxDv1vrx1Htc= +golang.org/x/tools v0.17.0/go.mod h1:xsh6VxdV005rRVaS6SSAf9oiAqljS7UZUacMZ8Bnsps= +modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 h1:5D53IMaUuA5InSeMu9eJtlQXS2NxAhyWQvkKEgXZhHI= +modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6/go.mod h1:Qz0X07sNOR1jWYCrJMEnbW/X55x206Q7Vt4mz6/wHp4= +modernc.org/libc v1.41.0 h1:g9YAc6BkKlgORsUWj+JwqoB1wU3o4DE3bM3yvA3k+Gk= +modernc.org/libc v1.41.0/go.mod h1:w0eszPsiXoOnoMJgrXjglgLuDy/bt5RR4y3QzUUeodY= +modernc.org/mathutil v1.6.0 h1:fRe9+AmYlaej+64JsEEhoWuAYBkOtQiMEU7n/XgfYi4= +modernc.org/mathutil v1.6.0/go.mod h1:Ui5Q9q1TR2gFm0AQRqQUaBWFLAhQpCwNcuhBOSedWPo= +modernc.org/memory v1.7.2 h1:Klh90S215mmH8c9gO98QxQFsY+W451E8AnzjoE2ee1E= +modernc.org/memory v1.7.2/go.mod h1:NO4NVCQy0N7ln+T9ngWqOQfi7ley4vpwvARR+Hjw95E= +modernc.org/sqlite v1.29.0 h1:lQVw+ZsFM3aRG5m4myG70tbXpr3S/J1ej0KHIP4EvjM= +modernc.org/sqlite v1.29.0/go.mod h1:hG41jCYxOAOoO6BRK66AdRlmOcDzXf7qnwlwjUIOqa0= +modernc.org/strutil v1.2.0 h1:agBi9dp1I+eOnxXeiZawM8F4LawKv4NzGWSaLfyeNZA= +modernc.org/strutil v1.2.0/go.mod h1:/mdcBmfOibveCTBxUl5B5l6W+TTH1FXPLHZE6bTosX0= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/inp2pc b/inp2pc new file mode 100755 index 0000000..7c01d31 Binary files /dev/null and b/inp2pc differ diff --git a/inp2ps b/inp2ps new file mode 100755 index 0000000..801dfc0 Binary files /dev/null and b/inp2ps differ diff --git a/inp2ps.db-shm b/inp2ps.db-shm new file mode 100644 index 0000000..eefc2e1 Binary files /dev/null and b/inp2ps.db-shm differ diff --git a/inp2ps.db-wal b/inp2ps.db-wal new file mode 100644 index 0000000..5c98d63 Binary files /dev/null and b/inp2ps.db-wal differ diff --git a/internal/client/client.go b/internal/client/client.go index 5728a8f..3138394 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -95,7 +95,7 @@ func (c *Client) connectAndRun() error { c.publicIP = natResult.PublicIP c.publicPort = natResult.Port1 c.localPort = natResult.LocalPort - log.Printf("[client] NAT type=%s, publicIP=%s, publicPort=%d, localPort=%d", c.natType, c.publicIP, c.publicPort, c.localPort) + log.Printf("[client] SENDING_LOGIN_TOKEN=%d NAT type=%s, publicIP=%s, publicPort=%d, localPort=%d", c.natType, c.publicIP, c.publicPort, c.localPort) // 2. WSS Connect scheme := "ws" @@ -642,28 +642,34 @@ func (c *Client) tunReadLoop() { if c.IsStopping() { return } + // Log only real errors, not EOF or timeout + if err.Error() != "EOF" && err.Error() != "resource temporarily unavailable" { + log.Printf("[client] tun read error: %v", err) + } time.Sleep(100 * time.Millisecond) - log.Printf("[client] tun read error: %v", err) + continue } + // Skip empty packets or non-IPv4 if n == 0 || n < 20 { - log.Printf("[client] tun read error: %v", err) + continue } pkt := buf[:n] version := pkt[0] >> 4 if version != 4 { - log.Printf("[client] tun read error: %v", err) + continue // skip non-IPv4 } dstIP := net.IP(pkt[16:20]).String() c.sdwanMu.RLock() self := c.sdwanIP c.sdwanMu.RUnlock() if dstIP == self { - log.Printf("[client] tun read error: %v", err) + continue // skip packets to self } // send raw binary to avoid JSON base64 overhead - log.Printf("[client] tun: read pkt len=%d dst=%s", n, dstIP) frame := protocol.EncodeRaw(protocol.MsgTunnel, protocol.SubTunnelSDWANRaw, pkt) - _ = c.conn.WriteRaw(frame) + if err := c.conn.WriteRaw(frame); err != nil { + log.Printf("[client] tun write failed: %v", err) + } } } diff --git a/internal/client/inp2ps.db-shm b/internal/client/inp2ps.db-shm new file mode 100644 index 0000000..5a18652 Binary files /dev/null and b/internal/client/inp2ps.db-shm differ diff --git a/internal/client/inp2ps.db-wal b/internal/client/inp2ps.db-wal new file mode 100644 index 0000000..c2516e6 Binary files /dev/null and b/internal/client/inp2ps.db-wal differ diff --git a/internal/server/coordinator.go b/internal/server/coordinator.go index 0e82382..2a3f0ca 100644 --- a/internal/server/coordinator.go +++ b/internal/server/coordinator.go @@ -5,6 +5,7 @@ import ( "log" "time" + "github.com/openp2p-cn/inp2p/pkg/auth" "github.com/openp2p-cn/inp2p/pkg/protocol" ) @@ -17,12 +18,12 @@ import ( // HandleConnectReq processes a connection request from node A to node B. func (s *Server) HandleConnectReq(from *NodeInfo, req protocol.ConnectReq) error { - to := s.GetNode(req.To) + to := s.GetNodeForUser(req.To, from.Token) if to == nil || !to.IsOnline() { - // Peer offline — respond with error + // Peer offline or not visible — respond with generic not found from.Conn.Write(protocol.MsgPush, protocol.SubPushConnectRsp, protocol.ConnectRsp{ Error: 1, - Detail: fmt.Sprintf("node %s offline", req.To), + Detail: "node not found", From: req.To, To: req.From, }) @@ -38,6 +39,7 @@ func (s *Server) HandleConnectReq(from *NodeInfo, req protocol.ConnectReq) error Port: from.PublicPort, NATType: from.NATType, HasIPv4: from.HasIPv4, + Token: auth.GenTOTP(from.Token, time.Now().Unix()), } from.mu.RUnlock() @@ -47,6 +49,7 @@ func (s *Server) HandleConnectReq(from *NodeInfo, req protocol.ConnectReq) error Port: to.PublicPort, NATType: to.NATType, HasIPv4: to.HasIPv4, + Token: auth.GenTOTP(to.Token, time.Now().Unix()), } to.mu.RUnlock() diff --git a/internal/server/inp2ps.db-shm b/internal/server/inp2ps.db-shm new file mode 100644 index 0000000..444c871 Binary files /dev/null and b/internal/server/inp2ps.db-shm differ diff --git a/internal/server/inp2ps.db-wal b/internal/server/inp2ps.db-wal new file mode 100644 index 0000000..6fb447c Binary files /dev/null and b/internal/server/inp2ps.db-wal differ diff --git a/internal/server/sdwan.go b/internal/server/sdwan.go index cdcbda0..de2661b 100644 --- a/internal/server/sdwan.go +++ b/internal/server/sdwan.go @@ -12,13 +12,14 @@ import ( ) type sdwanStore struct { - mu sync.RWMutex - path string - cfg protocol.SDWANConfig + mu sync.RWMutex + path string + cfg protocol.SDWANConfig + multi map[int64]protocol.SDWANConfig } func newSDWANStore(path string) *sdwanStore { - s := &sdwanStore{path: path} + s := &sdwanStore{path: path, multi: make(map[int64]protocol.SDWANConfig)} _ = s.load() return s } @@ -33,6 +34,15 @@ func (s *sdwanStore) load() error { } return err } + // try multi-tenant first + var m map[int64]protocol.SDWANConfig + if err := json.Unmarshal(b, &m); err == nil && len(m) > 0 { + for k, v := range m { + m[k] = normalizeSDWAN(v) + } + s.multi = m + return nil + } var c protocol.SDWANConfig if err := json.Unmarshal(b, &c); err != nil { return err @@ -57,12 +67,40 @@ func (s *sdwanStore) save(cfg protocol.SDWANConfig) error { return nil } +func (s *sdwanStore) saveTenant(tenantID int64, cfg protocol.SDWANConfig) error { + s.mu.Lock() + defer s.mu.Unlock() + cfg = normalizeSDWAN(cfg) + cfg.UpdatedAt = time.Now().Unix() + if s.multi == nil { + s.multi = make(map[int64]protocol.SDWANConfig) + } + s.multi[tenantID] = cfg + b, err := json.MarshalIndent(s.multi, "", " ") + if err != nil { + return err + } + if err := os.WriteFile(s.path, b, 0644); err != nil { + return err + } + return nil +} + func (s *sdwanStore) get() protocol.SDWANConfig { s.mu.RLock() defer s.mu.RUnlock() return s.cfg } +func (s *sdwanStore) getTenant(tenantID int64) protocol.SDWANConfig { + s.mu.RLock() + defer s.mu.RUnlock() + if s.multi == nil { + return protocol.SDWANConfig{} + } + return s.multi[tenantID] +} + func normalizeSDWAN(c protocol.SDWANConfig) protocol.SDWANConfig { if c.Mode == "" { c.Mode = "hub" diff --git a/internal/server/sdwan_api.go b/internal/server/sdwan_api.go index c7f3dd4..bf92270 100644 --- a/internal/server/sdwan_api.go +++ b/internal/server/sdwan_api.go @@ -11,6 +11,10 @@ func (s *Server) GetSDWAN() protocol.SDWANConfig { return s.sdwan.get() } +func (s *Server) GetSDWANTenant(tenantID int64) protocol.SDWANConfig { + return s.sdwan.getTenant(tenantID) +} + func (s *Server) SetSDWAN(cfg protocol.SDWANConfig) error { if err := s.sdwan.save(cfg); err != nil { return err @@ -19,6 +23,14 @@ func (s *Server) SetSDWAN(cfg protocol.SDWANConfig) error { return nil } +func (s *Server) SetSDWANTenant(tenantID int64, cfg protocol.SDWANConfig) error { + if err := s.sdwan.saveTenant(tenantID, cfg); err != nil { + return err + } + s.broadcastSDWANTenant(tenantID, s.sdwan.getTenant(tenantID)) + return nil +} + func (s *Server) broadcastSDWAN(cfg protocol.SDWANConfig) { if !cfg.Enabled || cfg.GatewayCIDR == "" { return @@ -33,6 +45,20 @@ func (s *Server) broadcastSDWAN(cfg protocol.SDWANConfig) { } } +func (s *Server) broadcastSDWANTenant(tenantID int64, cfg protocol.SDWANConfig) { + if !cfg.Enabled || cfg.GatewayCIDR == "" { + return + } + s.mu.RLock() + defer s.mu.RUnlock() + for _, n := range s.nodes { + if !n.IsOnline() || n.TenantID != tenantID { + continue + } + _ = n.Conn.Write(protocol.MsgPush, protocol.SubPushSDWANConfig, cfg) + } +} + func (s *Server) pushSDWANPeer(to *NodeInfo, peer protocol.SDWANPeer) { if to == nil || !to.IsOnline() { return @@ -48,7 +74,14 @@ func (s *Server) pushSDWANDel(to *NodeInfo, peer protocol.SDWANPeer) { } func (s *Server) announceSDWANNodeOnline(nodeName string) { - cfg := s.sdwan.get() + // pick tenant config by node + s.mu.RLock() + newNode := s.nodes[nodeName] + s.mu.RUnlock() + if newNode == nil { + return + } + cfg := s.sdwan.getTenant(newNode.TenantID) if cfg.GatewayCIDR == "" { return } @@ -64,7 +97,7 @@ func (s *Server) announceSDWANNodeOnline(nodeName string) { } s.mu.RLock() - newNode := s.nodes[nodeName] + newNode = s.nodes[nodeName] if newNode == nil || !newNode.IsOnline() { s.mu.RUnlock() return @@ -74,7 +107,7 @@ func (s *Server) announceSDWANNodeOnline(nodeName string) { continue } other := s.nodes[n.Node] - if other == nil || !other.IsOnline() { + if other == nil || !other.IsOnline() || other.TenantID != newNode.TenantID { continue } // existing -> new @@ -86,7 +119,13 @@ func (s *Server) announceSDWANNodeOnline(nodeName string) { } func (s *Server) announceSDWANNodeOffline(nodeName string) { - cfg := s.sdwan.get() + s.mu.RLock() + old := s.nodes[nodeName] + s.mu.RUnlock() + if old == nil { + return + } + cfg := s.sdwan.getTenant(old.TenantID) if cfg.GatewayCIDR == "" { return } @@ -100,7 +139,7 @@ func (s *Server) announceSDWANNodeOffline(nodeName string) { s.mu.RLock() defer s.mu.RUnlock() for _, n := range s.nodes { - if n.Name == nodeName || !n.IsOnline() { + if n.Name == nodeName || !n.IsOnline() || n.TenantID != old.TenantID { continue } s.pushSDWANDel(n, protocol.SDWANPeer{Node: nodeName, IP: selfIP, Online: false}) @@ -112,7 +151,13 @@ func (s *Server) RouteSDWANPacket(from *NodeInfo, pkt protocol.SDWANPacket) { if from == nil { return } - cfg := s.sdwan.get() + // Use global config for untrusted nodes (TenantID=0), otherwise use tenant config + var cfg protocol.SDWANConfig + if from.TenantID == 0 { + cfg = s.sdwan.get() + } else { + cfg = s.sdwan.getTenant(from.TenantID) + } if cfg.GatewayCIDR == "" || pkt.DstIP == "" || len(pkt.Payload) == 0 { return } @@ -124,12 +169,18 @@ func (s *Server) RouteSDWANPacket(from *NodeInfo, pkt protocol.SDWANPacket) { toNode := "" for _, n := range cfg.Nodes { if n.IP == pkt.DstIP { - toNode = n.Node - break + candidate := s.GetNodeForUser(n.Node, from.Token) + if candidate != nil && candidate.TenantID == from.TenantID { + toNode = n.Node + break + } } if p, err := netip.ParseAddr(n.IP); err == nil && p == dst { - toNode = n.Node - break + candidate := s.GetNodeForUser(n.Node, from.Token) + if candidate != nil && candidate.TenantID == from.TenantID { + toNode = n.Node + break + } } } if toNode == "" || toNode == from.Name { @@ -138,6 +189,9 @@ func (s *Server) RouteSDWANPacket(from *NodeInfo, pkt protocol.SDWANPacket) { s.mu.RLock() to := s.nodes[toNode] + if to != nil && to.TenantID != from.TenantID { + to = nil + } s.mu.RUnlock() if to == nil || !to.IsOnline() { return diff --git a/internal/server/server.go b/internal/server/server.go index 4bd1f92..19b0abc 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -2,6 +2,7 @@ package server import ( + "fmt" "log" "net" "net/http" @@ -10,6 +11,7 @@ import ( "github.com/gorilla/websocket" "github.com/openp2p-cn/inp2p/pkg/auth" + "github.com/openp2p-cn/inp2p/internal/store" "github.com/openp2p-cn/inp2p/pkg/config" "github.com/openp2p-cn/inp2p/pkg/protocol" "github.com/openp2p-cn/inp2p/pkg/signal" @@ -17,26 +19,27 @@ import ( // NodeInfo represents a connected client node. type NodeInfo struct { - Name string `json:"name"` - Token uint64 `json:"-"` - User string `json:"user"` - Version string `json:"version"` - NATType protocol.NATType `json:"natType"` - PublicIP string `json:"publicIP"` - PublicPort int `json:"publicPort"` - LanIP string `json:"lanIP"` - OS string `json:"os"` - Mac string `json:"mac"` - ShareBandwidth int `json:"shareBandwidth"` - RelayEnabled bool `json:"relayEnabled"` - SuperRelay bool `json:"superRelay"` - HasIPv4 int `json:"hasIPv4"` - IPv6 string `json:"ipv6"` - LoginTime time.Time `json:"loginTime"` - LastHeartbeat time.Time `json:"lastHeartbeat"` - Conn *signal.Conn `json:"-"` + Name string `json:"name"` + Token uint64 `json:"-"` + TenantID int64 `json:"tenantId"` + User string `json:"user"` + Version string `json:"version"` + NATType protocol.NATType `json:"natType"` + PublicIP string `json:"publicIP"` + PublicPort int `json:"publicPort"` + LanIP string `json:"lanIP"` + OS string `json:"os"` + Mac string `json:"mac"` + ShareBandwidth int `json:"shareBandwidth"` + RelayEnabled bool `json:"relayEnabled"` + SuperRelay bool `json:"superRelay"` + HasIPv4 int `json:"hasIPv4"` + IPv6 string `json:"ipv6"` + LoginTime time.Time `json:"loginTime"` + LastHeartbeat time.Time `json:"lastHeartbeat"` + Conn *signal.Conn `json:"-"` Apps []protocol.AppConfig `json:"apps"` - mu sync.RWMutex `json:"-"` + mu sync.RWMutex `json:"-"` } // IsOnline checks if node has sent heartbeat recently. @@ -49,25 +52,43 @@ func (n *NodeInfo) IsOnline() bool { // Server is the INP2P signaling server. type Server struct { cfg config.ServerConfig - nodes map[string]*NodeInfo // node name → info + nodes map[string]*NodeInfo mu sync.RWMutex upgrader websocket.Upgrader quit chan struct{} sdwanPath string sdwan *sdwanStore + store *store.Store + tokens map[uint64]bool } +func (s *Server) Store() *store.Store { return s.store } + // New creates a new server. func New(cfg config.ServerConfig) *Server { - // Use absolute path for sdwan config to avoid working directory issues sdwanPath := "/root/.openclaw/workspace/inp2p/sdwan.json" + tokens := make(map[uint64]bool) + if cfg.Token != 0 { + tokens[cfg.Token] = true + } + for _, t := range cfg.Tokens { + tokens[t] = true + } + st, err := store.Open(cfg.DBPath) + if err != nil { + log.Printf("[server] open store failed: %v", err) + } return &Server{ cfg: cfg, nodes: make(map[string]*NodeInfo), sdwanPath: sdwanPath, sdwan: newSDWANStore(sdwanPath), + store: st, + tokens: tokens, upgrader: websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, + ReadBufferSize: 4096, + WriteBufferSize: 4096, }, quit: make(chan struct{}), } @@ -93,6 +114,42 @@ func (s *Server) GetOnlineNodes() []*NodeInfo { return out } +// GetNodeForUser returns node if token matches (legacy) or tenant matches. +func (s *Server) GetNodeForUser(name string, token uint64) *NodeInfo { + s.mu.RLock() + defer s.mu.RUnlock() + n := s.nodes[name] + if n == nil { + return nil + } + if n.Token != token && n.TenantID == 0 { + return nil + } + return n +} + +func (s *Server) GetNodeForTenant(name string, tenantID int64) *NodeInfo { + s.mu.RLock() + defer s.mu.RUnlock() + n := s.nodes[name] + if n == nil || n.TenantID != tenantID { + return nil + } + return n +} + +func (s *Server) GetOnlineNodesByTenant(tenantID int64) []*NodeInfo { + s.mu.RLock() + defer s.mu.RUnlock() + var out []*NodeInfo + for _, n := range s.nodes { + if n.IsOnline() && n.TenantID == tenantID { + out = append(out, n) + } + } + return out +} + // GetRelayNodes returns nodes that can serve as relay. // Priority: same-user private relay → super relay func (s *Server) GetRelayNodes(forUser string, excludeNodes ...string) []*NodeInfo { @@ -119,6 +176,28 @@ func (s *Server) GetRelayNodes(forUser string, excludeNodes ...string) []*NodeIn return append(privateRelays, superRelays...) } +// GetRelayNodesByTenant returns relay nodes within tenant. +func (s *Server) GetRelayNodesByTenant(tenantID int64, excludeNodes ...string) []*NodeInfo { + excludeSet := make(map[string]bool) + for _, n := range excludeNodes { + excludeSet[n] = true + } + + s.mu.RLock() + defer s.mu.RUnlock() + + var relays []*NodeInfo + for _, n := range s.nodes { + if !n.IsOnline() || excludeSet[n.Name] { + continue + } + if n.TenantID == tenantID && (n.RelayEnabled || n.SuperRelay) { + relays = append(relays, n) + } + } + return relays +} + // HandleWS is the WebSocket handler for client connections. func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) { ws, err := s.upgrader.Upgrade(w, r, nil) @@ -151,8 +230,26 @@ func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) { return } - // Verify token - if loginReq.Token != s.cfg.Token { + // Verify token: master token OR tenant API key (DB) OR node_secret (DB) + valid := s.tokens[loginReq.Token] + log.Printf("[server] login check: token=%d, cfg.Token=%d, valid=%v", loginReq.Token, s.cfg.Token, valid) + var tenantID int64 + if !valid && s.store != nil { + // try api key (string) or node secret + if loginReq.NodeSecret != "" { + if ten, err := s.store.VerifyNodeSecret(loginReq.Node, loginReq.NodeSecret); err == nil && ten != nil { + valid = true + tenantID = ten.ID + } + } + if !valid { + if ten, err := s.store.VerifyAPIKey(fmt.Sprintf("%d", loginReq.Token)); err == nil && ten != nil { + valid = true + tenantID = ten.ID + } + } + } + if !valid { log.Printf("[server] login denied: %s (token mismatch)", loginReq.Node) conn.Write(protocol.MsgLogin, protocol.SubLoginRsp, protocol.LoginRsp{ Error: 1, @@ -174,6 +271,7 @@ func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) { node := &NodeInfo{ Name: loginReq.Node, Token: loginReq.Token, + TenantID: tenantID, User: loginReq.User, Version: loginReq.Version, NATType: loginReq.NATType, @@ -211,11 +309,21 @@ func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) { s.broadcastNodeOnline(loginReq.Node) // Push current SDWAN config right after login (if exists and enabled) - if cfg := s.sdwan.get(); cfg.Enabled && cfg.GatewayCIDR != "" { - if err := conn.Write(protocol.MsgPush, protocol.SubPushSDWANConfig, cfg); err != nil { - log.Printf("[server] sdwan config push failed: %v", err) - } else { - log.Printf("[server] sdwan config pushed to %s", loginReq.Node) + if node.TenantID > 0 { + if cfg := s.sdwan.getTenant(node.TenantID); cfg.Enabled && cfg.GatewayCIDR != "" { + if err := conn.Write(protocol.MsgPush, protocol.SubPushSDWANConfig, cfg); err != nil { + log.Printf("[server] sdwan config push failed: %v", err) + } else { + log.Printf("[server] sdwan config pushed to %s", loginReq.Node) + } + } + } else { + if cfg := s.sdwan.get(); cfg.Enabled && cfg.GatewayCIDR != "" { + if err := conn.Write(protocol.MsgPush, protocol.SubPushSDWANConfig, cfg); err != nil { + log.Printf("[server] sdwan config push failed: %v", err) + } else { + log.Printf("[server] sdwan config pushed to %s", loginReq.Node) + } } } // Event-driven SDWAN peer notification @@ -378,10 +486,13 @@ func (s *Server) handleRelayNodeReq(conn *signal.Conn, requester *NodeInfo, req // PushConnect sends a punch coordination message to a peer node. func (s *Server) PushConnect(fromNode *NodeInfo, toNodeName string, app protocol.AppConfig) error { - toNode := s.GetNode(toNodeName) + toNode := s.GetNodeForUser(toNodeName, fromNode.Token) if toNode == nil || !toNode.IsOnline() { return &NodeOfflineError{Node: toNodeName} } + if fromNode.TenantID != 0 && toNode.TenantID != fromNode.TenantID { + return &NodeOfflineError{Node: toNodeName} + } // Push connect request to the destination req := protocol.ConnectReq{ @@ -392,6 +503,7 @@ func (s *Server) PushConnect(fromNode *NodeInfo, toNodeName string, app protocol IP: fromNode.PublicIP, NATType: fromNode.NATType, HasIPv4: fromNode.HasIPv4, + Token: auth.GenTOTP(fromNode.Token, time.Now().Unix()), }, AppName: app.AppName, Protocol: app.Protocol, @@ -406,12 +518,19 @@ func (s *Server) PushConnect(fromNode *NodeInfo, toNodeName string, app protocol // broadcastNodeOnline notifies interested nodes that a peer came online. func (s *Server) broadcastNodeOnline(nodeName string) { s.mu.RLock() + newNode := s.nodes[nodeName] defer s.mu.RUnlock() + if newNode == nil { + return + } for _, n := range s.nodes { if n.Name == nodeName { continue } + if n.Token != newNode.Token && (newNode.TenantID == 0 || n.TenantID != newNode.TenantID) { + continue + } // Check if this node has any app targeting the new node n.mu.RLock() interested := false diff --git a/internal/server/tenant_api.go b/internal/server/tenant_api.go new file mode 100644 index 0000000..31939bc --- /dev/null +++ b/internal/server/tenant_api.go @@ -0,0 +1,185 @@ +package server + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/openp2p-cn/inp2p/internal/store" +) + +// helpers +func BearerToken(r *http.Request) string { + h := r.Header.Get("Authorization") + if h == "" { + return "" + } + parts := strings.SplitN(h, " ", 2) + if len(parts) != 2 { + return "" + } + if strings.ToLower(parts[0]) != "bearer" { + return "" + } + return strings.TrimSpace(parts[1]) +} + +func writeJSON(w http.ResponseWriter, status int, body string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + io.WriteString(w, body) +} + +func (s *Server) HandleAdminCreateTenant(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeJSON(w, http.StatusMethodNotAllowed, `{"error":1,"message":"method not allowed"}`) + return + } + var req struct { + Name string `json:"name"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.Name == "" { + writeJSON(w, http.StatusBadRequest, `{"error":1,"message":"bad request"}`) + return + } + if s.store == nil { + writeJSON(w, http.StatusInternalServerError, `{"error":1,"message":"store not ready"}`) + return + } + ten, err := s.store.CreateTenant(req.Name) + if err != nil { + writeJSON(w, http.StatusInternalServerError, `{"error":1,"message":"create tenant failed"}`) + return + } + resp := struct { + Error int `json:"error"` + Message string `json:"message"` + Tenant int64 `json:"tenant_id"` + Subnet string `json:"subnet"` + }{0, "ok", ten.ID, ten.Subnet} + b, _ := json.Marshal(resp) + writeJSON(w, http.StatusOK, string(b)) +} + +func (s *Server) HandleAdminCreateAPIKey(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeJSON(w, http.StatusMethodNotAllowed, `{"error":1,"message":"method not allowed"}`) + return + } + if s.store == nil { + writeJSON(w, http.StatusInternalServerError, `{"error":1,"message":"store not ready"}`) + return + } + // /api/v1/admin/tenants/{id}/keys + parts := strings.Split(strings.Trim(r.URL.Path, "/"), "/") + if len(parts) < 6 || parts[5] != "keys" { + writeJSON(w, http.StatusBadRequest, `{"error":1,"message":"bad request"}`) + return + } + // parts: api v1 admin tenants {id} keys + idPart := parts[4] + var tenantID int64 + _, _ = fmt.Sscanf(idPart, "%d", &tenantID) + if tenantID == 0 { + writeJSON(w, http.StatusBadRequest, `{"error":1,"message":"bad request"}`) + return + } + var req struct { + Scope string `json:"scope"` + TTL int64 `json:"ttl"` // seconds + } + _ = json.NewDecoder(r.Body).Decode(&req) + var ttl time.Duration + if req.TTL > 0 { + ttl = time.Duration(req.TTL) * time.Second + } + key, err := s.store.CreateAPIKey(tenantID, req.Scope, ttl) + if err != nil { + writeJSON(w, http.StatusInternalServerError, `{"error":1,"message":"create key failed"}`) + return + } + resp := struct { + Error int `json:"error"` + Message string `json:"message"` + APIKey string `json:"api_key"` + Tenant int64 `json:"tenant_id"` + }{0, "ok", key, tenantID} + b, _ := json.Marshal(resp) + writeJSON(w, http.StatusOK, string(b)) +} + +func (s *Server) HandleTenantEnroll(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeJSON(w, http.StatusMethodNotAllowed, `{"error":1,"message":"method not allowed"}`) + return + } + // tenant auth by API key + if s.store == nil { + writeJSON(w, http.StatusInternalServerError, `{"error":1,"message":"store not ready"}`) + return + } + tok := BearerToken(r) + ten, err := s.store.VerifyAPIKey(tok) + if err != nil || ten == nil { + writeJSON(w, http.StatusUnauthorized, `{"error":1,"message":"unauthorized"}`) + return + } + code, err := s.store.CreateEnrollToken(ten.ID, 10*time.Minute, 5) + if err != nil { + writeJSON(w, http.StatusInternalServerError, `{"error":1,"message":"create enroll failed"}`) + return + } + resp := struct { + Error int `json:"error"` + Message string `json:"message"` + Code string `json:"enroll_code"` + Tenant int64 `json:"tenant_id"` + }{0, "ok", code, ten.ID} + b, _ := json.Marshal(resp) + writeJSON(w, http.StatusOK, string(b)) +} + +func (s *Server) HandleEnrollConsume(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeJSON(w, http.StatusMethodNotAllowed, `{"error":1,"message":"method not allowed"}`) + return + } + var req struct { + Code string `json:"code"` + NodeName string `json:"node"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.Code == "" || req.NodeName == "" { + writeJSON(w, http.StatusBadRequest, `{"error":1,"message":"bad request"}`) + return + } + if s.store == nil { + writeJSON(w, http.StatusInternalServerError, `{"error":1,"message":"store not ready"}`) + return + } + et, err := s.store.ConsumeEnrollToken(req.Code) + if err != nil { + s.store.IncEnrollAttempt(req.Code) + writeJSON(w, http.StatusUnauthorized, `{"error":1,"message":"invalid enroll"}`) + return + } + cred, err := s.store.CreateNodeCredential(et.TenantID, req.NodeName) + if err != nil { + writeJSON(w, http.StatusInternalServerError, `{"error":1,"message":"create node failed"}`) + return + } + resp := struct { + Error int `json:"error"` + Message string `json:"message"` + NodeID int64 `json:"node_id"` + Secret string `json:"node_secret"` + Tenant int64 `json:"tenant_id"` + }{0, "ok", cred.NodeID, cred.Secret, cred.TenantID} + b, _ := json.Marshal(resp) + writeJSON(w, http.StatusOK, string(b)) +} + +// placeholder to avoid unused import +var _ = store.Tenant{} diff --git a/internal/store/store.go b/internal/store/store.go new file mode 100644 index 0000000..0358fd5 --- /dev/null +++ b/internal/store/store.go @@ -0,0 +1,343 @@ +package store + +import ( + "crypto/rand" + "crypto/sha256" + "database/sql" + "encoding/hex" + "errors" + "fmt" + "net" + "time" + + _ "modernc.org/sqlite" +) + +type Store struct { + DB *sql.DB +} + +type Tenant struct { + ID int64 + Name string + Status int + Subnet string +} + +type APIKey struct { + ID int64 + TenantID int64 + Hash string + Scope string + Expires *time.Time + Status int +} + +type NodeCredential struct { + NodeID int64 + NodeName string + Secret string + VirtualIP string + TenantID int64 +} + +type EnrollToken struct { + ID int64 + TenantID int64 + Hash string + ExpiresAt int64 + UsedAt *int64 + MaxAttempt int + Attempts int + Status int +} + +func Open(dbPath string) (*Store, error) { + db, err := sql.Open("sqlite", dbPath) + if err != nil { + return nil, err + } + if _, err := db.Exec(`PRAGMA journal_mode=WAL;`); err != nil { + return nil, err + } + if _, err := db.Exec(`PRAGMA foreign_keys=ON;`); err != nil { + return nil, err + } + s := &Store{DB: db} + if err := s.migrate(); err != nil { + return nil, err + } + if err := s.ensureSubnetPool(); err != nil { + return nil, err + } + return s, nil +} + +func (s *Store) migrate() error { + stmts := []string{ + `CREATE TABLE IF NOT EXISTS tenants ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL UNIQUE, + status INTEGER NOT NULL DEFAULT 1, + subnet TEXT NOT NULL UNIQUE, + created_at INTEGER NOT NULL + );`, + `CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + tenant_id INTEGER NOT NULL, + role TEXT NOT NULL, + email TEXT, + password_hash TEXT, + status INTEGER NOT NULL DEFAULT 1, + created_at INTEGER NOT NULL, + FOREIGN KEY(tenant_id) REFERENCES tenants(id) + );`, + `CREATE TABLE IF NOT EXISTS api_keys ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + tenant_id INTEGER NOT NULL, + key_hash TEXT NOT NULL UNIQUE, + scope TEXT, + expires_at INTEGER, + status INTEGER NOT NULL DEFAULT 1, + created_at INTEGER NOT NULL, + FOREIGN KEY(tenant_id) REFERENCES tenants(id) + );`, + `CREATE TABLE IF NOT EXISTS nodes ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + tenant_id INTEGER NOT NULL, + node_name TEXT NOT NULL, + node_pubkey TEXT, + node_secret_hash TEXT, + virtual_ip TEXT, + status INTEGER NOT NULL DEFAULT 1, + last_seen INTEGER, + FOREIGN KEY(tenant_id) REFERENCES tenants(id) + );`, + `CREATE TABLE IF NOT EXISTS enroll_tokens ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + tenant_id INTEGER NOT NULL, + token_hash TEXT NOT NULL UNIQUE, + expires_at INTEGER NOT NULL, + used_at INTEGER, + max_attempt INTEGER NOT NULL DEFAULT 5, + attempts INTEGER NOT NULL DEFAULT 0, + status INTEGER NOT NULL DEFAULT 1, + created_at INTEGER NOT NULL, + FOREIGN KEY(tenant_id) REFERENCES tenants(id) + );`, + `CREATE TABLE IF NOT EXISTS peering_policies ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + src_tenant_id INTEGER NOT NULL, + dst_tenant_id INTEGER NOT NULL, + rules TEXT, + expires_at INTEGER, + status INTEGER NOT NULL DEFAULT 1, + created_at INTEGER NOT NULL, + FOREIGN KEY(src_tenant_id) REFERENCES tenants(id), + FOREIGN KEY(dst_tenant_id) REFERENCES tenants(id) + );`, + `CREATE TABLE IF NOT EXISTS audit_logs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + actor_type TEXT, + actor_id TEXT, + action TEXT, + target_type TEXT, + target_id TEXT, + detail TEXT, + ip TEXT, + created_at INTEGER NOT NULL + );`, + `CREATE TABLE IF NOT EXISTS subnet_pool ( + subnet TEXT PRIMARY KEY, + status INTEGER NOT NULL DEFAULT 0, + reserved INTEGER NOT NULL DEFAULT 0, + tenant_id INTEGER, + updated_at INTEGER NOT NULL + );`, + } + for _, stmt := range stmts { + if _, err := s.DB.Exec(stmt); err != nil { + return err + } + } + return nil +} + +func (s *Store) ensureSubnetPool() error { + // pool: 10.10.1.0/24 .. 10.10.254.0/24 + // reserve: 10.10.0.0/24 and 10.10.255.0/24 + rows, err := s.DB.Query(`SELECT COUNT(1) FROM subnet_pool;`) + if err != nil { + return err + } + defer rows.Close() + var count int + if rows.Next() { + _ = rows.Scan(&count) + } + if count > 0 { + return nil + } + now := time.Now().Unix() + insert := `INSERT INTO subnet_pool(subnet,status,reserved,tenant_id,updated_at) VALUES(?,?,?,?,?)` + // reserved + _, _ = s.DB.Exec(insert, "10.10.0.0/24", 0, 1, nil, now) + _, _ = s.DB.Exec(insert, "10.10.255.0/24", 0, 1, nil, now) + for i := 1; i <= 254; i++ { + sn := fmt.Sprintf("10.10.%d.0/24", i) + _, _ = s.DB.Exec(insert, sn, 0, 0, nil, now) + } + return nil +} + +func (s *Store) AllocateSubnet() (string, error) { + // find first available subnet + row := s.DB.QueryRow(`SELECT subnet FROM subnet_pool WHERE status=0 AND reserved=0 ORDER BY subnet LIMIT 1`) + var subnet string + if err := row.Scan(&subnet); err != nil { + return "", err + } + if subnet == "" { + return "", errors.New("no subnet available") + } + return subnet, nil +} + +func (s *Store) CreateTenant(name string) (*Tenant, error) { + sn, err := s.AllocateSubnet() + if err != nil { + return nil, err + } + now := time.Now().Unix() + res, err := s.DB.Exec(`INSERT INTO tenants(name,status,subnet,created_at) VALUES(?,?,?,?)`, name, 1, sn, now) + if err != nil { + return nil, err + } + id, _ := res.LastInsertId() + _, _ = s.DB.Exec(`UPDATE subnet_pool SET status=1, tenant_id=?, updated_at=? WHERE subnet=?`, id, now, sn) + return &Tenant{ID: id, Name: name, Status: 1, Subnet: sn}, nil +} + +func (s *Store) CreateNodeCredential(tenantID int64, nodeName string) (*NodeCredential, error) { + secret := randToken() + h := hashTokenString(secret) + res, err := s.DB.Exec(`INSERT INTO nodes(tenant_id,node_name,node_secret_hash,status) VALUES(?,?,?,1)`, tenantID, nodeName, h) + if err != nil { + return nil, err + } + id, _ := res.LastInsertId() + return &NodeCredential{NodeID: id, NodeName: nodeName, Secret: secret, TenantID: tenantID}, nil +} + +func (s *Store) VerifyNodeSecret(nodeName, secret string) (*Tenant, error) { + h := hashTokenString(secret) + row := s.DB.QueryRow(`SELECT t.id,t.name,t.status,t.subnet FROM nodes n JOIN tenants t ON n.tenant_id=t.id WHERE n.node_name=? AND n.node_secret_hash=? AND n.status=1`, nodeName, h) + var t Tenant + if err := row.Scan(&t.ID, &t.Name, &t.Status, &t.Subnet); err != nil { + return nil, err + } + return &t, nil +} + +func (s *Store) GetTenantByToken(token uint64) (*Tenant, error) { + h := hashToken(token) + row := s.DB.QueryRow(`SELECT t.id,t.name,t.status,t.subnet FROM api_keys k JOIN tenants t ON k.tenant_id=t.id WHERE k.key_hash=? AND k.status=1`, h) + var t Tenant + if err := row.Scan(&t.ID, &t.Name, &t.Status, &t.Subnet); err != nil { + return nil, err + } + return &t, nil +} + +func (s *Store) CreateAPIKey(tenantID int64, scope string, ttl time.Duration) (string, error) { + token := randToken() + h := hashTokenString(token) + now := time.Now().Unix() + if ttl > 0 { + e := time.Now().Add(ttl) + _, err := s.DB.Exec(`INSERT INTO api_keys(tenant_id,key_hash,scope,expires_at,status,created_at) VALUES(?,?,?,?,1,?)`, tenantID, h, scope, e.Unix(), now) + return token, err + } + _, err := s.DB.Exec(`INSERT INTO api_keys(tenant_id,key_hash,scope,expires_at,status,created_at) VALUES(?,?,?,?,1,?)`, tenantID, h, scope, nil, now) + return token, err +} + +func (s *Store) CreateEnrollToken(tenantID int64, ttl time.Duration, maxAttempt int) (string, error) { + code := randToken() + h := hashTokenString(code) + exp := time.Now().Add(ttl).Unix() + now := time.Now().Unix() + _, err := s.DB.Exec(`INSERT INTO enroll_tokens(tenant_id,token_hash,expires_at,max_attempt,attempts,status,created_at) VALUES(?,?,?,?,0,1,?)`, tenantID, h, exp, maxAttempt, now) + return code, err +} + +func (s *Store) ConsumeEnrollToken(code string) (*EnrollToken, error) { + h := hashTokenString(code) + now := time.Now().Unix() + row := s.DB.QueryRow(`SELECT id,tenant_id,expires_at,used_at,max_attempt,attempts,status FROM enroll_tokens WHERE token_hash=?`, h) + var et EnrollToken + var used sql.NullInt64 + if err := row.Scan(&et.ID, &et.TenantID, &et.ExpiresAt, &used, &et.MaxAttempt, &et.Attempts, &et.Status); err != nil { + return nil, err + } + if used.Valid { + return nil, errors.New("token already used") + } + if et.Status != 1 { + return nil, errors.New("token disabled") + } + if et.Attempts >= et.MaxAttempt { + return nil, errors.New("token attempts exceeded") + } + if now > et.ExpiresAt { + return nil, errors.New("token expired") + } + // mark used + _, err := s.DB.Exec(`UPDATE enroll_tokens SET used_at=?, attempts=attempts+1 WHERE id=?`, now, et.ID) + if err != nil { + return nil, err + } + et.UsedAt = &now + return &et, nil +} + +func (s *Store) IncEnrollAttempt(code string) { + h := hashTokenString(code) + _, _ = s.DB.Exec(`UPDATE enroll_tokens SET attempts=attempts+1 WHERE token_hash=?`, h) +} + +func hashToken(token uint64) string { + b := make([]byte, 8) + for i := uint(0); i < 8; i++ { + b[7-i] = byte(token >> (i * 8)) + } + return hashTokenBytes(b) +} + +func hashTokenString(token string) string { + return hashTokenBytes([]byte(token)) +} + +func (s *Store) VerifyAPIKey(token string) (*Tenant, error) { + h := hashTokenString(token) + row := s.DB.QueryRow(`SELECT t.id,t.name,t.status,t.subnet FROM api_keys k JOIN tenants t ON k.tenant_id=t.id WHERE k.key_hash=? AND k.status=1`, h) + var t Tenant + if err := row.Scan(&t.ID, &t.Name, &t.Status, &t.Subnet); err != nil { + return nil, err + } + return &t, nil +} + +func hashTokenBytes(b []byte) string { + h := sha256.Sum256(b) + return hex.EncodeToString(h[:]) +} + +func randToken() string { + b := make([]byte, 24) + _, _ = rand.Read(b) + return hex.EncodeToString(b) +} + +// helper to avoid unused import (net) +var _ = net.IPv4len diff --git a/pkg/config/config.go b/pkg/config/config.go index 09ad439..c73b188 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -7,6 +7,7 @@ import ( "fmt" "os" "strconv" + "strings" ) // Version info (set via -ldflags) @@ -46,8 +47,9 @@ type ServerConfig struct { CertFile string `json:"certFile"` KeyFile string `json:"keyFile"` LogLevel int `json:"logLevel"` // 0=debug, 1=info, 2=warn, 3=error - Token uint64 `json:"token"` // master token for auth - JWTKey string `json:"jwtKey"` // auto-generated if empty + Token uint64 `json:"token"` // master token for auth + Tokens []uint64 `json:"tokens"` // additional tenant tokens + JWTKey string `json:"jwtKey"` // auto-generated if empty AdminUser string `json:"adminUser"` AdminPass string `json:"adminPass"` @@ -82,6 +84,18 @@ func (c *ServerConfig) FillFromEnv() { if v := os.Getenv("INP2PS_TOKEN"); v != "" { c.Token, _ = strconv.ParseUint(v, 10, 64) } + if v := os.Getenv("INP2PS_TOKENS"); v != "" { + parts := strings.Split(v, ",") + for _, p := range parts { + p = strings.TrimSpace(p) + if p == "" { + continue + } + if tv, err := strconv.ParseUint(p, 10, 64); err == nil { + c.Tokens = append(c.Tokens, tv) + } + } + } if v := os.Getenv("INP2PS_CERT"); v != "" { c.CertFile = v } @@ -96,8 +110,8 @@ func (c *ServerConfig) FillFromEnv() { } func (c *ServerConfig) Validate() error { - if c.Token == 0 { - return fmt.Errorf("token is required (INP2PS_TOKEN or -token)") + if c.Token == 0 && len(c.Tokens) == 0 { + return fmt.Errorf("token is required (INP2PS_TOKEN or INP2PS_TOKENS)") } return nil } @@ -108,6 +122,7 @@ type ClientConfig struct { ServerPort int `json:"serverPort"` Node string `json:"node"` Token uint64 `json:"token"` + NodeSecret string `json:"nodeSecret,omitempty"` User string `json:"user,omitempty"` Insecure bool `json:"insecure"` // skip TLS verify @@ -156,8 +171,8 @@ func (c *ClientConfig) Validate() error { if c.ServerHost == "" { return fmt.Errorf("serverHost is required") } - if c.Token == 0 { - return fmt.Errorf("token is required") + if c.Token == 0 && c.NodeSecret == "" { + return fmt.Errorf("token or nodeSecret is required") } if c.Node == "" { hostname, _ := os.Hostname() diff --git a/pkg/protocol/protocol.go b/pkg/protocol/protocol.go index 1ad094f..f80e94b 100644 --- a/pkg/protocol/protocol.go +++ b/pkg/protocol/protocol.go @@ -192,6 +192,7 @@ func DecodePayload(data []byte, v interface{}) error { type LoginReq struct { Node string `json:"node"` Token uint64 `json:"token"` + NodeSecret string `json:"nodeSecret,omitempty"` User string `json:"user,omitempty"` Version string `json:"version"` NATType NATType `json:"natType"` diff --git a/pkg/signal/conn.go b/pkg/signal/conn.go index 1695da9..6029170 100644 --- a/pkg/signal/conn.go +++ b/pkg/signal/conn.go @@ -103,6 +103,33 @@ func (c *Conn) Request(mainType, subType uint16, payload interface{}, // ReadLoop reads messages and dispatches to handlers. Blocks until error or Close(). func (c *Conn) ReadLoop() error { + // keepalive to avoid idle close (read deadline = 3x ping interval) + _ = c.ws.SetReadDeadline(time.Now().Add(90 * time.Second)) + c.ws.SetPongHandler(func(string) error { + _ = c.ws.SetReadDeadline(time.Now().Add(90 * time.Second)) + return nil + }) + + // Send ping frames periodically to keep NAT/WSS alive + // Increased frequency to 10s for better resilience against network hiccups + go func() { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + for { + select { + case <-c.quit: + return + case <-ticker.C: + c.writeMu.Lock() + _ = c.ws.SetWriteDeadline(time.Now().Add(5 * time.Second)) + err := c.ws.WriteMessage(websocket.PingMessage, []byte(time.Now().Format("20060102150405"))) + if err != nil { + log.Printf("[signal] ping failed: %v, will reconnect", err) + } + c.writeMu.Unlock() + } + } + }() for { _, msg, err := c.ws.ReadMessage() if err != nil { diff --git a/web/index.html b/web/index.html index c47c462..05bd3e7 100644 --- a/web/index.html +++ b/web/index.html @@ -224,6 +224,42 @@ + + +