-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
AWS stream bytes decoder, event parser, and frame decoder
- Loading branch information
1 parent
b67242b
commit 38cbced
Showing
3 changed files
with
101 additions
and
0 deletions.
There are no files selected for viewing
25 changes: 25 additions & 0 deletions
25
...ain/scala/io/cequence/openaiscala/anthropic/service/impl/AwsEventStreamBytesDecoder.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
package io.cequence.openaiscala.anthropic.service.impl | ||
|
||
import akka.NotUsed | ||
import akka.stream.scaladsl.Flow | ||
|
||
import java.util.Base64 | ||
import play.api.libs.json.{JsString, JsValue, Json} | ||
|
||
object AwsEventStreamBytesDecoder { | ||
def flow: Flow[JsValue, JsValue, NotUsed] = Flow[JsValue].map { eventJson => | ||
// eventJson might look like: | ||
// { ":message-type":"event", ":event-type":"...", "bytes":"base64string" } | ||
|
||
val base64Str = (eventJson \ "bytes").asOpt[String] | ||
base64Str match { | ||
case Some(encoded) => | ||
val decoded = Base64.getDecoder.decode(encoded) | ||
Json.parse(decoded) | ||
case None => | ||
// If there's no "bytes" field, return the original JSON (or handle differently) | ||
eventJson | ||
} | ||
} | ||
} | ||
|
20 changes: 20 additions & 0 deletions
20
...main/scala/io/cequence/openaiscala/anthropic/service/impl/AwsEventStreamEventParser.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
package io.cequence.openaiscala.anthropic.service.impl | ||
|
||
import akka.NotUsed | ||
import play.api.libs.json.{JsValue, Json} | ||
import akka.stream._ | ||
import akka.stream.scaladsl.Flow | ||
import akka.util.ByteString | ||
|
||
object AwsEventStreamEventParser { | ||
def flow: Flow[ByteString, Option[JsValue], NotUsed] = Flow[ByteString].map { frame => | ||
val rawString = new String(frame.toArray) | ||
|
||
if (rawString.contains("message-type")) { | ||
val jsonString = rawString.dropWhile(_ != '{').takeWhile(_ != '}') + "}" | ||
Some(Json.parse(jsonString)) | ||
} else | ||
None | ||
} | ||
} | ||
|
56 changes: 56 additions & 0 deletions
56
...ain/scala/io/cequence/openaiscala/anthropic/service/impl/AwsEventStreamFrameDecoder.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
package io.cequence.openaiscala.anthropic.service.impl | ||
|
||
import akka.stream._ | ||
import akka.stream.stage._ | ||
import akka.util.ByteString | ||
|
||
class AwsEventStreamFrameDecoder extends GraphStage[FlowShape[ByteString, ByteString]] { | ||
val in = Inlet[ByteString]("AwsEventStreamFrameDecoder.in") | ||
val out = Outlet[ByteString]("AwsEventStreamFrameDecoder.out") | ||
override val shape = FlowShape(in, out) | ||
|
||
private implicit val order = java.nio.ByteOrder.BIG_ENDIAN | ||
|
||
override def createLogic(attrs: Attributes): GraphStageLogic = new GraphStageLogic(shape) { | ||
var buffer = ByteString.empty | ||
|
||
setHandler(in, new InHandler { | ||
override def onPush(): Unit = { | ||
buffer ++= grab(in) | ||
emitFrames() | ||
} | ||
override def onUpstreamFinish(): Unit = { | ||
emitFrames() | ||
if (buffer.isEmpty) completeStage() | ||
else failStage(new RuntimeException("Truncated frame at stream end")) | ||
} | ||
}) | ||
|
||
setHandler(out, new OutHandler { | ||
override def onPull(): Unit = { | ||
if (!hasBeenPulled(in)) pull(in) | ||
} | ||
}) | ||
|
||
def emitFrames(): Unit = { | ||
while (buffer.size >= 4) { | ||
val totalLength = buffer.iterator.getInt | ||
println("buffer size: " + buffer.size) | ||
println("total length: " + totalLength) | ||
println("buffer: " + buffer.utf8String) | ||
|
||
if (buffer.size < 4 + totalLength) { | ||
// not enough data yet | ||
return | ||
} | ||
val frame = buffer.slice(4, 4 + totalLength) | ||
buffer = buffer.drop(4 + totalLength) | ||
emit(out, frame) | ||
} | ||
|
||
if (!hasBeenPulled(in) && !isClosed(in)) { | ||
pull(in) | ||
} | ||
} | ||
} | ||
} |