Skip to content

Commit

Permalink
Fix rare race condition in ZClient causing healthy connections to be …
Browse files Browse the repository at this point in the history
…discarded (#2924)

Fulfill onComplete promise before invoking final callback
  • Loading branch information
kyri-petrou committed Jun 23, 2024
1 parent 97dc4ae commit 7b5d2a1
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 16 deletions.
18 changes: 11 additions & 7 deletions zio-http/jvm/src/main/scala/zio/http/netty/AsyncBodyReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ abstract class AsyncBodyReader extends SimpleChannelInboundHandler[HttpContent](
val _ = ctx.channel().config().setAutoRead(previousAutoRead)
}

protected def onLastMessage(): Unit = ()

override def channelRead0(
ctx: ChannelHandlerContext,
msg: HttpContent,
Expand All @@ -87,6 +89,12 @@ abstract class AsyncBodyReader extends SimpleChannelInboundHandler[HttpContent](
val isLast = msg.isInstanceOf[LastHttpContent]
val content = ByteBufUtil.getBytes(msg.content())

if (isLast) {
readingDone = true
ctx.channel().pipeline().remove(this)
onLastMessage()
}

state match {
case State.Buffering =>
// `connect` method hasn't been called yet, add all incoming content to the buffer
Expand All @@ -103,13 +111,7 @@ abstract class AsyncBodyReader extends SimpleChannelInboundHandler[HttpContent](
callback(Chunk.fromArray(content), isLast)
}

if (isLast) {
readingDone = true
ctx.channel().pipeline().remove(this)
} else {
ctx.read()
}
()
if (!isLast) ctx.read(): Unit
}
}

Expand Down Expand Up @@ -137,6 +139,8 @@ abstract class AsyncBodyReader extends SimpleChannelInboundHandler[HttpContent](
}

object AsyncBodyReader {
private val FnUnit = () => ()

sealed trait State

object State {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,16 @@ final class ClientResponseStreamHandler(

private implicit val unsafe: Unsafe = Unsafe.unsafe

override def onLastMessage(): Unit =
if (keepAlive)
onComplete.unsafe.done(Exit.succeed(ChannelState.forStatus(status)))
else
onComplete.unsafe.done(Exit.succeed(ChannelState.Invalid))

override def channelRead0(ctx: ChannelHandlerContext, msg: HttpContent): Unit = {
val isLast = msg.isInstanceOf[LastHttpContent]
super.channelRead0(ctx, msg)

if (isLast) {
if (keepAlive)
onComplete.unsafe.done(Exit.succeed(ChannelState.forStatus(status)))
else {
onComplete.unsafe.done(Exit.succeed(ChannelState.Invalid))
ctx.close(): Unit
}
}
if (isLast && !keepAlive) ctx.close(): Unit
}

override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit =
Expand Down

0 comments on commit 7b5d2a1

Please sign in to comment.