Skip to content

Commit

Permalink
Merge pull request #695 from gmurro/master
Browse files Browse the repository at this point in the history
Extra quality metrics
  • Loading branch information
jxmorris12 authored Dec 15, 2022
2 parents 5fc1274 + 44c669a commit 8a7c88d
Show file tree
Hide file tree
Showing 6 changed files with 239 additions and 4 deletions.
21 changes: 17 additions & 4 deletions textattack/constraints/semantics/bert_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,26 @@ def __init__(
model_type=model_name, idf=False, device=utils.device, num_layers=num_layers
)

def _sim_score(self, starting_text, transformed_text):
"""Returns the metric similarity between the embedding of the starting
text and the transformed text.
Args:
starting_text: The ``AttackedText``to use as a starting point.
transformed_text: A transformed ``AttackedText``
Returns:
The similarity between the starting and transformed text using BERTScore metric.
"""
cand = transformed_text.text
ref = starting_text.text
result = self._bert_scorer.score([cand], [ref])
return result[BERTScore.SCORE_TYPE2IDX[self.score_type]].item()

def _check_constraint(self, transformed_text, reference_text):
"""Return `True` if BERT Score between `transformed_text` and
`reference_text` is lower than minimum BERT Score."""
cand = transformed_text.text
ref = reference_text.text
result = self._bert_scorer.score([cand], [ref])
score = result[BERTScore.SCORE_TYPE2IDX[self.score_type]].item()
score = self._sim_score(reference_text, transformed_text)
if score >= self.min_bert_score:
return True
else:
Expand Down
3 changes: 3 additions & 0 deletions textattack/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,6 @@

from .quality_metrics import Perplexity
from .quality_metrics import USEMetric
from .quality_metrics import SBERTMetric
from .quality_metrics import BERTScoreMetric
from .quality_metrics import MeteorMetric
3 changes: 3 additions & 0 deletions textattack/metrics/quality_metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@

from .perplexity import Perplexity
from .use import USEMetric
from .sentence_bert import SBERTMetric
from .bert_score import BERTScoreMetric
from .meteor_score import MeteorMetric
73 changes: 73 additions & 0 deletions textattack/metrics/quality_metrics/bert_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
BERTScoreMetric class:
-------------------------------------------------------
Class for calculating BERTScore on AttackResults
"""

from textattack.attack_results import FailedAttackResult, SkippedAttackResult
from textattack.constraints.semantics import BERTScore
from textattack.metrics import Metric


class BERTScoreMetric(Metric):
def __init__(self, **kwargs):
self.use_obj = BERTScore(min_bert_score=0.5, model_name="microsoft/deberta-large-mnli", num_layers=18)
self.original_candidates = []
self.successful_candidates = []
self.all_metrics = {}

def calculate(self, results):
"""Calculates average BERT score on all successfull attacks.
Args:
results (``AttackResult`` objects):
Attack results for each instance in dataset
Example::
>> import textattack
>> import transformers
>> model = transformers.AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
>> tokenizer = transformers.AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
>> attack = textattack.attack_recipes.DeepWordBugGao2018.build(model_wrapper)
>> dataset = textattack.datasets.HuggingFaceDataset("glue", "sst2", split="train")
>> attack_args = textattack.AttackArgs(
num_examples=1,
log_to_csv="log.csv",
checkpoint_interval=5,
checkpoint_dir="checkpoints",
disable_stdout=True
)
>> attacker = textattack.Attacker(attack, dataset, attack_args)
>> results = attacker.attack_dataset()
>> bertscorem = textattack.metrics.quality_metrics.BERTScoreMetric().calculate(results)
"""

self.results = results

for i, result in enumerate(self.results):
if isinstance(result, FailedAttackResult):
continue
elif isinstance(result, SkippedAttackResult):
continue
else:
self.original_candidates.append(result.original_result.attacked_text)
self.successful_candidates.append(result.perturbed_result.attacked_text)

sbert_scores = []
for c in range(len(self.original_candidates)):
sbert_scores.append(
self.use_obj._sim_score(
self.original_candidates[c], self.successful_candidates[c]
)
)

self.all_metrics["avg_attack_bert_score"] = round(
sum(sbert_scores) / len(sbert_scores), 2
)

return self.all_metrics
70 changes: 70 additions & 0 deletions textattack/metrics/quality_metrics/meteor_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""
MeteorMetric class:
-------------------------------------------------------
Class for calculating METEOR score on AttackResults
"""

