package mux import ( "bytes" "io" "net" "sync" "testing" "time" ) // pipe creates a connected pair of net.Conn using net.Pipe. func pipe() (net.Conn, net.Conn) { return net.Pipe() } func TestSessionOpenAccept(t *testing.T) { c1, c2 := pipe() defer c1.Close() defer c2.Close() client := NewSession(c1, false) server := NewSession(c2, true) defer client.Close() defer server.Close() // Client opens a stream st1, err := client.Open() if err != nil { t.Fatal(err) } // Server accepts st2, err := server.Accept() if err != nil { t.Fatal(err) } // Verify stream IDs: client=odd, server would be even if st1.id%2 != 1 { t.Errorf("client stream ID should be odd, got %d", st1.id) } _ = st2 // server accepted stream has client's ID } func TestStreamReadWrite(t *testing.T) { c1, c2 := pipe() client := NewSession(c1, false) server := NewSession(c2, true) defer client.Close() defer server.Close() st1, _ := client.Open() st2, _ := server.Accept() msg := []byte("hello from client to server via mux") // Write from client n, err := st1.Write(msg) if err != nil || n != len(msg) { t.Fatalf("write: n=%d err=%v", n, err) } // Read on server buf := make([]byte, 1024) n, err = st2.Read(buf) if err != nil || n != len(msg) { t.Fatalf("read: n=%d err=%v", n, err) } if !bytes.Equal(buf[:n], msg) { t.Fatalf("data mismatch: got %q want %q", buf[:n], msg) } // Bidirectional: server → client reply := []byte("pong") st2.Write(reply) n, _ = st1.Read(buf) if !bytes.Equal(buf[:n], reply) { t.Fatalf("reply mismatch: got %q want %q", buf[:n], reply) } } func TestMultipleStreams(t *testing.T) { c1, c2 := pipe() client := NewSession(c1, false) server := NewSession(c2, true) defer client.Close() defer server.Close() const numStreams = 10 var wg sync.WaitGroup // Client opens N streams concurrently wg.Add(numStreams) for i := 0; i < numStreams; i++ { go func(idx int) { defer wg.Done() st, err := client.Open() if err != nil { t.Errorf("open stream %d: %v", idx, err) return } msg := []byte("stream-data") st.Write(msg) }(i) } // Server accepts N streams for i := 0; i < numStreams; i++ { st, err := server.Accept() if err != nil { t.Fatalf("accept stream %d: %v", i, err) } buf := make([]byte, 64) n, _ := st.Read(buf) if string(buf[:n]) != "stream-data" { t.Errorf("stream %d data mismatch", i) } } wg.Wait() if client.NumStreams() != numStreams { t.Errorf("client streams: got %d want %d", client.NumStreams(), numStreams) } } func TestStreamClose(t *testing.T) { c1, c2 := pipe() client := NewSession(c1, false) server := NewSession(c2, true) defer client.Close() defer server.Close() st1, _ := client.Open() st2, _ := server.Accept() // Write then close st1.Write([]byte("before-close")) st1.Close() // Server should read data then get EOF buf := make([]byte, 64) n, _ := st2.Read(buf) if string(buf[:n]) != "before-close" { t.Errorf("unexpected data: %q", buf[:n]) } // Next read should eventually get EOF (FIN received) time.Sleep(50 * time.Millisecond) _, err := st2.Read(buf) if err != io.EOF { t.Errorf("expected EOF after FIN, got %v", err) } } func TestLargePayload(t *testing.T) { c1, c2 := pipe() client := NewSession(c1, false) server := NewSession(c2, true) defer client.Close() defer server.Close() st1, _ := client.Open() st2, _ := server.Accept() // Write 200KB — larger than maxPayload (65535), should auto-split data := make([]byte, 200*1024) for i := range data { data[i] = byte(i % 256) } var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() n, err := st1.Write(data) if err != nil { t.Errorf("write large: %v", err) } if n != len(data) { t.Errorf("write large: n=%d want %d", n, len(data)) } }() // Read all on server received := make([]byte, 0, len(data)) buf := make([]byte, 32*1024) for len(received) < len(data) { n, err := st2.Read(buf) if err != nil { t.Fatalf("read at %d: %v", len(received), err) } received = append(received, buf[:n]...) } wg.Wait() if !bytes.Equal(received, data) { t.Error("large payload data mismatch") } } func TestSessionClose(t *testing.T) { c1, c2 := pipe() client := NewSession(c1, false) server := NewSession(c2, true) st1, _ := client.Open() server.Accept() // Close session client.Close() // Stream operations should fail _, err := st1.Write([]byte("x")) if err == nil { t.Error("write after session close should fail") } // Server accept should fail time.Sleep(50 * time.Millisecond) server.Close() } func TestPingPong(t *testing.T) { c1, c2 := pipe() client := NewSession(c1, false) server := NewSession(c2, true) defer client.Close() defer server.Close() // Just verify it doesn't crash — ping/pong runs in background time.Sleep(100 * time.Millisecond) if client.IsClosed() || server.IsClosed() { t.Error("sessions should still be alive") } } func BenchmarkThroughput(b *testing.B) { c1, c2 := pipe() client := NewSession(c1, false) server := NewSession(c2, true) defer client.Close() defer server.Close() st1, _ := client.Open() st2, _ := server.Accept() data := make([]byte, 4096) buf := make([]byte, 4096) b.SetBytes(int64(len(data))) b.ResetTimer() go func() { for i := 0; i < b.N; i++ { st2.Read(buf) } }() for i := 0; i < b.N; i++ { st1.Write(data) } }