From 3af99cbfeea5b0c1cdb53cd222caff68cfd7caa5 Mon Sep 17 00:00:00 2001 From: Abhay Zala Date: Fri, 25 Aug 2023 19:38:59 -0400 Subject: [PATCH] allow for stage skipping during end-to-end --- hirest_dataset.py | 10 ++++++---- run.py | 7 +++++-- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/hirest_dataset.py b/hirest_dataset.py index 9a0b193..bf87d76 100644 --- a/hirest_dataset.py +++ b/hirest_dataset.py @@ -184,8 +184,9 @@ def __init__(self, args, data_path, video_dir=None, video_feature_dir=None, asr_ data.append(task_datum) elif task == 'moment_segmentation': - if len(video_ann['steps']) == 0: - continue + if not args.end_to_end: + if len(video_ann['steps']) == 0: + continue if 'train' in str(data_path): @@ -267,8 +268,9 @@ def __init__(self, args, data_path, video_dir=None, video_feature_dir=None, asr_ data.append(task_datum) elif task == 'step_captioning': - if len(video_ann['steps']) == 0: - continue + if not args.end_to_end: + if len(video_ann['steps']) == 0: + continue target_text = [] original_bounds = [] diff --git a/run.py b/run.py index c30e7f4..834361b 100644 --- a/run.py +++ b/run.py @@ -383,6 +383,8 @@ def train(self): if args.end_to_end: import shutil + shutil.copyfile(f"{args.data_dir}/all_data_test.json", f"{args.data_dir}/all_data_test_original.json") + if 'moment_retrieval' in self.tasks: moments = self.evaluate(self.test_moment_retrieval_loader, has_target=False) @@ -413,7 +415,6 @@ def train(self): {"index": i, "heading": "", "absolute_bounds": [i, i+1]} ) - shutil.copyfile(f"{args.data_dir}/all_data_test.json", f"{args.data_dir}/all_data_test_original.json") with open(f"{args.data_dir}/all_data_test.json", 'w') as f: json.dump(test, f, indent=2) @@ -461,6 +462,7 @@ def train(self): batch_size=args.eval_batch_size, task='step_captioning', ) + if 'step_captioning' in self.tasks: moments = self.evaluate(self.test_step_captioning_loader, has_target=False) @@ -484,7 +486,8 @@ def train(self): shutil.move(f"{args.data_dir}/all_data_test.json", f"{args.data_dir}/temp3.json") - shutil.move(f"{args.data_dir}/all_data_test_original.json", f"{args.data_dir}/all_data_test.json") + + shutil.move(f"{args.data_dir}/all_data_test_original.json", f"{args.data_dir}/all_data_test.json") else: if 'moment_retrieval' in self.tasks: