From 7c152d92b0d7c65ce1ba5e758c200754aa64c22f Mon Sep 17 00:00:00 2001 From: Jack Morris Date: Wed, 21 Dec 2022 11:51:59 -0500 Subject: [PATCH] format after #695 --- textattack/metrics/quality_metrics/bert_score.py | 4 +++- .../metrics/quality_metrics/meteor_score.py | 16 ++++++++++++---- .../metrics/quality_metrics/sentence_bert.py | 5 +++-- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/textattack/metrics/quality_metrics/bert_score.py b/textattack/metrics/quality_metrics/bert_score.py index d8dd5b740..e4f9e7947 100644 --- a/textattack/metrics/quality_metrics/bert_score.py +++ b/textattack/metrics/quality_metrics/bert_score.py @@ -13,7 +13,9 @@ class BERTScoreMetric(Metric): def __init__(self, **kwargs): - self.use_obj = BERTScore(min_bert_score=0.5, model_name="microsoft/deberta-large-mnli", num_layers=18) + self.use_obj = BERTScore( + min_bert_score=0.5, model_name="microsoft/deberta-large-mnli", num_layers=18 + ) self.original_candidates = [] self.successful_candidates = [] self.all_metrics = {} diff --git a/textattack/metrics/quality_metrics/meteor_score.py b/textattack/metrics/quality_metrics/meteor_score.py index fea0153c8..ffb92f0c8 100644 --- a/textattack/metrics/quality_metrics/meteor_score.py +++ b/textattack/metrics/quality_metrics/meteor_score.py @@ -6,8 +6,9 @@ """ -from textattack.attack_results import FailedAttackResult, SkippedAttackResult import nltk + +from textattack.attack_results import FailedAttackResult, SkippedAttackResult from textattack.metrics import Metric @@ -54,13 +55,20 @@ def calculate(self, results): elif isinstance(result, SkippedAttackResult): continue else: - self.original_candidates.append(result.original_result.attacked_text.text) - self.successful_candidates.append(result.perturbed_result.attacked_text.text) + self.original_candidates.append( + result.original_result.attacked_text.text + ) + self.successful_candidates.append( + result.perturbed_result.attacked_text.text + ) meteor_scores = [] for c in range(len(self.original_candidates)): meteor_scores.append( - nltk.translate.meteor([nltk.word_tokenize(self.original_candidates[c])], nltk.word_tokenize(self.successful_candidates[c])) + nltk.translate.meteor( + [nltk.word_tokenize(self.original_candidates[c])], + nltk.word_tokenize(self.successful_candidates[c]), + ) ) self.all_metrics["avg_attack_meteor_score"] = round( diff --git a/textattack/metrics/quality_metrics/sentence_bert.py b/textattack/metrics/quality_metrics/sentence_bert.py index 7bb157e26..f96660af6 100644 --- a/textattack/metrics/quality_metrics/sentence_bert.py +++ b/textattack/metrics/quality_metrics/sentence_bert.py @@ -13,13 +13,14 @@ class SBERTMetric(Metric): def __init__(self, **kwargs): - self.use_obj = BERT(model_name="all-MiniLM-L6-v2", metric="cosine") + self.use_obj = BERT(model_name="all-MiniLM-L6-v2", metric="cosine") self.original_candidates = [] self.successful_candidates = [] self.all_metrics = {} def calculate(self, results): - """Calculates average Sentence BERT similarity on all successfull attacks. + """Calculates average Sentence BERT similarity on all successfull + attacks. Args: results (``AttackResult`` objects):