Skip to content

Commit

Permalink
allow for stage skipping during end-to-end
Browse files Browse the repository at this point in the history
  • Loading branch information
aszala committed Aug 25, 2023
1 parent 0fc9f84 commit 3af99cb
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
10 changes: 6 additions & 4 deletions hirest_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,9 @@ def __init__(self, args, data_path, video_dir=None, video_feature_dir=None, asr_
data.append(task_datum)

elif task == 'moment_segmentation':
if len(video_ann['steps']) == 0:
continue
if not args.end_to_end:
if len(video_ann['steps']) == 0:
continue

if 'train' in str(data_path):

Expand Down Expand Up @@ -267,8 +268,9 @@ def __init__(self, args, data_path, video_dir=None, video_feature_dir=None, asr_
data.append(task_datum)

elif task == 'step_captioning':
if len(video_ann['steps']) == 0:
continue
if not args.end_to_end:
if len(video_ann['steps']) == 0:
continue

target_text = []
original_bounds = []
Expand Down
7 changes: 5 additions & 2 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,8 @@ def train(self):
if args.end_to_end:
import shutil

shutil.copyfile(f"{args.data_dir}/all_data_test.json", f"{args.data_dir}/all_data_test_original.json")

if 'moment_retrieval' in self.tasks:
moments = self.evaluate(self.test_moment_retrieval_loader, has_target=False)

Expand Down Expand Up @@ -413,7 +415,6 @@ def train(self):
{"index": i, "heading": "", "absolute_bounds": [i, i+1]}
)

shutil.copyfile(f"{args.data_dir}/all_data_test.json", f"{args.data_dir}/all_data_test_original.json")
with open(f"{args.data_dir}/all_data_test.json", 'w') as f:
json.dump(test, f, indent=2)

Expand Down Expand Up @@ -461,6 +462,7 @@ def train(self):
batch_size=args.eval_batch_size,
task='step_captioning',
)

if 'step_captioning' in self.tasks:
moments = self.evaluate(self.test_step_captioning_loader, has_target=False)

Expand All @@ -484,7 +486,8 @@ def train(self):


shutil.move(f"{args.data_dir}/all_data_test.json", f"{args.data_dir}/temp3.json")
shutil.move(f"{args.data_dir}/all_data_test_original.json", f"{args.data_dir}/all_data_test.json")

shutil.move(f"{args.data_dir}/all_data_test_original.json", f"{args.data_dir}/all_data_test.json")

else:
if 'moment_retrieval' in self.tasks:
Expand Down

0 comments on commit 3af99cb

Please sign in to comment.