Skip to content

Commit

Permalink
minor format change.. also fixed metric recipe error
Browse files Browse the repository at this point in the history
  • Loading branch information
qiyanjun committed Sep 30, 2023
1 parent 12f3de0 commit 5b66ec2
Show file tree
Hide file tree
Showing 15 changed files with 64 additions and 32 deletions.
2 changes: 1 addition & 1 deletion tests/test_attacked_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_window_around_index(self, attacked_text):

def test_big_window_around_index(self, attacked_text):
assert (
attacked_text.text_window_around_index(0, 10**5) + "."
attacked_text.text_window_around_index(0, 10 ** 5) + "."
) == attacked_text.text

def test_window_around_index_start(self, attacked_text):
Expand Down
43 changes: 36 additions & 7 deletions tests/test_metric_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ def test_use():
from textattack import AttackArgs, Attacker
from textattack.attack_recipes import DeepWordBugGao2018
from textattack.datasets import HuggingFaceDataset
from textattack.metrics.quality_metrics import USEMetric
from textattack.metrics.recipe import AdvancedAttackMetric
from textattack.metrics.quality_metrics import MeteorMetric
from textattack.models.wrappers import HuggingFaceModelWrapper

model = transformers.AutoModelForSequenceClassification.from_pretrained(
Expand All @@ -51,12 +50,42 @@ def test_use():
disable_stdout=True,
)
attacker = Attacker(attack, dataset, attack_args)

results = attacker.attack_dataset()

usem = USEMetric().calculate(results)
usem = MeteorMetric().calculate(results)

assert usem["avg_attack_meteor_score"] == 0.71


assert usem["avg_attack_use_score"] == 0.76
def test_metric_recipe():

import transformers

from textattack import AttackArgs, Attacker
from textattack.attack_recipes import DeepWordBugGao2018
from textattack.datasets import HuggingFaceDataset
from textattack.metrics.quality_metrics import USEMetric
from textattack.metrics.recipe import AdvancedAttackMetric
from textattack.models.wrappers import HuggingFaceModelWrapper

model = transformers.AutoModelForSequenceClassification.from_pretrained(
"distilbert-base-uncased-finetuned-sst-2-english"
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
"distilbert-base-uncased-finetuned-sst-2-english"
)
model_wrapper = HuggingFaceModelWrapper(model, tokenizer)
attack = DeepWordBugGao2018.build(model_wrapper)
dataset = HuggingFaceDataset("glue", "sst2", split="train")
attack_args = AttackArgs(
num_examples=1,
log_to_csv="log.csv",
checkpoint_interval=5,
checkpoint_dir="checkpoints",
disable_stdout=True,
)
attacker = Attacker(attack, dataset, attack_args)
results = attacker.attack_dataset()

adv_score = AdvancedAttackMetric(["use", "perplexity"]).calculate(results)
assert adv_score["use"]["avg_attack_use_score"] == 0.76
adv_score = AdvancedAttackMetric(["meteor_score", "perplexity"]).calculate(results)
assert adv_score["avg_attack_meteor_score"] == 0.71
4 changes: 2 additions & 2 deletions tests/test_word_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_embedding_paragramcf():
word_embedding = WordEmbedding.counterfitted_GLOVE_embedding()
assert pytest.approx(word_embedding[0][0]) == -0.022007
assert pytest.approx(word_embedding["fawn"][0]) == -0.022007
assert word_embedding[10**9] is None
assert word_embedding[10 ** 9] is None


def test_embedding_gensim():
Expand All @@ -37,7 +37,7 @@ def test_embedding_gensim():
word_embedding = GensimWordEmbedding(keyed_vectors)
assert pytest.approx(word_embedding[0][0]) == 1
assert pytest.approx(word_embedding["bye-bye"][0]) == -1 / np.sqrt(2)
assert word_embedding[10**9] is None
assert word_embedding[10 ** 9] is None

