Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Template tweaks #123

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ quoted-string = "0.6.1"
rand = { version = "0.8.5", features = ["small_rng"] }
rangemap = { version = "1.4.0" }
rcgen = { version = "0.11.1", features = ["x509-parser"] }
rhai = { version = "1.15.1", features = ["sync"] }
rhai = { version = "1.15.1", features = ["sync", "metadata"] }
rusqlite = { version = "0.29.0", features = ["serde_json", "time", "bundled", "uuid", "array", "load_extension", "column_decltype", "vtab", "functions", "chrono"] }
rustls = { version = "0.21.0", features = ["dangerous_configuration", "quic"] }
rustls-pemfile = "1.0.2"
Expand Down
90 changes: 42 additions & 48 deletions crates/corro-client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@ pub mod sub;
use std::{net::SocketAddr, ops::Deref, path::Path};

use corro_api_types::{
sqlite::ChangeType, ChangeId, ColumnName, ExecResponse, ExecResult, RowId, SqliteValue,
Statement,
sqlite::ChangeType, ChangeId, ColumnName, ExecResponse, ExecResult, QueryEvent, RowId,
SqliteValue, Statement,
};
use http::uri::PathAndQuery;
use hyper::{client::HttpConnector, http::HeaderName, Body, StatusCode};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use sub::SubscriptionStream;
use tracing::{debug, warn};
use tracing::warn;
use uuid::Uuid;

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum QueryEvent<T> {
pub enum TypedQueryEvent<T> {
Columns(Vec<ColumnName>),
Row(RowId, T),
#[serde(rename = "eoq")]
Expand Down Expand Up @@ -54,26 +54,14 @@ impl CorrosionApiClient {

if !res.status().is_success() {
let status = res.status();
match hyper::body::to_bytes(res.into_body()).await {
Ok(b) => match serde_json::from_slice(&b) {
Ok(res) => match res {
ExecResult::Error { error } => return Err(Error::ResponseError(error)),
res => return Err(Error::UnexpectedResult(res)),
},
Err(e) => {
debug!(
error = %e,
"could not deserialize response body, sending generic error..."
);
return Err(Error::UnexpectedStatusCode(status));
}
},
Err(e) => {
debug!(
error = %e,
"could not aggregate response body bytes, sending generic error..."
);
return Err(Error::UnexpectedStatusCode(status));
let b = hyper::body::to_bytes(res.into_body()).await?;
match serde_json::from_slice::<QueryEvent>(&b) {
Ok(QueryEvent::Error(error)) => {
return Err(Error::ResponseError(error.into_string()))
}
Ok(res) => return Err(Error::UnexpectedQueryResult(res)),
Err(_) => {
return Err(make_unexpected_status_error(status, b));
}
}
}
Expand Down Expand Up @@ -109,11 +97,7 @@ impl CorrosionApiClient {
.header(hyper::header::ACCEPT, "application/json")
.body(Body::from(serde_json::to_vec(statement)?))?;

let res = self.api_client.request(req).await?;

if !res.status().is_success() {
return Err(Error::UnexpectedStatusCode(res.status()));
}
let res = check_res(self.api_client.request(req).await?).await?;

// TODO: make that header name a const in corro-types
let id = res
Expand Down Expand Up @@ -166,11 +150,7 @@ impl CorrosionApiClient {
.header(hyper::header::ACCEPT, "application/json")
.body(hyper::Body::empty())?;

let res = self.api_client.request(req).await?;

if !res.status().is_success() {
return Err(Error::UnexpectedStatusCode(res.status()));
}
let res = check_res(self.api_client.request(req).await?).await?;

Ok(SubscriptionStream::new(
id,
Expand All @@ -197,11 +177,7 @@ impl CorrosionApiClient {
.header(hyper::header::ACCEPT, "application/json")
.body(Body::from(serde_json::to_vec(statements)?))?;

let res = self.api_client.request(req).await?;

if !res.status().is_success() {
return Err(Error::UnexpectedStatusCode(res.status()));
}
let res = check_res(self.api_client.request(req).await?).await?;

let bytes = hyper::body::to_bytes(res.into_body()).await?;

Expand All @@ -216,11 +192,7 @@ impl CorrosionApiClient {
.header(hyper::header::ACCEPT, "application/json")
.body(Body::from(serde_json::to_vec(statements)?))?;

let res = self.api_client.request(req).await?;

if !res.status().is_success() {
return Err(Error::UnexpectedStatusCode(res.status()));
}
let res = check_res(self.api_client.request(req).await?).await?;

let bytes = hyper::body::to_bytes(res.into_body()).await?;

Expand Down Expand Up @@ -314,6 +286,25 @@ impl CorrosionApiClient {
}
}

async fn check_res(res: hyper::Response<Body>) -> Result<hyper::Response<Body>, Error> {
let status = res.status();
if !status.is_success() {
let b = hyper::body::to_bytes(res.into_body())
.await
.unwrap_or_default();

return Err(make_unexpected_status_error(status, b));
}
Ok(res)
}

fn make_unexpected_status_error(status: StatusCode, body: bytes::Bytes) -> Error {
Error::UnexpectedStatusCode {
status,
body: String::from_utf8(body.to_vec()).unwrap_or_default(),
}
}

#[derive(Clone)]
pub struct CorrosionClient {
api_client: CorrosionApiClient,
Expand Down Expand Up @@ -355,14 +346,17 @@ pub enum Error {
#[error(transparent)]
Serde(#[from] serde_json::Error),

#[error("received unexpected response code: {0}")]
UnexpectedStatusCode(StatusCode),
#[error("received unexpected response code: {status}, body: {body}")]
UnexpectedStatusCode { status: StatusCode, body: String },

#[error("{0}")]
ResponseError(String),

#[error("unexpected result: {0:?}")]
UnexpectedResult(ExecResult),
#[error("unexpected exec result: {0:?}")]
UnexpectedExecResult(ExecResult),

#[error("unexpected query event: {0:?}")]
UnexpectedQueryResult(QueryEvent),

#[error("could not retrieve subscription id from headers")]
ExpectedQueryId,
Expand Down
10 changes: 5 additions & 5 deletions crates/corro-client/src/sub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use tokio_util::{
use tracing::error;
use uuid::Uuid;

use super::QueryEvent;
use super::TypedQueryEvent;

pin_project! {
pub struct IoBodyStream {
Expand Down Expand Up @@ -121,7 +121,7 @@ where
fn poll_stream(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<QueryEvent<T>, SubscriptionError>>> {
) -> Poll<Option<Result<TypedQueryEvent<T>, SubscriptionError>>> {
let stream = loop {
match self.stream.as_mut() {
None => match ready!(self.as_mut().poll_request(cx)) {
Expand All @@ -140,11 +140,11 @@ where
match res {
Some(Ok(b)) => match serde_json::from_slice(&b) {
Ok(evt) => {
if let QueryEvent::EndOfQuery { change_id, .. } = &evt {
if let TypedQueryEvent::EndOfQuery { change_id, .. } = &evt {
self.observed_eoq = true;
self.last_change_id = *change_id;
}
if let QueryEvent::Change(_, _, _, change_id) = &evt {
if let TypedQueryEvent::Change(_, _, _, change_id) = &evt {
if matches!(self.last_change_id, Some(id) if id.0 + 1 != change_id.0) {
return Poll::Ready(Some(Err(SubscriptionError::MissedChange)));
}
Expand Down Expand Up @@ -220,7 +220,7 @@ impl<T> Stream for SubscriptionStream<T>
where
T: DeserializeOwned + Unpin,
{
type Item = Result<QueryEvent<T>, SubscriptionError>;
type Item = Result<TypedQueryEvent<T>, SubscriptionError>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
// first, check if we need to wait for a backoff...
Expand Down
53 changes: 26 additions & 27 deletions crates/corro-tpl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::sync::Arc;

use compact_str::ToCompactString;
use corro_client::sub::SubscriptionStream;
use corro_client::{CorrosionApiClient, QueryEvent};
use corro_client::{CorrosionApiClient, TypedQueryEvent};
use corro_types::api::ColumnName;
use corro_types::api::RowId;
use corro_types::api::SqliteParam;
Expand Down Expand Up @@ -191,13 +191,13 @@ impl Cell {
struct SqliteValueWrap(SqliteValue);

impl SqliteValueWrap {
fn to_json(&mut self) -> String {
fn to_json(&mut self) -> Dynamic {
match &self.0 {
SqliteValue::Null => "null".into(),
SqliteValue::Integer(i) => i.to_string(),
SqliteValue::Real(f) => f.to_string(),
SqliteValue::Text(t) => enquote::enquote('"', t),
SqliteValue::Blob(b) => hex::encode(b.as_slice()),
SqliteValue::Null => Dynamic::UNIT,
SqliteValue::Integer(i) => Dynamic::from(*i),
SqliteValue::Real(f) => Dynamic::from(*f),
SqliteValue::Text(t) => Dynamic::from(t.to_string()),
SqliteValue::Blob(b) => Dynamic::from(b.to_vec()),
}
}

Expand Down Expand Up @@ -263,15 +263,15 @@ impl QueryResponseIter {
};
match res {
Some(Ok(evt)) => match evt {
QueryEvent::Columns(cols) => {
TypedQueryEvent::Columns(cols) => {
self.columns = Some(Arc::new(
cols.into_iter()
.enumerate()
.map(|(i, name)| (name, i as u16))
.collect(),
))
}
QueryEvent::EndOfQuery { .. } => {
TypedQueryEvent::EndOfQuery { .. } => {
match self.body.take() {
None => {
self.done = true;
Expand All @@ -292,24 +292,23 @@ impl QueryResponseIter {
self.done = true;
return None;
}
QueryEvent::Row(rowid, cells) | QueryEvent::Change(_, rowid, cells, _) => {
match self.columns.as_ref() {
Some(columns) => {
return Some(Ok(Row {
id: rowid,
columns: columns.clone(),
cells: Arc::new(cells),
}));
}
None => {
self.done = true;
return Some(Err(Box::new(EvalAltResult::from(
"did not receive columns data",
))));
}
TypedQueryEvent::Row(rowid, cells)
| TypedQueryEvent::Change(_, rowid, cells, _) => match self.columns.as_ref() {
Some(columns) => {
return Some(Ok(Row {
id: rowid,
columns: columns.clone(),
cells: Arc::new(cells),
}));
}
}
QueryEvent::Error(e) => {
None => {
self.done = true;
return Some(Err(Box::new(EvalAltResult::from(
"did not receive columns data",
))));
}
},
TypedQueryEvent::Error(e) => {
self.done = true;
return Some(Err(Box::new(EvalAltResult::from(e))));
}
Expand Down Expand Up @@ -424,7 +423,7 @@ async fn wait_for_rows(
};

match row_recv {
Some(Ok(QueryEvent::Change(_, _, cells, _))) => {
Some(Ok(TypedQueryEvent::Change(_, _, cells, _))) => {
trace!("got an updated row! {cells:?}");

if let Err(_e) = tx.send(TemplateCommand::Render).await {
Expand Down
Loading
Loading