feat: INP2P v0.1.0 — complete P2P tunneling system

Core modules (M1-M6):
- pkg/protocol: message format, encoding, NAT type enums
- pkg/config: server/client config structs, env vars, validation
- pkg/auth: CRC64 token, TOTP gen/verify, one-time relay tokens
- pkg/nat: UDP/TCP STUN client and server
- pkg/signal: WSS message dispatch, sync request/response
- pkg/punch: UDP/TCP hole punching + priority chain
- pkg/mux: stream multiplexer (7B frame: StreamID+Flags+Len)
- pkg/tunnel: mux-based port forwarding with stats
- pkg/relay: relay manager with TOTP auth + session bridging
- internal/server: signaling server (login/heartbeat/report/coordinator)
- internal/client: client (NAT detect/login/punch/relay/reconnect)
- cmd/inp2ps + cmd/inp2pc: main entrypoints with graceful shutdown

All tests pass: 16 tests across 5 packages
Code: 3559 lines core + 861 lines tests = 19 source files
This commit is contained in:
2026-03-02 15:13:22 +08:00
commit 91e3d4da2a
23 changed files with 4681 additions and 0 deletions

92
pkg/auth/auth.go Normal file
View File

@@ -0,0 +1,92 @@
// Package auth provides TOTP and token authentication for INP2P.
package auth
import (
"crypto/hmac"
"crypto/sha256"
"encoding/binary"
"fmt"
"hash/crc64"
"time"
)
const (
// TOTPStep is the time window in seconds for TOTP validity.
// A code is valid for ±1 step to allow for clock drift.
TOTPStep int64 = 60
)
var crcTable = crc64.MakeTable(crc64.ECMA)
// MakeToken generates a token from user+password using CRC64.
func MakeToken(user, password string) uint64 {
return crc64.Checksum([]byte(user+password), crcTable)
}
// GenTOTP generates a TOTP code for relay authentication.
func GenTOTP(token uint64, ts int64) uint64 {
step := ts / TOTPStep
buf := make([]byte, 16)
binary.BigEndian.PutUint64(buf[:8], token)
binary.BigEndian.PutUint64(buf[8:], uint64(step))
mac := hmac.New(sha256.New, buf[:8])
mac.Write(buf[8:])
sum := mac.Sum(nil)
return binary.BigEndian.Uint64(sum[:8])
}
// VerifyTOTP verifies a TOTP code with ±1 step tolerance.
func VerifyTOTP(code uint64, token uint64, ts int64) bool {
for delta := int64(-1); delta <= 1; delta++ {
expected := GenTOTP(token, ts+delta*TOTPStep)
if code == expected {
return true
}
}
return false
}
// RelayToken generates a one-time relay token signed by the server.
// Used for cross-user super relay authentication.
type RelayToken struct {
SessionID string `json:"sessionID"`
From string `json:"from"`
To string `json:"to"`
Relay string `json:"relay"`
Expires int64 `json:"expires"`
Signature []byte `json:"signature"`
}
// SignRelayToken creates a signed one-time relay token.
func SignRelayToken(secret []byte, sessionID, from, to, relay string, ttl time.Duration) RelayToken {
rt := RelayToken{
SessionID: sessionID,
From: from,
To: to,
Relay: relay,
Expires: time.Now().Add(ttl).Unix(),
}
msg := fmt.Sprintf("%s:%s:%s:%s:%d", rt.SessionID, rt.From, rt.To, rt.Relay, rt.Expires)
mac := hmac.New(sha256.New, secret)
mac.Write([]byte(msg))
rt.Signature = mac.Sum(nil)
return rt
}
// VerifyRelayToken validates a signed relay token.
func VerifyRelayToken(secret []byte, rt RelayToken) bool {
if time.Now().Unix() > rt.Expires {
return false
}
msg := fmt.Sprintf("%s:%s:%s:%s:%d", rt.SessionID, rt.From, rt.To, rt.Relay, rt.Expires)
mac := hmac.New(sha256.New, secret)
mac.Write([]byte(msg))
expected := mac.Sum(nil)
return hmac.Equal(rt.Signature, expected)
}

161
pkg/config/config.go Normal file
View File

@@ -0,0 +1,161 @@
// Package config provides shared configuration types.
package config
import (
"crypto/rand"
"encoding/hex"
"fmt"
"os"
"strconv"
)
const (
Version = "0.1.0"
DefaultWSPort = 27183 // WSS signaling
DefaultSTUNUDP1 = 27182 // UDP STUN port 1
DefaultSTUNUDP2 = 27183 // UDP STUN port 2
DefaultSTUNTCP1 = 27180 // TCP STUN port 1
DefaultSTUNTCP2 = 27181 // TCP STUN port 2
DefaultWebPort = 10088 // Web console
DefaultAPIPort = 10008 // REST API
DefaultMaxRelayLoad = 20
DefaultRelayPort = 27185
HeartbeatInterval = 30 // seconds
HeartbeatTimeout = 90 // seconds — 3x missed heartbeats → offline
)
// ServerConfig holds inp2ps configuration.
type ServerConfig struct {
WSPort int `json:"wsPort"`
STUNUDP1 int `json:"stunUDP1"`
STUNUDP2 int `json:"stunUDP2"`
STUNTCP1 int `json:"stunTCP1"`
STUNTCP2 int `json:"stunTCP2"`
WebPort int `json:"webPort"`
APIPort int `json:"apiPort"`
DBPath string `json:"dbPath"`
CertFile string `json:"certFile"`
KeyFile string `json:"keyFile"`
LogLevel int `json:"logLevel"` // 0=debug, 1=info, 2=warn, 3=error
Token uint64 `json:"token"` // master token for auth
JWTKey string `json:"jwtKey"` // auto-generated if empty
AdminUser string `json:"adminUser"`
AdminPass string `json:"adminPass"`
}
func DefaultServerConfig() ServerConfig {
return ServerConfig{
WSPort: DefaultWSPort,
STUNUDP1: DefaultSTUNUDP1,
STUNUDP2: DefaultSTUNUDP2,
STUNTCP1: DefaultSTUNTCP1,
STUNTCP2: DefaultSTUNTCP2,
WebPort: DefaultWebPort,
APIPort: DefaultAPIPort,
DBPath: "inp2ps.db",
LogLevel: 1,
AdminUser: "admin",
AdminPass: "admin123",
}
}
func (c *ServerConfig) FillFromEnv() {
if v := os.Getenv("INP2PS_WS_PORT"); v != "" {
c.WSPort, _ = strconv.Atoi(v)
}
if v := os.Getenv("INP2PS_WEB_PORT"); v != "" {
c.WebPort, _ = strconv.Atoi(v)
}
if v := os.Getenv("INP2PS_DB_PATH"); v != "" {
c.DBPath = v
}
if v := os.Getenv("INP2PS_TOKEN"); v != "" {
c.Token, _ = strconv.ParseUint(v, 10, 64)
}
if v := os.Getenv("INP2PS_CERT"); v != "" {
c.CertFile = v
}
if v := os.Getenv("INP2PS_KEY"); v != "" {
c.KeyFile = v
}
if c.JWTKey == "" {
b := make([]byte, 32)
rand.Read(b)
c.JWTKey = hex.EncodeToString(b)
}
}
func (c *ServerConfig) Validate() error {
if c.Token == 0 {
return fmt.Errorf("token is required (INP2PS_TOKEN or -token)")
}
return nil
}
// ClientConfig holds inp2pc configuration.
type ClientConfig struct {
ServerHost string `json:"serverHost"`
ServerPort int `json:"serverPort"`
Node string `json:"node"`
Token uint64 `json:"token"`
User string `json:"user,omitempty"`
Insecure bool `json:"insecure"` // skip TLS verify
// STUN ports (defaults match server defaults)
STUNUDP1 int `json:"stunUDP1,omitempty"`
STUNUDP2 int `json:"stunUDP2,omitempty"`
STUNTCP1 int `json:"stunTCP1,omitempty"`
STUNTCP2 int `json:"stunTCP2,omitempty"`
RelayEnabled bool `json:"relayEnabled"` // --relay
SuperRelay bool `json:"superRelay"` // --super
RelayPort int `json:"relayPort"`
MaxRelayLoad int `json:"maxRelayLoad"`
ShareBandwidth int `json:"shareBandwidth"` // Mbps
LogLevel int `json:"logLevel"`
Apps []AppConfig `json:"apps"`
}
type AppConfig struct {
AppName string `json:"appName"`
Protocol string `json:"protocol"` // tcp, udp
SrcPort int `json:"srcPort"`
PeerNode string `json:"peerNode"`
DstHost string `json:"dstHost"`
DstPort int `json:"dstPort"`
Enabled bool `json:"enabled"`
}
func DefaultClientConfig() ClientConfig {
return ClientConfig{
ServerPort: DefaultWSPort,
STUNUDP1: DefaultSTUNUDP1,
STUNUDP2: DefaultSTUNUDP2,
STUNTCP1: DefaultSTUNTCP1,
STUNTCP2: DefaultSTUNTCP2,
ShareBandwidth: 10,
RelayPort: DefaultRelayPort,
MaxRelayLoad: DefaultMaxRelayLoad,
LogLevel: 1,
}
}
func (c *ClientConfig) Validate() error {
if c.ServerHost == "" {
return fmt.Errorf("serverHost is required")
}
if c.Token == 0 {
return fmt.Errorf("token is required")
}
if c.Node == "" {
hostname, _ := os.Hostname()
c.Node = hostname
}
return nil
}

487
pkg/mux/mux.go Normal file
View File