from textattack.attack_results import FailedAttackResult, SkippedAttackResult
import nltk
from textattack.metrics import Metric


class MeteorMetric(Metric):
def __init__(self, **kwargs):
self.original_candidates = []
self.successful_candidates = []
self.all_metrics = {}

def calculate(self, results):
"""Calculates average Metero score on all successfull attacks.
Args:
results (``AttackResult`` objects):
Attack results for each instance in dataset
Example::
>> import textattack
>> import transformers
>> model = transformers.AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
>> tokenizer = transformers.AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
>> attack = textattack.attack_recipes.DeepWordBugGao2018.build(model_wrapper)
>> dataset = textattack.datasets.HuggingFaceDataset("glue", "sst2", split="train")
>> attack_args = textattack.AttackArgs(
num_examples=1,
log_to_csv="log.csv",
checkpoint_interval=5,
checkpoint_dir="checkpoints",
disable_stdout=True
)
>> attacker = textattack.Attacker(attack, dataset, attack_args)
>> results = attacker.attack_dataset()
>> sbertm = textattack.metrics.quality_metrics.MeteorMetric().calculate(results)
"""

self.results = results

for i, result in enumerate(self.results):
if isinstance(result, FailedAttackResult):
continue
elif isinstance(result, SkippedAttackResult):
continue
else:
self.original_candidates.append(result.original_result.attacked_text.text)
self.successful_candidates.append(result.perturbed_result.attacked_text.text)

meteor_scores = []
for c in range(len(self.original_candidates)):
meteor_scores.append(
nltk.translate.meteor([nltk.word_tokenize(self.original_candidates[c])], nltk.word_tokenize(self.successful_candidates[c]))
)

self.all_metrics["avg_attack_meteor_score"] = round(
sum(meteor_scores) / len(meteor_scores), 2
)

return self.all_metrics
73 changes: 73 additions & 0 deletions textattack/metrics/quality_metrics/sentence_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
USEMetric class:
-------------------------------------------------------
Class for calculating SentenceBERT similarity on AttackResults
"""

from textattack.attack_results import FailedAttackResult, SkippedAttackResult
from textattack.constraints.semantics.sentence_encoders import BERT
from textattack.metrics import Metric


class SBERTMetric(Metric):
def __init__(self, **kwargs):
self.use_obj = BERT(model_name="all-MiniLM-L6-v2", metric="cosine")
self.original_candidates = []
self.successful_candidates = []
self.all_metrics = {}

def calculate(self, results):
"""Calculates average Sentence BERT similarity on all successfull attacks.
Args:
results (``AttackResult`` objects):
Attack results for each instance in dataset
Example::
>> import textattack
>> import transformers
>> model = transformers.AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
>> tokenizer = transformers.AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)
>> attack = textattack.attack_recipes.DeepWordBugGao2018.build(model_wrapper)
>> dataset = textattack.datasets.HuggingFaceDataset("glue", "sst2", split="train")
>> attack_args = textattack.AttackArgs(
num_examples=1,
log_to_csv="log.csv",
checkpoint_interval=5,
checkpoint_dir="checkpoints",
disable_stdout=True
)
>> attacker = textattack.Attacker(attack, dataset, attack_args)
>> results = attacker.attack_dataset()
>> sbertm = textattack.metrics.quality_metrics.SBERTMetric().calculate(results)
"""

self.results = results

for i, result in enumerate(self.results):
if isinstance(result, FailedAttackResult):
continue
elif isinstance(result, SkippedAttackResult):
continue
else:
self.original_candidates.append(result.original_result.attacked_text)
self.successful_candidates.append(result.perturbed_result.attacked_text)

sbert_scores = []
for c in range(len(self.original_candidates)):
sbert_scores.append(
self.use_obj._sim_score(
self.original_candidates[c], self.successful_candidates[c]
).item()
)

self.all_metrics["avg_attack_sentence_bert_similarity"] = round(
sum(sbert_scores) / len(sbert_scores), 2
)

return self.all_metrics

0 comments on commit 8a7c88d

Please sign in to comment.