feat: audit api, sdwan persist, relay fallback updates
This commit is contained in:
@@ -3,6 +3,7 @@ package client
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
@@ -45,6 +46,7 @@ type Client struct {
|
||||
sdwanStop chan struct{}
|
||||
tunMu sync.Mutex
|
||||
tunFile *os.File
|
||||
sdwanPath string
|
||||
quit chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
@@ -53,6 +55,7 @@ type Client struct {
|
||||
func New(cfg config.ClientConfig) *Client {
|
||||
c := &Client{
|
||||
cfg: cfg,
|
||||
sdwanPath: "/etc/inp2p/sdwan.json",
|
||||
natType: protocol.NATUnknown,
|
||||
tunnels: make(map[string]*tunnel.Tunnel),
|
||||
sdwanStop: make(chan struct{}),
|
||||
@@ -62,7 +65,7 @@ func New(cfg config.ClientConfig) *Client {
|
||||
}
|
||||
|
||||
if cfg.RelayEnabled {
|
||||
c.relayMgr = relay.NewManager(cfg.RelayPort, true, cfg.SuperRelay, cfg.MaxRelayLoad, cfg.Token)
|
||||
c.relayMgr = relay.NewManager(cfg.RelayPort, true, cfg.SuperRelay, cfg.MaxRelayLoad, cfg.Token, cfg.ShareBandwidth)
|
||||
}
|
||||
|
||||
return c
|
||||
@@ -95,7 +98,7 @@ func (c *Client) connectAndRun() error {
|
||||
c.publicIP = natResult.PublicIP
|
||||
c.publicPort = natResult.Port1
|
||||
c.localPort = natResult.LocalPort
|
||||
log.Printf("[client] SENDING_LOGIN_TOKEN=%d NAT type=%s, publicIP=%s, publicPort=%d, localPort=%d", c.natType, c.publicIP, c.publicPort, c.localPort)
|
||||
log.Printf("[client] SENDING_LOGIN_TOKEN=%d NAT type=%s, publicIP=%s, publicPort=%d, localPort=%d", c.cfg.Token, c.natType, c.publicIP, c.publicPort, c.localPort)
|
||||
|
||||
// 2. WSS Connect
|
||||
scheme := "ws"
|
||||
@@ -130,12 +133,14 @@ func (c *Client) connectAndRun() error {
|
||||
loginReq := protocol.LoginReq{
|
||||
Node: c.cfg.Node,
|
||||
Token: c.cfg.Token,
|
||||
NodeSecret: c.cfg.NodeSecret,
|
||||
User: c.cfg.User,
|
||||
Version: config.Version,
|
||||
NATType: c.natType,
|
||||
ShareBandwidth: c.cfg.ShareBandwidth,
|
||||
RelayEnabled: c.cfg.RelayEnabled,
|
||||
SuperRelay: c.cfg.SuperRelay,
|
||||
RelayOfficial: c.cfg.RelayOfficial,
|
||||
PublicIP: c.publicIP,
|
||||
PublicPort: c.publicPort,
|
||||
}
|
||||
@@ -236,7 +241,6 @@ func (c *Client) registerHandlers() {
|
||||
return nil
|
||||
}
|
||||
log.Printf("[client] sdwan config received: gateway=%s nodes=%d mode=%s", cfg.GatewayCIDR, len(cfg.Nodes), cfg.Mode)
|
||||
_ = os.WriteFile("sdwan.json", data[protocol.HeaderSize:], 0644)
|
||||
|
||||
// apply control+data plane
|
||||
if err := c.applySDWAN(cfg); err != nil {
|
||||
@@ -396,7 +400,7 @@ func (c *Client) connectApp(app config.AppConfig) {
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("[client] connect coordination failed for %s: %v", app.PeerNode, err)
|
||||
c.tryRelay(app)
|
||||
c.tryRelay(app, "tenant")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -404,7 +408,7 @@ func (c *Client) connectApp(app config.AppConfig) {
|
||||
protocol.DecodePayload(rspData, &rsp)
|
||||
if rsp.Error != 0 {
|
||||
log.Printf("[client] connect denied: %s", rsp.Detail)
|
||||
c.tryRelay(app)
|
||||
c.tryRelay(app, "tenant")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -420,7 +424,7 @@ func (c *Client) connectApp(app config.AppConfig) {
|
||||
|
||||
if result.Error != nil {
|
||||
log.Printf("[client] punch failed for %s: %v", app.PeerNode, result.Error)
|
||||
c.tryRelay(app)
|
||||
c.tryRelay(app, "tenant")
|
||||
c.reportConnect(app, protocol.ReportConnect{
|
||||
PeerNode: app.PeerNode, Error: result.Error.Error(),
|
||||
NATType: c.natType, PeerNATType: rsp.Peer.NATType,
|
||||
@@ -448,12 +452,12 @@ func (c *Client) connectApp(app config.AppConfig) {
|
||||
}
|
||||
|
||||
// tryRelay attempts to use a relay node.
|
||||
func (c *Client) tryRelay(app config.AppConfig) {
|
||||
log.Printf("[client] trying relay for %s", app.PeerNode)
|
||||
func (c *Client) tryRelay(app config.AppConfig, mode string) {
|
||||
log.Printf("[client] trying relay(%s) for %s", mode, app.PeerNode)
|
||||
|
||||
rspData, err := c.conn.Request(
|
||||
protocol.MsgRelay, protocol.SubRelayNodeReq,
|
||||
protocol.RelayNodeReq{PeerNode: app.PeerNode},
|
||||
protocol.RelayNodeReq{PeerNode: app.PeerNode, Mode: mode},
|
||||
protocol.MsgRelay, protocol.SubRelayNodeRsp,
|
||||
10*time.Second,
|
||||
)
|
||||
@@ -465,6 +469,11 @@ func (c *Client) tryRelay(app config.AppConfig) {
|
||||
var rsp protocol.RelayNodeRsp
|
||||
protocol.DecodePayload(rspData, &rsp)
|
||||
if rsp.Error != 0 {
|
||||
if mode != "official" {
|
||||
log.Printf("[client] no relay available for %s, fallback official", app.PeerNode)
|
||||
go c.tryRelay(app, "official")
|
||||
return
|
||||
}
|
||||
log.Printf("[client] no relay available for %s", app.PeerNode)
|
||||
return
|
||||
}
|
||||
@@ -545,6 +554,19 @@ func (c *Client) reportConnect(app config.AppConfig, rc protocol.ReportConnect)
|
||||
c.conn.Write(protocol.MsgReport, protocol.SubReportConnect, rc)
|
||||
}
|
||||
|
||||
func (c *Client) writeSDWANConfig(cfg protocol.SDWANConfig) error {
|
||||
path := c.sdwanPath
|
||||
if path == "" {
|
||||
path = "/etc/inp2p/sdwan.json"
|
||||
}
|
||||
b, err := json.MarshalIndent(cfg, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_ = os.MkdirAll("/etc/inp2p", 0755)
|
||||
return os.WriteFile(path, b, 0644)
|
||||
}
|
||||
|
||||
func (c *Client) applySDWAN(cfg protocol.SDWANConfig) error {
|
||||
selfIP := ""
|
||||
for _, n := range cfg.Nodes {
|
||||
@@ -578,11 +600,24 @@ func (c *Client) applySDWAN(cfg protocol.SDWANConfig) error {
|
||||
// fallback broad route for hub mode / compatibility
|
||||
_ = runCmd("ip", "route", "replace", pfx.String(), "dev", "optun")
|
||||
|
||||
// refresh rule/table 100 for sdwan
|
||||
_ = runCmd("ip", "rule", "add", "pref", "100", "from", selfIP, "table", "100")
|
||||
_ = runCmd("ip", "route", "replace", pfx.String(), "dev", "optun", "table", "100")
|
||||
|
||||
c.sdwanMu.Lock()
|
||||
c.sdwan = cfg
|
||||
c.sdwanIP = selfIP
|
||||
c.sdwanMu.Unlock()
|
||||
|
||||
// persist sdwan config for local use/diagnostics
|
||||
if err := c.writeSDWANConfig(cfg); err != nil {
|
||||
log.Printf("[client] write sdwan.json failed: %v", err)
|
||||
}
|
||||
|
||||
// Apply subnet proxy (if configured)
|
||||
if err := c.applySubnetProxy(cfg); err != nil {
|
||||
log.Printf("[client] applySubnetProxy failed: %v", err)
|
||||
}
|
||||
// Try to start TUN reader, but don't fail SDWAN apply if it errors
|
||||
if err := c.ensureTUNReader(); err != nil {
|
||||
log.Printf("[client] ensureTUNReader failed (non-fatal): %v", err)
|
||||
@@ -591,6 +626,39 @@ func (c *Client) applySDWAN(cfg protocol.SDWANConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// applySubnetProxy configures local subnet proxying based on SDWAN config.
|
||||
func (c *Client) applySubnetProxy(cfg protocol.SDWANConfig) error {
|
||||
if len(cfg.SubnetProxies) == 0 {
|
||||
return nil
|
||||
}
|
||||
self := c.cfg.Node
|
||||
for _, sp := range cfg.SubnetProxies {
|
||||
if sp.Node != self {
|
||||
// for non-proxy nodes, add route to virtualCIDR via proxy node IP
|
||||
proxyIP := ""
|
||||
for _, n := range cfg.Nodes {
|
||||
if n.Node == sp.Node {
|
||||
proxyIP = strings.TrimSpace(n.IP)
|
||||
break
|
||||
}
|
||||
}
|
||||
if proxyIP == "" {
|
||||
continue
|
||||
}
|
||||
_ = runCmd("ip", "route", "replace", sp.VirtualCIDR, "via", proxyIP, "dev", "optun")
|
||||
continue
|
||||
}
|
||||
// This node is the proxy
|
||||
_ = runCmd("sysctl", "-w", "net.ipv4.ip_forward=1")
|
||||
// map virtualCIDR -> localCIDR (NETMAP)
|
||||
if sp.VirtualCIDR != "" && sp.LocalCIDR != "" {
|
||||
_ = runCmd("iptables", "-t", "nat", "-A", "PREROUTING", "-d", sp.VirtualCIDR, "-j", "NETMAP", "--to", sp.LocalCIDR)
|
||||
_ = runCmd("iptables", "-t", "nat", "-A", "POSTROUTING", "-s", sp.LocalCIDR, "-j", "MASQUERADE")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) ensureTUNReader() error {
|
||||
c.tunMu.Lock()
|
||||
defer c.tunMu.Unlock()
|
||||
@@ -637,13 +705,13 @@ func (c *Client) tunReadLoop() {
|
||||
if f == nil {
|
||||
return
|
||||
}
|
||||
n, err := f.Read(buf)
|
||||
n, err := unix.Read(int(f.Fd()), buf)
|
||||
if err != nil {
|
||||
if c.IsStopping() {
|
||||
return
|
||||
}
|
||||
// Log only real errors, not EOF or timeout
|
||||
if err.Error() != "EOF" && err.Error() != "resource temporarily unavailable" {
|
||||
// Ignore transient errors
|
||||
if err != unix.EINTR && err != unix.EAGAIN {
|
||||
log.Printf("[client] tun read error: %v", err)
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
56
internal/server/admin_settings.go
Normal file
56
internal/server/admin_settings.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// GET /api/v1/admin/settings
|
||||
// POST /api/v1/admin/settings {key,value}
|
||||
func (s *Server) HandleAdminSettings(w http.ResponseWriter, r *http.Request) {
|
||||
if s.store == nil {
|
||||
writeJSON(w, http.StatusInternalServerError, `{"error":1,"message":"store not ready"}`)
|
||||
return
|
||||
}
|
||||
if r.Method == http.MethodGet {
|
||||
settings, err := s.store.ListSettings()
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, `{"error":1,"message":"list settings failed"}`)
|
||||
return
|
||||
}
|
||||
b, _ := json.Marshal(map[string]any{"error": 0, "settings": settings})
|
||||
writeJSON(w, http.StatusOK, string(b))
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodPost {
|
||||
writeJSON(w, http.StatusMethodNotAllowed, `{"error":1,"message":"method not allowed"}`)
|
||||
return
|
||||
}
|
||||
var req struct {
|
||||
Key string `json:"key"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.Key == "" {
|
||||
writeJSON(w, http.StatusBadRequest, `{"error":1,"message":"bad request"}`)
|
||||
return
|
||||
}
|
||||
// allowlist
|
||||
switch req.Key {
|
||||
case "advanced_impersonate", "advanced_force_network", "advanced_cross_tenant":
|
||||
default:
|
||||
writeJSON(w, http.StatusBadRequest, `{"error":1,"message":"invalid key"}`)
|
||||
return
|
||||
}
|
||||
if req.Value == "" {
|
||||
req.Value = "0"
|
||||
}
|
||||
if err := s.store.SetSetting(req.Key, req.Value); err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, `{"error":1,"message":"set failed"}`)
|
||||
return
|
||||
}
|
||||
if ac := GetAccessContext(r); ac != nil {
|
||||
_ = s.store.AddAuditLog(ac.Kind, fmt.Sprintf("%d", ac.UserID), "setting_change", "setting", req.Key, req.Value, r.RemoteAddr)
|
||||
}
|
||||
writeJSON(w, http.StatusOK, `{"error":0,"message":"ok"}`)
|
||||
}
|
||||
40
internal/server/audit_api.go
Normal file
40
internal/server/audit_api.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// GET /api/v1/admin/audit?tenant=3&limit=50&offset=0
|
||||
func (s *Server) HandleAdminAudit(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
limit := 50
|
||||
offset := 0
|
||||
if v := r.URL.Query().Get("limit"); v != "" {
|
||||
if i, err := strconv.Atoi(v); err == nil && i > 0 && i <= 500 {
|
||||
limit = i
|
||||
}
|
||||
}
|
||||
if v := r.URL.Query().Get("offset"); v != "" {
|
||||
if i, err := strconv.Atoi(v); err == nil && i >= 0 {
|
||||
offset = i
|
||||
}
|
||||
}
|
||||
tenantID := int64(0)
|
||||
if v := r.URL.Query().Get("tenant"); v != "" {
|
||||
if i, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
tenantID = i
|
||||
}
|
||||
}
|
||||
logs, err := s.store.ListAuditLogs(tenantID, limit, offset)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{"error": 0, "logs": logs})
|
||||
}
|
||||
@@ -26,6 +26,17 @@ func (s *Server) ResolveAccess(r *http.Request, masterToken uint64) (*AccessCont
|
||||
return s.ResolveTenantAccessToken(tok)
|
||||
}
|
||||
|
||||
func GetAccessContext(r *http.Request) *AccessContext {
|
||||
v := r.Context().Value(ServerCtxKeyAccess{})
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
if ac, ok := v.(*AccessContext); ok {
|
||||
return ac
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) ResolveTenantAccessToken(tok string) (*AccessContext, bool) {
|
||||
if tok == "" || s.store == nil {
|
||||
return nil, false
|
||||
|
||||
@@ -3,6 +3,7 @@ package server
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/openp2p-cn/inp2p/pkg/auth"
|
||||
@@ -68,6 +69,19 @@ func (s *Server) HandleConnectReq(from *NodeInfo, req protocol.ConnectReq) error
|
||||
return nil
|
||||
}
|
||||
|
||||
// Debug: force relay path if explicit env set
|
||||
if os.Getenv("INP2P_FORCE_RELAY") == "1" {
|
||||
log.Printf("[coord] %s → %s: force relay requested", from.Name, to.Name)
|
||||
from.Conn.Write(protocol.MsgPush, protocol.SubPushConnectRsp, protocol.ConnectRsp{
|
||||
Error: 0,
|
||||
From: to.Name,
|
||||
To: from.Name,
|
||||
Peer: toParams,
|
||||
Detail: "punch-failed",
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Push PunchStart to BOTH sides simultaneously
|
||||
punchID := fmt.Sprintf("%s-%s-%d", from.Name, to.Name, time.Now().UnixMilli())
|
||||
|
||||
|
||||
7
internal/server/ctx.go
Normal file
7
internal/server/ctx.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package server
|
||||
|
||||
// ctx key alias for main
|
||||
// NOTE: main sets this type to avoid import cycles
|
||||
// use GetAccessContext to retrieve
|
||||
|
||||
type ServerCtxKeyAccess struct{}
|
||||
Binary file not shown.
Binary file not shown.
@@ -121,5 +121,21 @@ func normalizeSDWAN(c protocol.SDWANConfig) protocol.SDWANConfig {
|
||||
c.Nodes = append(c.Nodes, protocol.SDWANNode{Node: node, IP: ip})
|
||||
}
|
||||
sort.Slice(c.Nodes, func(i, j int) bool { return c.Nodes[i].Node < c.Nodes[j].Node })
|
||||
|
||||
// de-dup subnet proxies by node+cidr
|
||||
if len(c.SubnetProxies) > 0 {
|
||||
m2 := make(map[string]protocol.SubnetProxy)
|
||||
for _, sp := range c.SubnetProxies {
|
||||
if sp.Node == "" || sp.VirtualCIDR == "" || sp.LocalCIDR == "" {
|
||||
continue
|
||||
}
|
||||
key := sp.Node + "|" + sp.VirtualCIDR + "|" + sp.LocalCIDR
|
||||
m2[key] = sp
|
||||
}
|
||||
c.SubnetProxies = c.SubnetProxies[:0]
|
||||
for _, sp := range m2 {
|
||||
c.SubnetProxies = append(c.SubnetProxies, sp)
|
||||
}
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/netip"
|
||||
|
||||
@@ -24,7 +25,7 @@ func (s *Server) SetSDWAN(cfg protocol.SDWANConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) SetSDWANTenant(tenantID int64, cfg protocol.SDWANConfig) error {
|
||||
func (s *Server) SetSDWANTenant(tenantID int64, cfg protocol.SDWANConfig, actorType, actorID, ip string) error {
|
||||
if cfg.Mode == "hub" {
|
||||
if cfg.HubNode == "" {
|
||||
return errors.New("hub mode requires hubNode")
|
||||
@@ -37,6 +38,10 @@ func (s *Server) SetSDWANTenant(tenantID int64, cfg protocol.SDWANConfig) error
|
||||
if err := s.sdwan.saveTenant(tenantID, cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
if actorType != "" && s.store != nil {
|
||||
detail := fmt.Sprintf("mode=%s hub=%s nodes=%d subnetProxies=%d", cfg.Mode, cfg.HubNode, len(cfg.Nodes), len(cfg.SubnetProxies))
|
||||
_ = s.store.AddAuditLog(actorType, actorID, "sdwan_update", "tenant", fmt.Sprintf("%d", tenantID), detail, ip)
|
||||
}
|
||||
s.broadcastSDWANTenant(tenantID, s.sdwan.getTenant(tenantID))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -33,6 +33,7 @@ type NodeInfo struct {
|
||||
ShareBandwidth int `json:"shareBandwidth"`
|
||||
RelayEnabled bool `json:"relayEnabled"`
|
||||
SuperRelay bool `json:"superRelay"`
|
||||
RelayOfficial bool `json:"relayOfficial"`
|
||||
HasIPv4 int `json:"hasIPv4"`
|
||||
IPv6 string `json:"ipv6"`
|
||||
LoginTime time.Time `json:"loginTime"`
|
||||
@@ -78,12 +79,12 @@ func New(cfg config.ServerConfig) *Server {
|
||||
if err != nil {
|
||||
log.Printf("[server] open store failed: %v", err)
|
||||
} else {
|
||||
// bootstrap default admin/admin in tenant 1
|
||||
// bootstrap default tenant if missing
|
||||
if _, gErr := st.GetTenantByID(1); gErr != nil {
|
||||
if _, _, _, cErr := st.CreateTenantWithUsers("default", "admin", "admin"); cErr != nil {
|
||||
log.Printf("[server] bootstrap default tenant failed: %v", cErr)
|
||||
} else {
|
||||
log.Printf("[server] bootstrap default tenant created (tenant=1, admin/admin)")
|
||||
log.Printf("[server] bootstrap default tenant created (tenant=1)")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -160,7 +161,7 @@ func (s *Server) GetOnlineNodesByTenant(tenantID int64) []*NodeInfo {
|
||||
}
|
||||
|
||||
// GetRelayNodes returns nodes that can serve as relay.
|
||||
// Priority: same-user private relay → super relay
|
||||
// Priority: same-user private relay → super relay (exclude official relays)
|
||||
func (s *Server) GetRelayNodes(forUser string, excludeNodes ...string) []*NodeInfo {
|
||||
excludeSet := make(map[string]bool)
|
||||
for _, n := range excludeNodes {
|
||||
@@ -172,7 +173,7 @@ func (s *Server) GetRelayNodes(forUser string, excludeNodes ...string) []*NodeIn
|
||||
|
||||
var privateRelays, superRelays []*NodeInfo
|
||||
for _, n := range s.nodes {
|
||||
if !n.IsOnline() || excludeSet[n.Name] || !n.RelayEnabled {
|
||||
if !n.IsOnline() || excludeSet[n.Name] || !n.RelayEnabled || n.RelayOfficial {
|
||||
continue
|
||||
}
|
||||
if n.User == forUser {
|
||||
@@ -200,13 +201,33 @@ func (s *Server) GetRelayNodesByTenant(tenantID int64, excludeNodes ...string) [
|
||||
if !n.IsOnline() || excludeSet[n.Name] {
|
||||
continue
|
||||
}
|
||||
if n.TenantID == tenantID && (n.RelayEnabled || n.SuperRelay) {
|
||||
if n.TenantID == tenantID && (n.RelayEnabled || n.SuperRelay) && !n.RelayOfficial {
|
||||
relays = append(relays, n)
|
||||
}
|
||||
}
|
||||
return relays
|
||||
}
|
||||
|
||||
// GetOfficialRelays returns official relay nodes (global pool)
|
||||
func (s *Server) GetOfficialRelays(excludeNodes ...string) []*NodeInfo {
|
||||
excludeSet := make(map[string]bool)
|
||||
for _, n := range excludeNodes {
|
||||
excludeSet[n] = true
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
var relays []*NodeInfo
|
||||
for _, n := range s.nodes {
|
||||
if !n.IsOnline() || excludeSet[n.Name] || !n.RelayEnabled || !n.RelayOfficial {
|
||||
continue
|
||||
}
|
||||
relays = append(relays, n)
|
||||
}
|
||||
return relays
|
||||
}
|
||||
|
||||
// HandleWS is the WebSocket handler for client connections.
|
||||
func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) {
|
||||
ws, err := s.upgrader.Upgrade(w, r, nil)
|
||||
@@ -287,6 +308,7 @@ func (s *Server) HandleWS(w http.ResponseWriter, r *http.Request) {
|
||||
ShareBandwidth: loginReq.ShareBandwidth,
|
||||
RelayEnabled: loginReq.RelayEnabled,
|
||||
SuperRelay: loginReq.SuperRelay,
|
||||
RelayOfficial: loginReq.RelayOfficial,
|
||||
PublicIP: loginReq.PublicIP,
|
||||
PublicPort: loginReq.PublicPort,
|
||||
LoginTime: time.Now(),
|
||||
@@ -464,23 +486,68 @@ func (s *Server) registerHandlers(conn *signal.Conn, node *NodeInfo) {
|
||||
|
||||
// handleRelayNodeReq finds and returns the best relay node.
|
||||
func (s *Server) handleRelayNodeReq(conn *signal.Conn, requester *NodeInfo, req protocol.RelayNodeReq) error {
|
||||
relays := s.GetRelayNodes(requester.User, requester.Name, req.PeerNode)
|
||||
|
||||
if len(relays) == 0 {
|
||||
mode := "tenant"
|
||||
if req.Mode == "official" {
|
||||
mode = "official"
|
||||
official := s.GetOfficialRelays(requester.Name, req.PeerNode)
|
||||
if len(official) == 0 {
|
||||
return conn.Write(protocol.MsgRelay, protocol.SubRelayNodeRsp, protocol.RelayNodeRsp{Error: 1})
|
||||
}
|
||||
relay := official[0]
|
||||
totp := auth.GenTOTP(relay.Token, time.Now().Unix())
|
||||
log.Printf("[server] relay selected: %s (%s) for %s → %s", relay.Name, mode, requester.Name, req.PeerNode)
|
||||
return conn.Write(protocol.MsgRelay, protocol.SubRelayNodeRsp, protocol.RelayNodeRsp{
|
||||
Error: 1,
|
||||
RelayName: relay.Name,
|
||||
RelayIP: relay.PublicIP,
|
||||
RelayPort: config.DefaultRelayPort,
|
||||
RelayToken: totp,
|
||||
Mode: mode,
|
||||
Error: 0,
|
||||
})
|
||||
}
|
||||
// prefer hub relay if sdwan mode=hub
|
||||
if requester.TenantID > 0 && s.sdwan != nil {
|
||||
cfg := s.sdwan.getTenant(requester.TenantID)
|
||||
if cfg.Mode == "hub" && cfg.HubNode != "" && cfg.HubNode != requester.Name && cfg.HubNode != req.PeerNode {
|
||||
hub := s.GetNode(cfg.HubNode)
|
||||
if hub != nil && hub.IsOnline() && hub.TenantID == requester.TenantID && hub.RelayEnabled {
|
||||
log.Printf("[server] relay selected: %s (hub) for %s → %s", hub.Name, requester.Name, req.PeerNode)
|
||||
totp := auth.GenTOTP(hub.Token, time.Now().Unix())
|
||||
return conn.Write(protocol.MsgRelay, protocol.SubRelayNodeRsp, protocol.RelayNodeRsp{
|
||||
RelayName: hub.Name,
|
||||
RelayIP: hub.PublicIP,
|
||||
RelayPort: config.DefaultRelayPort,
|
||||
RelayToken: totp,
|
||||
Mode: "private",
|
||||
Error: 0,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
// prefer same-tenant relays, exclude requester and peer
|
||||
relays := s.GetRelayNodesByTenant(requester.TenantID, requester.Name, req.PeerNode)
|
||||
if len(relays) == 0 {
|
||||
// fallback to same-user (private) then super
|
||||
relays = s.GetRelayNodes(requester.User, requester.Name, req.PeerNode)
|
||||
if len(relays) == 0 {
|
||||
// final fallback: official relays
|
||||
official := s.GetOfficialRelays(requester.Name, req.PeerNode)
|
||||
if len(official) == 0 {
|
||||
return conn.Write(protocol.MsgRelay, protocol.SubRelayNodeRsp, protocol.RelayNodeRsp{Error: 1})
|
||||
}
|
||||
relays = official
|
||||
mode = "official"
|
||||
} else if relays[0].User != requester.User {
|
||||
mode = "super"
|
||||
} else {
|
||||
mode = "private"
|
||||
}
|
||||
}
|
||||
|
||||
// Pick the first (best) relay
|
||||
relay := relays[0]
|
||||
totp := auth.GenTOTP(relay.Token, time.Now().Unix())
|
||||
|
||||
mode := "private"
|
||||
if relay.User != requester.User {
|
||||
mode = "super"
|
||||
}
|
||||
|
||||
log.Printf("[server] relay selected: %s (%s) for %s → %s", relay.Name, mode, requester.Name, req.PeerNode)
|
||||
|
||||
return conn.Write(protocol.MsgRelay, protocol.SubRelayNodeRsp, protocol.RelayNodeRsp{
|
||||
@@ -510,6 +577,7 @@ func (s *Server) PushConnect(fromNode *NodeInfo, toNodeName string, app protocol
|
||||
FromIP: fromNode.PublicIP,
|
||||
Peer: protocol.PunchParams{
|
||||
IP: fromNode.PublicIP,
|
||||
Port: fromNode.PublicPort,
|
||||
NATType: fromNode.NATType,
|
||||
HasIPv4: fromNode.HasIPv4,
|
||||
Token: auth.GenTOTP(fromNode.Token, time.Now().Unix()),
|
||||
@@ -598,6 +666,9 @@ func (s *Server) StartCleanup() {
|
||||
cfg.Mode = "mesh"
|
||||
cfg.HubNode = ""
|
||||
_ = s.sdwan.saveTenant(tid, cfg)
|
||||
if s.store != nil {
|
||||
_ = s.store.AddAuditLog("system", "0", "sdwan_update", "tenant", fmt.Sprintf("%d", tid), "hub->mesh (hub offline)", "")
|
||||
}
|
||||
s.broadcastSDWANTenant(tid, cfg)
|
||||
log.Printf("[sdwan] hub offline, auto fallback to mesh (tenant=%d)", tid)
|
||||
}
|
||||
|
||||
@@ -62,6 +62,9 @@ func (s *Server) HandleAdminCreateTenant(w http.ResponseWriter, r *http.Request)
|
||||
status = 1
|
||||
}
|
||||
_ = s.store.UpdateTenantStatus(id, status)
|
||||
if ac := GetAccessContext(r); ac != nil {
|
||||
_ = s.store.AddAuditLog(ac.Kind, fmt.Sprintf("%d", ac.UserID), "tenant_status", "tenant", fmt.Sprintf("%d", id), fmt.Sprintf("status=%d", status), r.RemoteAddr)
|
||||
}
|
||||
writeJSON(w, http.StatusOK, `{"error":0,"message":"ok"}`)
|
||||
return
|
||||
}
|
||||
@@ -165,6 +168,9 @@ func (s *Server) HandleAdminCreateAPIKey(w http.ResponseWriter, r *http.Request)
|
||||
status = 1
|
||||
}
|
||||
_ = s.store.UpdateAPIKeyStatus(keyID, status)
|
||||
if ac := GetAccessContext(r); ac != nil {
|
||||
_ = s.store.AddAuditLog(ac.Kind, fmt.Sprintf("%d", ac.UserID), "apikey_status", "apikey", fmt.Sprintf("%d", keyID), fmt.Sprintf("status=%d", status), r.RemoteAddr)
|
||||
}
|
||||
writeJSON(w, http.StatusOK, `{"error":0,"message":"ok"}`)
|
||||
return
|
||||
}
|
||||
@@ -191,6 +197,9 @@ func (s *Server) HandleAdminCreateAPIKey(w http.ResponseWriter, r *http.Request)
|
||||
}{0, "ok", key, tenantID}
|
||||
b, _ := json.Marshal(resp)
|
||||
writeJSON(w, http.StatusOK, string(b))
|
||||
if ac := GetAccessContext(r); ac != nil {
|
||||
_ = s.store.AddAuditLog(ac.Kind, fmt.Sprintf("%d", ac.UserID), "apikey_create", "tenant", fmt.Sprintf("%d", tenantID), req.Scope, r.RemoteAddr)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) HandleTenantEnroll(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
176
internal/server/user_api.go
Normal file
176
internal/server/user_api.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// Admin user management
|
||||
// GET /api/v1/admin/users?tenant=1
|
||||
// POST /api/v1/admin/users {tenant, role, email, password}
|
||||
// POST /api/v1/admin/users/{id}?status=0|1
|
||||
// POST /api/v1/admin/users/{id}/password {password}
|
||||
func IsValidGlobalUsername(v string) bool {
|
||||
if len(v) < 6 {
|
||||
return false
|
||||
}
|
||||
for _, r := range v {
|
||||
if r > unicode.MaxASCII || !unicode.IsLetter(r) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Server) HandleAdminUsers(w http.ResponseWriter, r *http.Request) {
|
||||
if s.store == nil {
|
||||
writeJSON(w, http.StatusInternalServerError, `{"error":1,"message":"store not ready"}`)
|
||||
return
|
||||
}
|
||||
// list
|
||||
if r.Method == http.MethodGet {
|
||||
tenantID := int64(0)
|
||||
_ = r.ParseForm()
|
||||
fmt.Sscanf(r.Form.Get("tenant"), "%d", &tenantID)
|
||||
if tenantID <= 0 {
|
||||
writeJSON(w, http.StatusBadRequest, `{"error":1,"message":"tenant required"}`)
|
||||
return
|
||||
}
|
||||
users, err := s.store.ListUsers(tenantID)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, `{"error":1,"message":"list users failed"}`)
|
||||
return
|
||||
}
|
||||
// strip password hash
|
||||
out := make([]map[string]any, 0, len(users))
|
||||
for _, u := range users {
|
||||
out = append(out, map[string]any{
|
||||
"id": u.ID,
|
||||
"tenant_id": u.TenantID,
|
||||
"role": u.Role,
|
||||
"email": u.Email,
|
||||
"status": u.Status,
|
||||
"created_at": u.CreatedAt,
|
||||
})
|
||||
}
|
||||
resp := struct {
|
||||
Error int `json:"error"`
|
||||
Message string `json:"message"`
|
||||
Users interface{} `json:"users"`
|
||||
}{0, "ok", out}
|
||||
b, _ := json.Marshal(resp)
|
||||
writeJSON(w, http.StatusOK, string(b))
|
||||
return
|
||||
}
|
||||
|
||||
// update status or password
|
||||
if r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/admin/users/") {
|
||||
parts := strings.Split(strings.Trim(r.URL.Path, "/"), "/")
|
||||
var id int64
|
||||
// /api/v1/admin/users/{id}/password
|
||||
if strings.HasSuffix(r.URL.Path, "/password") && len(parts) >= 5 {
|
||||
_, _ = fmt.Sscanf(parts[len(parts)-2], "%d", &id)
|
||||
} else if strings.HasSuffix(r.URL.Path, "/delete") && len(parts) >= 5 {
|
||||
_, _ = fmt.Sscanf(parts[len(parts)-2], "%d", &id)
|
||||
} else {
|
||||
_, _ = fmt.Sscanf(parts[len(parts)-1], "%d", &id)
|
||||
}
|
||||
if id <= 0 {
|
||||
writeJSON(w, http.StatusBadRequest, `{"error":1,"message":"bad request"}`)
|
||||
return
|
||||
}
|
||||
// /password
|
||||
if strings.HasSuffix(r.URL.Path, "/password") {
|
||||
var req struct {
|
||||
Password string `json:"password"`
|
||||
}
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
if req.Password == "" || len(req.Password) < 6 {
|
||||
writeJSON(w, http.StatusBadRequest, `{"error":1,"message":"password too short"}`)
|
||||
return
|
||||
}
|
||||
if err := s.store.UpdateUserPassword(id, req.Password); err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, `{"error":1,"message":"update password failed"}`)
|
||||
return
|
||||
}
|
||||
if ac := GetAccessContext(r); ac != nil {
|
||||
_ = s.store.AddAuditLog(ac.Kind, fmt.Sprintf("%d", ac.UserID), "user_password", "user", fmt.Sprintf("%d", id), "", r.RemoteAddr)
|
||||
}
|
||||
writeJSON(w, http.StatusOK, `{"error":0,"message":"ok"}`)
|
||||
return
|
||||
}
|
||||
// delete
|
||||
if strings.HasSuffix(r.URL.Path, "/delete") {
|
||||
if err := s.store.UpdateUserStatus(id, 0); err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, `{"error":1,"message":"delete failed"}`)
|
||||
return
|
||||
}
|
||||
if ac := GetAccessContext(r); ac != nil {
|
||||
_ = s.store.AddAuditLog(ac.Kind, fmt.Sprintf("%d", ac.UserID), "user_delete", "user", fmt.Sprintf("%d", id), "", r.RemoteAddr)
|
||||
}
|
||||
writeJSON(w, http.StatusOK, `{"error":0,"message":"ok"}`)
|
||||
return
|
||||
}
|
||||
// status
|
||||
st := r.URL.Query().Get("status")
|
||||
if st == "" {
|
||||
writeJSON(w, http.StatusBadRequest, `{"error":1,"message":"status required"}`)
|
||||
return
|
||||
}
|
||||
status := 0
|
||||
if st == "1" {
|
||||
status = 1
|
||||
}
|
||||
if err := s.store.UpdateUserStatus(id, status); err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, `{"error":1,"message":"update status failed"}`)
|
||||
return
|
||||
}
|
||||
if ac := GetAccessContext(r); ac != nil {
|
||||
_ = s.store.AddAuditLog(ac.Kind, fmt.Sprintf("%d", ac.UserID), "user_status", "user", fmt.Sprintf("%d", id), fmt.Sprintf("status=%d", status), r.RemoteAddr)
|
||||
}
|
||||
writeJSON(w, http.StatusOK, `{"error":0,"message":"ok"}`)
|
||||
return
|
||||
}
|
||||
|
||||
// create
|
||||
if r.Method != http.MethodPost {
|
||||
writeJSON(w, http.StatusMethodNotAllowed, `{"error":1,"message":"method not allowed"}`)
|
||||
return
|
||||
}
|
||||
var req struct {
|
||||
TenantID int64 `json:"tenant"`
|
||||
Role string `json:"role"`
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.TenantID <= 0 || req.Role == "" || req.Email == "" || req.Password == "" {
|
||||
writeJSON(w, http.StatusBadRequest, `{"error":1,"message":"bad request"}`)
|
||||
return
|
||||
}
|
||||
if len(req.Password) < 6 {
|
||||
writeJSON(w, http.StatusBadRequest, `{"error":1,"message":"password too short"}`)
|
||||
return
|
||||
}
|
||||
if !IsValidGlobalUsername(req.Email) {
|
||||
writeJSON(w, http.StatusBadRequest, `{"error":1,"message":"username must be letters only and >=6"}`)
|
||||
return
|
||||
}
|
||||
if exists, err := s.store.UserEmailExistsGlobal(req.Email); err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, `{"error":1,"message":"check user failed"}`)
|
||||
return
|
||||
} else if exists {
|
||||
writeJSON(w, http.StatusBadRequest, `{"error":1,"message":"username exists"}`)
|
||||
return
|
||||
}
|
||||
if _, err := s.store.CreateUser(req.TenantID, req.Role, req.Email, req.Password, 1); err != nil {
|
||||
writeJSON(w, http.StatusInternalServerError, `{"error":1,"message":"create user failed"}`)
|
||||
return
|
||||
}
|
||||
if ac := GetAccessContext(r); ac != nil {
|
||||
_ = s.store.AddAuditLog(ac.Kind, fmt.Sprintf("%d", ac.UserID), "user_create", "tenant", fmt.Sprintf("%d", req.TenantID), req.Email, r.RemoteAddr)
|
||||
}
|
||||
writeJSON(w, http.StatusOK, `{"error":0,"message":"ok"}`)
|
||||
}
|
||||
@@ -27,6 +27,18 @@ type Tenant struct {
|
||||
CreatedAt int64
|
||||
}
|
||||
|
||||
type AuditLog struct {
|
||||
ID int64 `json:"id"`
|
||||
ActorType string `json:"actor_type"`
|
||||
ActorID string `json:"actor_id"`
|
||||
Action string `json:"action"`
|
||||
TargetType string `json:"target_type"`
|
||||
TargetID string `json:"target_id"`
|
||||
Detail string `json:"detail"`
|
||||
IP string `json:"ip"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID int64
|
||||
TenantID int64
|
||||
@@ -102,6 +114,12 @@ func Open(dbPath string) (*Store, error) {
|
||||
if err := s.ensureSubnetPool(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.ensureSettings(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.backfillNodeIdentity(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
@@ -182,6 +200,11 @@ func (s *Store) migrate() error {
|
||||
ip TEXT,
|
||||
created_at INTEGER NOT NULL
|
||||
);`,
|
||||
`CREATE TABLE IF NOT EXISTS system_settings (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT,
|
||||
updated_at INTEGER NOT NULL
|
||||
);`,
|
||||
`CREATE TABLE IF NOT EXISTS subnet_pool (
|
||||
subnet TEXT PRIMARY KEY,
|
||||
status INTEGER NOT NULL DEFAULT 0,
|
||||
@@ -282,11 +305,11 @@ func (s *Store) CreateTenantWithUsers(name, adminPassword, operatorPassword stri
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
admin, err := s.CreateUser(ten.ID, "admin", "admin@local", adminPassword, 1)
|
||||
admin, err := s.CreateUser(ten.ID, "admin", "admin@"+name, adminPassword, 1)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
op, err := s.CreateUser(ten.ID, "operator", "operator@local", operatorPassword, 1)
|
||||
op, err := s.CreateUser(ten.ID, "operator", "operator@"+name, operatorPassword, 1)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@@ -494,6 +517,88 @@ func (s *Store) IncEnrollAttempt(code string) {
|
||||
_, _ = s.DB.Exec(`UPDATE enroll_tokens SET attempts=attempts+1 WHERE token_hash=?`, h)
|
||||
}
|
||||
|
||||
func (s *Store) ensureSettings() error {
|
||||
defaults := map[string]string{
|
||||
"advanced_impersonate": "0",
|
||||
"advanced_force_network": "0",
|
||||
"advanced_cross_tenant": "0",
|
||||
}
|
||||
now := time.Now().Unix()
|
||||
for k, v := range defaults {
|
||||
_, _ = s.DB.Exec(`INSERT OR IGNORE INTO system_settings(key,value,updated_at) VALUES(?,?,?)`, k, v, now)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) GetSetting(key string) (string, bool, error) {
|
||||
row := s.DB.QueryRow(`SELECT value FROM system_settings WHERE key=?`, key)
|
||||
var v string
|
||||
if err := row.Scan(&v); err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
return v, true, nil
|
||||
}
|
||||
|
||||
func (s *Store) ListSettings() (map[string]string, error) {
|
||||
rows, err := s.DB.Query(`SELECT key,value FROM system_settings`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
out := map[string]string{}
|
||||
for rows.Next() {
|
||||
var k, v string
|
||||
if err := rows.Scan(&k, &v); err == nil {
|
||||
out[k] = v
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *Store) SetSetting(key, value string) error {
|
||||
now := time.Now().Unix()
|
||||
_, err := s.DB.Exec(`INSERT INTO system_settings(key,value,updated_at) VALUES(?,?,?) ON CONFLICT(key) DO UPDATE SET value=excluded.value, updated_at=excluded.updated_at`, key, value, now)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) AddAuditLog(actorType, actorID, action, targetType, targetID, detail, ip string) error {
|
||||
now := time.Now().Unix()
|
||||
_, err := s.DB.Exec(`INSERT INTO audit_logs(actor_type,actor_id,action,target_type,target_id,detail,ip,created_at) VALUES(?,?,?,?,?,?,?,?)`, actorType, actorID, action, targetType, targetID, detail, ip, now)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) ListAuditLogs(tenantID int64, limit, offset int) ([]AuditLog, error) {
|
||||
q := `SELECT id,actor_type,actor_id,action,target_type,target_id,detail,ip,created_at FROM audit_logs`
|
||||
args := []any{}
|
||||
if tenantID > 0 {
|
||||
// limit to logs related to this tenant
|
||||
q += ` WHERE (target_type='tenant' AND target_id=?)`
|
||||
args = append(args, fmt.Sprintf("%d", tenantID))
|
||||
}
|
||||
q += ` ORDER BY id DESC`
|
||||
if limit > 0 {
|
||||
q += ` LIMIT ?`
|
||||
args = append(args, limit)
|
||||
}
|
||||
if offset > 0 {
|
||||
q += ` OFFSET ?`
|
||||
args = append(args, offset)
|
||||
}
|
||||
rows, err := s.DB.Query(q, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
out := []AuditLog{}
|
||||
for rows.Next() {
|
||||
var a AuditLog
|
||||
if err := rows.Scan(&a.ID, &a.ActorType, &a.ActorID, &a.Action, &a.TargetType, &a.TargetID, &a.Detail, &a.IP, &a.CreatedAt); err == nil {
|
||||
out = append(out, a)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// ListTenants returns all tenants (admin)
|
||||
func (s *Store) ListTenants() ([]Tenant, error) {
|
||||
rows, err := s.DB.Query(`SELECT id,name,status,subnet,created_at FROM tenants ORDER BY id DESC`)
|
||||
@@ -662,6 +767,15 @@ func (s *Store) UserEmailExists(tenantID int64, email string) (bool, error) {
|
||||
return c > 0, nil
|
||||
}
|
||||
|
||||
func (s *Store) UserEmailExistsGlobal(email string) (bool, error) {
|
||||
row := s.DB.QueryRow(`SELECT COUNT(1) FROM users WHERE email=?`, email)
|
||||
var c int
|
||||
if err := row.Scan(&c); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return c > 0, nil
|
||||
}
|
||||
|
||||
func (s *Store) VerifyUserPassword(tenantID int64, email, password string) (*User, error) {
|
||||
u, err := s.GetUserByEmail(tenantID, email)
|
||||
if err != nil {
|
||||
@@ -682,6 +796,21 @@ func (s *Store) VerifyUserPassword(tenantID int64, email, password string) (*Use
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (s *Store) VerifyUserPasswordGlobal(email, password string) (*User, error) {
|
||||
row := s.DB.QueryRow(`SELECT id,tenant_id,role,email,password_hash,status,created_at FROM users WHERE email=? ORDER BY id LIMIT 1`, email)
|
||||
var u User
|
||||
if err := row.Scan(&u.ID, &u.TenantID, &u.Role, &u.Email, &u.PasswordHash, &u.Status, &u.CreatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if u.PasswordHash == "" {
|
||||
return nil, errors.New("password not set")
|
||||
}
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(u.PasswordHash), []byte(password)); err != nil {
|
||||
return nil, errors.New("invalid password")
|
||||
}
|
||||
return &u, nil
|
||||
}
|
||||
|
||||
func (s *Store) CreateSessionToken(userID, tenantID int64, role string, ttl time.Duration) (string, int64, error) {
|
||||
tok := randToken()
|
||||
h := hashTokenString(tok)
|
||||
@@ -742,6 +871,11 @@ func (s *Store) UpdateUserPassword(id int64, password string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Store) UpdateUserEmail(id int64, email string) error {
|
||||
_, err := s.DB.Exec(`UPDATE users SET email=? WHERE id=?`, email, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func hashTokenBytes(b []byte) string {
|
||||
h := sha256.Sum256(b)
|
||||
return hex.EncodeToString(h[:])
|
||||
@@ -767,5 +901,40 @@ func randUUID() string {
|
||||
)
|
||||
}
|
||||
|
||||
func (s *Store) backfillNodeIdentity() error {
|
||||
rows, err := s.DB.Query(`SELECT id,tenant_id,node_uuid,virtual_ip FROM nodes ORDER BY id`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
type rowNode struct {
|
||||
id int64
|
||||
tenantID int64
|
||||
uuid string
|
||||
vip string
|
||||
}
|
||||
var list []rowNode
|
||||
for rows.Next() {
|
||||
var r rowNode
|
||||
if err := rows.Scan(&r.id, &r.tenantID, &r.uuid, &r.vip); err == nil {
|
||||
list = append(list, r)
|
||||
}
|
||||
}
|
||||
for _, n := range list {
|
||||
if n.uuid == "" {
|
||||
if _, err := s.DB.Exec(`UPDATE nodes SET node_uuid=? WHERE id=?`, randUUID(), n.id); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(n.vip) == "" {
|
||||
vip, err := s.AllocateNodeIP(n.tenantID)
|
||||
if err == nil && vip != "" {
|
||||
_, _ = s.DB.Exec(`UPDATE nodes SET virtual_ip=? WHERE id=?`, vip, n.id)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// helper to avoid unused import (net)
|
||||
var _ = net.IPv4len
|
||||
|
||||
Reference in New Issue
Block a user