From 362ff036642ef4b1d8b7542f866ba2402b0df638 Mon Sep 17 00:00:00 2001 From: Vince Buffalo Date: Thu, 31 Aug 2023 13:51:24 -0700 Subject: [PATCH] refactored progress and downloads --- src/lib.rs | 2 + src/lib/api/zenodo.rs | 1 + src/lib/data.rs | 79 ++++++++------------------------ src/lib/download.rs | 95 +++++++++++++++++++++++++++++++++++++++ src/lib/progress.rs | 37 +++++++++++++++ src/lib/project.rs | 13 ++++-- src/lib/remote.rs | 21 +++------ src/lib/utils.rs | 102 +++++++++++++++++++++--------------------- 8 files changed, 220 insertions(+), 130 deletions(-) create mode 100644 src/lib/download.rs create mode 100644 src/lib/progress.rs diff --git a/src/lib.rs b/src/lib.rs index 199aaa8..187fbe8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,8 @@ pub mod lib { pub mod zenodo; } pub mod project; + pub mod download; + pub mod progress; pub mod macros; pub mod remote; pub mod utils; diff --git a/src/lib/api/zenodo.rs b/src/lib/api/zenodo.rs index a303a55..b0075c2 100644 --- a/src/lib/api/zenodo.rs +++ b/src/lib/api/zenodo.rs @@ -692,6 +692,7 @@ mod tests { tracked: true, md5: md5.to_string(), size, + url: None }; let path_context = Path::new("path/to/datafile"); diff --git a/src/lib/data.rs b/src/lib/data.rs index 684655e..56aa901 100644 --- a/src/lib/data.rs +++ b/src/lib/data.rs @@ -8,19 +8,16 @@ use crate::lib::data::serde::{Serializer,Deserializer}; use log::{info, trace, debug}; use chrono::prelude::*; use std::collections::{HashMap,BTreeMap}; -use futures::future::join_all; use futures::stream::FuturesUnordered; use futures::StreamExt; +use futures::future::join_all; use std::fs; -use trauma::downloader::{DownloaderBuilder,StyleOptions,ProgressBarOpts}; -use std::time::Duration; -use std::thread; -use indicatif::{ProgressBar, ProgressStyle}; use colored::*; use crate::{print_warn,print_info}; use crate::lib::utils::{format_mod_time,compute_md5, md5_status,pluralize}; use crate::lib::remote::{authenticate_remote,Remote,RemoteFile,RemoteStatusCode}; +use crate::lib::progress::Progress; // The status of a local data file, *conditioned* on it being in the manifest. #[derive(Debug,PartialEq,Clone)] @@ -682,12 +679,8 @@ impl DataCollection { self.authenticate_remotes()?; let mut all_remote_files = HashMap::new(); - let pb = ProgressBar::new(self.remotes.len() as u64); - pb.set_style(ProgressStyle::default_bar() - .progress_chars("=> ") - .template("{spinner:.green} [{bar:40.green/white}] {pos:>}/{len} ({percent}%) eta {eta_precise:.green} {msg}")? - ); - pb.set_message("Fetching remote files..."); + let pb = Progress::new(self.remotes.len() as u64)?; + pb.bar.set_message("Fetching remote files..."); // Convert remotes into Futures, so that they can be awaited in parallel let fetch_futures: Vec<_> = self.remotes.iter().map(|(path, remote)| { @@ -704,15 +697,15 @@ impl DataCollection { for result in results { match result { Ok((key, value)) => { - pb.set_message(format!("Fetching remote files... {} done.", key.0)); + pb.bar.set_message(format!("Fetching remote files... {} done.", key.0)); all_remote_files.insert(key, value); - pb.inc(1); + pb.bar.inc(1); }, Err(e) => return Err(e), // Handle errors as needed } } - pb.finish_with_message("Fetching completed."); + pb.bar.finish_with_message("Fetching completed."); Ok(all_remote_files) } // Merge all local and remote files. @@ -794,33 +787,21 @@ impl DataCollection { let mut statuses = BTreeMap::new(); - let pb = ProgressBar::new(statuses_futures.len() as u64); - pb.set_style(ProgressStyle::default_bar() - .progress_chars("=> ") - .template("{spinner:.green} [{bar:40.green/white}] {pos:>}/{len} ({percent}%) eta {eta_precise:.green} {msg}")? - ); + let pb = Progress::new(statuses.len() as u64)?; - - let pb_clone = pb.clone(); - thread::spawn(move || { - loop { - pb_clone.tick(); - thread::sleep(Duration::from_millis(20)); - } - }); // process the futures as they become ready - pb.set_message("Calculating MD5s..."); + pb.bar.set_message("Calculating MD5s..."); while let Some(result) = statuses_futures.next().await { if let Ok((key, value)) = result { - pb.set_message(format!("Calculating MD5s... {} done.", &value.name)); + pb.bar.set_message(format!("Calculating MD5s... {} done.", &value.name)); statuses.entry(key).or_insert_with(Vec::new).push(value); - pb.inc(1); + pb.bar.inc(1); } else { result?; } } - pb.finish_with_message("Complete."); + pb.bar.finish_with_message("Complete."); Ok(statuses) } @@ -1004,37 +985,13 @@ impl DataCollection { if do_download { if let Some(remote) = self.remotes.get(dir) { - let info = remote.get_download_info(merged_file, path_context, overwrite)?; - let download = info.trauma_download()?; + let download = remote.get_download_info(merged_file, path_context, overwrite)?; downloads.push(download); } } } } - let style = ProgressBarOpts::new( - Some("{spinner:.green} [{bar:40.green/white}] {pos:>}/{len} ({percent}%) eta {eta_precise:.green} {msg}".to_string()), - Some("=> ".to_string()), - true, true); - - let style_clone = style.clone(); - let style_opts = StyleOptions::new(style, style_clone); - - let total_files = downloads.len(); - if !downloads.is_empty() { - let downloader = DownloaderBuilder::new() - .style_options(style_opts) - .build(); - downloader.download(&downloads).await; - println!("Downloaded {}.", pluralize(total_files as u64, "file")); - for download in downloads { - let filename = PathBuf::from(&download.filename); - let name_str = filename.file_name().ok_or(anyhow!("Internal Error: could not extract filename from download"))?; - println!(" - {}", name_str.to_string_lossy()); - } - } else { - println!("No files downloaded."); - } let num_skipped = overwrite_skipped.len() + current_skipped.len() + messy_skipped.len(); @@ -1089,7 +1046,7 @@ mod tests { let nonexistent_path = "some/nonexistent/path".to_string(); let path_context = Path::new(""); - let result = DataFile::new(nonexistent_path, &path_context); + let result = DataFile::new(nonexistent_path, None, &path_context); match result { Ok(_) => assert!(false, "Expected an error, but got Ok"), Err(err) => { @@ -1109,7 +1066,7 @@ mod tests { // Make a DataFile let path = file.path().to_string_lossy().to_string(); - let data_file = DataFile::new(path, &path_context).unwrap(); + let data_file = DataFile::new(path, None, &path_context).unwrap(); // Compare MD5s let expected_md5 = "d3feb335769173b2db573413b0f6abf4".to_string(); @@ -1128,7 +1085,7 @@ mod tests { // Make a DataFile let path = file.path().to_string_lossy().to_string(); - let data_file = DataFile::new(path, &path_context).unwrap(); + let data_file = DataFile::new(path, None, &path_context).unwrap(); // Let's also check size assert!(data_file.size == 11, "Size mismatch {:?} != {:?}!", @@ -1147,7 +1104,7 @@ mod tests { // Make a DataFile let path = file.path().to_string_lossy().to_string(); - let mut data_file = DataFile::new(path, &path_context).unwrap(); + let mut data_file = DataFile::new(path, None, &path_context).unwrap(); // Now, we change the data. writeln!(file, "Modified mock data.").unwrap(); @@ -1176,7 +1133,7 @@ mod tests { // Make a DataFile let path = file.path().to_string_lossy().to_string(); - let mut data_file = DataFile::new(path, &path_context).unwrap(); + let mut data_file = DataFile::new(path, None, &path_context).unwrap(); // Now, we change the data. writeln!(file, "Modified mock data.").unwrap(); diff --git a/src/lib/download.rs b/src/lib/download.rs new file mode 100644 index 0000000..200ac9a --- /dev/null +++ b/src/lib/download.rs @@ -0,0 +1,95 @@ +use anyhow::{anyhow,Result,Context}; +use std::path::PathBuf; +use reqwest::Url; + +use trauma::downloader::{DownloaderBuilder,StyleOptions,ProgressBarOpts}; +use trauma::download::Download; + +use crate::lib::progress::{DEFAULT_PROGRESS_STYLE, DEFAULT_PROGRESS_INC}; +use crate::lib::utils::pluralize; + +pub struct Downloads { + list: Vec, +} + + +pub trait Downloadable { + fn to_url(self) -> Result; +} + +impl Downloadable for String { + fn to_url(self) -> Result { + let url = Url::parse(&self).context(format!("Download URL '{}' is not valid.", &self))?; + Ok(url) + } +} + +impl Downloadable for Url { + fn to_url(self) -> Result { + Ok(self) + } +} + +impl Downloads { + pub fn new() -> Self { + let list = Vec::new(); + Downloads { list } + } + + pub fn add(&mut self, item: T, filename: Option<&str>) -> Result<&Download> { + let url = item.to_url()?; + + let resolved_filename = match filename { + Some(name) => name.to_string(), + None => { + url.path_segments() + .ok_or_else(|| anyhow::anyhow!("Error parsing URL."))? + .last() + .ok_or_else(|| anyhow::anyhow!("Error getting filename from download URL."))? + .to_string() + } + }; + + + let download = Download { url, filename: resolved_filename }; + self.list.push(download); + Ok(self.list.last().ok_or(anyhow::anyhow!("Failed to add download"))?) + } + + pub fn default_style(&self) -> Result { + let style = ProgressBarOpts::new( + Some(DEFAULT_PROGRESS_STYLE.to_string()), + Some(DEFAULT_PROGRESS_INC.to_string()), + true, true); + + let style_clone = style.clone(); + Ok(StyleOptions::new(style, style_clone)) + } + + + pub async fn download_all(&self, success_status: Option<&str>, + no_downloads_message: Option<&str>) -> Result<()> { + let downloads = &self.list; + let total_files = downloads.len(); + if !downloads.is_empty() { + let downloader = DownloaderBuilder::new() + .style_options(self.default_style()?) + .build(); + downloader.download(&downloads).await; + println!("Downloaded {}.", pluralize(total_files as u64, "file")); + for download in downloads { + if let Some(msg) = success_status { + let filename = PathBuf::from(&download.filename); + let name_str = filename.file_name().ok_or(anyhow!("Internal Error: could not extract filename from download"))?; + //println!(" - {}", name_str.to_string_lossy()); + println!("{}", msg.replace("{}", &name_str.to_string_lossy())); + } + } + } else { + if no_downloads_message.is_some() { + println!("{}", no_downloads_message.unwrap_or("")); + } + } + Ok(()) + } +} diff --git a/src/lib/progress.rs b/src/lib/progress.rs new file mode 100644 index 0000000..4dd3462 --- /dev/null +++ b/src/lib/progress.rs @@ -0,0 +1,37 @@ +use indicatif::{ProgressBar, ProgressStyle}; +use std::time::Duration; +use std::thread; +use anyhow::{anyhow,Result}; + +// these are separated since some APIs don't overload +// indicatif bars, but take the same primitives. +pub const DEFAULT_PROGRESS_STYLE: &str = "{spinner:.green} [{bar:40.green/white}] {pos:>}/{len} ({percent}%) eta {eta_precise:.green} {msg}"; +pub const DEFAULT_PROGRESS_INC: &str = "=>"; + +pub fn default_progress_style() -> Result { + let style = ProgressStyle::default_bar() + .progress_chars(DEFAULT_PROGRESS_INC) + .template(DEFAULT_PROGRESS_STYLE)?; + Ok(style) +} + +pub struct Progress { + pub bar: ProgressBar, + spinner: thread::JoinHandle<()> +} + +impl Progress { + pub fn new(len: u64) -> Result { + let bar = ProgressBar::new(len as u64); + bar.set_style(default_progress_style()?); + + let bar_clone = bar.clone(); + let spinner = thread::spawn(move || { + loop { + bar_clone.tick(); + thread::sleep(Duration::from_millis(20)); + } + }); + Ok(Progress { bar, spinner }) + } +} diff --git a/src/lib/project.rs b/src/lib/project.rs index a673199..509f34c 100644 --- a/src/lib/project.rs +++ b/src/lib/project.rs @@ -8,7 +8,10 @@ use std::io::{Read, Write}; #[allow(unused_imports)] use log::{info, trace, debug}; use dirs; +use trauma::download::Download; +use reqwest::Url; +use crate::lib::download::Downloads; #[allow(unused_imports)] use crate::{print_warn,print_info}; use crate::lib::data::{DataFile,DataCollection}; @@ -382,9 +385,13 @@ impl Project { Ok(()) } - pub async fn get(&mut self, url: &str, filename: &str) -> Result<()> { - let data_file = DataFile::new(filename.to_string(), Some(url), &self.path_context())?; - info!("Adding file '{}'.", filename); + pub async fn get(&mut self, url: &str, filename: Option<&str>, path: Option<&str>) -> Result<()> { + let mut downloads = Downloads::new(); + let download = downloads.add(url.to_string(), filename)?; + + let file_path = &download.filename; + let data_file = DataFile::new(file_path.clone(), Some(url), &self.path_context())?; + info!("Adding file '{}'.", &file_path); self.data.register(data_file)?; Ok(()) } diff --git a/src/lib/remote.rs b/src/lib/remote.rs index d5767a3..a435999 100644 --- a/src/lib/remote.rs +++ b/src/lib/remote.rs @@ -1,4 +1,5 @@ use serde_yaml; +use trauma::download::Download; use std::fs; use std::fs::File; use std::path::Path; @@ -8,7 +9,6 @@ use anyhow::{anyhow,Result}; #[allow(unused_imports)] use log::{info, trace, debug}; use std::collections::HashMap; -use trauma::download::Download; use serde_derive::{Serialize,Deserialize}; use reqwest::Url; @@ -17,22 +17,11 @@ use crate::lib::api::figshare::FigShareAPI; use crate::lib::api::dryad::DataDryadAPI; use crate::lib::api::zenodo::ZenodoAPI; use crate::lib::project::LocalMetadata; +use crate::lib::download::{Downloads}; const AUTHKEYS: &str = ".scidataflow_authkeys.yml"; -#[derive(Debug, Clone, PartialEq)] -pub struct DownloadInfo { - pub url: String, - pub path: String, -} - -impl DownloadInfo { - pub fn trauma_download(&self) -> Result { - Ok(Download::new(&Url::parse(&self.url)?, &self.path)) - } -} - #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct RemoteFile { pub name: String, @@ -204,7 +193,7 @@ impl Remote { // Get Download info: the URL (with token) and destination // TODO: could be struct, if some APIs require more authentication // Note: requires each API actually *check* overwrite. - pub fn get_download_info(&self, merged_file: &MergedFile, path_context: &Path, overwrite: bool) -> Result { + pub fn get_download_info(&self, merged_file: &MergedFile, path_context: &Path, overwrite: bool) -> Result { // if local DataFile is none, not in manifest; // do not download let data_file = match &merged_file.local { @@ -228,7 +217,9 @@ impl Remote { Remote::DataDryadAPI(_) => service_not_implemented!("DataDryad"), }?; let save_path = &data_file.full_path(path_context)?; - Ok( DownloadInfo { url: authenticated_url, path:save_path.to_string_lossy().to_string() }) + let url = Url::parse(&authenticated_url)?; + let filename = save_path.to_string_lossy().to_string(); + Ok( Download { url, filename }) } } diff --git a/src/lib/utils.rs b/src/lib/utils.rs index 7ffc115..b81b5ce 100644 --- a/src/lib/utils.rs +++ b/src/lib/utils.rs @@ -61,64 +61,64 @@ pub fn compute_md5(file_path: &Path) -> Result> { md5.consume(&buffer[..bytes_read]); } - + let result = md5.compute(); Ok(Some(format!("{:x}", result))) } /* -pub fn print_fixed_width(rows: HashMap>, nspaces: Option, indent: Option, color: bool) { - let indent = indent.unwrap_or(0); - let nspaces = nspaces.unwrap_or(6); + pub fn print_fixed_width(rows: HashMap>, nspaces: Option, indent: Option, color: bool) { + let indent = indent.unwrap_or(0); + let nspaces = nspaces.unwrap_or(6); + + let max_cols = rows.values() + .flat_map(|v| v.iter()) + .filter_map(|entry| { + match &entry.cols { + None => None, + Some(cols) => Some(cols.len()) + } + }) + .max() + .unwrap_or(0); - let max_cols = rows.values() - .flat_map(|v| v.iter()) - .filter_map(|entry| { - match &entry.cols { - None => None, - Some(cols) => Some(cols.len()) - } - }) - .max() - .unwrap_or(0); + let mut max_lengths = vec![0; max_cols]; - let mut max_lengths = vec![0; max_cols]; - - // compute max lengths across all rows - for entry in rows.values().flat_map(|v| v.iter()) { - if let Some(cols) = &entry.cols { - for (i, col) in cols.iter().enumerate() { - max_lengths[i] = max_lengths[i].max(col.width()); - } - } - } - // print status table - let mut keys: Vec<&String> = rows.keys().collect(); - keys.sort(); - for (key, value) in &rows { - let pretty_key = if color { key.bold().to_string() } else { key.clone() }; - println!("[{}]", pretty_key); - - // Print the rows with the correct widths - for row in value { - let mut fixed_row = Vec::new(); - let tracked = &row.tracked; - let local_status = &row.local_status; - let remote_status = &row.remote_status; - if let Some(cols) = &row.cols { - for (i, col) in cols.iter().enumerate() { - // push a fixed-width column to vector - let fixed_col = format!("{:width$}", col, width = max_lengths[i]); - fixed_row.push(fixed_col); - } - } - let spacer = " ".repeat(nspaces); - let status_line = fixed_row.join(&spacer); - println!("{}{}", " ".repeat(indent), status_line); - } - println!(); - } +// compute max lengths across all rows +for entry in rows.values().flat_map(|v| v.iter()) { +if let Some(cols) = &entry.cols { +for (i, col) in cols.iter().enumerate() { +max_lengths[i] = max_lengths[i].max(col.width()); +} +} +} +// print status table +let mut keys: Vec<&String> = rows.keys().collect(); +keys.sort(); +for (key, value) in &rows { +let pretty_key = if color { key.bold().to_string() } else { key.clone() }; +println!("[{}]", pretty_key); + +// Print the rows with the correct widths +for row in value { +let mut fixed_row = Vec::new(); +let tracked = &row.tracked; +let local_status = &row.local_status; +let remote_status = &row.remote_status; +if let Some(cols) = &row.cols { +for (i, col) in cols.iter().enumerate() { +// push a fixed-width column to vector +let fixed_col = format!("{:width$}", col, width = max_lengths[i]); +fixed_row.push(fixed_col); +} +} +let spacer = " ".repeat(nspaces); +let status_line = fixed_row.join(&spacer); +println!("{}{}", " ".repeat(indent), status_line); +} +println!(); +} } - */ +*/ // More specialized version of print_fixed_width() for statuses. // Handles coloring, manual annotation, etc pub fn print_fixed_width_status(rows: BTreeMap>, nspaces: Option,