@@ -0,0 +1,487 @@
// Package mux provides stream multiplexing over a single net.Conn.
//
// Wire format per frame:
//
// StreamID (4B, big-endian)
// Flags (1B)
// Length (2B, big-endian, max 65535)
// Data (Length bytes)
//
// Total header = 7 bytes.
//
// Flags:
//
// 0x01 SYN — open a new stream
// 0x02 FIN — close a stream
// 0x04 DATA — payload data
// 0x08 PING — keepalive (StreamID=0)
// 0x10 PONG — keepalive response (StreamID=0)
// 0x20 RST — reset/abort a stream
package mux
import (
"encoding/binary"
"errors"
"fmt"
"io"
"log"
"net"
"sync"
"sync/atomic"
"time"
)
const (
headerSize = 7
maxPayload = 65535
FlagSYN byte = 0x01
FlagFIN byte = 0x02
FlagDATA byte = 0x04
FlagPING byte = 0x08
FlagPONG byte = 0x10
FlagRST byte = 0x20
defaultWindowSize = 256 * 1024 // 256KB per stream receive buffer
pingInterval = 15 * time.Second
pingTimeout = 10 * time.Second
acceptBacklog = 64
)
var (
ErrSessionClosed = errors.New("mux: session closed")
ErrStreamClosed = errors.New("mux: stream closed")
ErrStreamReset = errors.New("mux: stream reset by peer")
ErrTimeout = errors.New("mux: timeout")
ErrAcceptBacklog = errors.New("mux: accept backlog full")
)
// ─── Session ───
// A Session multiplexes many Streams over a single underlying net.Conn.
type Session struct {
conn net.Conn
streams map[uint32]*Stream
mu sync.RWMutex
nextID uint32 // client uses odd, server uses even
isServer bool
acceptCh chan *Stream
writeMu sync.Mutex // serialize frame writes
closed int32
quit chan struct{}
once sync.Once
// stats
BytesSent int64
BytesReceived int64
}
// NewSession wraps a net.Conn as a mux session.
// isServer determines stream ID allocation: server=even, client=odd.
func NewSession(conn net.Conn, isServer bool) *Session {
s := &Session{
conn: conn,
streams: make(map[uint32]*Stream),
acceptCh: make(chan *Stream, acceptBacklog),
quit: make(chan struct{}),
isServer: isServer,
}
if isServer {
s.nextID = 2
} else {
s.nextID = 1
}
go s.readLoop()
go s.pingLoop()
return s
}
// Open creates a new outbound stream.
func (s *Session) Open() (*Stream, error) {
if s.IsClosed() {
return nil, ErrSessionClosed
}
id := atomic.AddUint32(&s.nextID, 2) - 2 // increment by 2 to keep odd/even
st := newStream(id, s)
s.mu.Lock()
s.streams[id] = st
s.mu.Unlock()
// Send SYN
if err := s.writeFrame(id, FlagSYN, nil); err != nil {
s.mu.Lock()
delete(s.streams, id)
s.mu.Unlock()
return nil, err
}
return st, nil
}
// Accept waits for an inbound stream opened by the remote side.
func (s *Session) Accept() (*Stream, error) {
select {
case st := <-s.acceptCh:
return st, nil
case <-s.quit:
return nil, ErrSessionClosed
}
}
// Close shuts down the session and all streams.
func (s *Session) Close() error {
s.once.Do(func() {
atomic.StoreInt32(&s.closed, 1)
close(s.quit)
s.mu.Lock()
for _, st := range s.streams {
st.closeLocal()
}
s.streams = make(map[uint32]*Stream)
s.mu.Unlock()
s.conn.Close()
})
return nil
}
// IsClosed reports if the session is closed.
func (s *Session) IsClosed() bool {
return atomic.LoadInt32(&s.closed) == 1
}
// NumStreams returns active stream count.
func (s *Session) NumStreams() int {
s.mu.RLock()
defer s.mu.RUnlock()
return len(s.streams)
}
// ─── Frame I/O ───
func (s *Session) writeFrame(streamID uint32, flags byte, data []byte) error {
if len(data) > maxPayload {
return fmt.Errorf("mux: payload too large: %d > %d", len(data), maxPayload)
}
hdr := make([]byte, headerSize)
binary.BigEndian.PutUint32(hdr[0:4], streamID)
hdr[4] = flags
binary.BigEndian.PutUint16(hdr[5:7], uint16(len(data)))
s.writeMu.Lock()
defer s.writeMu.Unlock()
s.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
if _, err := s.conn.Write(hdr); err != nil {
return err
}
if len(data) > 0 {
if _, err := s.conn.Write(data); err != nil {
return err
}
}
atomic.AddInt64(&s.BytesSent, int64(headerSize+len(data)))
return nil
}
func (s *Session) readLoop() {
hdr := make([]byte, headerSize)
for {
if _, err := io.ReadFull(s.conn, hdr); err != nil {
if !s.IsClosed() {
log.Printf("[mux] read header error: %v", err)
}
s.Close()
return
}
streamID := binary.BigEndian.Uint32(hdr[0:4])
flags := hdr[4]
length := binary.BigEndian.Uint16(hdr[5:7])
var data []byte
if length > 0 {
data = make([]byte, length)
if _, err := io.ReadFull(s.conn, data); err != nil {
if !s.IsClosed() {
log.Printf("[mux] read data error: %v", err)
}
s.Close()
return
}
}
atomic.AddInt64(&s.BytesReceived, int64(headerSize+int(length)))
s.handleFrame(streamID, flags, data)
}
}
func (s *Session) handleFrame(streamID uint32, flags byte, data []byte) {
// Ping/Pong on StreamID 0
if flags&FlagPING != 0 {
s.writeFrame(0, FlagPONG, nil)
return
}
if flags&FlagPONG != 0 {
return // pong received, connection alive
}
// SYN — new inbound stream
if flags&FlagSYN != 0 {
st := newStream(streamID, s)
s.mu.Lock()
s.streams[streamID] = st
s.mu.Unlock()
select {
case s.acceptCh <- st:
default:
log.Printf("[mux] accept backlog full, dropping stream %d", streamID)
s.writeFrame(streamID, FlagRST, nil)
s.mu.Lock()
delete(s.streams, streamID)
s.mu.Unlock()
}
return
}
// Find the stream
s.mu.RLock()
st, ok := s.streams[streamID]
s.mu.RUnlock()
if !ok {
if flags&FlagRST == 0 {
s.writeFrame(streamID, FlagRST, nil)
}
return
}
// RST
if flags&FlagRST != 0 {
st.resetByPeer()
s.mu.Lock()
delete(s.streams, streamID)
s.mu.Unlock()
return
}
// DATA
if flags&FlagDATA != 0 && len(data) > 0 {
st.pushData(data)
}
// FIN
if flags&FlagFIN != 0 {
st.finByPeer()
}
}
func (s *Session) removeStream(id uint32) {
s.mu.Lock()
delete(s.streams, id)
s.mu.Unlock()
}
func (s *Session) pingLoop() {
ticker := time.NewTicker(pingInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := s.writeFrame(0, FlagPING, nil); err != nil {
return
}
case <-s.quit:
return
}
}
}
// ─── Stream ───
// A Stream is a virtual connection within a Session, implementing net.Conn.
type Stream struct {
id uint32
sess *Session
readBuf *ringBuffer
readCh chan struct{} // signaled when data arrives
closed int32
finRecv int32 // remote sent FIN
finSent int32 // we sent FIN
reset int32
mu sync.Mutex
}
func newStream(id uint32, sess *Session) *Stream {
return &Stream{
id: id,
sess: sess,
readBuf: newRingBuffer(defaultWindowSize),
readCh: make(chan struct{}, 1),
}
}
// Read implements io.Reader.
func (st *Stream) Read(p []byte) (int, error) {
for {
if atomic.LoadInt32(&st.reset) == 1 {
return 0, ErrStreamReset
}
n := st.readBuf.Read(p)
if n > 0 {
return n, nil
}
// Buffer empty — check if FIN received
if atomic.LoadInt32(&st.finRecv) == 1 {
return 0, io.EOF
}
if atomic.LoadInt32(&st.closed) == 1 {
return 0, ErrStreamClosed
}
// Wait for data
select {
case <-st.readCh:
case <-st.sess.quit:
return 0, ErrSessionClosed
}
}
}
// Write implements io.Writer.
func (st *Stream) Write(p []byte) (int, error) {
if atomic.LoadInt32(&st.closed) == 1 || atomic.LoadInt32(&st.reset) == 1 {
return 0, ErrStreamClosed
}
total := 0
for len(p) > 0 {
chunk := p
if len(chunk) > maxPayload {
chunk = p[:maxPayload]
}
if err := st.sess.writeFrame(st.id, FlagDATA, chunk); err != nil {
return total, err
}
total += len(chunk)
p = p[len(chunk):]
}
return total, nil
}
// Close sends FIN and closes the stream.
func (st *Stream) Close() error {
if !atomic.CompareAndSwapInt32(&st.closed, 0, 1) {
return nil
}
if atomic.CompareAndSwapInt32(&st.finSent, 0, 1) {
st.sess.writeFrame(st.id, FlagFIN, nil)
}
st.sess.removeStream(st.id)
st.notify()
return nil
}
// LocalAddr implements net.Conn.
func (st *Stream) LocalAddr() net.Addr { return st.sess.conn.LocalAddr() }
func (st *Stream) RemoteAddr() net.Addr { return st.sess.conn.RemoteAddr() }
func (st *Stream) SetDeadline(t time.Time) error {
return nil // TODO: implement per-stream deadlines
}
func (st *Stream) SetReadDeadline(t time.Time) error { return nil }
func (st *Stream) SetWriteDeadline(t time.Time) error { return nil }
func (st *Stream) pushData(data []byte) {
st.readBuf.Write(data)
st.notify()
}
func (st *Stream) finByPeer() {
atomic.StoreInt32(&st.finRecv, 1)
st.notify()
}
func (st *Stream) resetByPeer() {
atomic.StoreInt32(&st.reset, 1)
atomic.StoreInt32(&st.closed, 1)
st.notify()
}
func (st *Stream) closeLocal() {
atomic.StoreInt32(&st.closed, 1)
st.notify()
}
func (st *Stream) notify() {
select {
case st.readCh <- struct{}{}:
default:
}
}
// ─── Ring Buffer ───
// Lock-free-ish ring buffer for stream receive data.
type ringBuffer struct {
buf []byte
r, w int
mu sync.Mutex
size int
}
func newRingBuffer(size int) *ringBuffer {
return &ringBuffer{
buf: make([]byte, size),
size: size,
}
}
func (rb *ringBuffer) Write(p []byte) int {
rb.mu.Lock()
defer rb.mu.Unlock()
n := 0
for _, b := range p {
next := (rb.w + 1) % rb.size
if next == rb.r {
break // full
}
rb.buf[rb.w] = b
rb.w = next
n++
}
return n
}
func (rb *ringBuffer) Read(p []byte) int {
rb.mu.Lock()
defer rb.mu.Unlock()
n := 0
for n < len(p) && rb.r != rb.w {
p[n] = rb.buf[rb.r]
rb.r = (rb.r + 1) % rb.size
n++
}
return n
}
func (rb *ringBuffer) Len() int {
rb.mu.Lock()
defer rb.mu.Unlock()
if rb.w >= rb.r {
return rb.w - rb.r
}
return rb.size - rb.r + rb.w
}

