Skip to content

Commit

Permalink
wip: unify request and response using enum
Browse files Browse the repository at this point in the history
  • Loading branch information
junkurihara committed Apr 11, 2024
1 parent f3754c3 commit c62227e
Showing 1 changed file with 163 additions and 3 deletions.
166 changes: 163 additions & 3 deletions httpsig-hyper/src/hyper_http.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::error::{HyperSigError, HyperSigResult};
use http::{Request, Response};
use http::{HeaderMap, Request, Response};
use http_body::Body;
use httpsig::prelude::{
message_component::{
Expand Down Expand Up @@ -235,6 +235,7 @@ where
{
type Error = HyperSigError;

/// Set the http message signature from given http signature params and signing key
async fn set_message_signature<T>(
&mut self,
signature_params: &HttpSignatureParams,
Expand All @@ -245,7 +246,9 @@ where
Self: Sized,
T: SigningKey + Sync,
{
todo!()
self
.set_message_signatures(&[(&signature_params, signing_key, signature_name)])
.await
}

async fn set_message_signatures<T>(
Expand All @@ -257,14 +260,38 @@ where
T: SigningKey + Sync,
{
todo!()
// let vec_signature_headers_fut = params_key_name.iter().flat_map(|(params, key, name)| {
// build_signature_base_from_response(self, params).map(|base| async move { base.build_signature_headers(*key, *name) })
// });
// let vec_signature_headers = futures::future::join_all(vec_signature_headers_fut)
// .await
// .into_iter()
// .collect::<Result<Vec<_>, _>>()?;
// vec_signature_headers.iter().try_for_each(|headers| {
// self
// .headers_mut()
// .append("signature-input", headers.signature_input_header_value().parse()?);
// self
// .headers_mut()
// .append("signature", headers.signature_header_value().parse()?);
// Ok(()) as Result<(), HyperSigError>
// })
}

/// Verify the http message signature with given verifying key if the request has signature and signature-input headers
/// Return Ok(()) if the signature is valid.
/// If invalid for the given key or error occurs (like the case where the request does not have signature and/or signature-input headers), return Err.
/// If key_id is given, it is used to match the key id in signature params
async fn verify_message_signature<T>(&self, verifying_key: &T, key_id: Option<&str>) -> Result<SignatureName, Self::Error>
where
Self: Sized,
T: VerifyingKey + Sync,
{
todo!()
self
.verify_message_signatures(&[(verifying_key, key_id)])
.await?
.pop()
.unwrap()
}

async fn verify_message_signatures<T>(
Expand Down Expand Up @@ -297,6 +324,12 @@ where

/* --------------------------------------- */

/// A type to represent either http request or response
enum RequestOrResponse<'a, B> {
Request(&'a Request<B>),
Response(&'a Response<B>),
}

/// Extract signature and signature-input with signature-name indication from http request
fn extract_signature_headers_with_name<B>(req: &Request<B>) -> HyperSigResult<HttpSignatureHeadersMap> {
if !(req.headers().contains_key("signature-input") && req.headers().contains_key("signature")) {
Expand Down Expand Up @@ -338,6 +371,62 @@ fn build_signature_base_from_request<B>(
HttpSignatureBase::try_new(&component_lines, signature_params).map_err(|e| e.into())
}

/// Build signature base from hyper http request/response and signature params
/// - req_or_res: the hyper http request or response
/// - signature_params: the http signature params
/// - req_for_param: corresponding request to be considered in the signature base in response
fn build_signature_base<B>(
req_or_res: &RequestOrResponse<B>,
signature_params: &HttpSignatureParams,
req_for_param: Option<&Request<B>>,
) -> HyperSigResult<HttpSignatureBase> {
let component_lines = signature_params
.covered_components
.iter()
.map(|component_id| {
if component_id.params.0.contains(&HttpMessageComponentParam::Req) {
if matches!(req_or_res, RequestOrResponse::Request(_)) {
return Err(HyperSigError::InvalidComponentParam(
"`req` is not allowed in request".to_string(),
));
}
if req_for_param.is_none() {
return Err(HyperSigError::InvalidComponentParam(
"`req` is required for the param".to_string(),
));
}
let req = RequestOrResponse::Request(req_for_param.unwrap());
extract_http_message_component(&req, component_id)
} else {
extract_http_message_component(req_or_res, component_id)
}
})
.collect::<Result<Vec<_>, _>>()?;

HttpSignatureBase::try_new(&component_lines, signature_params).map_err(|e| e.into())
}

/// Extract http field from hyper http request/response
fn extract_http_field<B>(req_or_res: &RequestOrResponse<B>, id: &HttpMessageComponentId) -> HyperSigResult<HttpMessageComponent> {
let HttpMessageComponentName::HttpField(header_name) = &id.name else {
return Err(HyperSigError::InvalidComponentName(
"invalid http message component name as http field".to_string(),
));
};
let headers = match req_or_res {
RequestOrResponse::Request(req) => req.headers(),
RequestOrResponse::Response(res) => res.headers(),
};

let field_values = headers
.get_all(header_name)
.iter()
.map(|v| v.to_str().map(|s| s.to_owned()))
.collect::<Result<Vec<_>, _>>()?;

HttpMessageComponent::try_from((id, field_values.as_slice())).map_err(|e| e.into())
}

/// Extract http field from hyper http request
fn extract_http_field_from_request<B>(req: &Request<B>, id: &HttpMessageComponentId) -> HyperSigResult<HttpMessageComponent> {
let HttpMessageComponentName::HttpField(header_name) = &id.name else {
Expand Down Expand Up @@ -420,6 +509,66 @@ fn extract_derived_component_from_request<B>(
HttpMessageComponent::try_from((id, field_values.as_slice())).map_err(|e| e.into())
}

/// Extract derived component from hyper http request/response
fn extract_derived_component<B>(
req_or_res: &RequestOrResponse<B>,
id: &HttpMessageComponentId,
) -> HyperSigResult<HttpMessageComponent> {
let HttpMessageComponentName::Derived(derived_id) = &id.name else {
return Err(HyperSigError::InvalidComponentName(
"invalid http message component name as derived component".to_string(),
));
};
if !id.params.0.is_empty() {
return Err(HyperSigError::InvalidComponentParam(
"derived component does not allow parameters for request".to_string(),
));
}

// let field_values: Vec<String> = match derived_id {
// DerivedComponentName::Method => vec![req.method().as_str().to_string()],
// DerivedComponentName::TargetUri => vec![req.uri().to_string()],
// DerivedComponentName::Authority => vec![req.uri().authority().map(|s| s.to_string()).unwrap_or("".to_string())],
// DerivedComponentName::Scheme => vec![req.uri().scheme_str().unwrap_or("").to_string()],
// DerivedComponentName::RequestTarget => match *req.method() {
// http::Method::CONNECT => vec![req.uri().authority().map(|s| s.to_string()).unwrap_or("".to_string())],
// http::Method::OPTIONS => vec!["*".to_string()],
// _ => vec![req.uri().path_and_query().map(|s| s.to_string()).unwrap_or("".to_string())],
// },
// DerivedComponentName::Path => vec![{
// let p = req.uri().path();
// if p.is_empty() {
// "/".to_string()
// } else {
// p.to_string()
// }
// }],
// DerivedComponentName::Query => vec![req.uri().query().map(|v| format!("?{v}")).unwrap_or("?".to_string())],
// DerivedComponentName::QueryParam => {
// let query = req.uri().query().unwrap_or("");
// query
// .split('&')
// .filter(|s| !s.is_empty())
// .map(|s| s.to_string())
// .collect::<Vec<_>>()
// }
// DerivedComponentName::Status => {
// return Err(HyperSigError::InvalidComponentName(
// "`status` is only for response".to_string(),
// ))
// }
// DerivedComponentName::SignatureParams => req
// .headers()
// .get_all("signature-input")
// .iter()
// .map(|v| v.to_str().unwrap_or("").to_string())
// .collect::<Vec<_>>(),
// };

// HttpMessageComponent::try_from((id, field_values.as_slice())).map_err(|e| e.into())
todo!()
}

/* --------------------------------------- */
/// Extract http message component from hyper http request
fn extract_http_message_component_from_request<B>(
Expand All @@ -432,6 +581,17 @@ fn extract_http_message_component_from_request<B>(
}
}

/// Extract http message component from hyper http request
fn extract_http_message_component<B>(
req_or_res: &RequestOrResponse<B>,
target_component_id: &HttpMessageComponentId,
) -> HyperSigResult<HttpMessageComponent> {
match &target_component_id.name {
HttpMessageComponentName::HttpField(_) => extract_http_field(req_or_res, target_component_id),
HttpMessageComponentName::Derived(_) => extract_derived_component(req_or_res, target_component_id),
}
}

/* --------------------------------------- */
#[cfg(test)]
mod tests {
Expand Down

0 comments on commit c62227e

Please sign in to comment.