diff --git a/textattack/constraints/semantics/bert_score.py b/textattack/constraints/semantics/bert_score.py index 9f0c65e0c..f9ff51c22 100644 --- a/textattack/constraints/semantics/bert_score.py +++ b/textattack/constraints/semantics/bert_score.py @@ -59,13 +59,26 @@ def __init__( model_type=model_name, idf=False, device=utils.device, num_layers=num_layers ) + def _sim_score(self, starting_text, transformed_text): + """Returns the metric similarity between the embedding of the starting + text and the transformed text. + + Args: + starting_text: The ``AttackedText``to use as a starting point. + transformed_text: A transformed ``AttackedText`` + + Returns: + The similarity between the starting and transformed text using BERTScore metric. + """ + cand = transformed_text.text + ref = starting_text.text + result = self._bert_scorer.score([cand], [ref]) + return result[BERTScore.SCORE_TYPE2IDX[self.score_type]].item() + def _check_constraint(self, transformed_text, reference_text): """Return `True` if BERT Score between `transformed_text` and `reference_text` is lower than minimum BERT Score.""" - cand = transformed_text.text - ref = reference_text.text - result = self._bert_scorer.score([cand], [ref]) - score = result[BERTScore.SCORE_TYPE2IDX[self.score_type]].item() + score = self._sim_score(reference_text, transformed_text) if score >= self.min_bert_score: return True else: diff --git a/textattack/metrics/__init__.py b/textattack/metrics/__init__.py index e1df932b0..e4ab29546 100644 --- a/textattack/metrics/__init__.py +++ b/textattack/metrics/__init__.py @@ -12,3 +12,6 @@ from .quality_metrics import Perplexity from .quality_metrics import USEMetric +from .quality_metrics import SBERTMetric +from .quality_metrics import BERTScoreMetric +from .quality_metrics import MeteorMetric diff --git a/textattack/metrics/quality_metrics/__init__.py b/textattack/metrics/quality_metrics/__init__.py index 6ba13465e..6eaa41c73 100644 --- a/textattack/metrics/quality_metrics/__init__.py +++ b/textattack/metrics/quality_metrics/__init__.py @@ -10,3 +10,6 @@ from .perplexity import Perplexity from .use import USEMetric +from .sentence_bert import SBERTMetric +from .bert_score import BERTScoreMetric +from .meteor_score import MeteorMetric diff --git a/textattack/metrics/quality_metrics/bert_score.py b/textattack/metrics/quality_metrics/bert_score.py new file mode 100644 index 000000000..e4f9e7947 --- /dev/null +++ b/textattack/metrics/quality_metrics/bert_score.py @@ -0,0 +1,75 @@ +""" + +BERTScoreMetric class: +------------------------------------------------------- +Class for calculating BERTScore on AttackResults + +""" + +from textattack.attack_results import FailedAttackResult, SkippedAttackResult +from textattack.constraints.semantics import BERTScore +from textattack.metrics import Metric + + +class BERTScoreMetric(Metric): + def __init__(self, **kwargs): + self.use_obj = BERTScore( + min_bert_score=0.5, model_name="microsoft/deberta-large-mnli", num_layers=18 + ) + self.original_candidates = [] + self.successful_candidates = [] + self.all_metrics = {} + + def calculate(self, results): + """Calculates average BERT score on all successfull attacks. + + Args: + results (``AttackResult`` objects): + Attack results for each instance in dataset + + Example:: + + + >> import textattack + >> import transformers + >> model = transformers.AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") + >> tokenizer = transformers.AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") + >> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) + >> attack = textattack.attack_recipes.DeepWordBugGao2018.build(model_wrapper) + >> dataset = textattack.datasets.HuggingFaceDataset("glue", "sst2", split="train") + >> attack_args = textattack.AttackArgs( + num_examples=1, + log_to_csv="log.csv", + checkpoint_interval=5, + checkpoint_dir="checkpoints", + disable_stdout=True + ) + >> attacker = textattack.Attacker(attack, dataset, attack_args) + >> results = attacker.attack_dataset() + >> bertscorem = textattack.metrics.quality_metrics.BERTScoreMetric().calculate(results) + """ + + self.results = results + + for i, result in enumerate(self.results): + if isinstance(result, FailedAttackResult): + continue + elif isinstance(result, SkippedAttackResult): + continue + else: + self.original_candidates.append(result.original_result.attacked_text) + self.successful_candidates.append(result.perturbed_result.attacked_text) + + sbert_scores = [] + for c in range(len(self.original_candidates)): + sbert_scores.append( + self.use_obj._sim_score( + self.original_candidates[c], self.successful_candidates[c] + ) + ) + + self.all_metrics["avg_attack_bert_score"] = round( + sum(sbert_scores) / len(sbert_scores), 2 + ) + + return self.all_metrics diff --git a/textattack/metrics/quality_metrics/meteor_score.py b/textattack/metrics/quality_metrics/meteor_score.py new file mode 100644 index 000000000..ffb92f0c8 --- /dev/null +++ b/textattack/metrics/quality_metrics/meteor_score.py @@ -0,0 +1,78 @@ +""" + +MeteorMetric class: +------------------------------------------------------- +Class for calculating METEOR score on AttackResults + +""" + +import nltk + +from textattack.attack_results import FailedAttackResult, SkippedAttackResult +from textattack.metrics import Metric + + +class MeteorMetric(Metric): + def __init__(self, **kwargs): + self.original_candidates = [] + self.successful_candidates = [] + self.all_metrics = {} + + def calculate(self, results): + """Calculates average Metero score on all successfull attacks. + + Args: + results (``AttackResult`` objects): + Attack results for each instance in dataset + + Example:: + + + >> import textattack + >> import transformers + >> model = transformers.AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") + >> tokenizer = transformers.AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") + >> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) + >> attack = textattack.attack_recipes.DeepWordBugGao2018.build(model_wrapper) + >> dataset = textattack.datasets.HuggingFaceDataset("glue", "sst2", split="train") + >> attack_args = textattack.AttackArgs( + num_examples=1, + log_to_csv="log.csv", + checkpoint_interval=5, + checkpoint_dir="checkpoints", + disable_stdout=True + ) + >> attacker = textattack.Attacker(attack, dataset, attack_args) + >> results = attacker.attack_dataset() + >> sbertm = textattack.metrics.quality_metrics.MeteorMetric().calculate(results) + """ + + self.results = results + + for i, result in enumerate(self.results): + if isinstance(result, FailedAttackResult): + continue + elif isinstance(result, SkippedAttackResult): + continue + else: + self.original_candidates.append( + result.original_result.attacked_text.text + ) + self.successful_candidates.append( + result.perturbed_result.attacked_text.text + ) + + meteor_scores = [] + for c in range(len(self.original_candidates)): + meteor_scores.append( + nltk.translate.meteor( + [nltk.word_tokenize(self.original_candidates[c])], + nltk.word_tokenize(self.successful_candidates[c]), + ) + ) + + self.all_metrics["avg_attack_meteor_score"] = round( + sum(meteor_scores) / len(meteor_scores), 2 + ) + + return self.all_metrics diff --git a/textattack/metrics/quality_metrics/sentence_bert.py b/textattack/metrics/quality_metrics/sentence_bert.py new file mode 100644 index 000000000..f96660af6 --- /dev/null +++ b/textattack/metrics/quality_metrics/sentence_bert.py @@ -0,0 +1,74 @@ +""" + +USEMetric class: +------------------------------------------------------- +Class for calculating SentenceBERT similarity on AttackResults + +""" + +from textattack.attack_results import FailedAttackResult, SkippedAttackResult +from textattack.constraints.semantics.sentence_encoders import BERT +from textattack.metrics import Metric + + +class SBERTMetric(Metric): + def __init__(self, **kwargs): + self.use_obj = BERT(model_name="all-MiniLM-L6-v2", metric="cosine") + self.original_candidates = [] + self.successful_candidates = [] + self.all_metrics = {} + + def calculate(self, results): + """Calculates average Sentence BERT similarity on all successfull + attacks. + + Args: + results (``AttackResult`` objects): + Attack results for each instance in dataset + + Example:: + + + >> import textattack + >> import transformers + >> model = transformers.AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") + >> tokenizer = transformers.AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") + >> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) + >> attack = textattack.attack_recipes.DeepWordBugGao2018.build(model_wrapper) + >> dataset = textattack.datasets.HuggingFaceDataset("glue", "sst2", split="train") + >> attack_args = textattack.AttackArgs( + num_examples=1, + log_to_csv="log.csv", + checkpoint_interval=5, + checkpoint_dir="checkpoints", + disable_stdout=True + ) + >> attacker = textattack.Attacker(attack, dataset, attack_args) + >> results = attacker.attack_dataset() + >> sbertm = textattack.metrics.quality_metrics.SBERTMetric().calculate(results) + """ + + self.results = results + + for i, result in enumerate(self.results): + if isinstance(result, FailedAttackResult): + continue + elif isinstance(result, SkippedAttackResult): + continue + else: + self.original_candidates.append(result.original_result.attacked_text) + self.successful_candidates.append(result.perturbed_result.attacked_text) + + sbert_scores = [] + for c in range(len(self.original_candidates)): + sbert_scores.append( + self.use_obj._sim_score( + self.original_candidates[c], self.successful_candidates[c] + ).item() + ) + + self.all_metrics["avg_attack_sentence_bert_similarity"] = round( + sum(sbert_scores) / len(sbert_scores), 2 + ) + + return self.all_metrics