266
pkg/mux/mux_test.go Normal file
View File

@@ -0,0 +1,266 @@
package mux
import (
"bytes"
"io"
"net"
"sync"
"testing"
"time"
)
// pipe creates a connected pair of net.Conn using net.Pipe.
func pipe() (net.Conn, net.Conn) {
return net.Pipe()
}
func TestSessionOpenAccept(t *testing.T) {
c1, c2 := pipe()
defer c1.Close()
defer c2.Close()
client := NewSession(c1, false)
server := NewSession(c2, true)
defer client.Close()
defer server.Close()
// Client opens a stream
st1, err := client.Open()
if err != nil {
t.Fatal(err)
}
// Server accepts
st2, err := server.Accept()
if err != nil {
t.Fatal(err)
}
// Verify stream IDs: client=odd, server would be even
if st1.id%2 != 1 {
t.Errorf("client stream ID should be odd, got %d", st1.id)
}
_ = st2 // server accepted stream has client's ID
}
func TestStreamReadWrite(t *testing.T) {
c1, c2 := pipe()
client := NewSession(c1, false)
server := NewSession(c2, true)
defer client.Close()
defer server.Close()
st1, _ := client.Open()
st2, _ := server.Accept()
msg := []byte("hello from client to server via mux")
// Write from client
n, err := st1.Write(msg)
if err != nil || n != len(msg) {
t.Fatalf("write: n=%d err=%v", n, err)
}
// Read on server
buf := make([]byte, 1024)
n, err = st2.Read(buf)
if err != nil || n != len(msg) {
t.Fatalf("read: n=%d err=%v", n, err)
}
if !bytes.Equal(buf[:n], msg) {
t.Fatalf("data mismatch: got %q want %q", buf[:n], msg)
}
// Bidirectional: server → client
reply := []byte("pong")
st2.Write(reply)
n, _ = st1.Read(buf)
if !bytes.Equal(buf[:n], reply) {
t.Fatalf("reply mismatch: got %q want %q", buf[:n], reply)
}
}
func TestMultipleStreams(t *testing.T) {
c1, c2 := pipe()
client := NewSession(c1, false)
server := NewSession(c2, true)
defer client.Close()
defer server.Close()
const numStreams = 10
var wg sync.WaitGroup
// Client opens N streams concurrently
wg.Add(numStreams)
for i := 0; i < numStreams; i++ {
go func(idx int) {
defer wg.Done()
st, err := client.Open()
if err != nil {
t.Errorf("open stream %d: %v", idx, err)
return
}
msg := []byte("stream-data")
st.Write(msg)
}(i)
}
// Server accepts N streams
for i := 0; i < numStreams; i++ {
st, err := server.Accept()
if err != nil {
t.Fatalf("accept stream %d: %v", i, err)
}
buf := make([]byte, 64)
n, _ := st.Read(buf)
if string(buf[:n]) != "stream-data" {
t.Errorf("stream %d data mismatch", i)
}
}
wg.Wait()
if client.NumStreams() != numStreams {
t.Errorf("client streams: got %d want %d", client.NumStreams(), numStreams)
}
}
func TestStreamClose(t *testing.T) {
c1, c2 := pipe()
client := NewSession(c1, false)
server := NewSession(c2, true)
defer client.Close()
defer server.Close()
st1, _ := client.Open()
st2, _ := server.Accept()
// Write then close
st1.Write([]byte("before-close"))
st1.Close()
// Server should read data then get EOF
buf := make([]byte, 64)
n, _ := st2.Read(buf)
if string(buf[:n]) != "before-close" {
t.Errorf("unexpected data: %q", buf[:n])
}
// Next read should eventually get EOF (FIN received)
time.Sleep(50 * time.Millisecond)
_, err := st2.Read(buf)
if err != io.EOF {
t.Errorf("expected EOF after FIN, got %v", err)
}
}
func TestLargePayload(t *testing.T) {
c1, c2 := pipe()
client := NewSession(c1, false)
server := NewSession(c2, true)
defer client.Close()
defer server.Close()
st1, _ := client.Open()
st2, _ := server.Accept()
// Write 200KB — larger than maxPayload (65535), should auto-split
data := make([]byte, 200*1024)
for i := range data {
data[i] = byte(i % 256)
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
n, err := st1.Write(data)
if err != nil {
t.Errorf("write large: %v", err)
}
if n != len(data) {
t.Errorf("write large: n=%d want %d", n, len(data))
}
}()
// Read all on server
received := make([]byte, 0, len(data))
buf := make([]byte, 32*1024)
for len(received) < len(data) {
n, err := st2.Read(buf)
if err != nil {
t.Fatalf("read at %d: %v", len(received), err)
}
received = append(received, buf[:n]...)
}
wg.Wait()
if !bytes.Equal(received, data) {
t.Error("large payload data mismatch")
}
}
func TestSessionClose(t *testing.T) {
c1, c2 := pipe()
client := NewSession(c1, false)
server := NewSession(c2, true)
st1, _ := client.Open()
server.Accept()
// Close session
client.Close()
// Stream operations should fail
_, err := st1.Write([]byte("x"))
if err == nil {
t.Error("write after session close should fail")
}
// Server accept should fail
time.Sleep(50 * time.Millisecond)
server.Close()
}
func TestPingPong(t *testing.T) {
c1, c2 := pipe()
client := NewSession(c1, false)
server := NewSession(c2, true)
defer client.Close()
defer server.Close()
// Just verify it doesn't crash — ping/pong runs in background
time.Sleep(100 * time.Millisecond)
if client.IsClosed() || server.IsClosed() {
t.Error("sessions should still be alive")
}
}
func BenchmarkThroughput(b *testing.B) {
c1, c2 := pipe()
client := NewSession(c1, false)
server := NewSession(c2, true)
defer client.Close()
defer server.Close()
st1, _ := client.Open()
st2, _ := server.Accept()
data := make([]byte, 4096)
buf := make([]byte, 4096)
b.SetBytes(int64(len(data)))
b.ResetTimer()
go func() {
for i := 0; i < b.N; i++ {
st2.Read(buf)
}
}()
for i := 0; i < b.N; i++ {
st1.Write(data)
}
}

260
pkg/nat/detect.go Normal file
View File

