diff --git a/tests/test_transformations.py b/tests/test_transformations.py index 506d267a..0b78674d 100644 --- a/tests/test_transformations.py +++ b/tests/test_transformations.py @@ -33,6 +33,34 @@ def test_word_swap_change_location(): assert entity_original == entity_augmented +def test_word_swap_change_location_consistent(): + from flair.data import Sentence + from flair.models import SequenceTagger + + from textattack.augmentation import Augmenter + from textattack.transformations.word_swaps import WordSwapChangeLocation + + augmenter = Augmenter(transformation=WordSwapChangeLocation(consistent=True)) + s = "I am in New York. I love living in New York." + s_augmented = augmenter.augment(s) + augmented_text = Sentence(s_augmented[0]) + tagger = SequenceTagger.load("flair/ner-english") + original_text = Sentence(s) + tagger.predict(original_text) + tagger.predict(augmented_text) + + entity_original = [] + entity_augmented = [] + + for entity in original_text.get_spans("ner"): + entity_original.append(entity.tag) + for entity in augmented_text.get_spans("ner"): + entity_augmented.append(entity.tag) + + assert entity_original == entity_augmented + assert s_augmented[0].count("New York") == 0 + + def test_word_swap_change_name(): from flair.data import Sentence from flair.models import SequenceTagger @@ -59,6 +87,34 @@ def test_word_swap_change_name(): assert entity_original == entity_augmented +def test_word_swap_change_name_consistent(): + from flair.data import Sentence + from flair.models import SequenceTagger + + from textattack.augmentation import Augmenter + from textattack.transformations.word_swaps import WordSwapChangeName + + augmenter = Augmenter(transformation=WordSwapChangeName(consistent=True)) + s = "My name is Anthony Davis. Anthony Davis plays basketball." + s_augmented = augmenter.augment(s) + augmented_text = Sentence(s_augmented[0]) + tagger = SequenceTagger.load("flair/ner-english") + original_text = Sentence(s) + tagger.predict(original_text) + tagger.predict(augmented_text) + + entity_original = [] + entity_augmented = [] + + for entity in original_text.get_spans("ner"): + entity_original.append(entity.tag) + for entity in augmented_text.get_spans("ner"): + entity_augmented.append(entity.tag) + + assert entity_original == entity_augmented + assert s_augmented[0].count("Anthony") == 0 or s_augmented[0].count("Davis") == 0 + + def test_chinese_morphonym_character_swap(): from textattack.augmentation import Augmenter from textattack.transformations.word_swaps.chn_transformations import ( diff --git a/textattack/transformations/word_swaps/word_swap_change_location.py b/textattack/transformations/word_swaps/word_swap_change_location.py index 14f82ff6..916b97d3 100644 --- a/textattack/transformations/word_swaps/word_swap_change_location.py +++ b/textattack/transformations/word_swaps/word_swap_change_location.py @@ -2,6 +2,8 @@ Word Swap by Changing Location ------------------------------- """ +from collections import defaultdict + import more_itertools as mit import numpy as np @@ -25,12 +27,15 @@ def idx_to_words(ls, words): class WordSwapChangeLocation(WordSwap): - def __init__(self, n=3, confidence_score=0.7, language="en", **kwargs): + def __init__( + self, n=3, confidence_score=0.7, language="en", consistent=False, **kwargs + ): """Transformation that changes recognized locations of a sentence to another location that is given in the location map. :param n: Number of new locations to generate :param confidence_score: Location will only be changed if it's above the confidence score + :param consistent: Whether to change all instances of the same location to the same new location >>> from textattack.transformations import WordSwapChangeLocation >>> from textattack.augmentation import Augmenter @@ -44,6 +49,7 @@ def __init__(self, n=3, confidence_score=0.7, language="en", **kwargs): self.n = n self.confidence_score = confidence_score self.language = language + self.consistent = consistent def _get_transformations(self, current_text, indices_to_modify): words = current_text.words @@ -64,26 +70,55 @@ def _get_transformations(self, current_text, indices_to_modify): location_idx = [list(group) for group in mit.consecutive_groups(location_idx)] location_words = idx_to_words(location_idx, words) + if self.consistent: + location_to_indices = self._build_location_to_indicies_map( + location_words, current_text + ) + transformed_texts = [] for location in location_words: idx = location[0] - word = location[1].capitalize() + word = self._capitalize(location[1]) + + # If doing consistent replacements, only replace the + # word if it hasn't been replaced in a previous iteration + if self.consistent and word not in location_to_indices: + continue + replacement_words = self._get_new_location(word) for r in replacement_words: if r == word: continue - text = current_text - - # if original location is more than a single word, remain only the starting word - if len(idx) > 1: - index = idx[1] - for i in idx[1:]: - text = text.delete_word_at_index(index) - # replace the starting word with new location - text = text.replace_word_at_index(idx[0], r) + if self.consistent: + indices_to_delete = [] + if len(idx) > 1: + for i in location_to_indices[word]: + for j in range(1, len(idx)): + indices_to_delete.append(i + j) + + transformed_texts.append( + current_text.replace_words_at_indices( + location_to_indices[word] + indices_to_delete, + ([r] * len(location_to_indices[word])) + + ([""] * len(indices_to_delete)), + ) + ) + else: + # If the original location is more than a single word, keep only the starting word + # and replace the starting word with the new word + indices_to_delete = idx[1:] + transformed_texts.append( + current_text.replace_words_at_indices( + [idx[0]] + indices_to_delete, + [r] + [""] * len(indices_to_delete), + ) + ) + + if self.consistent: + # Delete this word to mark it as replaced + del location_to_indices[word] - transformed_texts.append(text) return transformed_texts def _get_new_location(self, word): @@ -101,3 +136,57 @@ def _get_new_location(self, word): elif word in NAMED_ENTITIES["city"]: return np.random.choice(NAMED_ENTITIES["city"], self.n) return [] + + def _capitalize(self, string): + """Capitalizes all words in the string.""" + return " ".join(word.capitalize() for word in string.split()) + + def _build_location_to_indicies_map(self, location_words, text): + """Returns a map of each location and the starting indicies of all + appearances of that location in the text.""" + + location_to_indices = defaultdict(list) + if len(location_words) == 0: + return location_to_indices + + location_words.sort( + # Sort by the number of words in the location + key=lambda index_location_pair: index_location_pair[0][-1] + - index_location_pair[0][0] + + 1, + reverse=True, + ) + max_length = location_words[0][0][-1] - location_words[0][0][0] + 1 + + for idx, location in location_words: + words_in_location = idx[-1] - idx[0] + 1 + found = False + location_start = idx[0] + + # Check each window of n words containing the original tagged location + # for n from the max_length down to the original location length. + # This prevents cases where the NER tagger misses a word in a location + # (e.g. it does not tag "New" in "New York") + for length in range(max_length, words_in_location, -1): + for start in range( + location_start - length + words_in_location, + location_start + 1, + ): + if start + length > len(text.words): + break + + expanded_location = self._capitalize( + " ".join(text.words[start : start + length]) + ) + if expanded_location in location_to_indices: + location_to_indices[expanded_location].append(start) + found = True + break + + if found: + break + + if not found: + location_to_indices[self._capitalize(location)].append(idx[0]) + + return location_to_indices diff --git a/textattack/transformations/word_swaps/word_swap_change_name.py b/textattack/transformations/word_swaps/word_swap_change_name.py index c4feeff4..429d05bc 100644 --- a/textattack/transformations/word_swaps/word_swap_change_name.py +++ b/textattack/transformations/word_swaps/word_swap_change_name.py @@ -3,6 +3,8 @@ ------------------------------- """ +from collections import defaultdict + import numpy as np from textattack.shared.data import PERSON_NAMES @@ -18,6 +20,7 @@ def __init__( last_only=False, confidence_score=0.7, language="en", + consistent=False, **kwargs ): """Transforms an input by replacing names of recognized name entity. @@ -26,6 +29,7 @@ def __init__( :param first_only: Whether to change first name only :param last_only: Whether to change last name only :param confidence_score: Name will only be changed when it's above confidence score + :param consistent: Whether to change all instances of the same name to the same new name >>> from textattack.transformations import WordSwapChangeName >>> from textattack.augmentation import Augmenter @@ -42,6 +46,7 @@ def __init__( self.last_only = last_only self.confidence_score = confidence_score self.language = language + self.consistent = consistent def _get_transformations(self, current_text, indices_to_modify): transformed_texts = [] @@ -52,14 +57,38 @@ def _get_transformations(self, current_text, indices_to_modify): else: model_name = "flair/ner-multi-fast" + if self.consistent: + word_to_indices = defaultdict(list) + for i in indices_to_modify: + word_to_replace = current_text.words[i].capitalize() + word_to_indices[word_to_replace].append(i) + for i in indices_to_modify: word_to_replace = current_text.words[i].capitalize() + # If we're doing consistent replacements, only replace the word + # if it hasn't already been replaced in a previous iteration + if self.consistent and word_to_replace not in word_to_indices: + continue word_to_replace_ner = current_text.ner_of_word_index(i, model_name) + replacement_words = self._get_replacement_words( word_to_replace, word_to_replace_ner ) + for r in replacement_words: - transformed_texts.append(current_text.replace_word_at_index(i, r)) + if self.consistent: + transformed_texts.append( + current_text.replace_words_at_indices( + word_to_indices[word_to_replace], + [r] * len(word_to_indices[word_to_replace]), + ) + ) + else: + transformed_texts.append(current_text.replace_word_at_index(i, r)) + + # Delete this word to mark it as replaced + if self.consistent and len(replacement_words) != 0: + del word_to_indices[word_to_replace] return transformed_texts