caddy-websocket-proxy/telnetproxy_test.go
2025-01-07 22:23:21 +00:00

331 lines
8.4 KiB
Go

package telnetproxy
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"sync"
"testing"
"time"
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
"github.com/gorilla/websocket"
)
// mockTelnetServer listens on a random port and echoes back what it receives, allowing testing of data flow.
type mockTelnetServer struct {
listener net.Listener
port int
close chan struct{}
data bytes.Buffer
mu sync.Mutex
t *testing.T
}
func newMockTelnetServer(t *testing.T) *mockTelnetServer {
l, err := net.Listen("tcp", "127.0.0.1:0") // Let the OS choose a port
if err != nil {
t.Fatalf("Failed to create listener: %v", err)
}
_, portStr, err := net.SplitHostPort(l.Addr().String())
if err != nil {
t.Fatalf("Failed to get port: %v", err)
}
port := 0
_, err = fmt.Sscan(portStr, &port)
if err != nil {
t.Fatalf("Failed to get port from string: %v", err)
}
s := &mockTelnetServer{
listener: l,
port: port,
close: make(chan struct{}),
t: t,
}
go s.serve()
return s
}
func (s *mockTelnetServer) serve() {
for {
select {
case <-s.close:
return
default:
conn, err := s.listener.Accept()
if err != nil {
return
}
go s.handleConnection(conn)
}
}
}
func (s *mockTelnetServer) handleConnection(conn net.Conn) {
defer conn.Close()
r := bufio.NewReader(conn)
for {
// Read until newline
data, err := r.ReadBytes('\n')
if err != nil {
if err != io.EOF {
s.t.Logf("mockTelnetServer read error: %v", err)
}
return
}
// Add to buffer
s.mu.Lock()
s.data.Write(data)
s.mu.Unlock()
// Echo back
_, err = conn.Write(data)
if err != nil {
return // End connection on write error
}
}
}
func (s *mockTelnetServer) getPort() int {
return s.port
}
func (s *mockTelnetServer) getRecievedData() []byte {
s.mu.Lock()
defer s.mu.Unlock()
return s.data.Bytes()
}
func (s *mockTelnetServer) closeServer() {
close(s.close)
s.listener.Close()
}
func (s *mockTelnetServer) resetData() {
s.mu.Lock()
defer s.mu.Unlock()
s.data.Reset()
}
func TestTelnetProxyCaddyfile(t *testing.T) {
tests := []struct {
name string
input string
expectError bool
secret string
}{
{
name: "valid config",
input: `telnet_proxy testsecret`,
secret: "testsecret",
},
{
name: "missing secret",
input: `telnet_proxy`,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
d := caddyfile.NewTestDispenser(tt.input)
tp := new(TelnetProxy)
err := tp.UnmarshalCaddyfile(d)
if tt.expectError {
if err == nil {
t.Fatal("expected error, got nil")
}
} else if err != nil {
t.Fatalf("expected no error, got: %v", err)
} else if tt.secret != tp.Secret {
t.Fatalf("expected secret %s, got %s", tt.secret, tp.Secret)
}
})
}
}
func TestTelnetProxy(t *testing.T) {
// setup mock telnet server
telnetServer := newMockTelnetServer(t)
defer telnetServer.closeServer()
// Caddy context setup
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
caddyCtx := caddy.Context{Context: ctx}
tp := TelnetProxy{
Secret: "testsecret",
}
// Create HTTP test handler
mux := http.NewServeMux()
mux.HandleFunc("/telnet", func(w http.ResponseWriter, r *http.Request) {
tp.ServeHTTP(w, r, caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
return nil
}))
})
server := httptest.NewServer(mux)
defer server.Close()
// Create websocket client
u, _ := url.Parse(server.URL)
u.Scheme = "ws"
u.Path = "/telnet"
// Test valid connection and sending data
t.Run("Valid Connection", func(t *testing.T) {
conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
if err != nil {
t.Fatalf("failed to open websocket: %v", err)
}
defer conn.Close()
// Send initial connection setup (authentication and host/port)
initialMsg := map[string]string{
"secret": "testsecret",
"host_port": fmt.Sprintf("127.0.0.1:%d", telnetServer.getPort()),
}
err = conn.WriteJSON(initialMsg)
if err != nil {
t.Fatalf("failed to write initial message: %v", err)
}
// Send test data through the websocket client.
testData := []byte("Test Data From Client\n")
err = conn.WriteMessage(websocket.TextMessage, testData)
if err != nil {
t.Fatalf("failed to write data over websocket %v", err)
}
time.Sleep(time.Second) // Wait for the proxy to process
// Get the received data from the telnet server
recievedData := telnetServer.getRecievedData()
if !bytes.Equal(recievedData, testData) {
t.Fatalf("received data does not match. expected: %s, got: %s", string(testData), string(recievedData))
}
// Get data from the telnet server back to the client
_, message, err := conn.ReadMessage()
if err != nil {
t.Fatalf("failed to read data from websocket %v", err)
}
if !bytes.Equal(message, testData) {
t.Fatalf("received data from the telnet server back to client does not match. expected: %s, got: %s", string(testData), string(message))
}
telnetServer.resetData()
})
// Test failing authentication
t.Run("Invalid Authentication", func(t *testing.T){
conn2, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
if err != nil {
t.Fatalf("failed to open websocket: %v", err)
}
defer conn2.Close()
initialMsg := map[string]string{
"secret": "wrong_secret",
"host_port": fmt.Sprintf("127.0.0.1:%d", telnetServer.getPort()),
}
err = conn2.WriteJSON(initialMsg)
if err != nil {
t.Fatalf("failed to write initial message: %v", err)
}
_, _, err = conn2.ReadMessage()
// Check if the connection was closed
if !websocket.IsCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure){
t.Fatalf("authentication failed connection was not closed. error: %v", err)
}
})
// Test an invalid host/port
t.Run("Invalid Host", func(t *testing.T) {
conn3, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
if err != nil {
t.Fatalf("failed to open websocket: %v", err)
}
defer conn3.Close()
initialMsg := map[string]string{
"secret": "testsecret",
"host_port": "not_a_valid_host",
}
err = conn3.WriteJSON(initialMsg)
if err != nil {
t.Fatalf("failed to write initial message: %v", err)
}
_, _, err = conn3.ReadMessage()
if !websocket.IsCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure){
t.Fatalf("Invalid connection not closed. Error: %v", err)
}
})
// Test Connection Timeout
t.Run("Connection Timeout", func(t *testing.T) {
tp = TelnetProxy{
Secret: "testsecret",
}
mux := http.NewServeMux()
mux.HandleFunc("/telnet", func(w http.ResponseWriter, r *http.Request) {
tp.ServeHTTP(w, r, caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
return nil
}))
})
server := httptest.NewServer(mux)
defer server.Close()
// Create websocket client
u, _ := url.Parse(server.URL)
u.Scheme = "ws"
u.Path = "/telnet"
conn4, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
if err != nil {
t.Fatalf("failed to open websocket: %v", err)
}
defer conn4.Close()
initialMsg := map[string]string{
"secret": "testsecret",
"host_port": fmt.Sprintf("127.0.0.1:%d", telnetServer.getPort()),
}
err = conn4.WriteJSON(initialMsg)
if err != nil {
t.Fatalf("failed to write initial message: %v", err)
}
// Wait longer than the timeout.
time.Sleep(130*time.Second)
_, _, err = conn4.ReadMessage()
if !websocket.IsCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure){
t.Fatalf("Timeout failed to close the connection. Error: %v", err)
}
})
}