@@ -0,0 +1,260 @@
// Package nat provides NAT type detection via UDP and TCP STUN.
package nat
import (
"encoding/json"
"fmt"
"net"
"time"
"github.com/openp2p-cn/inp2p/pkg/protocol"
)
const (
detectTimeout = 5 * time.Second
)
// DetectResult holds the NAT detection outcome.
type DetectResult struct {
Type protocol.NATType
PublicIP string
Port1 int // external port seen on STUN server port 1
Port2 int // external port seen on STUN server port 2
}
// stunReq is sent to the STUN endpoint.
type stunReq struct {
ID int `json:"id"`
}
// stunRsp is received from the STUN endpoint.
type stunRsp struct {
IP string `json:"ip"`
Port int `json:"port"`
ID int `json:"id"`
}
// DetectUDP sends probes from the same local port to two different server
// UDP ports. If both return the same external port → Cone; different → Symmetric.
func DetectUDP(serverIP string, port1, port2 int) DetectResult {
result := DetectResult{Type: protocol.NATUnknown}
// Bind a single local UDP port
conn, err := net.ListenPacket("udp", ":0")
if err != nil {
return result
}
defer conn.Close()
r1, err1 := probeUDP(conn, serverIP, port1, 1)
r2, err2 := probeUDP(conn, serverIP, port2, 2)
if err1 != nil || err2 != nil {
return result // timeout → NATUnknown
}
result.PublicIP = r1.IP
result.Port1 = r1.Port
result.Port2 = r2.Port
if r1.Port == r2.Port {
result.Type = protocol.NATCone
} else {
result.Type = protocol.NATSymmetric
}
// Check if public IP equals local IP → no NAT
localIP := conn.LocalAddr().(*net.UDPAddr).IP.String()
if localIP == r1.IP || r1.IP == "" {
// might be public
}
return result
}
func probeUDP(conn net.PacketConn, serverIP string, port, id int) (stunRsp, error) {
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", serverIP, port))
if err != nil {
return stunRsp{}, err
}
frame, _ := protocol.Encode(protocol.MsgNAT, protocol.SubNATDetectReq, stunReq{ID: id})
conn.SetWriteDeadline(time.Now().Add(detectTimeout))
if _, err := conn.WriteTo(frame, addr); err != nil {
return stunRsp{}, err
}
buf := make([]byte, 1024)
conn.SetReadDeadline(time.Now().Add(detectTimeout))
n, _, err := conn.ReadFrom(buf)
if err != nil {
return stunRsp{}, err
}
var rsp stunRsp
if n > protocol.HeaderSize {
json.Unmarshal(buf[protocol.HeaderSize:n], &rsp)
}
return rsp, nil
}
// DetectTCP connects to two different TCP ports on the server and compares
// the observed external port. This is the fallback when UDP is blocked.
func DetectTCP(serverIP string, port1, port2 int) DetectResult {
result := DetectResult{Type: protocol.NATUnknown}
r1, err1 := probeTCP(serverIP, port1, 1)
r2, err2 := probeTCP(serverIP, port2, 2)
if err1 != nil || err2 != nil {
return result
}
result.PublicIP = r1.IP
result.Port1 = r1.Port
result.Port2 = r2.Port
if r1.Port == r2.Port {
result.Type = protocol.NATCone
} else {
result.Type = protocol.NATSymmetric
}
return result
}
func probeTCP(serverIP string, port, id int) (stunRsp, error) {
conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", serverIP, port), detectTimeout)
if err != nil {
return stunRsp{}, err
}
defer conn.Close()
frame, _ := protocol.Encode(protocol.MsgNAT, protocol.SubNATDetectReq, stunReq{ID: id})
conn.SetWriteDeadline(time.Now().Add(detectTimeout))
if _, err := conn.Write(frame); err != nil {
return stunRsp{}, err
}
buf := make([]byte, 1024)
conn.SetReadDeadline(time.Now().Add(detectTimeout))
n, err := conn.Read(buf)
if err != nil {
return stunRsp{}, err
}
var rsp stunRsp
if n > protocol.HeaderSize {
json.Unmarshal(buf[protocol.HeaderSize:n], &rsp)
}
return rsp, nil
}
// Detect runs UDP detection first, falls back to TCP if UDP is blocked.
func Detect(serverIP string, udpPort1, udpPort2, tcpPort1, tcpPort2 int) DetectResult {
result := DetectUDP(serverIP, udpPort1, udpPort2)
if result.Type != protocol.NATUnknown {
return result
}
// UDP blocked, fallback to TCP
return DetectTCP(serverIP, tcpPort1, tcpPort2)
}
// ─── Server-side STUN handler ───
// ServeUDPSTUN listens on a UDP port and echoes back the sender's observed IP:port.
func ServeUDPSTUN(port int, quit <-chan struct{}) error {
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port))
if err != nil {
return err
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
return err
}
defer conn.Close()
go func() {
<-quit
conn.Close()
}()
buf := make([]byte, 1024)
for {
n, remoteAddr, err := conn.ReadFromUDP(buf)
if err != nil {
select {
case <-quit:
return nil
default:
continue
}
}
// Parse request
var req stunReq
if n > protocol.HeaderSize {
json.Unmarshal(buf[protocol.HeaderSize:n], &req)
}
// Reply with observed address
rsp := stunRsp{
IP: remoteAddr.IP.String(),
Port: remoteAddr.Port,
ID: req.ID,
}
frame, _ := protocol.Encode(protocol.MsgNAT, protocol.SubNATDetectRsp, rsp)
conn.WriteToUDP(frame, remoteAddr)
}
}
// ServeTCPSTUN listens on a TCP port. Each connection: read one req, write one rsp with observed addr.
func ServeTCPSTUN(port int, quit <-chan struct{}) error {
ln, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
return err
}
defer ln.Close()
go func() {
<-quit
ln.Close()
}()
for {
conn, err := ln.Accept()
if err != nil {
select {
case <-quit:
return nil
default:
continue
}
}
go func(c net.Conn) {
defer c.Close()
remoteAddr := c.RemoteAddr().(*net.TCPAddr)
buf := make([]byte, 1024)
c.SetReadDeadline(time.Now().Add(10 * time.Second))
n, err := c.Read(buf)
if err != nil {
return
}
var req stunReq
if n > protocol.HeaderSize {
json.Unmarshal(buf[protocol.HeaderSize:n], &req)
}
rsp := stunRsp{
IP: remoteAddr.IP.String(),
Port: remoteAddr.Port,
ID: req.ID,
}
frame, _ := protocol.Encode(protocol.MsgNAT, protocol.SubNATDetectRsp, rsp)
c.SetWriteDeadline(time.Now().Add(5 * time.Second))
c.Write(frame)
}(conn)
}
}

276
pkg/protocol/protocol.go Normal file
View File

@@ -0,0 +1,276 @@
// Package protocol defines the INP2P wire protocol.
//
// Message format: [Header 8B] + [JSON payload]
// Header: DataLen(uint32 LE) + MainType(uint16 LE) + SubType(uint16 LE)
// DataLen = len(header) + len(payload) = 8 + len(json)
package protocol
import (
"bytes"
"encoding/binary"
"encoding/json"
"fmt"
"io"
)
// HeaderSize is the fixed 8-byte message header.
const HeaderSize = 8
// ─── Main message types ───
const (
MsgLogin uint16 = 1
MsgHeartbeat uint16 = 2
MsgNAT uint16 = 3
MsgPush uint16 = 4 // signaling push (punch/relay coordination)
MsgRelay uint16 = 5
MsgReport uint16 = 6
MsgTunnel uint16 = 7 // in-tunnel control messages
)
// ─── Sub types: MsgLogin ───
const (
SubLoginReq uint16 = iota
SubLoginRsp
)
// ─── Sub types: MsgHeartbeat ───
const (
SubHeartbeatPing uint16 = iota
SubHeartbeatPong
)
// ─── Sub types: MsgNAT ───
const (
SubNATDetectReq uint16 = iota
SubNATDetectRsp
)
// ─── Sub types: MsgPush ───
const (
SubPushConnectReq uint16 = iota // "please connect to peer X"
SubPushConnectRsp // peer's punch parameters
SubPushPunchStart // coordinate simultaneous punch
SubPushPunchResult // report punch outcome
SubPushRelayOffer // relay node offers to relay
SubPushNodeOnline // notify: destination came online
SubPushEditApp // add/edit tunnel app
SubPushDeleteApp // delete tunnel app
SubPushReportApps // request app list
)
// ─── Sub types: MsgRelay ───
const (
SubRelayNodeReq uint16 = iota
SubRelayNodeRsp
SubRelayDataReq // establish data channel through relay
SubRelayDataRsp
)
// ─── Sub types: MsgReport ───
const (
SubReportBasic uint16 = iota // OS, version, MAC, etc.
SubReportApps // running tunnels
SubReportConnect // connection result
)
// ─── NAT types ───
type NATType int
const (
NATNone NATType = 0 // public IP, no NAT
NATCone NATType = 1 // full/restricted/port-restricted cone
NATSymmetric NATType = 2 // symmetric (port changes per dest)
NATUnknown NATType = 314 // detection failed / UDP blocked
)
func (n NATType) String() string {
switch n {
case NATNone:
return "None"
case NATCone:
return "Cone"
case NATSymmetric:
return "Symmetric"
default:
return "Unknown"
}
}
// CanPunch returns true if at least one side is Cone (or has public IP).
func CanPunch(a, b NATType) bool {
return a == NATNone || b == NATNone || a == NATCone || b == NATCone
}
// ─── Header ───
type Header struct {
DataLen uint32
MainType uint16
SubType uint16
}
// ─── Encode / Decode ───
// Encode packs header + JSON payload into a byte slice.
func Encode(mainType, subType uint16, payload interface{}) ([]byte, error) {
var jsonData []byte
if payload != nil {
var err error
jsonData, err = json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("marshal payload: %w", err)
}
}
h := Header{
DataLen: uint32(HeaderSize + len(jsonData)),
MainType: mainType,
SubType: subType,
}
buf := new(bytes.Buffer)
buf.Grow(int(h.DataLen))
if err := binary.Write(buf, binary.LittleEndian, h); err != nil {
return nil, err
}
buf.Write(jsonData)
return buf.Bytes(), nil
}
// DecodeHeader reads the 8-byte header from r.
func DecodeHeader(data []byte) (Header, error) {
if len(data) < HeaderSize {
return Header{}, io.ErrShortBuffer
}
var h Header
err := binary.Read(bytes.NewReader(data[:HeaderSize]), binary.LittleEndian, &h)
return h, err
}
// DecodePayload unmarshals the JSON portion after the header.
func DecodePayload(data []byte, v interface{}) error {
if len(data) <= HeaderSize {
return nil // empty payload is valid
}
return json.Unmarshal(data[HeaderSize:], v)
}
// ─── Common message structs ───
// LoginReq is sent by client on WSS connect.
type LoginReq struct {
Node string `json:"node"`
Token uint64 `json:"token"`
User string `json:"user,omitempty"`
Version string `json:"version"`
NATType NATType `json:"natType"`
ShareBandwidth int `json:"shareBandwidth"`
RelayEnabled bool `json:"relayEnabled"` // --relay flag
SuperRelay bool `json:"superRelay"` // --super flag
PublicIP string `json:"publicIP,omitempty"`
}
type LoginRsp struct {
Error int `json:"error"`
Detail string `json:"detail,omitempty"`
Ts int64 `json:"ts"`
Token uint64 `json:"token"`
User string `json:"user"`
Node string `json:"node"`
}
// ReportBasic is the initial system info report after login.
type ReportBasic struct {
OS string `json:"os"`
Mac string `json:"mac"`
LanIP string `json:"lanIP"`
Version string `json:"version"`
HasIPv4 int `json:"hasIPv4"`
HasUPNPorNATPMP int `json:"hasUPNPorNATPMP"`
IPv6 string `json:"IPv6,omitempty"`
}
type ReportBasicRsp struct {
Error int `json:"error"`
}
// PunchParams carries the information needed for hole-punching.
type PunchParams struct {
IP string `json:"ip"`
Port int `json:"port"`
NATType NATType `json:"natType"`
Token uint64 `json:"token"` // TOTP for auth
IPv6 string `json:"ipv6,omitempty"`
HasIPv4 int `json:"hasIPv4"`
LinkMode string `json:"linkMode"` // "udp" or "tcp"
}
// ConnectReq is pushed by server to coordinate a connection.
type ConnectReq struct {
From string `json:"from"`
To string `json:"to"`
FromIP string `json:"fromIP"`
Peer PunchParams `json:"peer"`
AppName string `json:"appName,omitempty"`
Protocol string `json:"protocol"` // "tcp" or "udp"
SrcPort int `json:"srcPort"`
DstHost string `json:"dstHost"`
DstPort int `json:"dstPort"`
}
type ConnectRsp struct {
Error int `json:"error"`
Detail string `json:"detail,omitempty"`
From string `json:"from"`
To string `json:"to"`
Peer PunchParams `json:"peer,omitempty"`
}
// RelayNodeReq asks the server for a relay node.
type RelayNodeReq struct {
PeerNode string `json:"peerNode"`
}
type RelayNodeRsp struct {
RelayName string `json:"relayName"`
RelayIP string `json:"relayIP"`
RelayPort int `json:"relayPort"`
RelayToken uint64 `json:"relayToken"`
Mode string `json:"mode"` // "private", "super", "server"
Error int `json:"error"`
}
// AppConfig defines a tunnel application.
type AppConfig struct {
AppName string `json:"appName"`
Protocol string `json:"protocol"` // "tcp" or "udp"
SrcPort int `json:"srcPort"`
PeerNode string `json:"peerNode"`
DstHost string `json:"dstHost"`
DstPort int `json:"dstPort"`
Enabled int `json:"enabled"`
RelayNode string `json:"relayNode,omitempty"` // force specific relay
}
// ReportConnect is the connection result reported to server.
type ReportConnect struct {
PeerNode string `json:"peerNode"`
NATType NATType `json:"natType"`
PeerNATType NATType `json:"peerNatType"`
LinkMode string `json:"linkMode"` // "udppunch", "tcppunch", "relay"
Error string `json:"error,omitempty"`
RTT int `json:"rtt,omitempty"` // milliseconds
RelayNode string `json:"relayNode,omitempty"`
Protocol string `json:"protocol,omitempty"`
SrcPort int `json:"srcPort,omitempty"`
DstPort int `json:"dstPort,omitempty"`
DstHost string `json:"dstHost,omitempty"`
Version string `json:"version,omitempty"`
ShareBandwidth int `json:"shareBandWidth,omitempty"`
}

