Skip to content

Commit

Permalink
Merge pull request #310 from ArshdeepSekhon/master
Browse files Browse the repository at this point in the history
changes to eval cli : option to eval on entire dataset
  • Loading branch information
qiyanjun authored Nov 28, 2020
2 parents 43e7577 + 7635c7b commit efe3ac7
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 7 deletions.
6 changes: 4 additions & 2 deletions textattack/commands/attack/attack_args_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def add_dataset_args(parser):
type=int,
required=False,
default="5",
help="The number of examples to process.",
help="The number of examples to process, -1 for entire dataset",
)

parser.add_argument(
Expand Down Expand Up @@ -419,6 +419,8 @@ def parse_dataset_from_args(args):
dataset.examples = dataset.examples[args.num_examples_offset :]
else:
raise ValueError("Must supply pretrained model or dataset")
if args.num_examples == -1 or args.num_examples > len(dataset):
args.num_examples = len(dataset)
return dataset


Expand Down Expand Up @@ -480,7 +482,7 @@ def parse_logger_from_args(args):
if args.log_to_txt == "" or args.log_to_txt:
attack_log_manager.add_output_file(os.path.join(out_dir_txt, filename_txt))

# if "--log-to-csv" specified in terminal command(with or without arg), save to a csv file
# if "--log-to-csv" specified in terminal command(with or without arg), save to a csv file
if args.log_to_csv == "" or args.log_to_csv:
# "--csv-style used to swtich from 'fancy' to 'plain'
color_method = None if args.csv_style == "plain" else "file"
Expand Down
3 changes: 1 addition & 2 deletions textattack/commands/attack/run_attack_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def attack_from_queue(args, in_queue, out_queue):
def run(args, checkpoint=None):
pytorch_multiprocessing_workaround()

dataset = parse_dataset_from_args(args)
num_total_examples = args.num_examples

if args.checkpoint_resume:
Expand Down Expand Up @@ -109,8 +110,6 @@ def run(args, checkpoint=None):
# We reserve the first GPU for coordinating workers.
num_gpus = torch.cuda.device_count()

dataset = parse_dataset_from_args(args)

textattack.shared.logger.info(f"Running on {num_gpus} GPUs")
start_time = time.time()

Expand Down
3 changes: 2 additions & 1 deletion textattack/commands/attack/run_attack_single_threaded.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def run(args, checkpoint=None):
)
print(checkpoint, "\n")
else:
if not args.interactive:
dataset = parse_dataset_from_args(args)
num_remaining_attacks = args.num_examples
worklist = deque(range(0, args.num_examples))
worklist_tail = worklist[-1]
Expand Down Expand Up @@ -103,7 +105,6 @@ def run(args, checkpoint=None):

else:
# Not interactive? Use default dataset.
dataset = parse_dataset_from_args(args)

pbar = tqdm.tqdm(total=num_remaining_attacks, smoothing=0)
if args.checkpoint_resume:
Expand Down
5 changes: 3 additions & 2 deletions textattack/commands/eval_model/eval_model_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def test_model_on_dataset(self, args):
preds = []
ground_truth_outputs = []
i = 0
while i < min(args.num_examples, len(dataset)):

while i < args.num_examples:
dataset_batch = dataset[
i : min(args.num_examples, i + args.model_batch_size)
]
Expand Down Expand Up @@ -110,7 +111,6 @@ def register_subcommand(main_parser: ArgumentParser):

add_model_args(parser)
add_dataset_args(parser)

parser.add_argument("--random-seed", default=765, type=int)

parser.add_argument(
Expand All @@ -119,4 +119,5 @@ def register_subcommand(main_parser: ArgumentParser):
default=256,
help="Batch size for model inference.",
)

parser.set_defaults(func=EvalModelCommand())

0 comments on commit efe3ac7

Please sign in to comment.