From b9cb68e9e9a1f912446cc9053d4757abda51bd7f Mon Sep 17 00:00:00 2001 From: Vinicius Fortuna Date: Fri, 15 Mar 2024 16:52:29 -0400 Subject: [PATCH] Modularize connection handling (#173) --- service/metrics/metrics.go | 9 +++- service/tcp.go | 97 ++++++++++++++++++++++++++------------ 2 files changed, 74 insertions(+), 32 deletions(-) diff --git a/service/metrics/metrics.go b/service/metrics/metrics.go index 667b41bc..5b95785c 100644 --- a/service/metrics/metrics.go +++ b/service/metrics/metrics.go @@ -53,8 +53,13 @@ func (c *measuredConn) Write(b []byte) (int, error) { return n, err } -func (c *measuredConn) ReadFrom(r io.Reader) (int64, error) { - n, err := io.Copy(c.StreamConn, r) +func (c *measuredConn) ReadFrom(r io.Reader) (n int64, err error) { + if rf, ok := c.StreamConn.(io.ReaderFrom); ok { + // Prefer ReadFrom if we are calling ReadFrom. Otherwise io.Copy will try WriteTo first. + n, err = rf.ReadFrom(r) + } else { + n, err = io.Copy(c.StreamConn, r) + } *c.writeCount += n return n, err } diff --git a/service/tcp.go b/service/tcp.go index c7010828..59186bfd 100644 --- a/service/tcp.go +++ b/service/tcp.go @@ -239,25 +239,22 @@ func (h *tcpHandler) Handle(ctx context.Context, clientConn transport.StreamConn logger.Debugf("Done with status %v, duration %v", status, connDuration) } -func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, clientConn transport.StreamConn, proxyMetrics *metrics.ProxyMetrics) (string, *onet.ConnectionError) { - // Set a deadline to receive the address to the target. - clientConn.SetReadDeadline(time.Now().Add(h.readTimeout)) - - // 1. Find the cipher and acess key id. +func (h *tcpHandler) authenticate(clientConn transport.StreamConn, proxyMetrics *metrics.ProxyMetrics) (string, transport.StreamConn, *onet.ConnectionError) { + // TODO(fortuna): Offer alternative transports. + // Find the cipher and acess key id. cipherEntry, clientReader, clientSalt, timeToCipher, keyErr := findAccessKey(clientConn, remoteIP(clientConn), h.ciphers) h.m.AddTCPCipherSearch(keyErr == nil, timeToCipher) if keyErr != nil { logger.Debugf("Failed to find a valid cipher after reading %v bytes: %v", proxyMetrics.ClientProxy, keyErr) const status = "ERR_CIPHER" - h.absorbProbe(listenerPort, clientConn, status, proxyMetrics) - return "", onet.NewConnectionError(status, "Failed to find a valid cipher", keyErr) + return "", nil, onet.NewConnectionError(status, "Failed to find a valid cipher", keyErr) } var id string if cipherEntry != nil { id = cipherEntry.ID } - // 2. Check if the connection is a replay. + // Check if the connection is a replay. isServerSalt := cipherEntry.SaltGenerator.IsServerSalt(clientSalt) // Only check the cache if findAccessKey succeeded and the salt is unrecognized. if isServerSalt || !h.replayCache.Add(cipherEntry.ID, clientSalt) { @@ -267,38 +264,39 @@ func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, cli } else { status = "ERR_REPLAY_CLIENT" } - h.absorbProbe(listenerPort, clientConn, status, proxyMetrics) logger.Debugf(status+": %v sent %d bytes", clientConn.RemoteAddr(), proxyMetrics.ClientProxy) - return id, onet.NewConnectionError(status, "Replay detected", nil) + return id, nil, onet.NewConnectionError(status, "Replay detected", nil) } - - // 3. Read target address and dial it. ssr := shadowsocks.NewReader(clientReader, cipherEntry.CryptoKey) - tgtAddr, err := socks.ReadAddr(ssr) + ssw := shadowsocks.NewWriter(clientConn, cipherEntry.CryptoKey) + ssw.SetSaltGenerator(cipherEntry.SaltGenerator) + return id, transport.WrapConn(clientConn, ssr, ssw), nil +} - // Clear the deadline for the target address - clientConn.SetReadDeadline(time.Time{}) +func getProxyRequest(clientConn transport.StreamConn) (string, error) { + // TODO(fortuna): Use Shadowsocks proxy, HTTP CONNECT or SOCKS5 based on first byte: + // case 1, 3 or 4: Shadowsocks (address type) + // case 5: SOCKS5 (protocol version) + // case "C": HTTP CONNECT (first char of method) + tgtAddr, err := socks.ReadAddr(clientConn) if err != nil { - // Drain to prevent a close on cipher error. - io.Copy(io.Discard, clientConn) - return id, onet.NewConnectionError("ERR_READ_ADDRESS", "Failed to get target address", err) + return "", err } - tgtConn, dialErr := h.dialer.DialStream(ctx, tgtAddr.String()) + return tgtAddr.String(), nil +} + +func proxyConnection(ctx context.Context, dialer transport.StreamDialer, tgtAddr string, clientConn transport.StreamConn) *onet.ConnectionError { + tgtConn, dialErr := dialer.DialStream(ctx, tgtAddr) if dialErr != nil { // We don't drain so dial errors and invalid addresses are communicated quickly. - return id, ensureConnectionError(dialErr, "ERR_CONNECT", "Failed to connect to target") + return ensureConnectionError(dialErr, "ERR_CONNECT", "Failed to connect to target") } - tgtConn = metrics.MeasureConn(tgtConn, &proxyMetrics.ProxyTarget, &proxyMetrics.TargetProxy) defer tgtConn.Close() - - // 4. Bridge the client and target connections logger.Debugf("proxy %s <-> %s", clientConn.RemoteAddr().String(), tgtConn.RemoteAddr().String()) - ssw := shadowsocks.NewWriter(clientConn, cipherEntry.CryptoKey) - ssw.SetSaltGenerator(cipherEntry.SaltGenerator) fromClientErrCh := make(chan error) go func() { - _, fromClientErr := ssr.WriteTo(tgtConn) + _, fromClientErr := io.Copy(tgtConn, clientConn) if fromClientErr != nil { // Drain to prevent a close in the case of a cipher error. io.Copy(io.Discard, clientConn) @@ -310,19 +308,58 @@ func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, cli tgtConn.CloseWrite() fromClientErrCh <- fromClientErr }() - _, fromTargetErr := ssw.ReadFrom(tgtConn) + _, fromTargetErr := io.Copy(clientConn, tgtConn) // Send FIN to client. clientConn.CloseWrite() tgtConn.CloseRead() fromClientErr := <-fromClientErrCh if fromClientErr != nil { - return id, onet.NewConnectionError("ERR_RELAY_CLIENT", "Failed to relay traffic from client", fromClientErr) + return onet.NewConnectionError("ERR_RELAY_CLIENT", "Failed to relay traffic from client", fromClientErr) } if fromTargetErr != nil { - return id, onet.NewConnectionError("ERR_RELAY_TARGET", "Failed to relay traffic from target", fromTargetErr) + return onet.NewConnectionError("ERR_RELAY_TARGET", "Failed to relay traffic from target", fromTargetErr) } - return id, nil + return nil +} + +func (h *tcpHandler) handleConnection(ctx context.Context, listenerPort int, outerConn transport.StreamConn, proxyMetrics *metrics.ProxyMetrics) (string, *onet.ConnectionError) { + // Set a deadline to receive the address to the target. + readDeadline := time.Now().Add(h.readTimeout) + if deadline, ok := ctx.Deadline(); ok { + outerConn.SetDeadline(deadline) + if deadline.Before(readDeadline) { + readDeadline = deadline + } + } + outerConn.SetReadDeadline(readDeadline) + + id, innerConn, authErr := h.authenticate(outerConn, proxyMetrics) + if authErr != nil { + // Drain to protect against probing attacks. + h.absorbProbe(listenerPort, outerConn, authErr.Status, proxyMetrics) + return id, authErr + } + + // Read target address and dial it. + tgtAddr, err := getProxyRequest(innerConn) + // Clear the deadline for the target address + outerConn.SetReadDeadline(time.Time{}) + if err != nil { + // Drain to prevent a close on cipher error. + io.Copy(io.Discard, outerConn) + return id, onet.NewConnectionError("ERR_READ_ADDRESS", "Failed to get target address", err) + } + + dialer := transport.FuncStreamDialer(func(ctx context.Context, addr string) (transport.StreamConn, error) { + tgtConn, err := h.dialer.DialStream(ctx, tgtAddr) + if err != nil { + return nil, err + } + tgtConn = metrics.MeasureConn(tgtConn, &proxyMetrics.ProxyTarget, &proxyMetrics.TargetProxy) + return tgtConn, nil + }) + return id, proxyConnection(ctx, dialer, tgtAddr, innerConn) } // Keep the connection open until we hit the authentication deadline to protect against probing attacks