204
pkg/punch/punch.go Normal file
View File

@@ -0,0 +1,204 @@
// Package punch implements UDP and TCP hole-punching.
package punch
import (
"fmt"
"log"
"net"
"time"
"github.com/openp2p-cn/inp2p/pkg/protocol"
)
const (
punchTimeout = 5 * time.Second
punchRetries = 5
handshakeMagic = "INP2P-PUNCH"
handshakeAck = "INP2P-PUNCH-ACK"
)
// Result holds the outcome of a punch attempt.
type Result struct {
Conn net.Conn
Mode string // "udp" or "tcp"
RTT time.Duration
PeerAddr string
Error error
}
// Config for a punch attempt.
type Config struct {
PeerIP string
PeerPort int
PeerNAT protocol.NATType
SelfNAT protocol.NATType
SelfPort int // local port to bind (0 = auto)
IsInitiator bool
}
// AttemptUDP tries to establish a UDP connection via hole-punching.
// Both sides must call this simultaneously (coordinated by server).
func AttemptUDP(cfg Config) Result {
if !protocol.CanPunch(cfg.SelfNAT, cfg.PeerNAT) {
return Result{Error: fmt.Errorf("cannot UDP punch: self=%s peer=%s", cfg.SelfNAT, cfg.PeerNAT)}
}
localAddr := &net.UDPAddr{Port: cfg.SelfPort}
conn, err := net.ListenUDP("udp", localAddr)
if err != nil {
return Result{Error: fmt.Errorf("listen UDP: %w", err)}
}
peerAddr := &net.UDPAddr{
IP: net.ParseIP(cfg.PeerIP),
Port: cfg.PeerPort,
}
start := time.Now()
// Send punch packets
for i := 0; i < punchRetries; i++ {
conn.SetWriteDeadline(time.Now().Add(time.Second))
conn.WriteTo([]byte(handshakeMagic), peerAddr)
time.Sleep(200 * time.Millisecond)
}
// Listen for response
buf := make([]byte, 256)
conn.SetReadDeadline(time.Now().Add(punchTimeout))
n, from, err := conn.ReadFromUDP(buf)
if err != nil {
conn.Close()
return Result{Error: fmt.Errorf("UDP punch timeout: %w", err)}
}
// Verify handshake
msg := string(buf[:n])
if msg != handshakeMagic && msg != handshakeAck {
conn.Close()
return Result{Error: fmt.Errorf("unexpected punch data: %q", msg)}
}
// Send ack
conn.WriteTo([]byte(handshakeAck), from)
rtt := time.Since(start)
log.Printf("[punch] UDP punch ok: peer=%s rtt=%s", from, rtt)
return Result{
Conn: conn,
Mode: "udp",
RTT: rtt,
PeerAddr: from.String(),
}
}
// AttemptTCP tries TCP hole-punching using simultaneous SYN.
// This works by having both sides dial each other at the same time.
func AttemptTCP(cfg Config) Result {
if !protocol.CanPunch(cfg.SelfNAT, cfg.PeerNAT) {
return Result{Error: fmt.Errorf("cannot TCP punch: self=%s peer=%s", cfg.SelfNAT, cfg.PeerNAT)}
}
peerAddr := fmt.Sprintf("%s:%d", cfg.PeerIP, cfg.PeerPort)
start := time.Now()
// TCP simultaneous open: keep trying to dial the peer
var conn net.Conn
var err error
for i := 0; i < punchRetries*2; i++ {
d := net.Dialer{Timeout: time.Second, LocalAddr: &net.TCPAddr{Port: cfg.SelfPort}}
conn, err = d.Dial("tcp", peerAddr)
if err == nil {
break
}
time.Sleep(300 * time.Millisecond)
}
if err != nil {
return Result{Error: fmt.Errorf("TCP punch failed: %w", err)}
}
// TCP handshake for INP2P
conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
conn.Write([]byte(handshakeMagic))
buf := make([]byte, 256)
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
n, err := conn.Read(buf)
if err != nil {
conn.Close()
return Result{Error: fmt.Errorf("TCP handshake read: %w", err)}
}
msg := string(buf[:n])
if msg != handshakeMagic && msg != handshakeAck {
conn.Close()
return Result{Error: fmt.Errorf("TCP unexpected handshake: %q", msg)}
}
conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
conn.Write([]byte(handshakeAck))
rtt := time.Since(start)
log.Printf("[punch] TCP punch ok: peer=%s rtt=%s", conn.RemoteAddr(), rtt)
return Result{
Conn: conn,
Mode: "tcp",
RTT: rtt,
PeerAddr: conn.RemoteAddr().String(),
}
}
// AttemptDirect tries to directly connect when one side has a public IP.
func AttemptDirect(cfg Config) Result {
addr := fmt.Sprintf("%s:%d", cfg.PeerIP, cfg.PeerPort)
start := time.Now()
conn, err := net.DialTimeout("tcp", addr, punchTimeout)
if err != nil {
return Result{Error: fmt.Errorf("direct connect failed: %w", err)}
}
rtt := time.Since(start)
log.Printf("[punch] direct connect ok: peer=%s rtt=%s", addr, rtt)
return Result{
Conn: conn,
Mode: "tcp-direct",
RTT: rtt,
PeerAddr: addr,
}
}
// Connect tries all punch methods in priority order and returns the first success.
func Connect(cfg Config) Result {
methods := []struct {
name string
fn func(Config) Result
}{
{"UDP-punch", AttemptUDP},
{"TCP-punch", AttemptTCP},
}
// If peer has public IP, try direct first
if cfg.PeerNAT == protocol.NATNone {
r := AttemptDirect(cfg)
if r.Error == nil {
return r
}
log.Printf("[punch] direct failed: %v", r.Error)
}
for _, m := range methods {
log.Printf("[punch] trying %s to %s:%d", m.name, cfg.PeerIP, cfg.PeerPort)
r := m.fn(cfg)
if r.Error == nil {
return r
}
log.Printf("[punch] %s failed: %v", m.name, r.Error)
}
return Result{Error: fmt.Errorf("all punch methods exhausted")}
}

415
pkg/relay/relay.go Normal file
View File