# test query functionality
assert pytest.approx(word_embedding.get_cos_sim(1, 3)) == 0
Expand Down
4 changes: 2 additions & 2 deletions textattack/attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def __init__(
constraints: List[Union[Constraint, PreTransformationConstraint]],
transformation: Transformation,
search_method: SearchMethod,
transformation_cache_size=2**15,
constraint_cache_size=2**15,
transformation_cache_size=2 ** 15,
constraint_cache_size=2 ** 15,
):
"""Initialize an attack object.
Expand Down
4 changes: 2 additions & 2 deletions textattack/attack_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,8 +507,8 @@ class _CommandLineAttackArgs:
interactive: bool = False
parallel: bool = False
model_batch_size: int = 32
model_cache_size: int = 2**18
constraint_cache_size: int = 2**18
model_cache_size: int = 2 ** 18
constraint_cache_size: int = 2 ** 18

@classmethod
def _add_parser_args(cls, parser):
Expand Down
2 changes: 1 addition & 1 deletion textattack/constraints/grammaticality/cola.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(

self.max_diff = max_diff
self.model_name = model_name
self._reference_score_cache = lru.LRU(2**10)
self._reference_score_cache = lru.LRU(2 ** 10)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = HuggingFaceModelWrapper(model, tokenizer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self):
self.sess, self.graph, self.PBTXT_PATH, self.CKPT_PATH
)

self.lm_cache = lru.LRU(2**18)
self.lm_cache = lru.LRU(2 ** 18)

def clear_cache(self):
self.lm_cache.clear()
Expand Down
2 changes: 1 addition & 1 deletion textattack/constraints/grammaticality/part_of_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
self.language_nltk = language_nltk
self.language_stanza = language_stanza

self._pos_tag_cache = lru.LRU(2**14)
self._pos_tag_cache = lru.LRU(2 ** 14)
if tagger_type == "flair":
if tagset == "universal":
self._flair_pos_tagger = SequenceTagger.load("upos-fast")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, embedding=None, **kwargs):
def clear_cache(self):
self._get_thought_vector.cache_clear()

@functools.lru_cache(maxsize=2**10)
@functools.lru_cache(maxsize=2 ** 10)
def _get_thought_vector(self, text):
"""Sums the embeddings of all the words in ``text`` into a "thought
vector"."""
Expand Down
2 changes: 1 addition & 1 deletion textattack/goal_functions/goal_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
use_cache=True,
query_budget=float("inf"),
model_batch_size=32,
model_cache_size=2**20,
model_cache_size=2 ** 20,
):
validators.validate_model_goal_function_compatibility(
self.__class__, model_wrapper.model.__class__
Expand Down
2 changes: 1 addition & 1 deletion textattack/goal_functions/text/minimize_bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def extra_repr_keys(self):
return ["maximizable", "target_bleu"]


@functools.lru_cache(maxsize=2**12)
@functools.lru_cache(maxsize=2 ** 12)
def get_bleu(a, b):
ref = a.words
hyp = b.words
Expand Down
4 changes: 2 additions & 2 deletions textattack/goal_functions/text/non_overlapping_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ def _get_score(self, model_output, _):
return num_words_diff / len(get_words_cached(self.ground_truth_output))


@functools.lru_cache(maxsize=2**12)
@functools.lru_cache(maxsize=2 ** 12)
def get_words_cached(s):
return np.array(words_from_text(s))


@functools.lru_cache(maxsize=2**12)
@functools.lru_cache(maxsize=2 ** 12)
def word_difference_score(s1, s2):
"""Returns the number of words that are non-overlapping between s1 and
s2."""
Expand Down
2 changes: 1 addition & 1 deletion textattack/metrics/attack_metrics/words_perturbed.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def calculate(self, results):
self.total_attacks = len(self.results)
self.all_num_words = np.zeros(len(self.results))
self.perturbed_word_percentages = np.zeros(len(self.results))
self.num_words_changed_until_success = np.zeros(2**16)
self.num_words_changed_until_success = np.zeros(2 ** 16)
self.max_words_changed = 0

for i, result in enumerate(self.results):
Expand Down
16 changes: 11 additions & 5 deletions textattack/metrics/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
"""
import random

from textattack.metrics.quality_metrics.bert_score import BERTScoreMetric
from textattack.metrics.quality_metrics.meteor_score import MeteorMetric
from textattack.metrics.quality_metrics.perplexity import Perplexity
from textattack.metrics.quality_metrics.sentence_bert import SBERTMetric
from textattack.metrics.quality_metrics.use import USEMetric

from .metric import Metric


Expand All @@ -18,13 +24,13 @@ def __init__(self, choices=["use"]):
def calculate(self, results):
advanced_metrics = {}
if "use" in self.achoices:
advanced_metrics["use"] = USEMetric().calculate(results)
advanced_metrics.update(USEMetric().calculate(results))
if "perplexity" in self.achoices:
advanced_metrics["perplexity"] = Perplexity().calculate(results)
advanced_metrics.update(Perplexity().calculate(results))
if "bert_score" in self.achoices:
advanced_metrics["bert_score"] = BERTScoreMetric().calculate(results)
advanced_metrics.update(BERTScoreMetric().calculate(results))
if "meteor_score" in self.achoices:
advanced_metrics["meteor_score"] = MeteorMetric().calculate(results)
advanced_metrics.update(MeteorMetric().calculate(results))
if "sbert_score" in self.achoices:
advanced_metrics["sbert_score"] = SBERTMetric().calculate(results)
advanced_metrics.update(SBERTMetric().calculate(results))
return advanced_metrics
5 changes: 1 addition & 4 deletions textattack/shared/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,7 @@
r"^textattack.models.helpers.word_cnn_for_classification.*",
r"^transformers.modeling_\w*\.\w*ForSequenceClassification$",
],
(
NonOverlappingOutput,
MinimizeBleu,
): [
(NonOverlappingOutput, MinimizeBleu,): [
r"^textattack.models.helpers.t5_for_text_to_text.*",
],
}
Expand Down

0 comments on commit 5b66ec2

Please sign in to comment.