From 635603e22d6c97bb14c7a4c51d0cc322bc53da92 Mon Sep 17 00:00:00 2001 From: Naftoli Gugenheim <98384+nafg@users.noreply.github.com> Date: Thu, 19 Sep 2024 01:32:19 -0400 Subject: [PATCH] Fix #3103 Only last response is generated into Endpoint code (#3151) * gen: failing test: code is only generated for last response * Fix #3103 Only last response is generated into Endpoint code --- .../zio/http/gen/openapi/EndpointGen.scala | 2 +- .../zio/http/gen/scala/CodeGenSpec.scala | 68 +++++++++++++++++++ 2 files changed, 69 insertions(+), 1 deletion(-) diff --git a/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala b/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala index 6c9209b6b0..f22afb38b8 100644 --- a/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala +++ b/zio-http-gen/src/main/scala/zio/http/gen/openapi/EndpointGen.scala @@ -460,7 +460,7 @@ final case class EndpointGen(config: Config) { val (outImports: Iterable[List[Code.Import]], outCodes: Iterable[Code.OutCode]) = // TODO: ignore default for now. Not sure how to handle it - op.responses.collect { + op.responses.toSeq.collect { case (OpenAPI.StatusOrDefault.StatusValue(status), OpenAPI.ReferenceOr.Reference(ResponseRef(key), _, _)) => val response = resolveResponseRef(openAPI, key) val (imports, code) = diff --git a/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala b/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala index df89ae8bec..e9a9e3a940 100644 --- a/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala +++ b/zio-http-gen/src/test/scala/zio/http/gen/scala/CodeGenSpec.scala @@ -3,6 +3,7 @@ package zio.http.gen.scala import java.nio.file._ import scala.annotation.nowarn +import scala.collection.immutable.ListMap import scala.jdk.CollectionConverters._ import scala.meta._ import scala.meta.parsers._ @@ -975,5 +976,72 @@ object CodeGenSpec extends ZIOSpecDefault { } } } @@ TestAspect.exceptScala3, + test("Generate all responses") { + val oapi = + OpenAPI( + openapi = "3.0.0", + info = OpenAPI.Info( + title = "XXX", + description = None, + termsOfService = None, + contact = None, + license = None, + version = "1.0.0", + ), + paths = ListMap( + OpenAPI.Path + .fromString(name = "/api/a/b") + .map { path => + path -> OpenAPI.PathItem( + ref = None, + summary = None, + description = None, + get = None, + put = None, + post = Some( + OpenAPI.Operation( + summary = None, + description = None, + externalDocs = None, + operationId = None, + requestBody = None, + responses = Map( + OpenAPI.StatusOrDefault.StatusValue(status = Status.Ok) -> + OpenAPI.ReferenceOr.Or(value = OpenAPI.Response()), + OpenAPI.StatusOrDefault.StatusValue(Status.BadRequest) -> + OpenAPI.ReferenceOr.Or(OpenAPI.Response()), + OpenAPI.StatusOrDefault.StatusValue(Status.Unauthorized) -> + OpenAPI.ReferenceOr.Or(OpenAPI.Response()), + ), + ), + ), + delete = None, + options = None, + head = None, + patch = None, + trace = None, + ) + } + .toSeq: _*, + ), + components = None, + externalDocs = None, + ) + + val maybeEndpointCode = + EndpointGen + .fromOpenAPI(oapi, Config.default) + .files + .flatMap(_.objects) + .flatMap(_.endpoints) + .collectFirst { + case (field, code) if field.name == "post" => code + } + + assertTrue( + maybeEndpointCode.is(_.some).outCodes.length == 1 && + maybeEndpointCode.is(_.some).errorsCode.length == 2, + ) + }, ) @@ java11OrNewer @@ flaky @@ blocking // Downloading scalafmt on CI is flaky }