From 57864fc6c597fa17d1c620502e768cbeafba7b02 Mon Sep 17 00:00:00 2001 From: k-ivey Date: Thu, 5 Oct 2023 18:00:03 -0400 Subject: [PATCH 1/5] Add consistent parameter to name and location word swap --- tests/test_transformations.py | 58 +++++++++++++++++++ .../word_swaps/word_swap_change_location.py | 54 +++++++++++++---- .../word_swaps/word_swap_change_name.py | 31 +++++++++- 3 files changed, 130 insertions(+), 13 deletions(-) diff --git a/tests/test_transformations.py b/tests/test_transformations.py index 506d267a6..d95d9facd 100644 --- a/tests/test_transformations.py +++ b/tests/test_transformations.py @@ -33,6 +33,36 @@ def test_word_swap_change_location(): assert entity_original == entity_augmented +def test_word_swap_change_location_consistent(): + from flair.data import Sentence + from flair.models import SequenceTagger + + from textattack.augmentation import Augmenter + from textattack.transformations.word_swaps import WordSwapChangeLocation + + augmenter = Augmenter(transformation=WordSwapChangeLocation(consistent=True)) + s = "I am in New York. I love living in New York." + s_augmented = augmenter.augment(s) + augmented_text = Sentence(s_augmented[0]) + tagger = SequenceTagger.load("flair/ner-english") + original_text = Sentence(s) + tagger.predict(original_text) + tagger.predict(augmented_text) + + entity_original = [] + entity_augmented = [] + + for entity in original_text.get_spans("ner"): + entity_original.append(entity.tag) + for entity in augmented_text.get_spans("ner"): + entity_augmented.append(entity.tag) + + print(entity_original) + + assert entity_original == entity_augmented + assert s_augmented[0].count("New York") == 0 + + def test_word_swap_change_name(): from flair.data import Sentence from flair.models import SequenceTagger @@ -59,6 +89,34 @@ def test_word_swap_change_name(): assert entity_original == entity_augmented +def test_word_swap_change_name_consistent(): + from flair.data import Sentence + from flair.models import SequenceTagger + + from textattack.augmentation import Augmenter + from textattack.transformations.word_swaps import WordSwapChangeName + + augmenter = Augmenter(transformation=WordSwapChangeName(consistent=True)) + s = "My name is Anthony Davis. Anthony Davis plays basketball." + s_augmented = augmenter.augment(s) + augmented_text = Sentence(s_augmented[0]) + tagger = SequenceTagger.load("flair/ner-english") + original_text = Sentence(s) + tagger.predict(original_text) + tagger.predict(augmented_text) + + entity_original = [] + entity_augmented = [] + + for entity in original_text.get_spans("ner"): + entity_original.append(entity.tag) + for entity in augmented_text.get_spans("ner"): + entity_augmented.append(entity.tag) + + assert entity_original == entity_augmented + assert s_augmented[0].count("Anthony") == 0 or s_augmented[0].count("Davis") == 0 + + def test_chinese_morphonym_character_swap(): from textattack.augmentation import Augmenter from textattack.transformations.word_swaps.chn_transformations import ( diff --git a/textattack/transformations/word_swaps/word_swap_change_location.py b/textattack/transformations/word_swaps/word_swap_change_location.py index 14f82ff6a..7c2bbff05 100644 --- a/textattack/transformations/word_swaps/word_swap_change_location.py +++ b/textattack/transformations/word_swaps/word_swap_change_location.py @@ -2,6 +2,8 @@ Word Swap by Changing Location ------------------------------- """ +from collections import defaultdict + import more_itertools as mit import numpy as np @@ -25,12 +27,15 @@ def idx_to_words(ls, words): class WordSwapChangeLocation(WordSwap): - def __init__(self, n=3, confidence_score=0.7, language="en", **kwargs): + def __init__( + self, n=3, confidence_score=0.7, language="en", consistent=False, **kwargs + ): """Transformation that changes recognized locations of a sentence to another location that is given in the location map. :param n: Number of new locations to generate :param confidence_score: Location will only be changed if it's above the confidence score + :param consistent: Whether to change all instances of the same location to the same new location >>> from textattack.transformations import WordSwapChangeLocation >>> from textattack.augmentation import Augmenter @@ -44,6 +49,7 @@ def __init__(self, n=3, confidence_score=0.7, language="en", **kwargs): self.n = n self.confidence_score = confidence_score self.language = language + self.consistent = consistent def _get_transformations(self, current_text, indices_to_modify): words = current_text.words @@ -64,26 +70,46 @@ def _get_transformations(self, current_text, indices_to_modify): location_idx = [list(group) for group in mit.consecutive_groups(location_idx)] location_words = idx_to_words(location_idx, words) + if self.consistent: + location_to_indices = defaultdict(list) + for idx, location in location_words: + location_to_indices[self._capitalize(location)].append(idx[0]) + transformed_texts = [] for location in location_words: idx = location[0] - word = location[1].capitalize() + word = self._capitalize(location[1]) replacement_words = self._get_new_location(word) for r in replacement_words: if r == word: continue - text = current_text - # if original location is more than a single word, remain only the starting word - if len(idx) > 1: - index = idx[1] - for i in idx[1:]: - text = text.delete_word_at_index(index) + if self.consistent: + # If we're doing consistent replacements, only replace the word + # if it hasn't already been replaced in a previous iteration + if word not in location_to_indices: + continue + + indices_to_delete = [] + if len(idx) > 1: + for i in location_to_indices[word]: + for j in range(1, len(idx)): + indices_to_delete.append(i + j) + + transformed_texts.append(current_text.replace_words_at_indices( + location_to_indices[word] + indices_to_delete, + ([r] * len(location_to_indices[word])) + + ([""] * len(indices_to_delete)), + )) + + # Delete this word to mark it as replaced + del location_to_indices[word] + else: + # If the original location is more than a single word, keep only the starting word + # and replace the starting word with the new word + indices_to_delete = idx[1:] + transformed_texts.append(current_text.replace_words_at_indices([idx[0]] + indices_to_delete, [r] + [""] * len(indices_to_delete))) - # replace the starting word with new location - text = text.replace_word_at_index(idx[0], r) - - transformed_texts.append(text) return transformed_texts def _get_new_location(self, word): @@ -101,3 +127,7 @@ def _get_new_location(self, word): elif word in NAMED_ENTITIES["city"]: return np.random.choice(NAMED_ENTITIES["city"], self.n) return [] + + def _capitalize(self, string): + """Capitalizes all words in the string.""" + return " ".join(word.capitalize() for word in string.split()) diff --git a/textattack/transformations/word_swaps/word_swap_change_name.py b/textattack/transformations/word_swaps/word_swap_change_name.py index c4feeff48..429d05bc5 100644 --- a/textattack/transformations/word_swaps/word_swap_change_name.py +++ b/textattack/transformations/word_swaps/word_swap_change_name.py @@ -3,6 +3,8 @@ ------------------------------- """ +from collections import defaultdict + import numpy as np from textattack.shared.data import PERSON_NAMES @@ -18,6 +20,7 @@ def __init__( last_only=False, confidence_score=0.7, language="en", + consistent=False, **kwargs ): """Transforms an input by replacing names of recognized name entity. @@ -26,6 +29,7 @@ def __init__( :param first_only: Whether to change first name only :param last_only: Whether to change last name only :param confidence_score: Name will only be changed when it's above confidence score + :param consistent: Whether to change all instances of the same name to the same new name >>> from textattack.transformations import WordSwapChangeName >>> from textattack.augmentation import Augmenter @@ -42,6 +46,7 @@ def __init__( self.last_only = last_only self.confidence_score = confidence_score self.language = language + self.consistent = consistent def _get_transformations(self, current_text, indices_to_modify): transformed_texts = [] @@ -52,14 +57,38 @@ def _get_transformations(self, current_text, indices_to_modify): else: model_name = "flair/ner-multi-fast" + if self.consistent: + word_to_indices = defaultdict(list) + for i in indices_to_modify: + word_to_replace = current_text.words[i].capitalize() + word_to_indices[word_to_replace].append(i) + for i in indices_to_modify: word_to_replace = current_text.words[i].capitalize() + # If we're doing consistent replacements, only replace the word + # if it hasn't already been replaced in a previous iteration + if self.consistent and word_to_replace not in word_to_indices: + continue word_to_replace_ner = current_text.ner_of_word_index(i, model_name) + replacement_words = self._get_replacement_words( word_to_replace, word_to_replace_ner ) + for r in replacement_words: - transformed_texts.append(current_text.replace_word_at_index(i, r)) + if self.consistent: + transformed_texts.append( + current_text.replace_words_at_indices( + word_to_indices[word_to_replace], + [r] * len(word_to_indices[word_to_replace]), + ) + ) + else: + transformed_texts.append(current_text.replace_word_at_index(i, r)) + + # Delete this word to mark it as replaced + if self.consistent and len(replacement_words) != 0: + del word_to_indices[word_to_replace] return transformed_texts From cc59ec55715c92df085f866bbcf382c1e88389f6 Mon Sep 17 00:00:00 2001 From: k-ivey Date: Thu, 5 Oct 2023 18:16:20 -0400 Subject: [PATCH 2/5] Remove debugging print statement --- tests/test_transformations.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_transformations.py b/tests/test_transformations.py index d95d9facd..0b78674d8 100644 --- a/tests/test_transformations.py +++ b/tests/test_transformations.py @@ -57,8 +57,6 @@ def test_word_swap_change_location_consistent(): for entity in augmented_text.get_spans("ner"): entity_augmented.append(entity.tag) - print(entity_original) - assert entity_original == entity_augmented assert s_augmented[0].count("New York") == 0 From 93e522e8ab7f91e4a4d992bbdc3bf39a422a1cc6 Mon Sep 17 00:00:00 2001 From: k-ivey Date: Thu, 5 Oct 2023 18:44:46 -0400 Subject: [PATCH 3/5] Fix formatting --- .../word_swaps/word_swap_change_location.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/textattack/transformations/word_swaps/word_swap_change_location.py b/textattack/transformations/word_swaps/word_swap_change_location.py index 7c2bbff05..f38a2cb05 100644 --- a/textattack/transformations/word_swaps/word_swap_change_location.py +++ b/textattack/transformations/word_swaps/word_swap_change_location.py @@ -96,11 +96,13 @@ def _get_transformations(self, current_text, indices_to_modify): for j in range(1, len(idx)): indices_to_delete.append(i + j) - transformed_texts.append(current_text.replace_words_at_indices( - location_to_indices[word] + indices_to_delete, - ([r] * len(location_to_indices[word])) - + ([""] * len(indices_to_delete)), - )) + transformed_texts.append( + current_text.replace_words_at_indices( + location_to_indices[word] + indices_to_delete, + ([r] * len(location_to_indices[word])) + + ([""] * len(indices_to_delete)), + ) + ) # Delete this word to mark it as replaced del location_to_indices[word] @@ -108,7 +110,12 @@ def _get_transformations(self, current_text, indices_to_modify): # If the original location is more than a single word, keep only the starting word # and replace the starting word with the new word indices_to_delete = idx[1:] - transformed_texts.append(current_text.replace_words_at_indices([idx[0]] + indices_to_delete, [r] + [""] * len(indices_to_delete))) + transformed_texts.append( + current_text.replace_words_at_indices( + [idx[0]] + indices_to_delete, + [r] + [""] * len(indices_to_delete), + ) + ) return transformed_texts From ea3fc964c69b9eb6ce53b7392a73136eece53738 Mon Sep 17 00:00:00 2001 From: k-ivey Date: Wed, 1 Nov 2023 20:27:49 -0400 Subject: [PATCH 4/5] Fix word swap for multi-word locations --- .../word_swaps/word_swap_change_location.py | 75 ++++++++++++++++--- 1 file changed, 64 insertions(+), 11 deletions(-) diff --git a/textattack/transformations/word_swaps/word_swap_change_location.py b/textattack/transformations/word_swaps/word_swap_change_location.py index f38a2cb05..2a22f5b39 100644 --- a/textattack/transformations/word_swaps/word_swap_change_location.py +++ b/textattack/transformations/word_swaps/word_swap_change_location.py @@ -71,25 +71,26 @@ def _get_transformations(self, current_text, indices_to_modify): location_words = idx_to_words(location_idx, words) if self.consistent: - location_to_indices = defaultdict(list) - for idx, location in location_words: - location_to_indices[self._capitalize(location)].append(idx[0]) + location_to_indices = self._build_location_to_indicies_map( + location_words, current_text + ) transformed_texts = [] for location in location_words: idx = location[0] word = self._capitalize(location[1]) + + # If doing consistent replacements, only replace the + # word if it hasn't been replaced in a previous iteration + if self.consistent and word not in location_to_indices: + continue + replacement_words = self._get_new_location(word) for r in replacement_words: if r == word: continue if self.consistent: - # If we're doing consistent replacements, only replace the word - # if it hasn't already been replaced in a previous iteration - if word not in location_to_indices: - continue - indices_to_delete = [] if len(idx) > 1: for i in location_to_indices[word]: @@ -103,9 +104,6 @@ def _get_transformations(self, current_text, indices_to_modify): + ([""] * len(indices_to_delete)), ) ) - - # Delete this word to mark it as replaced - del location_to_indices[word] else: # If the original location is more than a single word, keep only the starting word # and replace the starting word with the new word @@ -117,6 +115,10 @@ def _get_transformations(self, current_text, indices_to_modify): ) ) + if self.consistent: + # Delete this word to mark it as replaced + del location_to_indices[word] + return transformed_texts def _get_new_location(self, word): @@ -138,3 +140,54 @@ def _get_new_location(self, word): def _capitalize(self, string): """Capitalizes all words in the string.""" return " ".join(word.capitalize() for word in string.split()) + + def _build_location_to_indicies_map(self, location_words, text): + """Returns a map of each location and the starting indicies of all + appearances of that location in the text.""" + + location_to_indices = defaultdict(list) + if len(location_words) == 0: + return location_to_indices + + location_words.sort( + # Sort by the number of words in the location + key=lambda index_location_pair: index_location_pair[0][-1] + - index_location_pair[0][0] + + 1, + reverse=True, + ) + max_length = location_words[0][0][-1] - location_words[0][0][0] + 1 + + for idx, location in location_words: + + words_in_location = idx[-1] - idx[0] + 1 + found = False + location_start = idx[0] + + # Check each window of n words containing the original tagged location + # for n from the max_length down to the original location length. + # This prevents cases where the NER tagger misses a word in a location + # (e.g. it does not tag "New" in "New York") + for length in range(max_length, words_in_location, -1): + for start in range( + location_start - length + words_in_location, + location_start + 1, + ): + if start + length > len(text.words): + break + + expanded_location = self._capitalize( + " ".join(text.words[start : start + length]) + ) + if expanded_location in location_to_indices: + location_to_indices[expanded_location].append(start) + found = True + break + + if found: + break + + if not found: + location_to_indices[self._capitalize(location)].append(idx[0]) + + return location_to_indices From bebf70ff09369668df9cfc491556659277c8a232 Mon Sep 17 00:00:00 2001 From: k-ivey Date: Thu, 2 Nov 2023 14:23:19 -0400 Subject: [PATCH 5/5] Update formatting to match newest version of black --- .../transformations/word_swaps/word_swap_change_location.py | 1 - 1 file changed, 1 deletion(-) diff --git a/textattack/transformations/word_swaps/word_swap_change_location.py b/textattack/transformations/word_swaps/word_swap_change_location.py index 2a22f5b39..916b97d39 100644 --- a/textattack/transformations/word_swaps/word_swap_change_location.py +++ b/textattack/transformations/word_swaps/word_swap_change_location.py @@ -159,7 +159,6 @@ def _build_location_to_indicies_map(self, location_words, text): max_length = location_words[0][0][-1] - location_words[0][0][0] + 1 for idx, location in location_words: - words_in_location = idx[-1] - idx[0] + 1 found = False location_start = idx[0]