diff --git a/crates/corro-agent/src/api/public/mod.rs b/crates/corro-agent/src/api/public/mod.rs index 874883af..565f01db 100644 --- a/crates/corro-agent/src/api/public/mod.rs +++ b/crates/corro-agent/src/api/public/mod.rs @@ -466,14 +466,25 @@ async fn build_query_rows_response( pub async fn api_v1_queries( Extension(agent): Extension, + headers: axum::headers::HeaderMap, axum::extract::Json(stmt): axum::extract::Json, ) -> impl IntoResponse { + // https://github.com/ndjson/ndjson-spec#33-mediatype-and-file-extensions + let ndjson = headers.get("accept").map(|a| a == "application/x-ndjson").unwrap_or(false); + let (mut tx, body) = hyper::Body::channel(); // TODO: timeout on data send instead of infinitely waiting for channel space. let (data_tx, mut data_rx) = channel(512); tokio::spawn(async move { + if !ndjson { + if let Err(e) = tx.send_data(bytes::Bytes::from_static(b"[")).await { + error!("could not send data through body's channel: {e}"); + return; + } + } + let mut buf = BytesMut::new(); while let Some(row_res) = data_rx.recv().await { @@ -493,13 +504,23 @@ pub async fn api_v1_queries( } } - buf.extend_from_slice(b"\n"); + if !matches!(row_res, QueryEvent::EndOfQuery { .. }) { + buf.extend_from_slice(if ndjson { b"\n" } else { b"," }); + } if let Err(e) = tx.send_data(buf.split().freeze()).await { error!("could not send data through body's channel: {e}"); return; } } + + if !ndjson { + if let Err(e) = tx.send_data(bytes::Bytes::from_static(b"]")).await { + error!("could not send data through body's channel: {e}"); + return; + } + } + debug!("query body channel done"); }); @@ -613,6 +634,7 @@ pub async fn api_v1_db_schema( #[cfg(test)] mod tests { + use axum::http::HeaderValue; use bytes::Bytes; use corro_types::{api::RowId, config::Config, schema::SqliteType, base::Version}; use futures::Stream; @@ -781,8 +803,12 @@ mod tests { println!("transaction body: {body:?}"); + let mut headers = axum::headers::HeaderMap::new(); + headers.insert("Accept", HeaderValue::from_static("application/x-ndjson")); + let res = api_v1_queries( Extension(agent.clone()), + headers, axum::Json(Statement::Simple("select * from tests".into())), ) .await