Skip to content

Commit

Permalink
Merge pull request #752 from k-ivey/consistent-word-swap
Browse files Browse the repository at this point in the history
Consistent word swap
  • Loading branch information
qiyanjun authored Nov 5, 2023
2 parents 189d11f + bebf70f commit db4ae20
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 13 deletions.
56 changes: 56 additions & 0 deletions tests/test_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,34 @@ def test_word_swap_change_location():
assert entity_original == entity_augmented


def test_word_swap_change_location_consistent():
from flair.data import Sentence
from flair.models import SequenceTagger

from textattack.augmentation import Augmenter
from textattack.transformations.word_swaps import WordSwapChangeLocation

augmenter = Augmenter(transformation=WordSwapChangeLocation(consistent=True))
s = "I am in New York. I love living in New York."
s_augmented = augmenter.augment(s)
augmented_text = Sentence(s_augmented[0])
tagger = SequenceTagger.load("flair/ner-english")
original_text = Sentence(s)
tagger.predict(original_text)
tagger.predict(augmented_text)

entity_original = []
entity_augmented = []

for entity in original_text.get_spans("ner"):
entity_original.append(entity.tag)
for entity in augmented_text.get_spans("ner"):
entity_augmented.append(entity.tag)

assert entity_original == entity_augmented
assert s_augmented[0].count("New York") == 0


def test_word_swap_change_name():
from flair.data import Sentence
from flair.models import SequenceTagger
Expand All @@ -59,6 +87,34 @@ def test_word_swap_change_name():
assert entity_original == entity_augmented


def test_word_swap_change_name_consistent():
from flair.data import Sentence
from flair.models import SequenceTagger

from textattack.augmentation import Augmenter
from textattack.transformations.word_swaps import WordSwapChangeName

augmenter = Augmenter(transformation=WordSwapChangeName(consistent=True))
s = "My name is Anthony Davis. Anthony Davis plays basketball."
s_augmented = augmenter.augment(s)
augmented_text = Sentence(s_augmented[0])
tagger = SequenceTagger.load("flair/ner-english")
original_text = Sentence(s)
tagger.predict(original_text)
tagger.predict(augmented_text)

entity_original = []
entity_augmented = []

for entity in original_text.get_spans("ner"):
entity_original.append(entity.tag)
for entity in augmented_text.get_spans("ner"):
entity_augmented.append(entity.tag)

assert entity_original == entity_augmented
assert s_augmented[0].count("Anthony") == 0 or s_augmented[0].count("Davis") == 0


