-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
381 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
package tcp | ||
|
||
import ( | ||
"errors" | ||
"hash/crc64" | ||
"io" | ||
"net" | ||
"time" | ||
) | ||
|
||
// ConnMeta holds metadata about a TCP connection. | ||
type ConnMeta struct { | ||
Opened time.Time | ||
ClientAddress net.Addr | ||
ServerAddress net.Addr | ||
} | ||
|
||
// Hash returns a semi-unique number to every connection. | ||
// The implementation of Hash is not guaranteed to be stable between updates of this package. | ||
func (c ConnMeta) Hash() uint64 { | ||
// We use CRC64 as this hash does not need to be cryptographically secure, and it's easy to get an uint64 from it. | ||
hash := crc64.New(crc64.MakeTable(crc64.ISO)) | ||
_, _ = hash.Write([]byte(c.Opened.String())) | ||
_, _ = hash.Write([]byte(c.ClientAddress.String())) | ||
_, _ = hash.Write([]byte(c.ServerAddress.String())) | ||
|
||
return hash.Sum64() | ||
} | ||
|
||
// Handler is an object capable of acting when TCP messages are either sent or received. | ||
type Handler interface { | ||
// HandleUpward forwards data from the client to the server. Proxy will call HandleUpward once for every | ||
// connection, expecting it to keep consuming data until an error occurs, in which case the Proxy will close both | ||
// upstream and downstream connections. If ErrTerminate is returned, the connection is still closed but no error | ||
// message is logged. | ||
HandleUpward(client io.Reader, server io.Writer, meta ConnMeta) error | ||
// HandleDownward provides is the equivalent of HandleUpward for data sent from the server to the client. | ||
HandleDownward(server io.Reader, client io.Writer, meta ConnMeta) error | ||
} | ||
|
||
// ErrTerminate may be returned by Handler implementations that wish to willingly terminate a connection. Connection | ||
// will be closed, but no error log will be generated. | ||
var ErrTerminate = errors.New("connection terminated by proxy handler") | ||
|
||
// ForwardHandler is a handler that forwards data between client and server without taking any actions. | ||
type ForwardHandler struct{} | ||
|
||
func (ForwardHandler) HandleUpward(client io.Reader, server io.Writer, _ ConnMeta) error { | ||
_, err := io.Copy(server, client) | ||
return err | ||
} | ||
|
||
func (ForwardHandler) HandleDownward(server io.Reader, client io.Writer, _ ConnMeta) error { | ||
_, err := io.Copy(client, server) | ||
return err | ||
} | ||
|
||
// RejectHandler is a handler that closes connections immediately after being opened. | ||
type RejectHandler struct{} | ||
|
||
func (RejectHandler) HandleUpward(client io.Reader, server io.Writer, _ ConnMeta) error { | ||
return ErrTerminate | ||
} | ||
|
||
func (RejectHandler) HandleDownward(server io.Reader, client io.Writer, _ ConnMeta) error { | ||
return ErrTerminate | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
package tcp | ||
|
||
import ( | ||
"errors" | ||
"fmt" | ||
"log" | ||
"net" | ||
"time" | ||
) | ||
|
||
// Proxy implements a TCP transparent proxy between a client and a server. | ||
type Proxy struct { | ||
l net.Listener | ||
upstream net.Addr | ||
handler Handler | ||
} | ||
|
||
func NewProxy(l net.Listener, upstream net.Addr, handler Handler) *Proxy { | ||
return &Proxy{ | ||
l: l, | ||
upstream: upstream, | ||
handler: handler, | ||
} | ||
} | ||
|
||
func (p *Proxy) Start() error { | ||
for { | ||
conn, err := p.l.Accept() | ||
if err != nil { | ||
return err | ||
} | ||
|
||
go func() { | ||
err := p.handleConn(conn) | ||
// TODO: Better error handling | ||
log.Printf("handling connection: %v", err) | ||
}() | ||
} | ||
} | ||
|
||
func (p *Proxy) Stop() error { | ||
// TODO: Harvest open connections and close them. | ||
return nil | ||
} | ||
|
||
func (p *Proxy) handleConn(downstreamConn net.Conn) error { | ||
defer func() { | ||
_ = downstreamConn.Close() | ||
}() | ||
|
||
upstreamConn, err := net.Dial("tcp", p.upstream.String()) | ||
if err != nil { | ||
return fmt.Errorf("opening upstream connection: %w", err) | ||
} | ||
|
||
defer func() { | ||
_ = upstreamConn.Close() | ||
}() | ||
|
||
metadata := ConnMeta{ | ||
Opened: time.Now(), | ||
ClientAddress: downstreamConn.RemoteAddr(), | ||
ServerAddress: upstreamConn.RemoteAddr(), | ||
} | ||
|
||
errChan := make(chan error, 2) | ||
go func() { | ||
errChan <- func() error { | ||
err := p.handler.HandleUpward(downstreamConn, upstreamConn, metadata) | ||
if err != nil && !errors.Is(err, ErrTerminate) { | ||
return err | ||
} | ||
|
||
return nil | ||
}() | ||
}() | ||
go func() { | ||
errChan <- func() error { | ||
err := p.handler.HandleDownward(upstreamConn, downstreamConn, metadata) | ||
if err != nil && !errors.Is(err, ErrTerminate) { | ||
return err | ||
} | ||
|
||
return nil | ||
}() | ||
}() | ||
|
||
err = <-errChan | ||
return fmt.Errorf("forwarding data: %w", err) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,224 @@ | ||
package tcp_test | ||
|
||
import ( | ||
"bufio" | ||
"errors" | ||
"fmt" | ||
"io" | ||
"net" | ||
"testing" | ||
"time" | ||
|
||
"github.com/grafana/xk6-disruptor/pkg/agent/protocol/tcp" | ||
) | ||
|
||
const localv4 = "127.0.0.1:0" | ||
|
||
// Test_Proxy_Forwards tests the tcp.Proxy using tcp.ForwardHandler, ensuring messages are forwarded to and from the | ||
// proxy. | ||
func Test_Proxy_Forwards(t *testing.T) { | ||
t.Parallel() | ||
|
||
upstreamL, err := net.Listen("tcp", localv4) | ||
if err != nil { | ||
t.Fatalf("creating upstream listener: %v", err) | ||
} | ||
|
||
serverCh := make(chan string) | ||
serverErr := make(chan error) | ||
go func() { | ||
serverErr <- echoServer(upstreamL, serverCh) | ||
}() | ||
|
||
proxyL, err := net.Listen("tcp", localv4) | ||
if err != nil { | ||
t.Fatalf("creating proxy listener: %v", err) | ||
} | ||
|
||
proxy := tcp.NewProxy(proxyL, upstreamL.Addr(), tcp.ForwardHandler{}) | ||
go func() { | ||
err := proxy.Start() | ||
if err != nil { | ||
// t.Fatal cannot be used inside a goroutine. | ||
t.Errorf("couldn't start poxy: %v", err) | ||
} | ||
}() | ||
|
||
proxyConn, err := net.Dial("tcp", proxyL.Addr().String()) | ||
if err != nil { | ||
t.Fatalf("dialing proxy address: %v", err) | ||
} | ||
|
||
bufReader := bufio.NewReader(proxyConn) | ||
|
||
// Write a first line. | ||
_, err = fmt.Fprintln(proxyConn, "a line") | ||
if err != nil { | ||
t.Fatalf("writing to proxy conn: %v", err) | ||
} | ||
|
||
// Check the server received the line | ||
select { | ||
case <-time.After(time.Second): | ||
t.Fatalf("upstream did not receive the line before the deadline") | ||
case serverLine := <-serverCh: | ||
if serverLine != "a line\n" { | ||
t.Fatalf("upstream received unexpected data %q", serverLine) | ||
} | ||
} | ||
|
||
// Check we received the echoed data | ||
clientLine, err := bufReader.ReadString('\n') | ||
if err != nil { | ||
t.Fatalf("reading upstream response from proxyconn: %v", err) | ||
} | ||
if clientLine != "a line\n" { | ||
t.Fatalf("downstream received unexpected data %q", clientLine) | ||
} | ||
|
||
// Write a second line. | ||
_, err = fmt.Fprintln(proxyConn, "another line") | ||
if err != nil { | ||
t.Fatalf("writing to proxy conn: %v", err) | ||
} | ||
|
||
// Check the server received the line | ||
select { | ||
case <-time.After(time.Second): | ||
t.Fatalf("upstream did not receive the line before the deadline") | ||
case serverLine := <-serverCh: | ||
if serverLine != "another line\n" { | ||
t.Fatalf("upstream received unexpected data %q", serverLine) | ||
} | ||
} | ||
|
||
// Check we received the echoed data | ||
clientLine, err = bufReader.ReadString('\n') | ||
if err != nil { | ||
t.Fatalf("reading upstream response from proxyconn: %v", err) | ||
} | ||
if clientLine != "another line\n" { | ||
t.Fatalf("downstream received unexpected data %q", clientLine) | ||
} | ||
|
||
// Close the connection to the proxy. | ||
_ = proxyConn.Close() | ||
|
||
select { | ||
case <-time.After(time.Second): | ||
t.Fatalf("upstream connection was not closed") | ||
case line, ok := <-serverCh: | ||
if ok { | ||
t.Fatalf("upstream receive unexpected data: %q", line) | ||
} | ||
} | ||
|
||
select { | ||
case <-time.After(time.Second): | ||
t.Fatalf("server did not terminate") | ||
case err = <-serverErr: | ||
if err != nil { | ||
t.Fatalf("server returned an error: %v", err) | ||
} | ||
} | ||
} | ||
|
||
// Test_Proxy_Forwards tests the tcp.Proxy using tcp.RejectHandler, ensuring both client and server connections are | ||
// closed properly and cleanly when handlers return errors. | ||
func Test_Proxy_Rejects(t *testing.T) { | ||
t.Parallel() | ||
|
||
upstreamL, err := net.Listen("tcp", localv4) | ||
if err != nil { | ||
t.Fatalf("creating upstream listener: %v", err) | ||
} | ||
|
||
serverCh := make(chan string) | ||
serverErr := make(chan error) | ||
go func() { | ||
serverErr <- echoServer(upstreamL, serverCh) | ||
}() | ||
|
||
proxyL, err := net.Listen("tcp", localv4) | ||
if err != nil { | ||
t.Fatalf("creating proxy listener: %v", err) | ||
} | ||
|
||
proxy := tcp.NewProxy(proxyL, upstreamL.Addr(), tcp.RejectHandler{}) | ||
go func() { | ||
err := proxy.Start() | ||
if err != nil { | ||
// t.Fatal cannot be used inside a goroutine. | ||
t.Errorf("couldn't start poxy: %v", err) | ||
} | ||
}() | ||
|
||
proxyConn, err := net.Dial("tcp", proxyL.Addr().String()) | ||
if err != nil { | ||
t.Fatalf("dialing proxy address: %v", err) | ||
} | ||
|
||
// Attempt to write a first line. | ||
_, err = fmt.Fprintln(proxyConn, "a line") | ||
if err != nil { | ||
t.Fatalf("error writing data: %v", err) | ||
} | ||
|
||
singleByte := make([]byte, 1) | ||
_, err = proxyConn.Read(singleByte) | ||
if err == nil { | ||
t.Fatalf("expected connection to be closed by rejectHandler: %v", err) | ||
} | ||
|
||
select { | ||
case <-time.After(time.Second): | ||
t.Fatalf("upstream connection was not closed") | ||
case line, ok := <-serverCh: | ||
if ok { | ||
t.Fatalf("upstream receive unexpected data: %q", line) | ||
} | ||
} | ||
|
||
select { | ||
case <-time.After(time.Second): | ||
t.Fatalf("server did not terminate") | ||
case err = <-serverErr: | ||
if err != nil { | ||
t.Fatalf("server returned an error: %v", err) | ||
} | ||
} | ||
} | ||
|
||
// echoServer is a helper function for testing that accepts a single connection from the given listener, and pushes | ||
// each received line to lineCh. When the connection is closed, it also closes lineCh. | ||
func echoServer(l net.Listener, lineCh chan string) error { | ||
defer close(lineCh) | ||
|
||
conn, err := l.Accept() | ||
if err != nil { | ||
return fmt.Errorf("accepting conn: %w", err) | ||
} | ||
|
||
reader := bufio.NewReader(conn) | ||
for { | ||
line, err := reader.ReadString('\n') | ||
if errors.Is(err, io.EOF) { | ||
return nil | ||
} | ||
if err != nil { | ||
return fmt.Errorf("reading from conn: %w", err) | ||
} | ||
|
||
_, err = conn.Write([]byte(line)) | ||
if err != nil { | ||
return fmt.Errorf("echoing back to conn: %w", err) | ||
} | ||
|
||
select { | ||
case lineCh <- line: | ||
continue | ||
case <-time.After(time.Second): | ||
return fmt.Errorf("reader did not consume line %q", line) | ||
} | ||
} | ||
} |