package telnetproxy import ( "bufio" "bytes" "context" "fmt" "io" "net" "net/http" "net/http/httptest" "net/url" "strings" "sync" "testing" "time" "github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" "github.com/caddyserver/caddy/v2/modules/caddyhttp" "github.com/gorilla/websocket" ) // mockTelnetServer listens on a random port and echoes back what it receives, allowing testing of data flow. type mockTelnetServer struct { listener net.Listener port int close chan struct{} data bytes.Buffer mu sync.Mutex t *testing.T } func newMockTelnetServer(t *testing.T) *mockTelnetServer { l, err := net.Listen("tcp", "127.0.0.1:0") // Let the OS choose a port if err != nil { t.Fatalf("Failed to create listener: %v", err) } _, portStr, err := net.SplitHostPort(l.Addr().String()) if err != nil { t.Fatalf("Failed to get port: %v", err) } port := 0 _, err = fmt.Sscan(portStr, &port) if err != nil { t.Fatalf("Failed to get port from string: %v", err) } s := &mockTelnetServer{ listener: l, port: port, close: make(chan struct{}), t: t, } go s.serve() return s } func (s *mockTelnetServer) serve() { for { select { case <-s.close: return default: conn, err := s.listener.Accept() if err != nil { return } go s.handleConnection(conn) } } } func (s *mockTelnetServer) handleConnection(conn net.Conn) { defer conn.Close() r := bufio.NewReader(conn) for { // Read until newline data, err := r.ReadBytes('\n') if err != nil { if err != io.EOF { s.t.Logf("mockTelnetServer read error: %v", err) } return } // Add to buffer s.mu.Lock() s.data.Write(data) s.mu.Unlock() // Echo back _, err = conn.Write(data) if err != nil { return // End connection on write error } } } func (s *mockTelnetServer) getPort() int { return s.port } func (s *mockTelnetServer) getRecievedData() []byte { s.mu.Lock() defer s.mu.Unlock() return s.data.Bytes() } func (s *mockTelnetServer) closeServer() { close(s.close) s.listener.Close() } func (s *mockTelnetServer) resetData() { s.mu.Lock() defer s.mu.Unlock() s.data.Reset() } func TestTelnetProxyCaddyfile(t *testing.T) { tests := []struct { name string input string expectError bool secret string }{ { name: "valid config", input: `telnet_proxy testsecret`, secret: "testsecret", }, { name: "missing secret", input: `telnet_proxy`, expectError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { d := caddyfile.NewTestDispenser(tt.input) tp := new(TelnetProxy) err := tp.UnmarshalCaddyfile(d) if tt.expectError { if err == nil { t.Fatal("expected error, got nil") } } else if err != nil { t.Fatalf("expected no error, got: %v", err) } else if tt.secret != tp.Secret { t.Fatalf("expected secret %s, got %s", tt.secret, tp.Secret) } }) } } func TestTelnetProxy(t *testing.T) { // setup mock telnet server telnetServer := newMockTelnetServer(t) defer telnetServer.closeServer() // Caddy context setup ctx, cancel := context.WithCancel(context.Background()) defer cancel() caddyCtx := caddy.Context{Context: ctx} tp := TelnetProxy{ Secret: "testsecret", } // Create HTTP test handler mux := http.NewServeMux() mux.HandleFunc("/telnet", func(w http.ResponseWriter, r *http.Request) { tp.ServeHTTP(w, r, caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { return nil })) }) server := httptest.NewServer(mux) defer server.Close() // Create websocket client u, _ := url.Parse(server.URL) u.Scheme = "ws" u.Path = "/telnet" // Test valid connection and sending data t.Run("Valid Connection", func(t *testing.T) { conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) if err != nil { t.Fatalf("failed to open websocket: %v", err) } defer conn.Close() // Send initial connection setup (authentication and host/port) initialMsg := map[string]string{ "secret": "testsecret", "host_port": fmt.Sprintf("127.0.0.1:%d", telnetServer.getPort()), } err = conn.WriteJSON(initialMsg) if err != nil { t.Fatalf("failed to write initial message: %v", err) } // Send test data through the websocket client. testData := []byte("Test Data From Client\n") err = conn.WriteMessage(websocket.TextMessage, testData) if err != nil { t.Fatalf("failed to write data over websocket %v", err) } time.Sleep(time.Second) // Wait for the proxy to process // Get the received data from the telnet server recievedData := telnetServer.getRecievedData() if !bytes.Equal(recievedData, testData) { t.Fatalf("received data does not match. expected: %s, got: %s", string(testData), string(recievedData)) } // Get data from the telnet server back to the client _, message, err := conn.ReadMessage() if err != nil { t.Fatalf("failed to read data from websocket %v", err) } if !bytes.Equal(message, testData) { t.Fatalf("received data from the telnet server back to client does not match. expected: %s, got: %s", string(testData), string(message)) } telnetServer.resetData() }) // Test failing authentication t.Run("Invalid Authentication", func(t *testing.T){ conn2, _, err := websocket.DefaultDialer.Dial(u.String(), nil) if err != nil { t.Fatalf("failed to open websocket: %v", err) } defer conn2.Close() initialMsg := map[string]string{ "secret": "wrong_secret", "host_port": fmt.Sprintf("127.0.0.1:%d", telnetServer.getPort()), } err = conn2.WriteJSON(initialMsg) if err != nil { t.Fatalf("failed to write initial message: %v", err) } _, _, err = conn2.ReadMessage() // Check if the connection was closed if !websocket.IsCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure){ t.Fatalf("authentication failed connection was not closed. error: %v", err) } }) // Test an invalid host/port t.Run("Invalid Host", func(t *testing.T) { conn3, _, err := websocket.DefaultDialer.Dial(u.String(), nil) if err != nil { t.Fatalf("failed to open websocket: %v", err) } defer conn3.Close() initialMsg := map[string]string{ "secret": "testsecret", "host_port": "not_a_valid_host", } err = conn3.WriteJSON(initialMsg) if err != nil { t.Fatalf("failed to write initial message: %v", err) } _, _, err = conn3.ReadMessage() if !websocket.IsCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure){ t.Fatalf("Invalid connection not closed. Error: %v", err) } }) // Test Connection Timeout t.Run("Connection Timeout", func(t *testing.T) { tp = TelnetProxy{ Secret: "testsecret", } mux := http.NewServeMux() mux.HandleFunc("/telnet", func(w http.ResponseWriter, r *http.Request) { tp.ServeHTTP(w, r, caddyhttp.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { return nil })) }) server := httptest.NewServer(mux) defer server.Close() // Create websocket client u, _ := url.Parse(server.URL) u.Scheme = "ws" u.Path = "/telnet" conn4, _, err := websocket.DefaultDialer.Dial(u.String(), nil) if err != nil { t.Fatalf("failed to open websocket: %v", err) } defer conn4.Close() initialMsg := map[string]string{ "secret": "testsecret", "host_port": fmt.Sprintf("127.0.0.1:%d", telnetServer.getPort()), } err = conn4.WriteJSON(initialMsg) if err != nil { t.Fatalf("failed to write initial message: %v", err) } // Wait longer than the timeout. time.Sleep(130*time.Second) _, _, err = conn4.ReadMessage() if !websocket.IsCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure){ t.Fatalf("Timeout failed to close the connection. Error: %v", err) } }) }