@@ -0,0 +1,415 @@
// Package relay implements relay/super-relay node capabilities.
//
// Relay flow:
// 1. Client A asks server for relay (RelayNodeReq)
// 2. Server finds relay R, generates TOTP/token, responds to A (RelayNodeRsp)
// 3. Server pushes RelayOffer to R with session info
// 4. A connects to R:relayPort, sends RelayHandshake{SessionID, Role="from", Token}
// 5. B connects to R:relayPort, sends RelayHandshake{SessionID, Role="to", Token}
// (B gets the session info via server push)
// 6. R verifies both tokens, bridges A↔B
package relay
import (
"encoding/binary"
"encoding/json"
"fmt"
"io"
"log"
"net"
"sync"
"sync/atomic"
"time"
"github.com/openp2p-cn/inp2p/pkg/auth"
)
const (
handshakeTimeout = 10 * time.Second
pairTimeout = 30 * time.Second // how long to wait for the second peer
headerLen = 4 // uint32 LE length prefix for handshake JSON
)
// RelayHandshake is sent by each peer when connecting to a relay node.
type RelayHandshake struct {
SessionID string `json:"sessionID"`
Role string `json:"role"` // "from" or "to"
Token uint64 `json:"token"` // TOTP or one-time token
Node string `json:"node"` // sender's node name
}
// Node represents a relay-capable node's metadata (used by server).
type Node struct {
Name string
IP string
Port int
Token uint64
Mode string // "private" (same user), "super" (shared)
Bandwidth int
LastUsed time.Time
ActiveLoad int32
}
// pendingSession waits for both peers to arrive.
type pendingSession struct {
id string
from string
to string
token uint64
connFrom net.Conn
connTo net.Conn
mu sync.Mutex
done chan struct{}
created time.Time
}
// Manager manages relay sessions on this node.
type Manager struct {
enabled bool
superRelay bool
maxLoad int
token uint64 // this node's auth token
port int
listener net.Listener
pending map[string]*pendingSession // sessionID → pending
pMu sync.Mutex
sessions map[string]*Session // sessionID → active session
sMu sync.RWMutex
quit chan struct{}
}
// Session represents an active relay bridging two peers.
type Session struct {
ID string
From string
To string
ConnA net.Conn
ConnB net.Conn
BytesFwd int64
StartTime time.Time
closed int32
}
// NewManager creates a relay manager.
func NewManager(port int, enabled, superRelay bool, maxLoad int, token uint64) *Manager {
return &Manager{
enabled: enabled,
superRelay: superRelay,
maxLoad: maxLoad,
token: token,
port: port,
pending: make(map[string]*pendingSession),
sessions: make(map[string]*Session),
quit: make(chan struct{}),
}
}
func (m *Manager) IsEnabled() bool { return m.enabled }
func (m *Manager) IsSuperRelay() bool { return m.superRelay }
func (m *Manager) ActiveSessions() int {
m.sMu.RLock()
defer m.sMu.RUnlock()
return len(m.sessions)
}
func (m *Manager) CanAcceptRelay() bool {
return m.enabled && m.ActiveSessions() < m.maxLoad
}
// Start begins listening for relay connections.
func (m *Manager) Start() error {
if !m.enabled {
return nil
}
ln, err := net.Listen("tcp", fmt.Sprintf(":%d", m.port))
if err != nil {
return fmt.Errorf("relay listen :%d: %w", m.port, err)
}
m.listener = ln
log.Printf("[relay] listening on :%d (super=%v, maxLoad=%d)", m.port, m.superRelay, m.maxLoad)
go m.acceptLoop()
go m.cleanupLoop()
return nil
}
func (m *Manager) acceptLoop() {
for {
conn, err := m.listener.Accept()
if err != nil {
select {
case <-m.quit:
return
default:
continue
}
}
go m.handleConn(conn)
}
}
func (m *Manager) handleConn(conn net.Conn) {
// Read handshake with timeout
conn.SetReadDeadline(time.Now().Add(handshakeTimeout))
// Length-prefixed JSON: [4B len][JSON]
var length uint32
if err := binary.Read(conn, binary.LittleEndian, &length); err != nil {
log.Printf("[relay] handshake read len: %v", err)
conn.Close()
return
}
if length > 4096 {
log.Printf("[relay] handshake too large: %d", length)
conn.Close()
return
}
buf := make([]byte, length)
if _, err := io.ReadFull(conn, buf); err != nil {
log.Printf("[relay] handshake read body: %v", err)
conn.Close()
return
}
conn.SetReadDeadline(time.Time{}) // clear deadline
var hs RelayHandshake
if err := json.Unmarshal(buf, &hs); err != nil {
log.Printf("[relay] handshake parse: %v", err)
conn.Close()
return
}
// Verify TOTP
if !auth.VerifyTOTP(hs.Token, m.token, time.Now().Unix()) {
log.Printf("[relay] handshake denied: %s (TOTP mismatch)", hs.Node)
sendRelayResult(conn, 1, "auth failed")
conn.Close()
return
}
log.Printf("[relay] handshake ok: session=%s role=%s node=%s", hs.SessionID, hs.Role, hs.Node)
// Find or create pending session
m.pMu.Lock()
ps, exists := m.pending[hs.SessionID]
if !exists {
ps = &pendingSession{
id: hs.SessionID,
token: hs.Token,
done: make(chan struct{}),
created: time.Now(),
}
m.pending[hs.SessionID] = ps
}
m.pMu.Unlock()
ps.mu.Lock()
switch hs.Role {
case "from":
ps.from = hs.Node
ps.connFrom = conn
case "to":
ps.to = hs.Node
ps.connTo = conn
default:
ps.mu.Unlock()
log.Printf("[relay] unknown role: %s", hs.Role)
conn.Close()
return
}
// Check if both peers have arrived
bothReady := ps.connFrom != nil && ps.connTo != nil
ps.mu.Unlock()
if bothReady {
// Both peers connected — bridge them
m.pMu.Lock()
delete(m.pending, hs.SessionID)
m.pMu.Unlock()
sendRelayResult(ps.connFrom, 0, "ok")
sendRelayResult(ps.connTo, 0, "ok")
m.bridge(ps)
} else {
// Wait for the other peer
select {
case <-ps.done:
// Woken up by the other peer's arrival
case <-time.After(pairTimeout):
log.Printf("[relay] session %s timeout waiting for pair", hs.SessionID)
m.pMu.Lock()
delete(m.pending, hs.SessionID)
m.pMu.Unlock()
sendRelayResult(conn, 1, "pair timeout")
conn.Close()
case <-m.quit:
conn.Close()
}
}
}
// relayResult is sent back to each peer after handshake.
type relayResult struct {
Error int `json:"error"`
Detail string `json:"detail,omitempty"`
}
func sendRelayResult(conn net.Conn, errCode int, detail string) {
data, _ := json.Marshal(relayResult{Error: errCode, Detail: detail})
length := uint32(len(data))
conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
binary.Write(conn, binary.LittleEndian, length)
conn.Write(data)
conn.SetWriteDeadline(time.Time{})
}
func (m *Manager) bridge(ps *pendingSession) {
sess := &Session{
ID: ps.id,
From: ps.from,
To: ps.to,
ConnA: ps.connFrom,
ConnB: ps.connTo,
StartTime: time.Now(),
}
m.sMu.Lock()
m.sessions[ps.id] = sess
m.sMu.Unlock()
log.Printf("[relay] bridging %s ↔ %s (session %s)", ps.from, ps.to, ps.id)
go func() {
defer func() {
sess.Close()
m.sMu.Lock()
delete(m.sessions, ps.id)
m.sMu.Unlock()
log.Printf("[relay] session %s ended, %d bytes forwarded", ps.id, atomic.LoadInt64(&sess.BytesFwd))
}()
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
n, _ := io.Copy(sess.ConnB, sess.ConnA)
atomic.AddInt64(&sess.BytesFwd, n)
}()
go func() {
defer wg.Done()
n, _ := io.Copy(sess.ConnA, sess.ConnB)
atomic.AddInt64(&sess.BytesFwd, n)
}()
wg.Wait()
}()
}
func (m *Manager) cleanupLoop() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
m.pMu.Lock()
for id, ps := range m.pending {
if time.Since(ps.created) > pairTimeout {
delete(m.pending, id)
if ps.connFrom != nil {
ps.connFrom.Close()
}
if ps.connTo != nil {
ps.connTo.Close()
}
}
}
m.pMu.Unlock()
case <-m.quit:
return
}
}
}
// Close shuts down a session.
func (s *Session) Close() {
if !atomic.CompareAndSwapInt32(&s.closed, 0, 1) {
return
}
if s.ConnA != nil {
s.ConnA.Close()
}
if s.ConnB != nil {
s.ConnB.Close()
}
}
// Stop shuts down the relay manager.
func (m *Manager) Stop() {
close(m.quit)
if m.listener != nil {
m.listener.Close()
}
m.sMu.Lock()
for _, s := range m.sessions {
s.Close()
}
m.sMu.Unlock()
}
// ─── Client-side helper ───
// ConnectToRelay connects to a relay node and performs the handshake.
func ConnectToRelay(relayAddr string, sessionID, role, node string, token uint64) (net.Conn, error) {
conn, err := net.DialTimeout("tcp", relayAddr, 10*time.Second)
if err != nil {
return nil, fmt.Errorf("dial relay %s: %w", relayAddr, err)
}
hs := RelayHandshake{
SessionID: sessionID,
Role: role,
Token: token,
Node: node,
}
data, _ := json.Marshal(hs)
conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
length := uint32(len(data))
if err := binary.Write(conn, binary.LittleEndian, length); err != nil {
conn.Close()
return nil, err
}
if _, err := conn.Write(data); err != nil {
conn.Close()
return nil, err
}
// Read result
conn.SetReadDeadline(time.Now().Add(pairTimeout + 5*time.Second))
if err := binary.Read(conn, binary.LittleEndian, &length); err != nil {
conn.Close()
return nil, fmt.Errorf("read relay result: %w", err)
}
buf := make([]byte, length)
if _, err := io.ReadFull(conn, buf); err != nil {
conn.Close()
return nil, fmt.Errorf("read relay result body: %w", err)
}
conn.SetReadDeadline(time.Time{})
var result relayResult
json.Unmarshal(buf, &result)
if result.Error != 0 {
conn.Close()
return nil, fmt.Errorf("relay denied: %s", result.Detail)
}
log.Printf("[relay] connected to relay %s, session=%s role=%s", relayAddr, sessionID, role)
return conn, nil
}

189
pkg/relay/relay_test.go Normal file
View File

