Skip to content

Commit

Permalink
Merge pull request #506 from QData/augment-test
Browse files Browse the repository at this point in the history
add more tests for augment function
  • Loading branch information
qiyanjun authored Aug 2, 2021
2 parents c30728b + e98c801 commit fd7117d
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
22 changes: 22 additions & 0 deletions tests/test_misc.py → tests/test_augment_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ def test_easydata_augmenter():
assert augmented_s in augmented_text_list


def test_easydata_augmenter2():
from textattack.augmentation import EasyDataAugmenter

augmenter = EasyDataAugmenter(
pct_words_to_swap=0.01, transformations_per_example=64
)
s = "hello hello hello derek"
augmented_text_list = augmenter.augment(s)
augmented_s = "derek hello hello hello"
assert augmented_s in augmented_text_list


def test_wordnet_augmenter():
from textattack.augmentation import WordNetAugmenter

Expand All @@ -69,3 +81,13 @@ def test_wordnet_augmenter():
augmented_text_list = augmenter.augment(s)
augmented_s = "The firedrake warrior is a panda"
assert augmented_s in augmented_text_list


def test_deletion_augmenter():
from textattack.augmentation import DeletionAugmenter

augmenter = DeletionAugmenter(pct_words_to_swap=0.1, transformations_per_example=10)
s = "The United States"
augmented_text_list = augmenter.augment(s)
augmented_s = "United States"
assert augmented_s in augmented_text_list
7 changes: 7 additions & 0 deletions textattack/constraints/semantics/word_embedding_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ def _check_constraint(self, transformed_text, reference_text):
"Cannot apply part-of-speech constraint without `newly_modified_indices`"
)

# FIXME The index i is sometimes larger than the number of tokens - 1
if any(
i >= len(reference_text.words) or i >= len(transformed_text.words)
for i in indices
):
return False

for i in indices:
ref_word = reference_text.words[i]
transformed_word = transformed_text.words[i]
Expand Down

0 comments on commit fd7117d

Please sign in to comment.