Skip to content

Commit

Permalink
Fix word swap for multi-word locations
Browse files Browse the repository at this point in the history
  • Loading branch information
k-ivey committed Nov 2, 2023
1 parent 51a6835 commit ea3fc96
Showing 1 changed file with 64 additions and 11 deletions.
75 changes: 64 additions & 11 deletions textattack/transformations/word_swaps/word_swap_change_location.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,25 +71,26 @@ def _get_transformations(self, current_text, indices_to_modify):
location_words = idx_to_words(location_idx, words)

if self.consistent:
location_to_indices = defaultdict(list)
for idx, location in location_words:
location_to_indices[self._capitalize(location)].append(idx[0])
location_to_indices = self._build_location_to_indicies_map(
location_words, current_text
)

transformed_texts = []
for location in location_words:
idx = location[0]
word = self._capitalize(location[1])

# If doing consistent replacements, only replace the
# word if it hasn't been replaced in a previous iteration
if self.consistent and word not in location_to_indices:
continue

replacement_words = self._get_new_location(word)
for r in replacement_words:
if r == word:
continue

if self.consistent:
# If we're doing consistent replacements, only replace the word
# if it hasn't already been replaced in a previous iteration
if word not in location_to_indices:
continue

indices_to_delete = []
if len(idx) > 1:
for i in location_to_indices[word]:
Expand All @@ -103,9 +104,6 @@ def _get_transformations(self, current_text, indices_to_modify):
+ ([""] * len(indices_to_delete)),
)
)

# Delete this word to mark it as replaced
del location_to_indices[word]
else:
# If the original location is more than a single word, keep only the starting word
# and replace the starting word with the new word
Expand All @@ -117,6 +115,10 @@ def _get_transformations(self, current_text, indices_to_modify):
)
)

if self.consistent:
# Delete this word to mark it as replaced
del location_to_indices[word]

return transformed_texts

def _get_new_location(self, word):
Expand All @@ -138,3 +140,54 @@ def _get_new_location(self, word):
def _capitalize(self, string):
"""Capitalizes all words in the string."""
return " ".join(word.capitalize() for word in string.split())

def _build_location_to_indicies_map(self, location_words, text):
"""Returns a map of each location and the starting indicies of all
appearances of that location in the text."""

location_to_indices = defaultdict(list)
if len(location_words) == 0:
return location_to_indices

location_words.sort(
# Sort by the number of words in the location
key=lambda index_location_pair: index_location_pair[0][-1]
- index_location_pair[0][0]
+ 1,
reverse=True,
)
max_length = location_words[0][0][-1] - location_words[0][0][0] + 1

for idx, location in location_words:

words_in_location = idx[-1] - idx[0] + 1
found = False
location_start = idx[0]

# Check each window of n words containing the original tagged location
# for n from the max_length down to the original location length.
# This prevents cases where the NER tagger misses a word in a location
# (e.g. it does not tag "New" in "New York")
for length in range(max_length, words_in_location, -1):
for start in range(
location_start - length + words_in_location,
location_start + 1,
):
if start + length > len(text.words):
break

expanded_location = self._capitalize(
" ".join(text.words[start : start + length])
)
if expanded_location in location_to_indices:
location_to_indices[expanded_location].append(start)
found = True
break

if found:
break

if not found:
location_to_indices[self._capitalize(location)].append(idx[0])

return location_to_indices

0 comments on commit ea3fc96

Please sign in to comment.