Skip to content

Commit

Permalink
Merge pull request #710 from QData/format
Browse files Browse the repository at this point in the history
format after #695
  • Loading branch information
jxmorris12 committed Dec 21, 2022
2 parents 8a7c88d + 7c152d9 commit edbcb83
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 7 deletions.
4 changes: 3 additions & 1 deletion textattack/metrics/quality_metrics/bert_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@

class BERTScoreMetric(Metric):
def __init__(self, **kwargs):
self.use_obj = BERTScore(min_bert_score=0.5, model_name="microsoft/deberta-large-mnli", num_layers=18)
self.use_obj = BERTScore(
min_bert_score=0.5, model_name="microsoft/deberta-large-mnli", num_layers=18
)
self.original_candidates = []
self.successful_candidates = []
self.all_metrics = {}
Expand Down
16 changes: 12 additions & 4 deletions textattack/metrics/quality_metrics/meteor_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
"""

from textattack.attack_results import FailedAttackResult, SkippedAttackResult
import nltk

from textattack.attack_results import FailedAttackResult, SkippedAttackResult
from textattack.metrics import Metric


Expand Down Expand Up @@ -54,13 +55,20 @@ def calculate(self, results):
elif isinstance(result, SkippedAttackResult):
continue
else:
self.original_candidates.append(result.original_result.attacked_text.text)
self.successful_candidates.append(result.perturbed_result.attacked_text.text)
self.original_candidates.append(
result.original_result.attacked_text.text
)
self.successful_candidates.append(
result.perturbed_result.attacked_text.text
)

meteor_scores = []
for c in range(len(self.original_candidates)):
meteor_scores.append(
nltk.translate.meteor([nltk.word_tokenize(self.original_candidates[c])], nltk.word_tokenize(self.successful_candidates[c]))
nltk.translate.meteor(
[nltk.word_tokenize(self.original_candidates[c])],
nltk.word_tokenize(self.successful_candidates[c]),
)
)

self.all_metrics["avg_attack_meteor_score"] = round(
Expand Down
5 changes: 3 additions & 2 deletions textattack/metrics/quality_metrics/sentence_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@

class SBERTMetric(Metric):
def __init__(self, **kwargs):
self.use_obj = BERT(model_name="all-MiniLM-L6-v2", metric="cosine")
self.use_obj = BERT(model_name="all-MiniLM-L6-v2", metric="cosine")
self.original_candidates = []
self.successful_candidates = []
self.all_metrics = {}

def calculate(self, results):
"""Calculates average Sentence BERT similarity on all successfull attacks.
"""Calculates average Sentence BERT similarity on all successfull
attacks.
Args:
results (``AttackResult`` objects):
Expand Down

0 comments on commit edbcb83

Please sign in to comment.