Skip to content

Commit

Permalink
poc: tcp proxy
Browse files Browse the repository at this point in the history
  • Loading branch information
nadiamoe committed Sep 7, 2023
1 parent f120430 commit 252baf5
Show file tree
Hide file tree
Showing 3 changed files with 381 additions and 0 deletions.
67 changes: 67 additions & 0 deletions pkg/agent/protocol/tcp/handler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package tcp

import (
"errors"
"hash/crc64"
"io"
"net"
"time"
)

// ConnMeta holds metadata about a TCP connection.
type ConnMeta struct {
Opened time.Time
ClientAddress net.Addr
ServerAddress net.Addr
}

// Hash returns a semi-unique number to every connection.
// The implementation of Hash is not guaranteed to be stable between updates of this package.
func (c ConnMeta) Hash() uint64 {
// We use CRC64 as this hash does not need to be cryptographically secure, and it's easy to get an uint64 from it.
hash := crc64.New(crc64.MakeTable(crc64.ISO))
_, _ = hash.Write([]byte(c.Opened.String()))
_, _ = hash.Write([]byte(c.ClientAddress.String()))
_, _ = hash.Write([]byte(c.ServerAddress.String()))

return hash.Sum64()
}

// Handler is an object capable of acting when TCP messages are either sent or received.
type Handler interface {
// HandleUpward forwards data from the client to the server. Proxy will call HandleUpward once for every
// connection, expecting it to keep consuming data until an error occurs, in which case the Proxy will close both
// upstream and downstream connections. If ErrTerminate is returned, the connection is still closed but no error
// message is logged.
HandleUpward(client io.Reader, server io.Writer, meta ConnMeta) error
// HandleDownward provides is the equivalent of HandleUpward for data sent from the server to the client.
HandleDownward(server io.Reader, client io.Writer, meta ConnMeta) error
}

// ErrTerminate may be returned by Handler implementations that wish to willingly terminate a connection. Connection
// will be closed, but no error log will be generated.
var ErrTerminate = errors.New("connection terminated by proxy handler")

// ForwardHandler is a handler that forwards data between client and server without taking any actions.
type ForwardHandler struct{}

func (ForwardHandler) HandleUpward(client io.Reader, server io.Writer, _ ConnMeta) error {
_, err := io.Copy(server, client)
return err
}

func (ForwardHandler) HandleDownward(server io.Reader, client io.Writer, _ ConnMeta) error {
_, err := io.Copy(client, server)
return err
}

// RejectHandler is a handler that closes connections immediately after being opened.
type RejectHandler struct{}

func (RejectHandler) HandleUpward(client io.Reader, server io.Writer, _ ConnMeta) error {
return ErrTerminate
}

func (RejectHandler) HandleDownward(server io.Reader, client io.Writer, _ ConnMeta) error {
return ErrTerminate
}
90 changes: 90 additions & 0 deletions pkg/agent/protocol/tcp/proxy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package tcp

import (
"errors"
"fmt"
"log"
"net"
"time"
)

// Proxy implements a TCP transparent proxy between a client and a server.
type Proxy struct {
l net.Listener
upstream net.Addr
handler Handler
}

func NewProxy(l net.Listener, upstream net.Addr, handler Handler) *Proxy {
return &Proxy{
l: l,
upstream: upstream,
handler: handler,
}
}

func (p *Proxy) Start() error {
for {
conn, err := p.l.Accept()
if err != nil {
return err
}

go func() {
err := p.handleConn(conn)
// TODO: Better error handling
log.Printf("handling connection: %v", err)
}()
}
}

func (p *Proxy) Stop() error {
// TODO: Harvest open connections and close them.
return nil
}

func (p *Proxy) handleConn(downstreamConn net.Conn) error {
defer func() {
_ = downstreamConn.Close()
}()

upstreamConn, err := net.Dial("tcp", p.upstream.String())
if err != nil {
return fmt.Errorf("opening upstream connection: %w", err)
}

defer func() {
_ = upstreamConn.Close()
}()

metadata := ConnMeta{
Opened: time.Now(),
ClientAddress: downstreamConn.RemoteAddr(),
ServerAddress: upstreamConn.RemoteAddr(),
}

errChan := make(chan error, 2)
go func() {
errChan <- func() error {
err := p.handler.HandleUpward(downstreamConn, upstreamConn, metadata)
if err != nil && !errors.Is(err, ErrTerminate) {
return err
}

return nil
}()
}()
go func() {
errChan <- func() error {
err := p.handler.HandleDownward(upstreamConn, downstreamConn, metadata)
if err != nil && !errors.Is(err, ErrTerminate) {
return err
}

return nil
}()
}()

err = <-errChan
return fmt.Errorf("forwarding data: %w", err)
}
224 changes: 224 additions & 0 deletions pkg/agent/protocol/tcp/proxy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
package tcp_test

import (
"bufio"
"errors"
"fmt"
"io"
"net"
"testing"
"time"

"github.com/grafana/xk6-disruptor/pkg/agent/protocol/tcp"
)

const localv4 = "127.0.0.1:0"