def test_chinese_morphonym_character_swap():
from textattack.augmentation import Augmenter
from textattack.transformations.word_swaps.chn_transformations import (
Expand Down
113 changes: 101 additions & 12 deletions textattack/transformations/word_swaps/word_swap_change_location.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Word Swap by Changing Location
-------------------------------
"""
from collections import defaultdict

import more_itertools as mit
import numpy as np

Expand All @@ -25,12 +27,15 @@ def idx_to_words(ls, words):


class WordSwapChangeLocation(WordSwap):
def __init__(self, n=3, confidence_score=0.7, language="en", **kwargs):
def __init__(
self, n=3, confidence_score=0.7, language="en", consistent=False, **kwargs
):
"""Transformation that changes recognized locations of a sentence to
another location that is given in the location map.
:param n: Number of new locations to generate
:param confidence_score: Location will only be changed if it's above the confidence score
:param consistent: Whether to change all instances of the same location to the same new location
>>> from textattack.transformations import WordSwapChangeLocation
>>> from textattack.augmentation import Augmenter
Expand All @@ -44,6 +49,7 @@ def __init__(self, n=3, confidence_score=0.7, language="en", **kwargs):
self.n = n
self.confidence_score = confidence_score
self.language = language
self.consistent = consistent

def _get_transformations(self, current_text, indices_to_modify):
words = current_text.words
Expand All @@ -64,26 +70,55 @@ def _get_transformations(self, current_text, indices_to_modify):
location_idx = [list(group) for group in mit.consecutive_groups(location_idx)]
location_words = idx_to_words(location_idx, words)

if self.consistent:
location_to_indices = self._build_location_to_indicies_map(
location_words, current_text
)

transformed_texts = []
for location in location_words:
idx = location[0]
word = location[1].capitalize()
word = self._capitalize(location[1])

# If doing consistent replacements, only replace the
# word if it hasn't been replaced in a previous iteration
if self.consistent and word not in location_to_indices:
continue

replacement_words = self._get_new_location(word)
for r in replacement_words:
if r == word:
continue
text = current_text

# if original location is more than a single word, remain only the starting word
if len(idx) > 1:
index = idx[1]
for i in idx[1:]:
text = text.delete_word_at_index(index)

# replace the starting word with new location
text = text.replace_word_at_index(idx[0], r)
if self.consistent:
indices_to_delete = []
if len(idx) > 1:
for i in location_to_indices[word]:
for j in range(1, len(idx)):
indices_to_delete.append(i + j)

transformed_texts.append(
current_text.replace_words_at_indices(
location_to_indices[word] + indices_to_delete,
([r] * len(location_to_indices[word]))
+ ([""] * len(indices_to_delete)),
)
)
else:
# If the original location is more than a single word, keep only the starting word
# and replace the starting word with the new word
indices_to_delete = idx[1:]
transformed_texts.append(
current_text.replace_words_at_indices(
[idx[0]] + indices_to_delete,
[r] + [""] * len(indices_to_delete),
)
)

if self.consistent:
# Delete this word to mark it as replaced
del location_to_indices[word]

transformed_texts.append(text)
return transformed_texts

def _get_new_location(self, word):
Expand All @@ -101,3 +136,57 @@ def _get_new_location(self, word):
elif word in NAMED_ENTITIES["city"]:
return np.random.choice(NAMED_ENTITIES["city"], self.n)
return []

def _capitalize(self, string):
"""Capitalizes all words in the string."""
return " ".join(word.capitalize() for word in string.split())

def _build_location_to_indicies_map(self, location_words, text):
"""Returns a map of each location and the starting indicies of all
appearances of that location in the text."""

location_to_indices = defaultdict(list)
if len(location_words) == 0:
return location_to_indices

location_words.sort(
# Sort by the number of words in the location
key=lambda index_location_pair: index_location_pair[0][-1]
- index_location_pair[0][0]
+ 1,
reverse=True,
)
max_length = location_words[0][0][-1] - location_words[0][0][0] + 1

for idx, location in location_words:
words_in_location = idx[-1] - idx[0] + 1
found = False
location_start = idx[0]

# Check each window of n words containing the original tagged location
# for n from the max_length down to the original location length.
# This prevents cases where the NER tagger misses a word in a location
# (e.g. it does not tag "New" in "New York")
for length in range(max_length, words_in_location, -1):
for start in range(
location_start - length + words_in_location,
location_start + 1,
):
if start + length > len(text.words):
break

expanded_location = self._capitalize(
" ".join(text.words[start : start + length])
)
if expanded_location in location_to_indices:
location_to_indices[expanded_location].append(start)
found = True
break

if found:
break

if not found:
location_to_indices[self._capitalize(location)].append(idx[0])

return location_to_indices
31 changes: 30 additions & 1 deletion textattack/transformations/word_swaps/word_swap_change_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
-------------------------------
"""

from collections import defaultdict

import numpy as np

from textattack.shared.data import PERSON_NAMES
Expand All @@ -18,6 +20,7 @@ def __init__(
last_only=False,
confidence_score=0.7,
language="en",
consistent=False,
**kwargs
):
"""Transforms an input by replacing names of recognized name entity.
Expand All @@ -26,6 +29,7 @@ def __init__(
:param first_only: Whether to change first name only
:param last_only: Whether to change last name only
:param confidence_score: Name will only be changed when it's above confidence score
:param consistent: Whether to change all instances of the same name to the same new name
>>> from textattack.transformations import WordSwapChangeName
>>> from textattack.augmentation import Augmenter
Expand All @@ -42,6 +46,7 @@ def __init__(
self.last_only = last_only
self.confidence_score = confidence_score
self.language = language
self.consistent = consistent

def _get_transformations(self, current_text, indices_to_modify):
transformed_texts = []
Expand All @@ -52,14 +57,38 @@ def _get_transformations(self, current_text, indices_to_modify):
else:
model_name = "flair/ner-multi-fast"

if self.consistent:
word_to_indices = defaultdict(list)
for i in indices_to_modify:
word_to_replace = current_text.words[i].capitalize()
word_to_indices[word_to_replace].append(i)

for i in indices_to_modify:
word_to_replace = current_text.words[i].capitalize()
# If we're doing consistent replacements, only replace the word
# if it hasn't already been replaced in a previous iteration
if self.consistent and word_to_replace not in word_to_indices:
continue
word_to_replace_ner = current_text.ner_of_word_index(i, model_name)

replacement_words = self._get_replacement_words(
word_to_replace, word_to_replace_ner
)

for r in replacement_words:
transformed_texts.append(current_text.replace_word_at_index(i, r))
if self.consistent:
transformed_texts.append(
current_text.replace_words_at_indices(
word_to_indices[word_to_replace],
[r] * len(word_to_indices[word_to_replace]),
)
)
else:
transformed_texts.append(current_text.replace_word_at_index(i, r))

# Delete this word to mark it as replaced
if self.consistent and len(replacement_words) != 0:
del word_to_indices[word_to_replace]

return transformed_texts

Expand Down

0 comments on commit db4ae20

Please sign in to comment.