@@ -0,0 +1,189 @@
package relay
import (
"fmt"
"net"
"sync"
"testing"
"time"
"github.com/openp2p-cn/inp2p/pkg/auth"
)
func TestRelayBridge(t *testing.T) {
token := auth.MakeToken("test", "pass")
mgr := NewManager(29700, true, false, 10, token)
if err := mgr.Start(); err != nil {
t.Fatal(err)
}
defer mgr.Stop()
sessionID := "test-session-1"
totp := auth.GenTOTP(token, time.Now().Unix())
var wg sync.WaitGroup
var connA, connB net.Conn
var errA, errB error
// Peer A connects as "from"
wg.Add(1)
go func() {
defer wg.Done()
connA, errA = ConnectToRelay(
fmt.Sprintf("127.0.0.1:%d", 29700),
sessionID, "from", "nodeA", totp,
)
}()
// Peer B connects as "to" after a short delay
wg.Add(1)
go func() {
defer wg.Done()
time.Sleep(200 * time.Millisecond)
connB, errB = ConnectToRelay(
fmt.Sprintf("127.0.0.1:%d", 29700),
sessionID, "to", "nodeB", totp,
)
}()
wg.Wait()
if errA != nil {
t.Fatalf("connA error: %v", errA)
}
if errB != nil {
t.Fatalf("connB error: %v", errB)
}
defer connA.Close()
defer connB.Close()
// Test data flow: A → B
msg := []byte("hello through relay")
connA.Write(msg)
buf := make([]byte, 256)
connB.SetReadDeadline(time.Now().Add(3 * time.Second))
n, err := connB.Read(buf)
if err != nil {
t.Fatalf("read from B: %v", err)
}
if string(buf[:n]) != string(msg) {
t.Errorf("got %q, want %q", buf[:n], msg)
}
// Test data flow: B → A
reply := []byte("relay pong")
connB.Write(reply)
connA.SetReadDeadline(time.Now().Add(3 * time.Second))
n, err = connA.Read(buf)
if err != nil {
t.Fatalf("read from A: %v", err)
}
if string(buf[:n]) != string(reply) {
t.Errorf("got %q, want %q", buf[:n], reply)
}
// Verify session count
if mgr.ActiveSessions() != 1 {
t.Errorf("active sessions: got %d want 1", mgr.ActiveSessions())
}
t.Logf("✅ Relay bridge OK: A↔B bidirectional, %d active sessions", mgr.ActiveSessions())
}
func TestRelayLargeData(t *testing.T) {
token := auth.MakeToken("test", "pass")
mgr := NewManager(29701, true, false, 10, token)
if err := mgr.Start(); err != nil {
t.Fatal(err)
}
defer mgr.Stop()
sessionID := "test-large-data"
totp := auth.GenTOTP(token, time.Now().Unix())
var wg sync.WaitGroup
var connA, connB net.Conn
wg.Add(2)
go func() {
defer wg.Done()
var err error
connA, err = ConnectToRelay("127.0.0.1:29701", sessionID, "from", "bigA", totp)
if err != nil {
t.Errorf("connA: %v", err)
}
}()
go func() {
defer wg.Done()
time.Sleep(100 * time.Millisecond)
var err error
connB, err = ConnectToRelay("127.0.0.1:29701", sessionID, "to", "bigB", totp)
if err != nil {
t.Errorf("connB: %v", err)
}
}()
wg.Wait()
if connA == nil || connB == nil {
t.Fatal("connection failed")
}
defer connA.Close()
defer connB.Close()
// Send 1MB through relay
const dataSize = 1024 * 1024
data := make([]byte, dataSize)
for i := range data {
data[i] = byte(i % 256)
}
wg.Add(1)
go func() {
defer wg.Done()
connA.Write(data)
}()
// Read exact amount on B side
received := make([]byte, dataSize)
total := 0
connB.SetReadDeadline(time.Now().Add(10 * time.Second))
for total < dataSize {
n, err := connB.Read(received[total:])
if err != nil {
t.Fatalf("read at %d: %v", total, err)
}
total += n
}
wg.Wait()
if len(received) != len(data) {
t.Fatalf("size mismatch: got %d want %d", len(received), len(data))
}
for i := 0; i < len(data); i++ {
if received[i] != data[i] {
t.Fatalf("data mismatch at byte %d", i)
break
}
}
t.Logf("✅ 1MB relay transfer OK")
}
func TestRelayAuthDenied(t *testing.T) {
token := auth.MakeToken("real", "token")
mgr := NewManager(29702, true, false, 10, token)
if err := mgr.Start(); err != nil {
t.Fatal(err)
}
defer mgr.Stop()
// Use wrong TOTP
wrongToken := auth.GenTOTP(auth.MakeToken("wrong", "creds"), time.Now().Unix())
_, err := ConnectToRelay("127.0.0.1:29702", "bad-session", "from", "badNode", wrongToken)
if err == nil {
t.Fatal("expected auth denied, got success")
}
t.Logf("✅ Auth denied correctly: %v", err)
}

180
pkg/signal/conn.go Normal file
View File

@@ -0,0 +1,180 @@
// Package signal provides the WSS signaling connection between client and server.
package signal
import (
"encoding/json"
"fmt"
"log"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/openp2p-cn/inp2p/pkg/protocol"
)
// Conn wraps a WebSocket connection with message framing.
type Conn struct {
ws *websocket.Conn
writeMu sync.Mutex
handlers map[msgKey]Handler
hMu sync.RWMutex
quit chan struct{}
once sync.Once
Node string
Token uint64
// waiters for synchronous request-response
waiters map[msgKey]chan []byte
wMu sync.Mutex
}
type msgKey struct {
main uint16
sub uint16
}
// Handler processes an incoming message. data includes header + payload.
type Handler func(data []byte) error
// NewConn wraps an existing websocket.
func NewConn(ws *websocket.Conn) *Conn {
return &Conn{
ws: ws,
handlers: make(map[msgKey]Handler),
waiters: make(map[msgKey]chan []byte),
quit: make(chan struct{}),
}
}
// OnMessage registers a handler for a specific (MainType, SubType).
func (c *Conn) OnMessage(mainType, subType uint16, h Handler) {
c.hMu.Lock()
c.handlers[msgKey{mainType, subType}] = h
c.hMu.Unlock()
}
// Write sends a message with the given type and JSON payload.
func (c *Conn) Write(mainType, subType uint16, payload interface{}) error {
frame, err := protocol.Encode(mainType, subType, payload)
if err != nil {
return err
}
return c.WriteRaw(frame)
}
// WriteRaw sends raw bytes.
func (c *Conn) WriteRaw(data []byte) error {
c.writeMu.Lock()
defer c.writeMu.Unlock()
c.ws.SetWriteDeadline(time.Now().Add(10 * time.Second))
return c.ws.WriteMessage(websocket.BinaryMessage, data)
}
// Request sends a message and waits for a specific response type.
func (c *Conn) Request(mainType, subType uint16, payload interface{},
rspMain, rspSub uint16, timeout time.Duration) ([]byte, error) {
ch := make(chan []byte, 1)
key := msgKey{rspMain, rspSub}
c.wMu.Lock()
c.waiters[key] = ch
c.wMu.Unlock()
defer func() {
c.wMu.Lock()
delete(c.waiters, key)
c.wMu.Unlock()
}()
if err := c.Write(mainType, subType, payload); err != nil {
return nil, err
}
select {
case data := <-ch:
return data, nil
case <-time.After(timeout):
return nil, fmt.Errorf("request timeout %d:%d → %d:%d", mainType, subType, rspMain, rspSub)
case <-c.quit:
return nil, fmt.Errorf("connection closed")
}
}
// ReadLoop reads messages and dispatches to handlers. Blocks until error or Close().
func (c *Conn) ReadLoop() error {
for {
_, msg, err := c.ws.ReadMessage()
if err != nil {
select {
case <-c.quit:
return nil
default:
return err
}
}
if len(msg) < protocol.HeaderSize {
continue
}
h, err := protocol.DecodeHeader(msg)
if err != nil {
continue
}
key := msgKey{h.MainType, h.SubType}
// Check waiters first (synchronous request-response)
c.wMu.Lock()
if ch, ok := c.waiters[key]; ok {
delete(c.waiters, key)
c.wMu.Unlock()
select {
case ch <- msg:
default:
}
continue
}
c.wMu.Unlock()
// Dispatch to registered handler
c.hMu.RLock()
handler, ok := c.handlers[key]
c.hMu.RUnlock()
if ok {
if err := handler(msg); err != nil {
log.Printf("[signal] handler %d:%d error: %v", h.MainType, h.SubType, err)
}
}
}
}
// Close gracefully shuts down the connection.
func (c *Conn) Close() {
c.once.Do(func() {
close(c.quit)
c.ws.Close()
})
}
// IsClosed reports whether the connection has been closed.
func (c *Conn) IsClosed() bool {
select {
case <-c.quit:
return true
default:
return false
}
}
// ─── Helpers ───
// ParsePayload is a convenience to unmarshal JSON from a raw message.
func ParsePayload[T any](data []byte) (T, error) {
var v T
if len(data) <= protocol.HeaderSize {
return v, nil
}
err := json.Unmarshal(data[protocol.HeaderSize:], &v)
return v, err
}

233
pkg/tunnel/tunnel.go Normal file
View File

