Skip to content

Commit

Permalink
feat: task based files
Browse files Browse the repository at this point in the history
  • Loading branch information
Threated committed Apr 30, 2024
1 parent e303922 commit 48fd9be
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 71 deletions.
3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@ reqwest = { version = "0.12", features = ["rustls-tls", "stream"], default-featu
clap = { version = "4.5", features = ["derive", "env"] }
tokio = { version = "1", features = ["macros", "rt-multi-thread", "signal", "fs", "io-std"] }
once_cell = "1"
beam-lib = { git = "https://github.com/samply/beam", branch = "fix/relative-urls", features = ["http-util", "sockets"] }
beam-lib = { git = "https://github.com/samply/beam", branch = "fix/relative-urls", features = ["http-util"] }
serde_json = "1"
serde = { version = "1", features = ["derive"] }
tokio-util = { version = "0.7", features = ["io"] }
futures-util = { version = "0.3", default-features = false, features = ["std"] }
anyhow = "1"
sync_wrapper = { version = "1", features = ["futures"] }

[features]
server = ["dep:axum", "dep:axum-extra"]
Expand Down
121 changes: 70 additions & 51 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,16 @@ mod server;

use std::{path::Path, process::ExitCode, time::SystemTime};

use beam_lib::{AppId, BeamClient, BlockingOptions, SocketTask};
use anyhow::{anyhow, bail, Context, Result};
use beam_lib::{AppId, BeamClient, BlockingOptions, MsgId, RawString, TaskRequest};
use clap::Parser;
use config::{Config, SendArgs, Mode, ReceiveMode};
use config::{Config, Mode, ReceiveMode, SendArgs};
use futures_util::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
use once_cell::sync::Lazy;
use reqwest::{Client, Upgraded, Url};
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use reqwest::{Client, Url};
use serde::{Deserialize, Serialize};
use sync_wrapper::SyncStream;
use tokio::io::AsyncRead;
use tokio_util::io::ReaderStream;
use reqwest::header::{HeaderName, HeaderValue, HeaderMap};
use anyhow::{anyhow, bail, Context, Result};
use tokio::io::{AsyncReadExt, AsyncWriteExt};

pub static CONFIG: Lazy<Config> = Lazy::new(Config::parse);

Expand All @@ -31,18 +29,26 @@ pub static CLIENT: Lazy<Client> = Lazy::new(Client::new);
#[tokio::main]
async fn main() -> ExitCode {
let work = match &CONFIG.mode {
Mode::Send(send_args) if send_args.file.to_string_lossy() == "-" => send_file(tokio::io::stdin(), send_args).boxed(),
Mode::Send(send_args) if send_args.file.to_string_lossy() == "-" => async {
let mut buf = String::new();
tokio::io::stdin().read_to_string(&mut buf).await?;
send_file(buf, send_args).await
}.boxed(),
Mode::Send(send_args) => tokio::fs::File::open(&send_args.file)
.err_into()
.and_then(|f| send_file(f, send_args))
.boxed(),
.and_then(|mut f| {
let send_args = send_args.clone();
async move {
let mut buf = String::new();
f.read_to_string(&mut buf).await?;
send_file(buf, &send_args).await
}}).boxed(),
Mode::Receive { count, mode } => stream_tasks()
.and_then(connect_socket)
.inspect_ok(|(task, _)| eprintln!("Receiving file from: {}", task.from))
.and_then(move |(task, inc)| match mode {
ReceiveMode::Print => print_file(task, inc).boxed(),
ReceiveMode::Save { outdir, naming } => save_file(outdir, task, inc, naming).boxed(),
ReceiveMode::Callback { url } => forward_file(task, inc, url).boxed(),
.inspect_ok(|task| eprintln!("Receiving file from: {}", task.from))
.and_then(move |task| match mode {
ReceiveMode::Print => print_file(task).boxed(),
ReceiveMode::Save { outdir, naming } => save_file(outdir, task, naming).boxed(),
ReceiveMode::Callback { url } => forward_file(task, url).boxed(),
})
.take(*count as usize)
.for_each(|v| {
Expand All @@ -54,7 +60,7 @@ async fn main() -> ExitCode {
.map(Ok)
.boxed(),
#[cfg(feature = "server")]
Mode::Server { bind_addr, api_key} => server::serve(bind_addr, api_key).boxed(),
Mode::Server { bind_addr, api_key } => server::serve(bind_addr, api_key).boxed(),
};
let result = tokio::select! {
res = work => res,
Expand All @@ -71,51 +77,60 @@ async fn main() -> ExitCode {
}
}

pub async fn save_file(dir: &Path, socket_task: SocketTask, mut incoming: impl AsyncRead + Unpin, naming_scheme: &str) -> Result<()> {
pub async fn save_file(dir: &Path, task: TaskRequest<RawString>, naming_scheme: &str) -> Result<()> {
let ts = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs();
let from = socket_task.from.as_ref().split('.').take(2).collect::<Vec<_>>().join(".");
let meta: FileMeta = serde_json::from_value(socket_task.metadata).context("Failed to deserialize metadata")?;
let from = task.from.as_ref().split('.').take(2).collect::<Vec<_>>().join(".");
let meta: FileMeta = serde_json::from_value(task.metadata).context("Failed to deserialize metadata")?;
let filename = naming_scheme
.replace("%f", &from)
.replace("%t", &ts.to_string())
// Save because deserialize implementation of suggested_name does path traversal check
.replace("%n", meta.suggested_name.as_deref().unwrap_or(""));
let mut file = tokio::fs::File::create(dir.join(filename)).await?;
tokio::io::copy(&mut incoming, &mut file).await?;
file.write_all(task.body.0.as_bytes()).await?;
Ok(())
}

async fn send_file(mut stream: impl AsyncRead + Unpin, meta @ SendArgs { to, .. }: &SendArgs) -> Result<()> {
async fn send_file(body: impl Into<String>, meta @ SendArgs { to, .. }: &SendArgs) -> Result<()> {
let to = AppId::new_unchecked(format!(
"{to}.{}",
CONFIG.beam_id.as_ref().splitn(3, '.').nth(2).expect("Invalid app id")
CONFIG
.beam_id
.as_ref()
.splitn(3, '.')
.nth(2)
.expect("Invalid app id")
));
let mut conn = BEAM_CLIENT
.create_socket_with_metadata(&to, meta.to_file_meta())
BEAM_CLIENT
.post_task(&TaskRequest {
id: MsgId::new(),
from: CONFIG.beam_id.clone(),
to: vec![to],
body: RawString::from(body.into()),
ttl: "60s".to_string(),
failure_strategy: beam_lib::FailureStrategy::Discard,
metadata: serde_json::to_value(meta.to_file_meta())?,
})
.await?;
tokio::io::copy(&mut stream, &mut conn).await?;
Ok(())
}

pub fn stream_tasks() -> impl Stream<Item = Result<SocketTask>> {
pub fn stream_tasks() -> impl Stream<Item = Result<TaskRequest<RawString>>> {
static BLOCK: Lazy<BlockingOptions> = Lazy::new(|| BlockingOptions::from_count(1));
futures_util::stream::repeat_with(move || {
BEAM_CLIENT.get_socket_tasks(&BLOCK)
}).filter_map(|v| async {
match v.await {
Ok(mut v) => Some(Ok(v.pop()?)),
Err(e) => Some(Err(anyhow::Error::from(e)).context("Failed to get socket tasks from beam")),
}
})
}

pub async fn connect_socket(socket_task: SocketTask) -> Result<(SocketTask, Upgraded)> {
let id = socket_task.id;
Ok((socket_task, BEAM_CLIENT.connect_socket(&id).await.with_context(|| format!("Failed to connect to socket {id}"))?))
futures_util::stream::repeat_with(move || BEAM_CLIENT.poll_pending_tasks(&BLOCK)).filter_map(
|v| async {
match v.await {
Ok(mut v) => Some(Ok(v.pop()?)),
Err(e) => Some(
Err(anyhow::Error::from(e)).context("Failed to get socket tasks from beam"),
),
}
},
)
}

pub async fn forward_file(socket_task: SocketTask, incoming: impl AsyncRead + Unpin + Send + 'static, cb: &Url) -> Result<()> {
let FileMeta { suggested_name, meta } = serde_json::from_value(socket_task.metadata).context("Failed to deserialize metadata")?;
pub async fn forward_file(task: TaskRequest<RawString>, cb: &Url) -> Result<()> {
let FileMeta { suggested_name, meta } = serde_json::from_value(task.metadata).context("Failed to deserialize metadata")?;
let mut headers = HeaderMap::with_capacity(2);
if let Some(meta) = meta {
headers.append(HeaderName::from_static("metadata"), HeaderValue::from_bytes(&serde_json::to_vec(&meta)?)?);
Expand All @@ -126,24 +141,28 @@ pub async fn forward_file(socket_task: SocketTask, incoming: impl AsyncRead + Un
let res = CLIENT
.post(cb.clone())
.headers(headers)
.body(reqwest::Body::wrap_stream(SyncStream::new(ReaderStream::new(incoming))))
.body(task.body.0)
.send()
.await;
match res {
Ok(r) if !r.status().is_success() => bail!("Got unsuccessful status code from callback server: {}", r.status()),
Ok(r) if !r.status().is_success() => bail!(
"Got unsuccessful status code from callback server: {}",
r.status()
),
Err(e) => bail!("Failed to send file to {cb}: {e}"),
_ => Ok(())
_ => Ok(()),
}
}

pub async fn print_file(socket_task: SocketTask, mut incoming: impl AsyncRead + Unpin) -> Result<()> {
eprintln!("Incoming file from {}", socket_task.from);
tokio::io::copy(&mut incoming, &mut tokio::io::stdout()).await?;
eprintln!("Done printing file from {}", socket_task.from);
pub async fn print_file(task: TaskRequest<RawString>) -> Result<()> {
eprintln!("Incoming file from {}", task.from);
tokio::io::stdout()
.write_all(task.body.0.as_bytes())
.await?;
eprintln!("Done printing file from {}", task.from);
Ok(())
}


fn validate_filename(name: &str) -> Result<&str> {
if name.chars().all(|c| c.is_alphanumeric() || ['_', '.', '-'].contains(&c)) {
Ok(name)
Expand Down
37 changes: 19 additions & 18 deletions src/server.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
use std::{io, net::SocketAddr, sync::Arc};
use std::{net::SocketAddr, sync::Arc};

use axum::{
extract::{Path, State, Request}, http::{HeaderMap, StatusCode}, routing::post, Router
extract::{Path, State}, http::{HeaderMap, StatusCode}, routing::post, Router
};
use axum_extra::{headers::{authorization, Authorization}, TypedHeader};
use beam_lib::AppId;
use futures_util::TryStreamExt as _;
use beam_lib::{AppId, MsgId, RawString, TaskRequest};
use tokio::net::TcpListener;
use tokio_util::io::StreamReader;

use crate::{FileMeta, BEAM_CLIENT, CONFIG};

Expand All @@ -28,7 +26,7 @@ async fn send_file(
auth: TypedHeader<Authorization<authorization::Basic>>,
headers: HeaderMap,
State(api_key): State<AppState>,
req: Request,
req: String,
) -> Result<(), StatusCode> {
if auth.password() != api_key.as_ref() {
return Err(StatusCode::UNAUTHORIZED);
Expand All @@ -37,22 +35,25 @@ async fn send_file(
"{other_proxy_name}.{}",
CONFIG.beam_id.as_ref().splitn(3, '.').nth(2).expect("Invalid app id")
));
let mut conn = BEAM_CLIENT
.create_socket_with_metadata(&to, FileMeta {
meta: headers.get("metadata").and_then(|v| serde_json::from_slice(v.as_bytes()).map_err(|e| eprintln!("Failed to deserialize metadata: {e}. Skipping metadata")).ok()),
suggested_name: headers.get("filename").and_then(|v| v.to_str().map(Into::into).ok()),
})
let meta = FileMeta {
meta: headers.get("metadata").and_then(|v| serde_json::from_slice(v.as_bytes()).map_err(|e| eprintln!("Failed to deserialize metadata: {e}. Skipping metadata")).ok()),
suggested_name: headers.get("filename").and_then(|v| v.to_str().map(Into::into).ok()),
};
let task = TaskRequest {
id: MsgId::new(),
from: CONFIG.beam_id.clone(),
to: vec![to],
body: RawString(req),
ttl: "30s".to_string(),
failure_strategy: beam_lib::FailureStrategy::Discard,
metadata: serde_json::to_value(meta).unwrap(),
};
BEAM_CLIENT
.post_task(&task)
.await
.map_err(|e| {
eprintln!("Failed to tunnel request: {e}");
StatusCode::INTERNAL_SERVER_ERROR
})?;
tokio::spawn(async move {
let mut reader = StreamReader::new(req.into_body().into_data_stream().map_err(|err| io::Error::new(io::ErrorKind::Other, err)));
if let Err(e) = tokio::io::copy(&mut reader, &mut conn).await {
// TODO: Some of these are normal find out which
eprintln!("Error sending file: {e}")
}
});
Ok(())
}

0 comments on commit 48fd9be

Please sign in to comment.