// Test_Proxy_Forwards tests the tcp.Proxy using tcp.ForwardHandler, ensuring messages are forwarded to and from the
// proxy.
func Test_Proxy_Forwards(t *testing.T) {
t.Parallel()

upstreamL, err := net.Listen("tcp", localv4)
if err != nil {
t.Fatalf("creating upstream listener: %v", err)
}

serverCh := make(chan string)
serverErr := make(chan error)
go func() {
serverErr <- echoServer(upstreamL, serverCh)
}()

proxyL, err := net.Listen("tcp", localv4)
if err != nil {
t.Fatalf("creating proxy listener: %v", err)
}

proxy := tcp.NewProxy(proxyL, upstreamL.Addr(), tcp.ForwardHandler{})
go func() {
err := proxy.Start()
if err != nil {
// t.Fatal cannot be used inside a goroutine.
t.Errorf("couldn't start poxy: %v", err)
}
}()

proxyConn, err := net.Dial("tcp", proxyL.Addr().String())
if err != nil {
t.Fatalf("dialing proxy address: %v", err)
}

bufReader := bufio.NewReader(proxyConn)

// Write a first line.
_, err = fmt.Fprintln(proxyConn, "a line")
if err != nil {
t.Fatalf("writing to proxy conn: %v", err)
}

// Check the server received the line
select {
case <-time.After(time.Second):
t.Fatalf("upstream did not receive the line before the deadline")
case serverLine := <-serverCh:
if serverLine != "a line\n" {
t.Fatalf("upstream received unexpected data %q", serverLine)
}
}

// Check we received the echoed data
clientLine, err := bufReader.ReadString('\n')
if err != nil {
t.Fatalf("reading upstream response from proxyconn: %v", err)
}
if clientLine != "a line\n" {
t.Fatalf("downstream received unexpected data %q", clientLine)
}

// Write a second line.
_, err = fmt.Fprintln(proxyConn, "another line")
if err != nil {
t.Fatalf("writing to proxy conn: %v", err)
}

// Check the server received the line
select {
case <-time.After(time.Second):
t.Fatalf("upstream did not receive the line before the deadline")
case serverLine := <-serverCh:
if serverLine != "another line\n" {
t.Fatalf("upstream received unexpected data %q", serverLine)
}
}

// Check we received the echoed data
clientLine, err = bufReader.ReadString('\n')
if err != nil {
t.Fatalf("reading upstream response from proxyconn: %v", err)
}
if clientLine != "another line\n" {
t.Fatalf("downstream received unexpected data %q", clientLine)
}

// Close the connection to the proxy.
_ = proxyConn.Close()

select {
case <-time.After(time.Second):
t.Fatalf("upstream connection was not closed")
case line, ok := <-serverCh:
if ok {
t.Fatalf("upstream receive unexpected data: %q", line)
}
}

select {
case <-time.After(time.Second):
t.Fatalf("server did not terminate")
case err = <-serverErr:
if err != nil {
t.Fatalf("server returned an error: %v", err)
}
}
}

// Test_Proxy_Forwards tests the tcp.Proxy using tcp.RejectHandler, ensuring both client and server connections are
// closed properly and cleanly when handlers return errors.
func Test_Proxy_Rejects(t *testing.T) {
t.Parallel()

upstreamL, err := net.Listen("tcp", localv4)
if err != nil {
t.Fatalf("creating upstream listener: %v", err)
}

serverCh := make(chan string)
serverErr := make(chan error)
go func() {
serverErr <- echoServer(upstreamL, serverCh)
}()

proxyL, err := net.Listen("tcp", localv4)
if err != nil {
t.Fatalf("creating proxy listener: %v", err)
}

proxy := tcp.NewProxy(proxyL, upstreamL.Addr(), tcp.RejectHandler{})
go func() {
err := proxy.Start()
if err != nil {
// t.Fatal cannot be used inside a goroutine.
t.Errorf("couldn't start poxy: %v", err)
}
}()

proxyConn, err := net.Dial("tcp", proxyL.Addr().String())
if err != nil {
t.Fatalf("dialing proxy address: %v", err)
}

// Attempt to write a first line.
_, err = fmt.Fprintln(proxyConn, "a line")
if err != nil {
t.Fatalf("error writing data: %v", err)
}

singleByte := make([]byte, 1)
_, err = proxyConn.Read(singleByte)
if err == nil {
t.Fatalf("expected connection to be closed by rejectHandler: %v", err)
}

select {
case <-time.After(time.Second):
t.Fatalf("upstream connection was not closed")
case line, ok := <-serverCh:
if ok {
t.Fatalf("upstream receive unexpected data: %q", line)
}
}

select {
case <-time.After(time.Second):
t.Fatalf("server did not terminate")
case err = <-serverErr:
if err != nil {
t.Fatalf("server returned an error: %v", err)
}
}
}

// echoServer is a helper function for testing that accepts a single connection from the given listener, and pushes
// each received line to lineCh. When the connection is closed, it also closes lineCh.
func echoServer(l net.Listener, lineCh chan string) error {
defer close(lineCh)

conn, err := l.Accept()
if err != nil {
return fmt.Errorf("accepting conn: %w", err)
}

reader := bufio.NewReader(conn)
for {
line, err := reader.ReadString('\n')
if errors.Is(err, io.EOF) {
return nil
}
if err != nil {
return fmt.Errorf("reading from conn: %w", err)
}

_, err = conn.Write([]byte(line))
if err != nil {
return fmt.Errorf("echoing back to conn: %w", err)
}

select {
case lineCh <- line:
continue
case <-time.After(time.Second):
return fmt.Errorf("reader did not consume line %q", line)
}
}
}

0 comments on commit 252baf5

Please sign in to comment.