344 lines
9.3 KiB
Go
344 lines
9.3 KiB
Go
package store
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"time"
|
|
|
|
_ "modernc.org/sqlite"
|
|
)
|
|
|
|
type Store struct {
|
|
DB *sql.DB
|
|
}
|
|
|
|
type Tenant struct {
|
|
ID int64
|
|
Name string
|
|
Status int
|
|
Subnet string
|
|
}
|
|
|
|
type APIKey struct {
|
|
ID int64
|
|
TenantID int64
|
|
Hash string
|
|
Scope string
|
|
Expires *time.Time
|
|
Status int
|
|
}
|
|
|
|
type NodeCredential struct {
|
|
NodeID int64
|
|
NodeName string
|
|
Secret string
|
|
VirtualIP string
|
|
TenantID int64
|
|
}
|
|
|
|
type EnrollToken struct {
|
|
ID int64
|
|
TenantID int64
|
|
Hash string
|
|
ExpiresAt int64
|
|
UsedAt *int64
|
|
MaxAttempt int
|
|
Attempts int
|
|
Status int
|
|
}
|
|
|
|
func Open(dbPath string) (*Store, error) {
|
|
db, err := sql.Open("sqlite", dbPath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if _, err := db.Exec(`PRAGMA journal_mode=WAL;`); err != nil {
|
|
return nil, err
|
|
}
|
|
if _, err := db.Exec(`PRAGMA foreign_keys=ON;`); err != nil {
|
|
return nil, err
|
|
}
|
|
s := &Store{DB: db}
|
|
if err := s.migrate(); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := s.ensureSubnetPool(); err != nil {
|
|
return nil, err
|
|
}
|
|
return s, nil
|
|
}
|
|
|
|
func (s *Store) migrate() error {
|
|
stmts := []string{
|
|
`CREATE TABLE IF NOT EXISTS tenants (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
name TEXT NOT NULL UNIQUE,
|
|
status INTEGER NOT NULL DEFAULT 1,
|
|
subnet TEXT NOT NULL UNIQUE,
|
|
created_at INTEGER NOT NULL
|
|
);`,
|
|
`CREATE TABLE IF NOT EXISTS users (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
tenant_id INTEGER NOT NULL,
|
|
role TEXT NOT NULL,
|
|
email TEXT,
|
|
password_hash TEXT,
|
|
status INTEGER NOT NULL DEFAULT 1,
|
|
created_at INTEGER NOT NULL,
|
|
FOREIGN KEY(tenant_id) REFERENCES tenants(id)
|
|
);`,
|
|
`CREATE TABLE IF NOT EXISTS api_keys (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
tenant_id INTEGER NOT NULL,
|
|
key_hash TEXT NOT NULL UNIQUE,
|
|
scope TEXT,
|
|
expires_at INTEGER,
|
|
status INTEGER NOT NULL DEFAULT 1,
|
|
created_at INTEGER NOT NULL,
|
|
FOREIGN KEY(tenant_id) REFERENCES tenants(id)
|
|
);`,
|
|
`CREATE TABLE IF NOT EXISTS nodes (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
tenant_id INTEGER NOT NULL,
|
|
node_name TEXT NOT NULL,
|
|
node_pubkey TEXT,
|
|
node_secret_hash TEXT,
|
|
virtual_ip TEXT,
|
|
status INTEGER NOT NULL DEFAULT 1,
|
|
last_seen INTEGER,
|
|
FOREIGN KEY(tenant_id) REFERENCES tenants(id)
|
|
);`,
|
|
`CREATE TABLE IF NOT EXISTS enroll_tokens (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
tenant_id INTEGER NOT NULL,
|
|
token_hash TEXT NOT NULL UNIQUE,
|
|
expires_at INTEGER NOT NULL,
|
|
used_at INTEGER,
|
|
max_attempt INTEGER NOT NULL DEFAULT 5,
|
|
attempts INTEGER NOT NULL DEFAULT 0,
|
|
status INTEGER NOT NULL DEFAULT 1,
|
|
created_at INTEGER NOT NULL,
|
|
FOREIGN KEY(tenant_id) REFERENCES tenants(id)
|
|
);`,
|
|
`CREATE TABLE IF NOT EXISTS peering_policies (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
src_tenant_id INTEGER NOT NULL,
|
|
dst_tenant_id INTEGER NOT NULL,
|
|
rules TEXT,
|
|
expires_at INTEGER,
|
|
status INTEGER NOT NULL DEFAULT 1,
|
|
created_at INTEGER NOT NULL,
|
|
FOREIGN KEY(src_tenant_id) REFERENCES tenants(id),
|
|
FOREIGN KEY(dst_tenant_id) REFERENCES tenants(id)
|
|
);`,
|
|
`CREATE TABLE IF NOT EXISTS audit_logs (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
actor_type TEXT,
|
|
actor_id TEXT,
|
|
action TEXT,
|
|
target_type TEXT,
|
|
target_id TEXT,
|
|
detail TEXT,
|
|
ip TEXT,
|
|
created_at INTEGER NOT NULL
|
|
);`,
|
|
`CREATE TABLE IF NOT EXISTS subnet_pool (
|
|
subnet TEXT PRIMARY KEY,
|
|
status INTEGER NOT NULL DEFAULT 0,
|
|
reserved INTEGER NOT NULL DEFAULT 0,
|
|
tenant_id INTEGER,
|
|
updated_at INTEGER NOT NULL
|
|
);`,
|
|
}
|
|
for _, stmt := range stmts {
|
|
if _, err := s.DB.Exec(stmt); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Store) ensureSubnetPool() error {
|
|
// pool: 10.10.1.0/24 .. 10.10.254.0/24
|
|
// reserve: 10.10.0.0/24 and 10.10.255.0/24
|
|
rows, err := s.DB.Query(`SELECT COUNT(1) FROM subnet_pool;`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
var count int
|
|
if rows.Next() {
|
|
_ = rows.Scan(&count)
|
|
}
|
|
if count > 0 {
|
|
return nil
|
|
}
|
|
now := time.Now().Unix()
|
|
insert := `INSERT INTO subnet_pool(subnet,status,reserved,tenant_id,updated_at) VALUES(?,?,?,?,?)`
|
|
// reserved
|
|
_, _ = s.DB.Exec(insert, "10.10.0.0/24", 0, 1, nil, now)
|
|
_, _ = s.DB.Exec(insert, "10.10.255.0/24", 0, 1, nil, now)
|
|
for i := 1; i <= 254; i++ {
|
|
sn := fmt.Sprintf("10.10.%d.0/24", i)
|
|
_, _ = s.DB.Exec(insert, sn, 0, 0, nil, now)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Store) AllocateSubnet() (string, error) {
|
|
// find first available subnet
|
|
row := s.DB.QueryRow(`SELECT subnet FROM subnet_pool WHERE status=0 AND reserved=0 ORDER BY subnet LIMIT 1`)
|
|
var subnet string
|
|
if err := row.Scan(&subnet); err != nil {
|
|
return "", err
|
|
}
|
|
if subnet == "" {
|
|
return "", errors.New("no subnet available")
|
|
}
|
|
return subnet, nil
|
|
}
|
|
|
|
func (s *Store) CreateTenant(name string) (*Tenant, error) {
|
|
sn, err := s.AllocateSubnet()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
now := time.Now().Unix()
|
|
res, err := s.DB.Exec(`INSERT INTO tenants(name,status,subnet,created_at) VALUES(?,?,?,?)`, name, 1, sn, now)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
id, _ := res.LastInsertId()
|
|
_, _ = s.DB.Exec(`UPDATE subnet_pool SET status=1, tenant_id=?, updated_at=? WHERE subnet=?`, id, now, sn)
|
|
return &Tenant{ID: id, Name: name, Status: 1, Subnet: sn}, nil
|
|
}
|
|
|
|
func (s *Store) CreateNodeCredential(tenantID int64, nodeName string) (*NodeCredential, error) {
|
|
secret := randToken()
|
|
h := hashTokenString(secret)
|
|
res, err := s.DB.Exec(`INSERT INTO nodes(tenant_id,node_name,node_secret_hash,status) VALUES(?,?,?,1)`, tenantID, nodeName, h)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
id, _ := res.LastInsertId()
|
|
return &NodeCredential{NodeID: id, NodeName: nodeName, Secret: secret, TenantID: tenantID}, nil
|
|
}
|
|
|
|
func (s *Store) VerifyNodeSecret(nodeName, secret string) (*Tenant, error) {
|
|
h := hashTokenString(secret)
|
|
row := s.DB.QueryRow(`SELECT t.id,t.name,t.status,t.subnet FROM nodes n JOIN tenants t ON n.tenant_id=t.id WHERE n.node_name=? AND n.node_secret_hash=? AND n.status=1`, nodeName, h)
|
|
var t Tenant
|
|
if err := row.Scan(&t.ID, &t.Name, &t.Status, &t.Subnet); err != nil {
|
|
return nil, err
|
|
}
|
|
return &t, nil
|
|
}
|
|
|
|
func (s *Store) GetTenantByToken(token uint64) (*Tenant, error) {
|
|
h := hashToken(token)
|
|
row := s.DB.QueryRow(`SELECT t.id,t.name,t.status,t.subnet FROM api_keys k JOIN tenants t ON k.tenant_id=t.id WHERE k.key_hash=? AND k.status=1`, h)
|
|
var t Tenant
|
|
if err := row.Scan(&t.ID, &t.Name, &t.Status, &t.Subnet); err != nil {
|
|
return nil, err
|
|
}
|
|
return &t, nil
|
|
}
|
|
|
|
func (s *Store) CreateAPIKey(tenantID int64, scope string, ttl time.Duration) (string, error) {
|
|
token := randToken()
|
|
h := hashTokenString(token)
|
|
now := time.Now().Unix()
|
|
if ttl > 0 {
|
|
e := time.Now().Add(ttl)
|
|
_, err := s.DB.Exec(`INSERT INTO api_keys(tenant_id,key_hash,scope,expires_at,status,created_at) VALUES(?,?,?,?,1,?)`, tenantID, h, scope, e.Unix(), now)
|
|
return token, err
|
|
}
|
|
_, err := s.DB.Exec(`INSERT INTO api_keys(tenant_id,key_hash,scope,expires_at,status,created_at) VALUES(?,?,?,?,1,?)`, tenantID, h, scope, nil, now)
|
|
return token, err
|
|
}
|
|
|
|
func (s *Store) CreateEnrollToken(tenantID int64, ttl time.Duration, maxAttempt int) (string, error) {
|
|
code := randToken()
|
|
h := hashTokenString(code)
|
|
exp := time.Now().Add(ttl).Unix()
|
|
now := time.Now().Unix()
|
|
_, err := s.DB.Exec(`INSERT INTO enroll_tokens(tenant_id,token_hash,expires_at,max_attempt,attempts,status,created_at) VALUES(?,?,?,?,0,1,?)`, tenantID, h, exp, maxAttempt, now)
|
|
return code, err
|
|
}
|
|
|
|
func (s *Store) ConsumeEnrollToken(code string) (*EnrollToken, error) {
|
|
h := hashTokenString(code)
|
|
now := time.Now().Unix()
|
|
row := s.DB.QueryRow(`SELECT id,tenant_id,expires_at,used_at,max_attempt,attempts,status FROM enroll_tokens WHERE token_hash=?`, h)
|
|
var et EnrollToken
|
|
var used sql.NullInt64
|
|
if err := row.Scan(&et.ID, &et.TenantID, &et.ExpiresAt, &used, &et.MaxAttempt, &et.Attempts, &et.Status); err != nil {
|
|
return nil, err
|
|
}
|
|
if used.Valid {
|
|
return nil, errors.New("token already used")
|
|
}
|
|
if et.Status != 1 {
|
|
return nil, errors.New("token disabled")
|
|
}
|
|
if et.Attempts >= et.MaxAttempt {
|
|
return nil, errors.New("token attempts exceeded")
|
|
}
|
|
if now > et.ExpiresAt {
|
|
return nil, errors.New("token expired")
|
|
}
|
|
// mark used
|
|
_, err := s.DB.Exec(`UPDATE enroll_tokens SET used_at=?, attempts=attempts+1 WHERE id=?`, now, et.ID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
et.UsedAt = &now
|
|
return &et, nil
|
|
}
|
|
|
|
func (s *Store) IncEnrollAttempt(code string) {
|
|
h := hashTokenString(code)
|
|
_, _ = s.DB.Exec(`UPDATE enroll_tokens SET attempts=attempts+1 WHERE token_hash=?`, h)
|
|
}
|
|
|
|
func hashToken(token uint64) string {
|
|
b := make([]byte, 8)
|
|
for i := uint(0); i < 8; i++ {
|
|
b[7-i] = byte(token >> (i * 8))
|
|
}
|
|
return hashTokenBytes(b)
|
|
}
|
|
|
|
func hashTokenString(token string) string {
|
|
return hashTokenBytes([]byte(token))
|
|
}
|
|
|
|
func (s *Store) VerifyAPIKey(token string) (*Tenant, error) {
|
|
h := hashTokenString(token)
|
|
row := s.DB.QueryRow(`SELECT t.id,t.name,t.status,t.subnet FROM api_keys k JOIN tenants t ON k.tenant_id=t.id WHERE k.key_hash=? AND k.status=1`, h)
|
|
var t Tenant
|
|
if err := row.Scan(&t.ID, &t.Name, &t.Status, &t.Subnet); err != nil {
|
|
return nil, err
|
|
}
|
|
return &t, nil
|
|
}
|
|
|
|
func hashTokenBytes(b []byte) string {
|
|
h := sha256.Sum256(b)
|
|
return hex.EncodeToString(h[:])
|
|
}
|
|
|
|
func randToken() string {
|
|
b := make([]byte, 24)
|
|
_, _ = rand.Read(b)
|
|
return hex.EncodeToString(b)
|
|
}
|
|
|
|
// helper to avoid unused import (net)
|
|
var _ = net.IPv4len
|