Skip to content

Commit

Permalink
feat(query): Fetch query response in JSON format.
Browse files Browse the repository at this point in the history
Allows row deserialization into a `T: Deserialize`,
which eliminates the limitations of `Query::fetch`:

* when the table schema is not known: `SELECT * from ?`
* when the table schema is not specified: `DESCRIBE TABLE ?`
* when we read less columns than we select
  • Loading branch information
pravic committed Nov 27, 2024
1 parent 5f3b985 commit 55da617
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 0 deletions.
56 changes: 56 additions & 0 deletions src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use serde::{Deserialize, Serialize};
use std::fmt::Display;
use url::Url;

#[cfg(feature = "watch")]
use crate::watch;
use crate::{
error::{Error, Result},
headers::with_request_headers,
Expand Down Expand Up @@ -90,6 +92,60 @@ impl Query {
Ok(RowCursor::new(response))
}

/// Executes the query, returning a [`watch::RowJsonCursor`] to obtain results.
#[cfg(feature = "watch")]
pub fn fetch_json<T>(mut self) -> Result<watch::RowJsonCursor<T>> {
self.sql.append(" FORMAT JSONEachRowWithProgress");

let response = self.do_execute(true)?;
Ok(watch::RowJsonCursor::new(response))
}

/// Executes the query and returns just a single row.
///
/// Note that `T` must be owned.
#[cfg(feature = "watch")]
pub async fn fetch_json_one<T>(self) -> Result<T>
where
T: for<'b> Deserialize<'b>,
{
match self.fetch_json()?.next().await {
Ok(Some(row)) => Ok(row),
Ok(None) => Err(Error::RowNotFound),
Err(err) => Err(err),
}
}

/// Executes the query and returns at most one row.
///
/// Note that `T` must be owned.
#[cfg(feature = "watch")]
pub async fn fetch_json_optional<T>(self) -> Result<Option<T>>
where
T: for<'b> Deserialize<'b>,
{
self.fetch_json()?.next().await
}

/// Executes the query and returns all the generated results,
/// collected into a [`Vec`].
///
/// Note that `T` must be owned.
#[cfg(feature = "watch")]
pub async fn fetch_json_all<T>(self) -> Result<Vec<T>>
where
T: for<'b> Deserialize<'b>,
{
let mut result = Vec::new();
let mut cursor = self.fetch_json::<T>()?;

while let Some(row) = cursor.next().await? {
result.push(row);
}

Ok(result)
}

/// Executes the query and returns just a single row.
///
/// Note that `T` must be owned.
Expand Down
18 changes: 18 additions & 0 deletions src/watch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use sha1::{Digest, Sha1};
use crate::{
cursor::JsonCursor,
error::{Error, Result},
response::Response,
row::Row,
sql::{Bind, SqlBuilder},
Client, Compression,
Expand Down Expand Up @@ -165,6 +166,23 @@ impl EventCursor {
}
}

/// A cursor that emits rows in JSON format.
pub struct RowJsonCursor<T>(JsonCursor<T>);

impl<T> RowJsonCursor<T> {
pub(crate) fn new(response: Response) -> Self {
Self(JsonCursor::new(response))
}

/// Emits the next row.
pub async fn next<'a, 'b: 'a>(&'a mut self) -> Result<Option<T>>
where
T: Deserialize<'b>,
{
self.0.next().await
}
}

// === RowCursor ===

/// A cursor that emits `(Version, T)`.
Expand Down
95 changes: 95 additions & 0 deletions tests/it/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,98 @@ async fn prints_query() {
"SELECT ?fields FROM test WHERE a = ? AND b < ?"
);
}

#[cfg(feature = "watch")]
#[tokio::test]
async fn fetches_json_row() {
let client = prepare_database!();

let value = client
.query("SELECT 1,2,3")
.fetch_json_one::<serde_json::Value>()
.await
.unwrap();

assert_eq!(value, serde_json::json!({ "1": 1, "2": 2, "3": 3}));

let value = client
.query("SELECT (1,2,3) as data")
.fetch_json_one::<serde_json::Value>()
.await
.unwrap();

assert_eq!(value, serde_json::json!({ "data": [1,2,3]}));
}

#[cfg(feature = "watch")]
#[tokio::test]
async fn fetches_json_struct() {
let client = prepare_database!();

#[derive(Debug, Deserialize, PartialEq)]
struct Row {
one: i8,
two: String,
three: f32,
four: bool,
}

let value = client
.query("SELECT -1 as one, '2' as two, 3.0 as three, false as four")
.fetch_json_one::<Row>()
.await
.unwrap();

assert_eq!(
value,
Row {
one: -1,
two: "2".to_owned(),
three: 3.0,
four: false,
}
);
}

#[cfg(feature = "watch")]
#[tokio::test]
async fn describes_table() {
let client = prepare_database!();

let columns = client
.query("DESCRIBE TABLE system.users")
.fetch_json_all::<serde_json::Value>()
.await
.unwrap();
for c in &columns {
println!("{c}");
}
let columns = columns
.into_iter()
.map(|row| {
let column_name = row
.as_object()
.expect("JSONEachRow")
.get("name")
.expect("`system.users` must contain the `name` column");
(column_name.as_str().unwrap().to_owned(), row)
})
.collect::<std::collections::HashMap<String, serde_json::Value>>();
dbg!(&columns);

let name_column = columns
.get("name")
.expect("`system.users` must contain the `name` column");
assert_eq!(
name_column.as_object().unwrap().get("type").unwrap(),
&serde_json::json!("String")
);

let id_column = columns
.get("id")
.expect("`system.users` must contain the `id` column");
assert_eq!(
id_column.as_object().unwrap().get("type").unwrap(),
&serde_json::json!("UUID")
);
}

0 comments on commit 55da617

Please sign in to comment.