diff --git a/runtime/handler_test.go b/runtime/handler_test.go index 4ef78e7f0ef..34d70539f6e 100644 --- a/runtime/handler_test.go +++ b/runtime/handler_test.go @@ -178,36 +178,68 @@ func (c *CustomMarshaler) NewDecoder(r io.Reader) runtime.Decoder { return c func (c *CustomMarshaler) NewEncoder(w io.Writer) runtime.Encoder { return c.m.NewEncoder(w) } func (c *CustomMarshaler) ContentType(v interface{}) string { return "Custom-Content-Type" } +// marshalerStreamContentType implements Marshaler, but with the addition of a custom StreamContentType. +type marshalerStreamContentType struct { + runtime.Marshaler + CustomStreamContentType string +} + +func (m marshalerStreamContentType) StreamContentType(interface{}) string { + return m.CustomStreamContentType +} + func TestForwardResponseStreamCustomMarshaler(t *testing.T) { type msg struct { pb proto.Message err error } + marshaler := &CustomMarshaler{&runtime.JSONPb{}} + tests := []struct { - name string - msgs []msg - statusCode int + name string + marshaler runtime.Marshaler + msgs []msg + statusCode int + wantContentType string }{{ - name: "encoding", + name: "encoding", + marshaler: marshaler, msgs: []msg{ {&pb.SimpleMessage{Id: "One"}, nil}, {&pb.SimpleMessage{Id: "Two"}, nil}, }, - statusCode: http.StatusOK, + statusCode: http.StatusOK, + wantContentType: "Custom-Content-Type", }, { name: "empty", + marshaler: marshaler, statusCode: http.StatusOK, }, { - name: "error", - msgs: []msg{{nil, status.Errorf(codes.OutOfRange, "400")}}, - statusCode: http.StatusBadRequest, + name: "error", + marshaler: marshaler, + msgs: []msg{{nil, status.Errorf(codes.OutOfRange, "400")}}, + statusCode: http.StatusBadRequest, + wantContentType: "Custom-Content-Type", }, { - name: "stream_error", + name: "stream_error", + marshaler: marshaler, msgs: []msg{ {&pb.SimpleMessage{Id: "One"}, nil}, {nil, status.Errorf(codes.OutOfRange, "400")}, }, - statusCode: http.StatusOK, + statusCode: http.StatusOK, + wantContentType: "Custom-Content-Type", + }, { + name: "stream_content_type", + marshaler: marshalerStreamContentType{ + Marshaler: marshaler, + CustomStreamContentType: "Stream-Content-Type", + }, + msgs: []msg{ + {&pb.SimpleMessage{Id: "One"}, nil}, + }, + statusCode: http.StatusOK, + wantContentType: "Stream-Content-Type", }} newTestRecv := func(t *testing.T, msgs []msg) func() (proto.Message, error) { @@ -224,14 +256,13 @@ func TestForwardResponseStreamCustomMarshaler(t *testing.T) { } } ctx := runtime.NewServerMetadataContext(context.Background(), runtime.ServerMetadata{}) - marshaler := &CustomMarshaler{&runtime.JSONPb{}} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { recv := newTestRecv(t, tt.msgs) req := httptest.NewRequest("GET", "http://example.com/foo", nil) resp := httptest.NewRecorder() - runtime.ForwardResponseStream(ctx, runtime.NewServeMux(), marshaler, resp, req, recv) + runtime.ForwardResponseStream(ctx, runtime.NewServeMux(), tt.marshaler, resp, req, recv) w := resp.Result() if w.StatusCode != tt.statusCode { @@ -245,8 +276,8 @@ func TestForwardResponseStreamCustomMarshaler(t *testing.T) { t.Errorf("Failed to read response body with %v", err) } w.Body.Close() - if len(body) > 0 && w.Header.Get("Content-Type") != "Custom-Content-Type" { - t.Errorf("Content-Type %s want Custom-Content-Type", w.Header.Get("Content-Type")) + if w.Header.Get("Content-Type") != tt.wantContentType { + t.Errorf("Content-Type %q want %q", w.Header.Get("Content-Type"), tt.wantContentType) } var want []byte @@ -254,7 +285,7 @@ func TestForwardResponseStreamCustomMarshaler(t *testing.T) { if msg.err != nil { t.Skip("checking error encodings") } - b, err := marshaler.Marshal(map[string]proto.Message{"result": msg.pb}) + b, err := tt.marshaler.Marshal(map[string]proto.Message{"result": msg.pb}) if err != nil { t.Errorf("marshaler.Marshal() failed %v", err) }