Skip to content

Commit

Permalink
gemini support and bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jorge-menjivar committed Dec 20, 2023
1 parent 7785445 commit ced314c
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 85 deletions.
20 changes: 12 additions & 8 deletions apps/desktop/providers/messages.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -190,16 +190,20 @@ export const MessagesProvider = ({

const { id } = event.payload;

const message = liveMessages.current.find((message) => message.id === id);
setMessages((messages) => {
const message = messages.find((message) => message.id === id);

if (!message) return;
if (!message) return messages;

storageCreateMessage(
liveDatabase.current,
liveSession.current.user!,
message,
liveMessages.current.filter((message) => message.id !== id),
);
storageCreateMessage(
liveDatabase.current!,
liveSession.current!.user!,
message,
messages.filter((message) => message.id !== id),
);

return messages;
});

setMessageIsStreaming(false);
});
Expand Down
202 changes: 144 additions & 58 deletions apps/desktop/src-tauri/src/lib/services/google/stream.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,43 @@
use log::{debug, error};
use tauri::Manager;

use crate::models::{
ai_models::{AiModel, ModelParams},
conversations::Message,
use std::sync::Mutex;

use crate::{
models::{
ai_models::{AiModel, ModelParams},
conversations::Message,
},
ControllerState,
};
use log::debug;
use tauri::{Manager, State};
use tokio_stream::StreamExt;

#[tauri::command(rename_all = "snake_case")]
pub async fn stream_google(
pub async fn stream_google<'a>(
_saved_settings: serde_json::Value,
_model: AiModel,
model: AiModel,
system_prompt: String,
params: ModelParams,
api_key: Option<String>,
messages: Vec<Message>,
_token_count: u32,
assistant_message_id: String,
handle: tauri::AppHandle,
) {
state: State<'a, Mutex<ControllerState>>,
) -> Result<(), ()> {
if api_key.is_none() {
return;
return Err(());
}

let mut messages_to_send = vec![];
let mut messages_to_send = vec![
serde_json::json!({
"role": "user",
"parts": [{"text": system_prompt}],
}),
serde_json::json!({
"role": "model",
"parts": [{"text": "Understood."}],
}),
];

for message in messages.iter() {
let role = if message.role == "user" {
Expand All @@ -31,61 +46,69 @@ pub async fn stream_google(
"model"
};
messages_to_send.push(serde_json::json!({
"author": role,
"content": message.content,
"role": role,
"parts": [{"text": message.content}],
}));
}

let examples = vec![
serde_json::json!({
"input": { "content": "Hi" },
"output": { "content": "Hi, how can I help you today?" }
}),
serde_json::json!({
"input": { "content": "Tell me a joke" },
"output": { "content": "Why don't scientists trust atoms? Because they make up everything!" }
}),
serde_json::json!({
"input": { "content": "What can you do?" },
"output": { "content": "I can assist with a variety of tasks, including answering questions, providing information, and more." }
}),
];

let prompt = serde_json::json!({
"context": system_prompt,
"examples": examples,
"messages": messages_to_send,
});
let mut body = serde_json::json!({
"prompt": prompt,
"candidateCount": 1,
"topP": 0.95,
"topK": 40,
"contents": messages_to_send,
"generationConfig": {
"topP": 0.95,
"topK": 40,
},
"safetySettings": [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE"
}
],
});

if let Some(temperature) = params.temperature {
body["temperature"] = serde_json::json!(temperature);
body["generationConfig"]["temperature"] = serde_json::json!(temperature);
}

if let Some(max_tokens) = params.max_tokens {
body["maxOutputTokens"] = serde_json::json!(max_tokens);
body["generationConfig"]["maxOutputTokens"] = serde_json::json!(max_tokens);
}

if let Some(stop) = params.stop {
body["stopSequences"] = serde_json::json!(stop);
body["generationConfig"]["stopSequences"] = serde_json::json!(stop);
}

if let Some(top_k) = params.top_k {
body["topK"] = serde_json::json!(top_k);
body["generationConfig"]["topK"] = serde_json::json!(top_k);
}

if let Some(top_p) = params.top_p {
body["topP"] = serde_json::json!(top_p);
body["generationConfig"]["topP"] = serde_json::json!(top_p);
}

let client = reqwest::Client::new();
let url = format!(
"https://generativelanguage.googleapis.com/v1beta/models/{}:streamGenerateContent?key={}",
model.id,
api_key.unwrap()
);

// pretty json
println!("body: {}", serde_json::to_string_pretty(&body).unwrap());

debug!("URL: {}", url);

let url = ("https://generativelanguage.googleapis.com/v1beta3/models/chat-bison-001:generateMessage?key=".to_owned() + &api_key.unwrap()).to_string();
let client = reqwest::Client::new();

let res = client
.post(url)
Expand All @@ -97,40 +120,103 @@ pub async fn stream_google(
match res {
Ok(response) if response.status() != 200 => {
let result: serde_json::Value = response.json().await.unwrap();
error!("Result: {:?}", result);

if let Some(_) = result.get("error") {
return;
return Err(());
} else {
let error_message = format!(
"PaLM API returned an error: {}",
"Google API returned an error: {}",
result
.get("value")
.map_or_else(|| result["statusText"].to_string(), |v| v.to_string())
);
debug!("{}", error_message);
panic!("{}", error_message);
}
}
Ok(response) => {
let result: serde_json::Value = response.json().await.unwrap();
Ok(_) => (),
Err(_) => panic!("Network error while trying to reach Google API"),
}

let text = result["candidates"][0]["content"]
.as_str()
.unwrap()
.to_string();
let mut stream = res.unwrap().bytes_stream();

let mut buffer = "".to_string();
while let Some(event) = stream.next().await {
if state.lock().unwrap().0 == "abort" {
// Reset the state
state.lock().unwrap().0 = "run".to_string();

// Tell the client to save the message to the database since the stream is done
handle
.emit_all(
"completion-stream",
"post-message",
Some(serde_json::json!({
"text": text,
"id": assistant_message_id,
})),
)
.unwrap();
return Ok(());
}

match event {
Ok(raw_event) => {
let event = String::from_utf8(raw_event.to_vec()).unwrap();

buffer += &event;

// Use grep style "grep "text" to get the text
let matched_string = buffer
.split("\n")
.filter(|line| line.contains("text"))
.map(|line| line.replace("text: ", ""))
.collect::<Vec<String>>()
.join("\n");

let trimmed_text = matched_string.trim().to_string();

if trimmed_text == "" {
continue;
}

if !trimmed_text.ends_with('"') {
continue;
}

debug!("Matched String: {}", trimmed_text);

let parsed_text = format!("{{{}}}", trimmed_text);

let json_value: serde_json::Value = serde_json::from_str(&parsed_text).unwrap();

debug!(
"JSON Value: {}",
serde_json::to_string_pretty(&json_value).unwrap()
);

let text = json_value["text"].as_str();

match text {
Some(text) => {
handle
.emit_all(
"completion-stream",
Some(serde_json::json!({
"text": text.to_string(),
"id": assistant_message_id,
})),
)
.unwrap();
}
None => (),
}

buffer = "".to_string();
}
Err(e) => {
// Handle the error case
eprintln!("Error reading line: {}", e);
return Err(());
}
}
Err(_) => panic!("Network error while trying to reach PaLM API"),
}

handle
Expand All @@ -142,5 +228,5 @@ pub async fn stream_google(
)
.unwrap();

return;
return Ok(());
}
4 changes: 2 additions & 2 deletions apps/desktop/types/settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ export const SystemSettings: Settings = {
secret: true,
},
'google.key': {
name: 'PaLM 2 API Key',
description: 'The API key to use for PaLM 2 models.',
name: 'Google API Key',
description: 'The API key to use for Google models.',
type: 'string',
secret: true,
},
Expand Down
Loading

0 comments on commit ced314c

Please sign in to comment.