Skip to content

Commit

Permalink
Merge pull request #10 from raphaelmansuy:feat/claude
Browse files Browse the repository at this point in the history
chore: Update Cargo.lock and Cargo.toml for Hiramu 0.1.8 release
  • Loading branch information
raphaelmansuy committed Apr 6, 2024
2 parents 37bce39 + e542279 commit 03f39f9
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 21 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
[package]
name = "hiramu"
version = "0.1.7"
version = "0.1.8"
edition = "2021"
license = "MIT"
description = "A Rust AI Engineering Toolbox"
description = "A Rust AI Engineering Toolbox to Access Ollama, AWS Bedrock"
repository = "https://github.com/raphaelmansuy/hiramu"
keywords = ["api", "client", "async"]
keywords = ["api", "client", "async","aws bedrock", "ollama"]


# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand Down
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ To start using Hiramu in your Rust project, add the following to your `Cargo.tom

```toml
[dependencies]
hiramu = "0.1.7"
hiramu = "0.1.8"
```

## Examples
Expand Down Expand Up @@ -94,6 +94,7 @@ async fn main() {
```rust
use hiramu::ollama::ollama_client::OllamaClient;
use hiramu::ollama::model::{GenerateRequest, GenerateRequestBuilder};
use futures::stream::TryStreamExt;

#[tokio::main]
async fn main() {
Expand Down Expand Up @@ -132,7 +133,7 @@ async fn main() {
let mut conversation_request = ConversationRequest::default();
conversation_request
.messages
.push(Message::new_user_message("Hello, Claude!"));
.push(Message::new_user_message("Hello, Claude!".to_owned()));

let chat_options = ChatOptions::default()
.with_temperature(0.7)
Expand Down
73 changes: 58 additions & 15 deletions src/bedrock/models/claude/claude_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ use crate::bedrock::models::claude::claude_request_message::{
use crate::bedrock::models::claude::error::ClaudeError;
use futures::stream::Stream;
use futures::TryStreamExt;
use serde_json::Value;

use super::claude_request_message::{
ContentBlockDelta, ContentBlockStart, ContentBlockStop, MessageDelta, MessageStart,
MessageStop, StreamResultData,
};

pub type ClaudeOptions = BedrockClientOptions;

Expand Down Expand Up @@ -47,25 +53,62 @@ impl ClaudeClient {
&self,
request: &ConversationRequest,
options: &ChatOptions,
) -> Result<impl Stream<Item = Result<StreamResult, ClaudeError>>, ClaudeError> {
) -> Result<impl Stream<Item = Result<StreamResultData, ClaudeError>>, ClaudeError> {
let model_id = options.model_id.to_string();
let payload = serde_json::to_value(request);
let payload = serde_json::to_value(request).map_err(|err| ClaudeError::Json(err))?;

let payload = match payload {
Ok(payload) => payload,
Err(err) => return Err(ClaudeError::Json(err)),
};
let response = self.client.generate_raw_stream(model_id, payload).await?;

let response = self.client.generate_raw_stream(model_id, payload).await;
let stream = response
.map_err(ClaudeError::from)
.and_then(|chunk| async move {
let stream_result = deserialize_stream_result(chunk)?;
Ok(stream_result)
});

let response = match response {
Ok(response) => response,
Err(err) => return Err(ClaudeError::from(err)),
};
Ok(stream)
}
}

fn deserialize_stream_result(value: Value) -> Result<StreamResultData, ClaudeError> {
let stream_result: StreamResult = serde_json::from_value(value)
.map_err(|err| ClaudeError::Deserialization(err.to_string()))?;

Ok(response
.map_ok(|value| serde_json::from_value(value).map_err(ClaudeError::Json))
.map_err(|err| ClaudeError::Unknown(err.to_string()))
.and_then(futures::future::ready))
match stream_result.result_type.as_str() {
"message_start" => {
let message_start: MessageStart = serde_json::from_value(stream_result.data)
.map_err(|err| ClaudeError::Deserialization(err.to_string()))?;
Ok(StreamResultData::MessageStart(message_start))
}
"content_block_start" => {
let content_block_start: ContentBlockStart = serde_json::from_value(stream_result.data)
.map_err(|err| ClaudeError::Deserialization(err.to_string()))?;
Ok(StreamResultData::ContentBlockStart(content_block_start))
}
"content_block_delta" => {
let content_block_delta: ContentBlockDelta = serde_json::from_value(stream_result.data)
.map_err(|err| ClaudeError::Deserialization(err.to_string()))?;
Ok(StreamResultData::ContentBlockDelta(content_block_delta))
}
"content_block_stop" => {
let content_block_stop: ContentBlockStop =
serde_json::from_value(stream_result.data)
.map_err(|err| ClaudeError::Deserialization(err.to_string()))?;
Ok(StreamResultData::ContentBlockStop(content_block_stop))
}
"message_delta" => {
let message_delta: MessageDelta = serde_json::from_value(stream_result.data)
.map_err(|err| ClaudeError::Deserialization(err.to_string()))?;
Ok(StreamResultData::MessageDelta(message_delta))
}
"message_stop" => {
let message_stop: MessageStop = serde_json::from_value(stream_result.data)
.map_err(|err| ClaudeError::Deserialization(err.to_string()))?;
Ok(StreamResultData::MessageStop(message_stop))
}
_ => Err(ClaudeError::Deserialization(format!(
"Unknown StreamResult type: {}",
stream_result.result_type
))),
}
}
13 changes: 13 additions & 0 deletions src/bedrock/models/claude/claude_request_message.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use serde::{Deserialize, Serialize};

use super::ClaudeError;

pub struct ChatOptions {
pub model_id: String,
pub temperature: Option<f32>,
Expand Down Expand Up @@ -202,6 +204,17 @@ pub struct StreamResult {
pub data: serde_json::Value,
}


#[derive(Debug, Serialize, Deserialize)]
pub enum StreamResultData {
MessageStart(MessageStart),
ContentBlockStart(ContentBlockStart),
ContentBlockDelta(ContentBlockDelta),
ContentBlockStop(ContentBlockStop),
MessageDelta(MessageDelta),
MessageStop(MessageStop),
}

#[derive(Debug, Serialize, Deserialize)]
pub struct MessageStart {
pub message: Message,
Expand Down
3 changes: 3 additions & 0 deletions src/bedrock/models/claude/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,7 @@ pub enum ClaudeError {

#[error("Unknown error: {0}")]
Unknown(String),

#[error("Deserialization error: {0}")]
Deserialization(String),
}

0 comments on commit 03f39f9

Please sign in to comment.