Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent OOM when receiving large request streams #3174

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 27 additions & 17 deletions zio-http/jvm/src/main/scala/zio/http/netty/AsyncBodyReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -95,23 +95,33 @@ private[netty] abstract class AsyncBodyReader extends SimpleChannelInboundHandle
onLastMessage()
}

state match {
case State.Buffering =>
// `connect` method hasn't been called yet, add all incoming content to the buffer
buffer0.addAll(content)
case State.Direct(callback) if isLast && buffer0.knownSize == 0 =>
// Buffer is empty, we can just use the array directly
callback(Chunk.fromArray(content), isLast = true)
case State.Direct(callback: UnsafeAsync.Aggregating) =>
// We're aggregating the full response, only call the callback on the last message
buffer0.addAll(content)
if (isLast) callback(result(buffer0), isLast = true)
case State.Direct(callback) =>
// We're streaming, emit chunks as they come
callback(Chunk.fromArray(content), isLast)
}

if (!isLast) ctx.read(): Unit
val readMore =
state match {
case State.Buffering =>
// `connect` method hasn't been called yet, add all incoming content to the buffer
buffer0.addAll(content)
true
case State.Direct(callback) if isLast && buffer0.knownSize == 0 =>
// Buffer is empty, we can just use the array directly
callback(Chunk.fromArray(content), isLast = true)
false
case State.Direct(callback: UnsafeAsync.Aggregating) =>
// We're aggregating the full response, only call the callback on the last message
buffer0.addAll(content)
if (isLast) callback(result(buffer0), isLast = true)
!isLast
case State.Direct(callback: UnsafeAsync.Streaming) =>
// We're streaming, emit chunks as they come
callback(Chunk.fromArray(content), isLast)
// ctx.read will be called when the chunk is consumed
false
case State.Direct(callback) =>
// We're streaming, emit chunks as they come
callback(Chunk.fromArray(content), isLast)
!isLast
}

if (readMore) ctx.read(): Unit
}
}

Expand Down
27 changes: 18 additions & 9 deletions zio-http/jvm/src/main/scala/zio/http/netty/NettyBody.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,13 @@ object NettyBody extends BodyEncoding {
unsafeAsync: UnsafeAsync => Unit,
knownContentLength: Option[Long],
contentTypeHeader: Option[Header.ContentType] = None,
readMore: () => Unit = () => (),
): Body = {
AsyncBody(
unsafeAsync,
knownContentLength,
contentTypeHeader.map(Body.ContentType.fromHeader),
readMore,
)
}

Expand Down Expand Up @@ -92,6 +94,7 @@ object NettyBody extends BodyEncoding {
unsafeAsync: UnsafeAsync => Unit,
knownContentLength: Option[Long],
override val contentType: Option[Body.ContentType] = None,
nettyRead: () => Unit,
) extends Body {

override def asArray(implicit trace: Trace): Task[Array[Byte]] = asChunk.map {
Expand All @@ -110,12 +113,14 @@ object NettyBody extends BodyEncoding {
}

override def asStream(implicit trace: Trace): ZStream[Any, Throwable, Byte] = {
asyncUnboundedStream[Any, Throwable, Byte](emit =>
try {
unsafeAsync(new UnsafeAsync.Streaming(emit))
} catch {
case e: Throwable => emit(ZIO.fail(Option(e)))
},
asyncUnboundedStream[Any, Throwable, Byte](
emit =>
try {
unsafeAsync(new UnsafeAsync.Streaming(emit))
} catch {
case e: Throwable => emit(ZIO.fail(Option(e)))
},
ZIO.succeed(nettyRead()),
)
}

Expand All @@ -137,16 +142,20 @@ object NettyBody extends BodyEncoding {
}

/**
* Code ported from zio.stream to use an unbounded queue
* Code ported from zio.stream to use an unbounded queue On top of that the
* nettyRead() function is added. It is used to call netty ctx.read() when the
* queue is empty
*/
private def asyncUnboundedStream[R, E, A](
register: ZStream.Emit[R, E, A, Unit] => Unit,
nettyRead: UIO[Unit],
)(implicit trace: Trace): ZStream[R, E, A] =
ZStream.unwrapScoped[R](for {
queue <- ZIO.acquireRelease(Queue.unbounded[Take[E, A]])(_.shutdown)
runtime <- ZIO.runtime[R]
} yield {
val rtm = runtime.unsafe
val maybeRead = ZChannel.fromZIO(nettyRead.whenZIODiscard(queue.isEmpty))
val rtm = runtime.unsafe
register { k =>
try {
rtm
Expand All @@ -166,7 +175,7 @@ object NettyBody extends BodyEncoding {
maybeError =>
ZChannel.fromZIO(queue.shutdown) *>
maybeError.fold[ZChannel[Any, Any, Any, Any, E, Chunk[A], Unit]](ZChannel.unit)(ZChannel.fail(_)),
a => ZChannel.write(a) *> loop,
a => ZChannel.write(a) *> maybeRead *> loop,
),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ private[netty] object NettyBodyWriter {
val stream = ZStream.fromFile(body.file)
val s = StreamBody(stream, None, contentType = body.contentType)
NettyBodyWriter.writeAndFlush(s, None, ctx)
case AsyncBody(async, _, _) =>
case AsyncBody(async, _, _, _) =>
async(
new UnsafeAsync {
override def apply(message: Chunk[Byte], isLast: Boolean): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ private[netty] object NettyResponse {
callback => responseHandler.connect(callback),
knownContentLength,
contentType,
() => ctx.read(): Unit,
)
Response(status, headers, data)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,12 @@ private[zio] final case class ServerInboundHandler(
case nettyReq: HttpRequest =>
val knownContentLength = headers.get(Header.ContentLength).map(_.length)
val handler = addAsyncBodyHandler(ctx)
val body = NettyBody.fromAsync(async => handler.connect(async), knownContentLength, contentTypeHeader)
val body = NettyBody.fromAsync(
async => handler.connect(async),
knownContentLength,
contentTypeHeader,
() => ctx.read(): Unit,
)

Request(
body = body,
Expand Down
Loading