diff --git a/Cargo.toml b/Cargo.toml index 94d2aeb..f4e83a9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,4 +47,5 @@ httpmock = "0.6.8" cd = "0.2.1" indicatif = { version = "0.17.6", features = ["futures"] } tokio-util = { version = "0.7.8", features = ["codec"] } +csv = "1.2.2" diff --git a/src/lib/data.rs b/src/lib/data.rs index 56aa901..825049d 100644 --- a/src/lib/data.rs +++ b/src/lib/data.rs @@ -4,6 +4,7 @@ use std::fs::{metadata}; use serde_derive::{Serialize,Deserialize}; use serde; use crate::lib::data::serde::{Serializer,Deserializer}; +use crate::lib::download::Downloads; #[allow(unused_imports)] use log::{info, trace, debug}; use chrono::prelude::*; @@ -194,10 +195,13 @@ impl MergedFile { self.local.as_ref().map(|data_file| data_file.tracked) } - pub fn local_md5(&self, path_context: &Path) -> Option { - self.local.as_ref() - .and_then(|local| local.get_md5(path_context).ok()) - .flatten() + pub async fn local_md5(&self, path_context: &Path) -> Option { + if let Some(local) = &self.local { + if let Ok(md5_result) = local.get_md5(path_context).await { + return md5_result; + } + } + None } pub fn remote_md5(&self) -> Option { @@ -212,8 +216,8 @@ impl MergedFile { self.local.as_ref().map(|local| local.md5.clone()) } - pub fn local_remote_md5_mismatch(&self, path_context: &Path) -> Option { - let local_md5 = self.local_md5(path_context); + pub async fn local_remote_md5_mismatch(&self, path_context: &Path) -> Option { + let local_md5 = self.local_md5(path_context).await; let remote_md5 = self.remote_md5(); match (remote_md5, local_md5) { (Some(remote), Some(local)) => Some(remote != local), @@ -227,18 +231,21 @@ impl MergedFile { .get_mod_time(path_context).ok()) } - pub fn status(&self, path_context: &Path) -> Result { + pub async fn status(&self, path_context: &Path) -> Result { //let tracked = self.local.as_ref().map_or(None,|df| Some(df.tracked)); // local status, None if no local file found - let local_status = self.local - .as_ref() - .and_then(|local| local.status(path_context).ok()); + let local_status = if let Some(local) = &self.local { + local.status(path_context).await.ok() + } else { + None + }; + // TODO fix path_context //info!("{:?} local status: {:?} ({:?})", self.name(), local_status, &path_context); - let md5_mismatch = self.local_remote_md5_mismatch(path_context); - + let md5_mismatch = self.local_remote_md5_mismatch(path_context).await; + if !self.has_remote().unwrap_or(false) { return Ok(RemoteStatusCode::NotExists) } @@ -285,13 +292,15 @@ impl MergedFile { // Create a StatusEntry, for printing the status to the user. pub async fn status_entry(&self, path_context: &Path, include_remotes: bool) -> Result { let tracked = self.local.as_ref().map(|df| df.tracked); - let local_status = self.local - .as_ref() - .and_then(|local| local.status(path_context).ok()); + let local_status = if let Some(local) = self.local.as_ref() { + local.status(path_context).await.ok() + } else { + None + }; - let remote_status = if include_remotes { Some(self.status(path_context)?) } else { None }; + let remote_status = if include_remotes { Some(self.status(path_context).await?) } else { None }; //let remote_status = if self.remote_service.is_some() { Some(self.status(path_context)?) } else { None }; - + let remote_service = if include_remotes { self.remote_service.clone() } else { None }; if self.local.is_none() && self.remote.is_none() { @@ -304,7 +313,7 @@ impl MergedFile { remote_status, tracked, remote_service, - local_md5: self.local_md5(path_context), + local_md5: self.local_md5(path_context).await, remote_md5: self.remote_md5(), manifest_md5: self.manifest_md5(), local_mod_time: self.local_mod_time(path_context) @@ -314,12 +323,12 @@ impl MergedFile { impl DataFile { - pub fn new(path: String, url: Option<&str>, path_context: &Path) -> Result { + pub async fn new(path: String, url: Option<&str>, path_context: &Path) -> Result { let full_path = path_context.join(&path); if !full_path.exists() { return Err(anyhow!("File '{}' does not exist.", path)) } - let md5 = match compute_md5(&full_path)? { + let md5 = match compute_md5(&full_path).await? { Some(md5) => md5, None => return Err(anyhow!("Could not compute MD5 as file does not exist")), }; @@ -357,8 +366,8 @@ impl DataFile { .to_string()) } - pub fn get_md5(&self, path_context: &Path) -> Result> { - compute_md5(&self.full_path(path_context)?) + pub async fn get_md5(&self, path_context: &Path) -> Result> { + compute_md5(&self.full_path(path_context)?).await } pub fn get_mod_time(&self, path_context: &Path) -> Result> { @@ -381,16 +390,16 @@ impl DataFile { // Returns true if the file does not exist. - pub fn is_changed(&self, path_context: &Path) -> Result { - match self.get_md5(path_context)? { + pub async fn is_changed(&self, path_context: &Path) -> Result { + match self.get_md5(path_context).await? { Some(new_md5) => Ok(new_md5 != self.md5), None => Ok(true), } } - pub fn status(&self, path_context: &Path) -> Result { + pub async fn status(&self, path_context: &Path) -> Result { let is_alive = self.is_alive(path_context); - let is_changed = self.is_changed(path_context)?; + let is_changed = self.is_changed(path_context).await?; let local_status = match (is_changed, is_alive) { (false, true) => LocalStatusCode::Current, (true, true) => LocalStatusCode::Modified, @@ -403,8 +412,8 @@ impl DataFile { Ok(local_status) } - pub fn update(&mut self, path_context: &Path) -> Result<()> { - self.update_md5(path_context)?; + pub async fn update(&mut self, path_context: &Path) -> Result<()> { + self.update_md5(path_context).await?; self.update_size(path_context)?; Ok(()) } @@ -415,8 +424,8 @@ impl DataFile { Ok(()) } - pub fn update_md5(&mut self, path_context: &Path) -> Result<()> { - let new_md5 = match self.get_md5(path_context)? { + pub async fn update_md5(&mut self, path_context: &Path) -> Result<()> { + let new_md5 = match self.get_md5(path_context).await? { Some(md5) => md5, None => return Err(anyhow!("Cannot update MD5: file does not exist")), }; @@ -542,11 +551,11 @@ impl DataCollection { } } - pub fn update(&mut self, filename: Option<&String>, path_context: &Path) -> Result<()> { + pub async fn update(&mut self, filename: Option<&String>, path_context: &Path) -> Result<()> { match filename { Some(file) => { if let Some(data_file) = self.files.get_mut(file) { - data_file.update(path_context)?; + data_file.update(path_context).await?; debug!("rehashed file {:?}", data_file.path); } } @@ -555,7 +564,7 @@ impl DataCollection { let all_files: Vec<_> = self.files.keys().cloned().collect(); for file in all_files { if let Some(data_file) = self.files.get_mut(&file) { - data_file.update(path_context)?; + data_file.update(path_context).await?; debug!("rehashed file {:?}", data_file.path); } @@ -708,6 +717,7 @@ impl DataCollection { pb.bar.finish_with_message("Fetching completed."); Ok(all_remote_files) } + // Merge all local and remote files. // // Use a fetch to get all remote files (as RemoteFile), and merge these @@ -787,10 +797,9 @@ impl DataCollection { let mut statuses = BTreeMap::new(); - let pb = Progress::new(statuses.len() as u64)?; + let pb = Progress::new(statuses_futures.len() as u64)?; // process the futures as they become ready - pb.bar.set_message("Calculating MD5s..."); while let Some(result) = statuses_futures.next().await { if let Ok((key, value)) = result { pb.bar.set_message(format!("Calculating MD5s... {} done.", &value.name)); @@ -836,7 +845,7 @@ impl DataCollection { // now we need to figure out whether to push the file, // which depends on the RemoteStatusCode and whether // we should overwrite (TODO) - let do_upload = match merged_file.status(path_context)? { + let do_upload = match merged_file.status(path_context).await? { RemoteStatusCode::NoLocal => { // A file exists on the remote, but not locally: there // is nothing to push in this case (or count!) @@ -931,7 +940,7 @@ impl DataCollection { pub async fn pull(&mut self, path_context: &Path, overwrite: bool) -> Result<()> { let all_files = self.merge(true).await?; - let mut downloads = Vec::new(); + let mut downloads = Downloads::new(); let mut current_skipped = Vec::new(); let mut messy_skipped = Vec::new(); @@ -944,7 +953,7 @@ impl DataCollection { let path = merged_file.name()?; - let do_download = match merged_file.status(path_context)? { + let do_download = match merged_file.status(path_context).await? { RemoteStatusCode::NoLocal => { return Err(anyhow!("Internal error: execution should not have reached this point, please report.\n\ 'sdf pull' filtered by MergedFile.can_download() but found a RemoteStatusCode::NoLocal status.")); @@ -986,12 +995,14 @@ impl DataCollection { if do_download { if let Some(remote) = self.remotes.get(dir) { let download = remote.get_download_info(merged_file, path_context, overwrite)?; - downloads.push(download); + downloads.list.push(download); } } } } + // now retrieve all the files in the queue. + downloads.retrieve(Some(" - {}"), Some("No files downloaded.")).await?; let num_skipped = overwrite_skipped.len() + current_skipped.len() + messy_skipped.len(); @@ -1046,7 +1057,7 @@ mod tests { let nonexistent_path = "some/nonexistent/path".to_string(); let path_context = Path::new(""); - let result = DataFile::new(nonexistent_path, None, &path_context); + let result = DataFile::new(nonexistent_path, None, &path_context).await; match result { Ok(_) => assert!(false, "Expected an error, but got Ok"), Err(err) => { @@ -1066,11 +1077,11 @@ mod tests { // Make a DataFile let path = file.path().to_string_lossy().to_string(); - let data_file = DataFile::new(path, None, &path_context).unwrap(); + let data_file = DataFile::new(path, None, &path_context).await.unwrap(); // Compare MD5s let expected_md5 = "d3feb335769173b2db573413b0f6abf4".to_string(); - let observed_md5 = data_file.get_md5(&path_context).unwrap().unwrap(); + let observed_md5 = data_file.get_md5(&path_context).await.unwrap().unwrap(); assert!(observed_md5 == expected_md5, "MD5 mismatch!"); } @@ -1085,7 +1096,7 @@ mod tests { // Make a DataFile let path = file.path().to_string_lossy().to_string(); - let data_file = DataFile::new(path, None, &path_context).unwrap(); + let data_file = DataFile::new(path, None, &path_context).await.unwrap(); // Let's also check size assert!(data_file.size == 11, "Size mismatch {:?} != {:?}!", @@ -1104,14 +1115,14 @@ mod tests { // Make a DataFile let path = file.path().to_string_lossy().to_string(); - let mut data_file = DataFile::new(path, None, &path_context).unwrap(); + let mut data_file = DataFile::new(path, None, &path_context).await.unwrap(); // Now, we change the data. writeln!(file, "Modified mock data.").unwrap(); // Make sure the file MD5 is right let expected_md5 = "c6526ab1de615b49e53398ae5588bd00".to_string(); - let observed_md5 = data_file.get_md5(&path_context).unwrap().unwrap(); + let observed_md5 = data_file.get_md5(&path_context).await.unwrap().unwrap(); assert!(observed_md5 == expected_md5); // Make sure the old MD5 is in the DataFile @@ -1119,7 +1130,7 @@ mod tests { assert!(data_file.md5 == old_md5, "DataFile.md5 mismatch!"); // Now update - data_file.update_md5(path_context).unwrap(); + data_file.update_md5(path_context).await.unwrap(); assert!(data_file.md5 == expected_md5, "DataFile.update_md5() failed!"); } @@ -1133,7 +1144,7 @@ mod tests { // Make a DataFile let path = file.path().to_string_lossy().to_string(); - let mut data_file = DataFile::new(path, None, &path_context).unwrap(); + let mut data_file = DataFile::new(path, None, &path_context).await.unwrap(); // Now, we change the data. writeln!(file, "Modified mock data.").unwrap(); diff --git a/src/lib/download.rs b/src/lib/download.rs index 200ac9a..c70f883 100644 --- a/src/lib/download.rs +++ b/src/lib/download.rs @@ -9,7 +9,7 @@ use crate::lib::progress::{DEFAULT_PROGRESS_STYLE, DEFAULT_PROGRESS_INC}; use crate::lib::utils::pluralize; pub struct Downloads { - list: Vec, + pub list: Vec, } @@ -36,7 +36,8 @@ impl Downloads { Downloads { list } } - pub fn add(&mut self, item: T, filename: Option<&str>) -> Result<&Download> { + pub fn add(&mut self, item: T, filename: Option<&str>, + overwrite: bool) -> Result> { let url = item.to_url()?; let resolved_filename = match filename { @@ -50,10 +51,14 @@ impl Downloads { } }; + let file_path = PathBuf::from(&resolved_filename); + if file_path.exists() && !overwrite { + return Ok(None); + } let download = Download { url, filename: resolved_filename }; self.list.push(download); - Ok(self.list.last().ok_or(anyhow::anyhow!("Failed to add download"))?) + Ok(Some(self.list.last().ok_or(anyhow::anyhow!("Failed to add download"))?)) } pub fn default_style(&self) -> Result { @@ -67,7 +72,7 @@ impl Downloads { } - pub async fn download_all(&self, success_status: Option<&str>, + pub async fn retrieve(&self, success_status: Option<&str>, no_downloads_message: Option<&str>) -> Result<()> { let downloads = &self.list; let total_files = downloads.len(); diff --git a/src/lib/progress.rs b/src/lib/progress.rs index 4dd3462..46e7b82 100644 --- a/src/lib/progress.rs +++ b/src/lib/progress.rs @@ -1,12 +1,13 @@ use indicatif::{ProgressBar, ProgressStyle}; use std::time::Duration; use std::thread; -use anyhow::{anyhow,Result}; +use std::sync::mpsc::{self, Sender, Receiver}; +use anyhow::Result; // these are separated since some APIs don't overload // indicatif bars, but take the same primitives. pub const DEFAULT_PROGRESS_STYLE: &str = "{spinner:.green} [{bar:40.green/white}] {pos:>}/{len} ({percent}%) eta {eta_precise:.green} {msg}"; -pub const DEFAULT_PROGRESS_INC: &str = "=>"; +pub const DEFAULT_PROGRESS_INC: &str = "=> "; pub fn default_progress_style() -> Result { let style = ProgressStyle::default_bar() @@ -17,21 +18,38 @@ pub fn default_progress_style() -> Result { pub struct Progress { pub bar: ProgressBar, - spinner: thread::JoinHandle<()> + stop_spinner: Sender<()>, + #[allow(dead_code)] + spinner: Option> } impl Progress { pub fn new(len: u64) -> Result { - let bar = ProgressBar::new(len as u64); + let bar = ProgressBar::new(len); bar.set_style(default_progress_style()?); + let (tx, rx): (Sender<()>, Receiver<()>) = mpsc::channel(); + let bar_clone = bar.clone(); let spinner = thread::spawn(move || { loop { + if rx.try_recv().is_ok() { + break; + } bar_clone.tick(); thread::sleep(Duration::from_millis(20)); } }); - Ok(Progress { bar, spinner }) - } + Ok(Progress { bar, stop_spinner: tx, spinner: Some(spinner) }) + } +} + +impl Drop for Progress { + fn drop(&mut self) { + self.stop_spinner.send(()).unwrap(); + if let Some(spinner) = self.spinner.take() { + spinner.join().expect("Failed to join spinner thread"); + } + } } + diff --git a/src/lib/project.rs b/src/lib/project.rs index 509f34c..75ef23e 100644 --- a/src/lib/project.rs +++ b/src/lib/project.rs @@ -7,9 +7,8 @@ use std::path::{Path,PathBuf}; use std::io::{Read, Write}; #[allow(unused_imports)] use log::{info, trace, debug}; +use csv::{ReaderBuilder, StringRecord}; use dirs; -use trauma::download::Download; -use reqwest::Url; use crate::lib::download::Downloads; #[allow(unused_imports)] @@ -278,9 +277,10 @@ impl Project { Ok(()) } - pub fn is_clean(&self) -> Result { + // TODO + pub async fn is_clean(&self) -> Result { for data_file in self.data.files.values() { - let status = data_file.status(&self.path_context())?; + let status = data_file.status(&self.path_context()).await?; if status != LocalStatusCode::Current { return Ok(false); } @@ -308,11 +308,11 @@ impl Project { } */ - pub fn add(&mut self, files: &Vec) -> Result<()> { + pub async fn add(&mut self, files: &Vec) -> Result<()> { let mut num_added = 0; for filepath in files { let filename = self.relative_path_string(Path::new(&filepath.clone()))?; - let data_file = DataFile::new(filename.clone(), None, &self.path_context())?; + let data_file = DataFile::new(filename.clone(), None, &self.path_context()).await?; info!("Adding file '{}'.", filename); self.data.register(data_file)?; num_added += 1; @@ -321,9 +321,9 @@ impl Project { self.save() } - pub fn update(&mut self, filepath: Option<&String>) -> Result<()> { + pub async fn update(&mut self, filepath: Option<&String>) -> Result<()> { let path_context = self.path_context(); - self.data.update(filepath, &path_context)?; + self.data.update(filepath, &path_context).await?; self.save() } @@ -385,19 +385,90 @@ impl Project { Ok(()) } - pub async fn get(&mut self, url: &str, filename: Option<&str>, path: Option<&str>) -> Result<()> { + pub async fn get(&mut self, url: &str, filename: Option<&str>, + overwrite: bool) -> Result<()> { let mut downloads = Downloads::new(); - let download = downloads.add(url.to_string(), filename)?; - - let file_path = &download.filename; - let data_file = DataFile::new(file_path.clone(), Some(url), &self.path_context())?; - info!("Adding file '{}'.", &file_path); - self.data.register(data_file)?; - Ok(()) + let download = downloads.add(url.to_string(), filename, overwrite)?; + if let Some(dl) = download { + let filepath = dl.filename.clone(); + + // get the file + downloads.retrieve(None, None).await?; + + // convert to relative path (based on where we are) + let filepath = self.relative_path_string(Path::new(&filepath))?; + + let data_file = DataFile::new(filepath.clone(), Some(url), &self.path_context()).await?; + + // Note: we do not use Project::add() since this works off strings. + // and we need to pass the URL, etc. + self.data.register(data_file)?; + self.save()?; + Ok(()) + } else { + Err(anyhow!("The file at '{}' was not downloaded because it would overwrite a file.\n\ + Use 'sdf get --ovewrite' to overwrite it.", url)) + } } - pub async fn get_from_file(&mut self, filename: &str, column: u64) -> Result<()> { - // TODO + pub async fn bulk(&mut self, filename: &str, column: Option, + header: bool, overwrite: bool) -> Result<()> { + let extension = std::path::Path::new(filename) + .extension() + .and_then(std::ffi::OsStr::to_str); + + let delimiter = match extension { + Some("csv") => b',', + Some("tsv") => b'\t', + _ => return Err(anyhow!("Unsupported file type: {:?}", extension)), + }; + + let file = File::open(filename)?; + let mut reader = ReaderBuilder::new() + .delimiter(delimiter) + .has_headers(header) + .from_reader(file); + + let column = column.unwrap_or(0) as usize; + + let mut downloads = Downloads::new(); + let mut filepaths = Vec::new(); + let mut urls = Vec::new(); + let mut skipped = Vec::new(); + let mut num_lines = 0; + for result in reader.records() { + let record: StringRecord = result?; + if let Some(url) = record.get(column) { + num_lines += 1; + let url = url.to_string(); + let download = downloads.add(url.clone(), None, overwrite)?; + if let Some(dl) = download { + let filepath = dl.filename.clone(); + filepaths.push(filepath); + urls.push(url.clone()); + } else { + skipped.push(url.clone()); + } + } + } + + // grab all the files + downloads.retrieve(None, None).await?; + + let mut num_added = 0; + for (filepath, url) in filepaths.iter().zip(urls.iter()) { + let rel_file_path = self.relative_path_string(Path::new(&filepath))?; + let data_file = DataFile::new(rel_file_path.clone(), Some(url), &self.path_context()).await?; + self.data.register(data_file)?; + num_added += 1; + } + let num_skipped = skipped.len(); + println!("{} URLs found in '{}.'\n\ + {} files were downloaded and added.\n\ + {} files were skipped because they existed (and --overwrite was no specified).", + num_lines, filename, + num_added, num_skipped); + self.save()?; Ok(()) } diff --git a/src/lib/remote.rs b/src/lib/remote.rs index a435999..763c336 100644 --- a/src/lib/remote.rs +++ b/src/lib/remote.rs @@ -17,8 +17,6 @@ use crate::lib::api::figshare::FigShareAPI; use crate::lib::api::dryad::DataDryadAPI; use crate::lib::api::zenodo::ZenodoAPI; use crate::lib::project::LocalMetadata; -use crate::lib::download::{Downloads}; - const AUTHKEYS: &str = ".scidataflow_authkeys.yml"; diff --git a/src/lib/utils.rs b/src/lib/utils.rs index b81b5ce..7de0813 100644 --- a/src/lib/utils.rs +++ b/src/lib/utils.rs @@ -41,7 +41,7 @@ pub fn ensure_exists(path: &Path) -> Result<()> { } /// Compute the MD5 of a file returning None if the file is empty. -pub fn compute_md5(file_path: &Path) -> Result> { +pub async fn compute_md5(file_path: &Path) -> Result> { const BUFFER_SIZE: usize = 1024; let mut file = match File::open(file_path) { diff --git a/src/main.rs b/src/main.rs index b6f1c60..db3f7f0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,6 +3,8 @@ use anyhow::Result; use structopt::StructOpt; #[allow(unused_imports)] use log::{info, trace, debug}; +use tokio::runtime::Builder; + use scidataflow::lib::project::Project; use scidataflow::logging_setup::setup; @@ -74,7 +76,32 @@ enum Commands { #[structopt(long)] name: Option }, - + #[structopt(name = "get")] + /// Download a file from a URL. + Get { + /// Download filename (default: based on URL). + url: String, + #[structopt(long)] + name: Option, + /// Overwrite local files if they exit. + #[structopt(long)] + overwrite: bool + }, + #[structopt(name = "bulk")] + /// Download a bunch of files from links stored in a file. + Bulk { + /// A TSV or CSV file containing a column of URLs. Type inferred from suffix. + filename: String, + /// Which column contains links (default: first). + #[structopt(long)] + column: Option, + /// The TSV or CSV starts with a header (i.e. skip first line). + #[structopt(long)] + header: bool, + /// Overwrite local files if they exit. + #[structopt(long)] + overwrite: bool, + }, #[structopt(name = "status")] /// Show status of data. Status { @@ -139,7 +166,7 @@ enum Commands { #[structopt(name = "push")] /// Push all tracked files to remote. Push { - // Overwrite remote files if they exit. + /// Overwrite remote files if they exit. #[structopt(long)] overwrite: bool, }, @@ -147,7 +174,7 @@ enum Commands { #[structopt(name = "pull")] /// Pull in all tracked files from the remote. Pull { - // Overwrite local files if they exit. + /// Overwrite local files if they exit. #[structopt(long)] overwrite: bool, @@ -175,16 +202,27 @@ pub fn print_errors(response: Result<()>) { } } -#[tokio::main] -async fn main() { +fn main() { setup(); - match run().await { - Ok(_) => {} - Err(e) => { - eprintln!("Error: {:?}", e); - std::process::exit(1); + + let ncores = 4; + + let runtime = Builder::new_multi_thread() + .worker_threads(ncores) + .enable_all() + .build() + .unwrap(); + + + runtime.block_on(async { + match run().await { + Ok(_) => {} + Err(e) => { + eprintln!("Error: {:?}", e); + std::process::exit(1); + } } - } + }); } async fn run() -> Result<()> { @@ -192,11 +230,19 @@ async fn run() -> Result<()> { match &cli.command { Some(Commands::Add { filenames }) => { let mut proj = Project::new()?; - proj.add(filenames) + proj.add(filenames).await } Some(Commands::Config { name, email, affiliation }) => { Project::set_config(name, email, affiliation) } + Some(Commands::Get { url, name, overwrite }) => { + let mut proj = Project::new()?; + proj.get(url, name.as_deref(), *overwrite).await + } + Some(Commands::Bulk { filename, column, header, overwrite }) => { + let mut proj = Project::new()?; + proj.bulk(filename, *column, *header, *overwrite).await + } Some(Commands::Init { name }) => { Project::init(name.clone()) } @@ -211,7 +257,7 @@ async fn run() -> Result<()> { } Some(Commands::Update { filename }) => { let mut proj = Project::new()?; - proj.update(filename.as_ref()) + proj.update(filename.as_ref()).await } Some(Commands::Link { dir, service, key, name, link_only }) => { let mut proj = Project::new()?; diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 4fafb04..bfedc37 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -175,7 +175,7 @@ impl Drop for TestEnvironment { } } -pub fn setup(do_add: bool) -> TestFixture { +pub async fn setup(do_add: bool) -> TestFixture { lazy_static! { static ref INIT_LOGGING: Once = Once::new(); } @@ -209,7 +209,7 @@ pub fn setup(do_add: bool) -> TestFixture { .collect(); // add those files - let _ = project.add(&add_files); + let _ = project.add(&add_files).await; } TestFixture { env: test_env, project } diff --git a/tests/test_project.rs b/tests/test_project.rs index d1ef9d9..0087d83 100644 --- a/tests/test_project.rs +++ b/tests/test_project.rs @@ -16,9 +16,9 @@ mod tests { use std::path::PathBuf; use scidataflow::lib::data::LocalStatusCode; - #[test] - fn test_fixture() { - let fixture = setup(false); + #[tokio::test] + async fn test_fixture() { + let fixture = setup(false).await; // test that the fixtures were created //info!("files: {:?}", test_env.files); if let Some(fixtures) = &fixture.env.files { @@ -28,9 +28,9 @@ mod tests { } } - #[test] - fn test_init() { - let fixture = setup(false); + #[tokio::test] + async fn test_init() { + let fixture = setup(false).await; // test that init() creates the data manifest let data_manifest = fixture.env.get_file_path("data_manifest.yml"); info!("Checking for file at path: {}", data_manifest.display()); // Add this log @@ -39,7 +39,7 @@ mod tests { #[tokio::test] async fn test_add_status_current() { - let mut fixture = setup(false); + let mut fixture = setup(false).await; let path_context = fixture.project.path_context(); let statuses = get_statuses(&mut fixture, &path_context).await; @@ -77,7 +77,7 @@ mod tests { #[tokio::test] async fn test_add_update_status_modified() { - let mut fixture = setup(true); + let mut fixture = setup(true).await; let path_context = fixture.project.path_context(); let statuses = get_statuses_map(&mut fixture, &path_context).await; @@ -100,7 +100,7 @@ mod tests { let re_add_files = vec![file_to_check.to_string_lossy().to_string()]; for file in &re_add_files { - let result = fixture.project.update(Some(&file)); + let result = fixture.project.update(Some(&file)).await; assert!(result.is_ok(), "re-adding raised Error!"); } @@ -113,13 +113,14 @@ mod tests { #[tokio::test] async fn test_add_already_added_error() { - let mut fixture = setup(true); + let mut fixture = setup(true).await; + println!("DATAAAA: {:?}", fixture.project.data); if let Some(files) = &fixture.env.files { for file in files { let mut file_list = Vec::new(); file_list.push(file.path.clone()); - let result = fixture.project.add(&file_list); + let result = fixture.project.add(&file_list).await; // check that we get match result {