Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Add adversarial defense benchmark #694

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 44 additions & 32 deletions docs/2notebook/1_Introduction_and_Transformations.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -77,19 +77,19 @@
"source": [
"from textattack.transformations import WordSwap\n",
"\n",
"\n",
"class BananaWordSwap(WordSwap):\n",
" \"\"\" Transforms an input by replacing any word with 'banana'.\n",
" \"\"\"\n",
" \n",
" \"\"\"Transforms an input by replacing any word with 'banana'.\"\"\"\n",
"\n",
" # We don't need a constructor, since our class doesn't require any parameters.\n",
"\n",
" def _get_replacement_words(self, word):\n",
" \"\"\" Returns 'banana', no matter what 'word' was originally.\n",
" \n",
" Returns a list with one item, since `_get_replacement_words` is intended to\n",
" return a list of candidate replacement words.\n",
" \"\"\"Returns 'banana', no matter what 'word' was originally.\n",
"\n",
" Returns a list with one item, since `_get_replacement_words` is intended to\n",
" return a list of candidate replacement words.\n",
" \"\"\"\n",
" return ['banana']"
" return [\"banana\"]"
]
},
{
Expand Down Expand Up @@ -133,17 +133,23 @@
"import transformers\n",
"from textattack.models.wrappers import HuggingFaceModelWrapper\n",
"\n",
"model = transformers.AutoModelForSequenceClassification.from_pretrained(\"textattack/bert-base-uncased-ag-news\")\n",
"tokenizer = transformers.AutoTokenizer.from_pretrained(\"textattack/bert-base-uncased-ag-news\")\n",
"model = transformers.AutoModelForSequenceClassification.from_pretrained(\n",
" \"textattack/bert-base-uncased-ag-news\"\n",
")\n",
"tokenizer = transformers.AutoTokenizer.from_pretrained(\n",
" \"textattack/bert-base-uncased-ag-news\"\n",
")\n",
"\n",
"model_wrapper = HuggingFaceModelWrapper(model, tokenizer)\n",
"\n",
"# Create the goal function using the model\n",
"from textattack.goal_functions import UntargetedClassification\n",
"\n",
"goal_function = UntargetedClassification(model_wrapper)\n",
"\n",
"# Import the dataset\n",
"from textattack.datasets import HuggingFaceDataset\n",
"\n",
"dataset = HuggingFaceDataset(\"ag_news\", None, \"test\")"
]
},
Expand All @@ -166,14 +172,16 @@
"outputs": [],
"source": [
"from textattack.search_methods import GreedySearch\n",
"from textattack.constraints.pre_transformation import RepeatModification, StopwordModification\n",
"from textattack.constraints.pre_transformation import (\n",
" RepeatModification,\n",
" StopwordModification,\n",
")\n",
"from textattack import Attack\n",
"\n",
"# We're going to use our Banana word swap class as the attack transformation.\n",
"transformation = BananaWordSwap() \n",
"transformation = BananaWordSwap()\n",
"# We'll constrain modification of already modified indices and stopwords\n",
"constraints = [RepeatModification(),\n",
" StopwordModification()]\n",
"constraints = [RepeatModification(), StopwordModification()]\n",
"# We'll use the Greedy search method\n",
"search_method = GreedySearch()\n",
"# Now, let's make the attack from the 4 components:\n",
Expand Down Expand Up @@ -517,8 +525,8 @@
}
],
"source": [
"from tqdm import tqdm # tqdm provides us a nice progress bar.\n",
"from textattack.loggers import CSVLogger # tracks a dataframe for us.\n",
"from tqdm import tqdm # tqdm provides us a nice progress bar.\n",
"from textattack.loggers import CSVLogger # tracks a dataframe for us.\n",
"from textattack.attack_results import SuccessfulAttackResult\n",
"from textattack import Attacker\n",
"from textattack import AttackArgs\n",
Expand All @@ -530,14 +538,14 @@
"\n",
"attack_results = attacker.attack_dataset()\n",
"\n",
"#The following legacy tutorial code shows how the Attack API works in detail.\n",
"# The following legacy tutorial code shows how the Attack API works in detail.\n",
"\n",
"#logger = CSVLogger(color_method='html')\n",
"# logger = CSVLogger(color_method='html')\n",
"\n",
"#num_successes = 0\n",
"#i = 0\n",
"#while num_successes < 10:\n",
" #result = next(results_iterable)\n",
"# num_successes = 0\n",
"# i = 0\n",
"# while num_successes < 10:\n",
"# result = next(results_iterable)\n",
"# example, ground_truth_output = dataset[i]\n",
"# i += 1\n",
"# result = attack.attack(example, ground_truth_output)\n",
Expand Down Expand Up @@ -652,15 +660,19 @@
],
"source": [
"import pandas as pd\n",
"pd.options.display.max_colwidth = 480 # increase colum width so we can actually read the examples\n",
"\n",
"logger = CSVLogger(color_method='html')\n",
"pd.options.display.max_colwidth = (\n",
" 480 # increase colum width so we can actually read the examples\n",
")\n",
"\n",
"logger = CSVLogger(color_method=\"html\")\n",
"\n",
"for result in attack_results:\n",
" logger.log_attack_result(result)\n",
"\n",
"from IPython.core.display import display, HTML\n",
"display(HTML(logger.df[['original_text', 'perturbed_text']].to_html(escape=False)))"
"\n",
"display(HTML(logger.df[[\"original_text\", \"perturbed_text\"]].to_html(escape=False)))"
]
},
{
Expand Down Expand Up @@ -867,10 +879,10 @@
"# For AG News, labels are 0: World, 1: Sports, 2: Business, 3: Sci/Tech\n",
"\n",
"custom_dataset = [\n",
" ('Malaria deaths in Africa fall by 5% from last year', 0),\n",
" ('Washington Nationals defeat the Houston Astros to win the World Series', 1),\n",
" ('Exxon Mobil hires a new CEO', 2),\n",
" ('Microsoft invests $1 billion in OpenAI', 3),\n",
" (\"Malaria deaths in Africa fall by 5% from last year\", 0),\n",
" (\"Washington Nationals defeat the Houston Astros to win the World Series\", 1),\n",
" (\"Exxon Mobil hires a new CEO\", 2),\n",
" (\"Microsoft invests $1 billion in OpenAI\", 3),\n",
"]\n",
"\n",
"attack_args = AttackArgs(num_examples=4)\n",
Expand All @@ -881,14 +893,14 @@
"\n",
"results_iterable = attacker.attack_dataset()\n",
"\n",
"logger = CSVLogger(color_method='html')\n",
"logger = CSVLogger(color_method=\"html\")\n",
"\n",
"for result in results_iterable:\n",
" logger.log_attack_result(result)\n",
"\n",
"from IPython.core.display import display, HTML\n",
" \n",
"display(HTML(logger.df[['original_text', 'perturbed_text']].to_html(escape=False)))"
"\n",
"display(HTML(logger.df[[\"original_text\", \"perturbed_text\"]].to_html(escape=False)))"
]
}
],
Expand Down
86 changes: 54 additions & 32 deletions docs/2notebook/2_Constraints.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
],
"source": [
"import tensorflow as tf\n",
"\n",
"print(tf.__version__)"
]
},
Expand Down Expand Up @@ -149,10 +150,11 @@
"!pip3 install .\n",
"\n",
"import nltk\n",
"nltk.download('punkt') # The NLTK tokenizer\n",
"nltk.download('maxent_ne_chunker') # NLTK named-entity chunker\n",
"nltk.download('words') # NLTK list of words\n",
"nltk.download('averaged_perceptron_tagger')"
"\n",
"nltk.download(\"punkt\") # The NLTK tokenizer\n",
"nltk.download(\"maxent_ne_chunker\") # NLTK named-entity chunker\n",
"nltk.download(\"words\") # NLTK list of words\n",
"nltk.download(\"averaged_perceptron_tagger\")"
]
},
{
Expand Down Expand Up @@ -205,8 +207,10 @@
}
],
"source": [
"sentence = ('In 2017, star quarterback Tom Brady led the Patriots to the Super Bowl, '\n",
" 'but lost to the Philadelphia Eagles.')\n",
"sentence = (\n",
" \"In 2017, star quarterback Tom Brady led the Patriots to the Super Bowl, \"\n",
" \"but lost to the Philadelphia Eagles.\"\n",
")\n",
"\n",
"# 1. Tokenize using the NLTK tokenizer.\n",
"tokens = nltk.word_tokenize(sentence)\n",
Expand Down Expand Up @@ -285,6 +289,7 @@
"source": [
"import functools\n",
"\n",
"\n",
"@functools.lru_cache(maxsize=2**14)\n",
"def get_entities(sentence):\n",
" tokens = nltk.word_tokenize(sentence)\n",
Expand Down Expand Up @@ -379,9 +384,10 @@
"source": [
"from textattack.constraints import Constraint\n",
"\n",
"\n",
"class NamedEntityConstraint(Constraint):\n",
" \"\"\" A constraint that ensures `transformed_text` only substitutes named entities from `current_text` with other named entities.\n",
" \"\"\"\n",
" \"\"\"A constraint that ensures `transformed_text` only substitutes named entities from `current_text` with other named entities.\"\"\"\n",
"\n",
" def _check_constraint(self, transformed_text, current_text):\n",
" transformed_entities = get_entities(transformed_text.text)\n",
" current_entities = get_entities(current_text.text)\n",
Expand All @@ -390,26 +396,27 @@
" if len(current_entities) == 0:\n",
" return False\n",
" if len(current_entities) != len(transformed_entities):\n",
" # If the two sentences have a different number of entities, then \n",
" # they definitely don't have the same labels. In this case, the \n",
" # If the two sentences have a different number of entities, then\n",
" # they definitely don't have the same labels. In this case, the\n",
" # constraint is violated, and we return False.\n",
" return False\n",
" else:\n",
" # Here we compare all of the words, in order, to make sure that they match.\n",
" # If we find two words that don't match, this means a word was swapped \n",
" # If we find two words that don't match, this means a word was swapped\n",
" # between `current_text` and `transformed_text`. That word must be a named entity to fulfill our\n",
" # constraint.\n",
" current_word_label = None\n",
" transformed_word_label = None\n",
" for (word_1, label_1), (word_2, label_2) in zip(current_entities, transformed_entities):\n",
" for (word_1, label_1), (word_2, label_2) in zip(\n",
" current_entities, transformed_entities\n",
" ):\n",
" if word_1 != word_2:\n",
" # Finally, make sure that words swapped between `x` and `x_adv` are named entities. If \n",
" # Finally, make sure that words swapped between `x` and `x_adv` are named entities. If\n",
" # they're not, then we also return False.\n",
" if (label_1 not in ['NNP', 'NE']) or (label_2 not in ['NNP', 'NE']):\n",
" return False \n",
" if (label_1 not in [\"NNP\", \"NE\"]) or (label_2 not in [\"NNP\", \"NE\"]):\n",
" return False\n",
" # If we get here, all of the labels match up. Return True!\n",
" return True\n",
" "
" return True"
]
},
{
Expand Down Expand Up @@ -638,17 +645,23 @@
"import transformers\n",
"from textattack.models.wrappers import HuggingFaceModelWrapper\n",
"\n",
"model = transformers.AutoModelForSequenceClassification.from_pretrained(\"textattack/albert-base-v2-ag-news\")\n",
"tokenizer = transformers.AutoTokenizer.from_pretrained(\"textattack/albert-base-v2-ag-news\")\n",
"model = transformers.AutoModelForSequenceClassification.from_pretrained(\n",
" \"textattack/albert-base-v2-ag-news\"\n",
")\n",
"tokenizer = transformers.AutoTokenizer.from_pretrained(\n",
" \"textattack/albert-base-v2-ag-news\"\n",
")\n",
"\n",
"model_wrapper = HuggingFaceModelWrapper(model, tokenizer)\n",
"\n",
"# Create the goal function using the model\n",
"from textattack.goal_functions import UntargetedClassification\n",
"\n",
"goal_function = UntargetedClassification(model_wrapper)\n",
"\n",
"# Import the dataset\n",
"from textattack.datasets import HuggingFaceDataset\n",
"\n",
"dataset = HuggingFaceDataset(\"ag_news\", None, \"test\")"
]
},
Expand All @@ -663,23 +676,27 @@
"from textattack.transformations import WordSwapEmbedding\n",
"from textattack.search_methods import GreedyWordSwapWIR\n",
"from textattack import Attack\n",
"from textattack.constraints.pre_transformation import RepeatModification, StopwordModification\n",
"from textattack.constraints.pre_transformation import (\n",
" RepeatModification,\n",
" StopwordModification,\n",
")\n",
"\n",
"# We're going to the `WordSwapEmbedding` transformation. Using the default settings, this\n",
"# will try substituting words with their neighbors in the counter-fitted embedding space. \n",
"transformation = WordSwapEmbedding(max_candidates=20) \n",
"# will try substituting words with their neighbors in the counter-fitted embedding space.\n",
"transformation = WordSwapEmbedding(max_candidates=20)\n",
"\n",
"# We'll use the greedy search with word importance ranking method again\n",
"search_method = GreedyWordSwapWIR()\n",
"\n",
"# Our constraints will be the same as Tutorial 1, plus the named entity constraint\n",
"constraints = [RepeatModification(),\n",
" StopwordModification(),\n",
" NamedEntityConstraint(False)]\n",
"constraints = [\n",
" RepeatModification(),\n",
" StopwordModification(),\n",
" NamedEntityConstraint(False),\n",
"]\n",
"\n",
"# Now, let's make the attack using these parameters. \n",
"attack = Attack(goal_function, constraints, transformation, search_method)\n",
"\n"
"# Now, let's make the attack using these parameters.\n",
"attack = Attack(goal_function, constraints, transformation, search_method)"
]
},
{
Expand Down Expand Up @@ -800,11 +817,13 @@
}
],
"source": [
"from textattack.loggers import CSVLogger # tracks a dataframe for us.\n",
"from textattack.loggers import CSVLogger # tracks a dataframe for us.\n",
"from textattack.attack_results import SuccessfulAttackResult\n",
"from textattack import Attacker, AttackArgs\n",
"\n",
"attack_args = AttackArgs(num_successful_examples=5, log_to_csv=\"results.csv\", csv_coloring_style=\"html\")\n",
"attack_args = AttackArgs(\n",
" num_successful_examples=5, log_to_csv=\"results.csv\", csv_coloring_style=\"html\"\n",
")\n",
"attacker = Attacker(attack, dataset, attack_args)\n",
"\n",
"attacker.attack_dataset()"
Expand Down Expand Up @@ -833,13 +852,16 @@
"outputs": [],
"source": [
"import pandas as pd\n",
"pd.options.display.max_colwidth = 480 # increase column width so we can actually read the examples\n",
"\n",
"pd.options.display.max_colwidth = (\n",
" 480 # increase column width so we can actually read the examples\n",
")\n",
"\n",
"from IPython.core.display import display, HTML\n",
"\n",
"logger = attacker.attack_log_manager.loggers[0]\n",
"successes = logger.df[logger.df[\"result_type\"] == \"Successful\"]\n",
"display(HTML(successes[['original_text', 'perturbed_text']].to_html(escape=False)))"
"display(HTML(successes[[\"original_text\", \"perturbed_text\"]].to_html(escape=False)))"
]
},
{
Expand Down
Loading