From 2da4beb003329d83c74a52d66a009110ef1fd180 Mon Sep 17 00:00:00 2001 From: Andrew Haines Date: Wed, 15 May 2024 21:16:52 +0100 Subject: [PATCH] fix: handle `X-Forwarded-*` headers correctly (#4334) Signed-off-by: Andrew Haines --- runtime/context.go | 16 +++++-- runtime/context_test.go | 102 ++++++++++++++++++++++++++++++++++++---- 2 files changed, 103 insertions(+), 15 deletions(-) diff --git a/runtime/context.go b/runtime/context.go index 7c9ff657bfd..5dd4e447862 100644 --- a/runtime/context.go +++ b/runtime/context.go @@ -148,6 +148,12 @@ func annotateContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcM var pairs []string for key, vals := range req.Header { key = textproto.CanonicalMIMEHeaderKey(key) + switch key { + case xForwardedFor, xForwardedHost: + // Handled separately below + continue + } + for _, val := range vals { // For backwards-compatibility, pass through 'authorization' header with no prefix. if key == "Authorization" { @@ -181,15 +187,15 @@ func annotateContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcM pairs = append(pairs, strings.ToLower(xForwardedHost), req.Host) } + xff := req.Header.Values(xForwardedFor) if addr := req.RemoteAddr; addr != "" { if remoteIP, _, err := net.SplitHostPort(addr); err == nil { - if fwd := req.Header.Get(xForwardedFor); fwd == "" { - pairs = append(pairs, strings.ToLower(xForwardedFor), remoteIP) - } else { - pairs = append(pairs, strings.ToLower(xForwardedFor), fmt.Sprintf("%s, %s", fwd, remoteIP)) - } + xff = append(xff, remoteIP) } } + if len(xff) > 0 { + pairs = append(pairs, strings.ToLower(xForwardedFor), strings.Join(xff, ", ")) + } if timeout != 0 { ctx, _ = context.WithTimeout(ctx, timeout) diff --git a/runtime/context_test.go b/runtime/context_test.go index 84aec91426d..7af4721cc59 100644 --- a/runtime/context_test.go +++ b/runtime/context_test.go @@ -113,17 +113,20 @@ func TestAnnotateContext_ForwardGrpcBinaryMetadata(t *testing.T) { } } -func TestAnnotateContext_XForwardedFor(t *testing.T) { +func TestAnnotateContext_AddsXForwardedHeaders(t *testing.T) { ctx := context.Background() expectedRPCName := "/example.Example/Example" request, err := http.NewRequestWithContext(ctx, "GET", "http://bar.foo.example.com", nil) if err != nil { t.Fatalf("http.NewRequestWithContext(ctx, %q, %q, nil) failed with %v; want success", "GET", "http://bar.foo.example.com", err) } - request.Header.Add("X-Forwarded-For", "192.0.2.100") // client - request.RemoteAddr = "192.0.2.200:12345" // proxy + request.RemoteAddr = "192.0.2.100:12345" // client - annotated, err := runtime.AnnotateContext(ctx, runtime.NewServeMux(), request, expectedRPCName) + serveMux := runtime.NewServeMux(runtime.WithIncomingHeaderMatcher(func(key string) (string, bool) { + return key, true + })) + + annotated, err := runtime.AnnotateContext(ctx, serveMux, request, expectedRPCName) if err != nil { t.Errorf("runtime.AnnotateContext(ctx, %#v) failed with %v; want success", request, err) return @@ -135,8 +138,46 @@ func TestAnnotateContext_XForwardedFor(t *testing.T) { if got, want := md["x-forwarded-host"], []string{"bar.foo.example.com"}; !reflect.DeepEqual(got, want) { t.Errorf(`md["host"] = %v; want %v`, got, want) } + if got, want := md["x-forwarded-for"], []string{"192.0.2.100"}; !reflect.DeepEqual(got, want) { + t.Errorf(`md["x-forwarded-for"] = %v want %v`, got, want) + } + if m, ok := runtime.RPCMethod(annotated); !ok { + t.Errorf("runtime.RPCMethod(annotated) failed with no value; want %s", expectedRPCName) + } else if m != expectedRPCName { + t.Errorf("runtime.RPCMethod(annotated) failed with %s; want %s", m, expectedRPCName) + } +} + +func TestAnnotateContext_AppendsToExistingXForwardedHeaders(t *testing.T) { + ctx := context.Background() + expectedRPCName := "/example.Example/Example" + request, err := http.NewRequestWithContext(ctx, "GET", "http://bar.foo.example.com", nil) + if err != nil { + t.Fatalf("http.NewRequestWithContext(ctx, %q, %q, nil) failed with %v; want success", "GET", "http://bar.foo.example.com", err) + } + request.Header.Add("X-Forwarded-Host", "qux.example.com") + request.Header.Add("X-Forwarded-For", "192.0.2.100") // client + request.Header.Add("X-Forwarded-For", "192.0.2.101, 192.0.2.102") // intermediate proxies + request.RemoteAddr = "192.0.2.200:12345" // final proxy + + serveMux := runtime.NewServeMux(runtime.WithIncomingHeaderMatcher(func(key string) (string, bool) { + return key, true + })) + + annotated, err := runtime.AnnotateContext(ctx, serveMux, request, expectedRPCName) + if err != nil { + t.Errorf("runtime.AnnotateContext(ctx, %#v) failed with %v; want success", request, err) + return + } + md, ok := metadata.FromOutgoingContext(annotated) + if !ok || len(md) != emptyForwardMetaCount+1 { + t.Errorf("Expected %d metadata items in context; got %v", emptyForwardMetaCount+1, md) + } + if got, want := md["x-forwarded-host"], []string{"qux.example.com"}; !reflect.DeepEqual(got, want) { + t.Errorf(`md["host"] = %v; want %v`, got, want) + } // Note: it must be in order client, proxy1, proxy2 - if got, want := md["x-forwarded-for"], []string{"192.0.2.100, 192.0.2.200"}; !reflect.DeepEqual(got, want) { + if got, want := md["x-forwarded-for"], []string{"192.0.2.100, 192.0.2.101, 192.0.2.102, 192.0.2.200"}; !reflect.DeepEqual(got, want) { t.Errorf(`md["x-forwarded-for"] = %v want %v`, got, want) } if m, ok := runtime.RPCMethod(annotated); !ok { @@ -356,17 +397,20 @@ func TestAnnotateIncomingContext_ForwardGrpcBinaryMetadata(t *testing.T) { } } -func TestAnnotateIncomingContext_XForwardedFor(t *testing.T) { +func TestAnnotateIncomingContext_AddsXForwardedHeaders(t *testing.T) { ctx := context.Background() expectedRPCName := "/example.Example/Example" request, err := http.NewRequestWithContext(ctx, "GET", "http://bar.foo.example.com", nil) if err != nil { t.Fatalf("http.NewRequestWithContext(ctx, %q, %q, nil) failed with %v; want success", "GET", "http://bar.foo.example.com", err) } - request.Header.Add("X-Forwarded-For", "192.0.2.100") // client - request.RemoteAddr = "192.0.2.200:12345" // proxy + request.RemoteAddr = "192.0.2.100:12345" // client - annotated, err := runtime.AnnotateIncomingContext(ctx, runtime.NewServeMux(), request, expectedRPCName) + serveMux := runtime.NewServeMux(runtime.WithIncomingHeaderMatcher(func(key string) (string, bool) { + return key, true + })) + + annotated, err := runtime.AnnotateIncomingContext(ctx, serveMux, request, expectedRPCName) if err != nil { t.Errorf("runtime.AnnotateIncomingContext(ctx, %#v) failed with %v; want success", request, err) return @@ -378,8 +422,46 @@ func TestAnnotateIncomingContext_XForwardedFor(t *testing.T) { if got, want := md["x-forwarded-host"], []string{"bar.foo.example.com"}; !reflect.DeepEqual(got, want) { t.Errorf(`md["host"] = %v; want %v`, got, want) } + if got, want := md["x-forwarded-for"], []string{"192.0.2.100"}; !reflect.DeepEqual(got, want) { + t.Errorf(`md["x-forwarded-for"] = %v want %v`, got, want) + } + if m, ok := runtime.RPCMethod(annotated); !ok { + t.Errorf("runtime.RPCMethod(annotated) failed with no value; want %s", expectedRPCName) + } else if m != expectedRPCName { + t.Errorf("runtime.RPCMethod(annotated) failed with %s; want %s", m, expectedRPCName) + } +} + +func TestAnnotateIncomingContext_AppendsToExistingXForwardedHeaders(t *testing.T) { + ctx := context.Background() + expectedRPCName := "/example.Example/Example" + request, err := http.NewRequestWithContext(ctx, "GET", "http://bar.foo.example.com", nil) + if err != nil { + t.Fatalf("http.NewRequestWithContext(ctx, %q, %q, nil) failed with %v; want success", "GET", "http://bar.foo.example.com", err) + } + request.Header.Add("X-Forwarded-Host", "qux.example.com") + request.Header.Add("X-Forwarded-For", "192.0.2.100") // client + request.Header.Add("X-Forwarded-For", "192.0.2.101, 192.0.2.102") // intermediate proxies + request.RemoteAddr = "192.0.2.200:12345" // final proxy + + serveMux := runtime.NewServeMux(runtime.WithIncomingHeaderMatcher(func(key string) (string, bool) { + return key, true + })) + + annotated, err := runtime.AnnotateIncomingContext(ctx, serveMux, request, expectedRPCName) + if err != nil { + t.Errorf("runtime.AnnotateIncomingContext(ctx, %#v) failed with %v; want success", request, err) + return + } + md, ok := metadata.FromIncomingContext(annotated) + if !ok || len(md) != emptyForwardMetaCount+1 { + t.Errorf("Expected %d metadata items in context; got %v", emptyForwardMetaCount+1, md) + } + if got, want := md["x-forwarded-host"], []string{"qux.example.com"}; !reflect.DeepEqual(got, want) { + t.Errorf(`md["host"] = %v; want %v`, got, want) + } // Note: it must be in order client, proxy1, proxy2 - if got, want := md["x-forwarded-for"], []string{"192.0.2.100, 192.0.2.200"}; !reflect.DeepEqual(got, want) { + if got, want := md["x-forwarded-for"], []string{"192.0.2.100, 192.0.2.101, 192.0.2.102, 192.0.2.200"}; !reflect.DeepEqual(got, want) { t.Errorf(`md["x-forwarded-for"] = %v want %v`, got, want) } if m, ok := runtime.RPCMethod(annotated); !ok {