@@ -0,0 +1,233 @@
// Package tunnel provides P2P tunnel with mux-based port forwarding.
package tunnel
import (
"fmt"
"io"
"log"
"net"
"sync"
"sync/atomic"
"time"
"github.com/openp2p-cn/inp2p/pkg/mux"
)
// Tunnel represents a P2P tunnel that multiplexes port forwards over one connection.
type Tunnel struct {
PeerNode string
PeerIP string
LinkMode string // "udppunch", "tcppunch", "relay", "direct"
RTT time.Duration
sess *mux.Session
listeners map[int]*forwarder // srcPort → forwarder
mu sync.Mutex
closed int32
stats Stats
}
type forwarder struct {
listener net.Listener
dstHost string
dstPort int
quit chan struct{}
}
// Stats tracks tunnel traffic.
type Stats struct {
BytesSent int64
BytesReceived int64
Connections int64
ActiveStreams int32
}
// New creates a tunnel from an established P2P connection.
// isInitiator: the side that opened the P2P connection is the mux client.
func New(peerNode string, conn net.Conn, linkMode string, rtt time.Duration, isInitiator bool) *Tunnel {
return &Tunnel{
PeerNode: peerNode,
PeerIP: conn.RemoteAddr().String(),
LinkMode: linkMode,
RTT: rtt,
sess: mux.NewSession(conn, !isInitiator), // initiator=client, responder=server
listeners: make(map[int]*forwarder),
}
}
// ListenAndForward starts a local listener that forwards connections through the tunnel.
// Each accepted connection opens a mux stream to the peer, which connects to dstHost:dstPort.
func (t *Tunnel) ListenAndForward(protocol string, srcPort int, dstHost string, dstPort int) error {
addr := fmt.Sprintf(":%d", srcPort)
ln, err := net.Listen(protocol, addr)
if err != nil {
return fmt.Errorf("listen %s %s: %w", protocol, addr, err)
}
fwd := &forwarder{
listener: ln,
dstHost: dstHost,
dstPort: dstPort,
quit: make(chan struct{}),
}
t.mu.Lock()
t.listeners[srcPort] = fwd
t.mu.Unlock()
log.Printf("[tunnel] LISTEN %s:%d → %s(%s:%d) via %s", protocol, srcPort, t.PeerNode, dstHost, dstPort, t.LinkMode)
go t.acceptLoop(fwd)
return nil
}
func (t *Tunnel) acceptLoop(fwd *forwarder) {
for {
conn, err := fwd.listener.Accept()
if err != nil {
select {
case <-fwd.quit:
return
default:
if atomic.LoadInt32(&t.closed) == 1 {
return
}
log.Printf("[tunnel] accept error: %v", err)
continue
}
}
atomic.AddInt64(&t.stats.Connections, 1)
go t.handleLocalConn(conn, fwd.dstHost, fwd.dstPort)
}
}
func (t *Tunnel) handleLocalConn(local net.Conn, dstHost string, dstPort int) {
defer local.Close()
// Open a mux stream
stream, err := t.sess.Open()
if err != nil {
log.Printf("[tunnel] mux open error: %v", err)
return
}
defer stream.Close()
atomic.AddInt32(&t.stats.ActiveStreams, 1)
defer atomic.AddInt32(&t.stats.ActiveStreams, -1)
// Send destination info as first message on the stream
// Format: "host:port\n"
header := fmt.Sprintf("%s:%d\n", dstHost, dstPort)
if _, err := stream.Write([]byte(header)); err != nil {
log.Printf("[tunnel] stream write header: %v", err)
return
}
// Bidirectional copy
t.bridge(local, stream)
}
// AcceptAndConnect handles incoming mux streams (called on the responder side).
// It reads the destination header and connects to the local target.
func (t *Tunnel) AcceptAndConnect() {
for {
stream, err := t.sess.Accept()
if err != nil {
if !t.sess.IsClosed() {
log.Printf("[tunnel] mux accept error: %v", err)
}
return
}
go t.handleRemoteStream(stream)
}
}
func (t *Tunnel) handleRemoteStream(stream *mux.Stream) {
defer stream.Close()
atomic.AddInt32(&t.stats.ActiveStreams, 1)
defer atomic.AddInt32(&t.stats.ActiveStreams, -1)
// Read destination header: "host:port\n"
buf := make([]byte, 256)
n := 0
for n < len(buf) {
nn, err := stream.Read(buf[n : n+1])
if err != nil {
log.Printf("[tunnel] read dest header: %v", err)
return
}
n += nn
if buf[n-1] == '\n' {
break
}
}
dest := string(buf[:n-1]) // trim \n
// Connect to local destination
conn, err := net.DialTimeout("tcp", dest, 5*time.Second)
if err != nil {
log.Printf("[tunnel] connect to %s failed: %v", dest, err)
return
}
defer conn.Close()
log.Printf("[tunnel] stream → %s connected", dest)
// Bidirectional copy
t.bridge(conn, stream)
}
func (t *Tunnel) bridge(a, b io.ReadWriter) {
var wg sync.WaitGroup
wg.Add(2)
copyAndCount := func(dst io.Writer, src io.Reader, counter *int64) {
defer wg.Done()
n, _ := io.Copy(dst, src)
atomic.AddInt64(counter, n)
}
go copyAndCount(a, b, &t.stats.BytesReceived)
go copyAndCount(b, a, &t.stats.BytesSent)
wg.Wait()
}
// Close shuts down the tunnel and all listeners.
func (t *Tunnel) Close() {
if !atomic.CompareAndSwapInt32(&t.closed, 0, 1) {
return
}
t.mu.Lock()
for port, fwd := range t.listeners {
close(fwd.quit)
fwd.listener.Close()
log.Printf("[tunnel] stopped :%d", port)
}
t.mu.Unlock()
t.sess.Close()
log.Printf("[tunnel] closed → %s", t.PeerNode)
}
// GetStats returns traffic statistics.
func (t *Tunnel) GetStats() Stats {
return Stats{
BytesSent: atomic.LoadInt64(&t.stats.BytesSent),
BytesReceived: atomic.LoadInt64(&t.stats.BytesReceived),
Connections: atomic.LoadInt64(&t.stats.Connections),
ActiveStreams: atomic.LoadInt32(&t.stats.ActiveStreams),
}
}
// IsAlive returns true if the tunnel is open.
func (t *Tunnel) IsAlive() bool {
return atomic.LoadInt32(&t.closed) == 0 && !t.sess.IsClosed()
}
// NumStreams returns active mux streams.
func (t *Tunnel) NumStreams() int {
return t.sess.NumStreams()
}

176
pkg/tunnel/tunnel_test.go Normal file
View File

@@ -0,0 +1,176 @@
package tunnel
import (
"fmt"
"io"
"net"
"testing"
"time"
)
func TestEndToEndForward(t *testing.T) {
// 1. Start a "target" TCP server (simulates SSH on the remote side)
targetLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer targetLn.Close()
targetPort := targetLn.Addr().(*net.TCPAddr).Port
go func() {
for {
conn, err := targetLn.Accept()
if err != nil {
return
}
go func(c net.Conn) {
defer c.Close()
buf := make([]byte, 1024)
n, _ := c.Read(buf)
c.Write([]byte("ECHO:" + string(buf[:n])))
}(conn)
}
}()
// 2. Create a connected pair (simulates a P2P punch connection)
c1, c2 := net.Pipe()
// 3. Create tunnels on both sides
initiator := New("remote-node", c1, "test", 0, true)
responder := New("local-node", c2, "test", 0, false)
defer initiator.Close()
defer responder.Close()
// Responder accepts incoming mux streams and connects to local targets
go responder.AcceptAndConnect()
// 4. Initiator listens on a local port and forwards to remote target
localLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
localPort := localLn.Addr().(*net.TCPAddr).Port
localLn.Close() // free the port so tunnel can use it
err = initiator.ListenAndForward("tcp", localPort, "127.0.0.1", targetPort)
if err != nil {
t.Fatal(err)
}
time.Sleep(50 * time.Millisecond)
// 5. Connect to the tunnel's local port
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", localPort))
if err != nil {
t.Fatal(err)
}
defer conn.Close()
// 6. Send data and verify echo
conn.Write([]byte("hello-tunnel"))
conn.SetReadDeadline(time.Now().Add(3 * time.Second))
buf := make([]byte, 1024)
n, err := conn.Read(buf)
if err != nil {
t.Fatal(err)
}
got := string(buf[:n])
want := "ECHO:hello-tunnel"
if got != want {
t.Errorf("got %q, want %q", got, want)
}
}
func TestMultipleConnections(t *testing.T) {
// Target server: echoes back with a prefix
targetLn, _ := net.Listen("tcp", "127.0.0.1:0")
defer targetLn.Close()
targetPort := targetLn.Addr().(*net.TCPAddr).Port
go func() {
for {
conn, err := targetLn.Accept()
if err != nil {
return
}
go func(c net.Conn) {
defer c.Close()
io.Copy(c, c) // pure echo
}(conn)
}
}()
c1, c2 := net.Pipe()
initiator := New("peer", c1, "test", 0, true)
responder := New("me", c2, "test", 0, false)
defer initiator.Close()
defer responder.Close()
go responder.AcceptAndConnect()
localLn, _ := net.Listen("tcp", "127.0.0.1:0")
localPort := localLn.Addr().(*net.TCPAddr).Port
localLn.Close()
initiator.ListenAndForward("tcp", localPort, "127.0.0.1", targetPort)
time.Sleep(50 * time.Millisecond)
// Open 5 concurrent connections through the tunnel
const N = 5
done := make(chan bool, N)
for i := 0; i < N; i++ {
go func(idx int) {
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", localPort))
if err != nil {
t.Errorf("conn %d: dial: %v", idx, err)
done <- false
return
}
defer conn.Close()
msg := fmt.Sprintf("msg-%d", idx)
conn.Write([]byte(msg))
conn.SetReadDeadline(time.Now().Add(3 * time.Second))
buf := make([]byte, 256)
n, err := conn.Read(buf)
if err != nil || string(buf[:n]) != msg {
t.Errorf("conn %d: got %q, want %q, err=%v", idx, buf[:n], msg, err)
done <- false
return
}
done <- true
}(i)
}
for i := 0; i < N; i++ {
if ok := <-done; !ok {
t.Errorf("connection %d failed", i)
}
}
stats := initiator.GetStats()
if stats.Connections != N {
t.Errorf("connections: got %d want %d", stats.Connections, N)
}
}
func TestTunnelStats(t *testing.T) {
c1, c2 := net.Pipe()
initiator := New("peer", c1, "test", 0, true)
responder := New("me", c2, "test", 0, false)
defer initiator.Close()
defer responder.Close()
if !initiator.IsAlive() || !responder.IsAlive() {
t.Error("tunnels should be alive")
}
initiator.Close()
time.Sleep(50 * time.Millisecond)
if initiator.IsAlive() {
t.Error("initiator should be dead after close")
}
}