Skip to content

Commit

Permalink
fixed the shuffle bug QData#791.
Browse files Browse the repository at this point in the history
Added random seed for reproducibility.
  • Loading branch information
ToldoDM committed May 14, 2024
1 parent ad88963 commit a8e0c7b
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions textattack/datasets/huggingface_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class HuggingFaceDataset(Dataset):
Factor to divide ground-truth outputs by. Generally, TextAttack goal functions require model outputs between 0 and 1.
Some datasets are regression tasks, in which case this is necessary.
shuffle (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to shuffle the underlying dataset.
random_seed (:obj:`int`, `optional`, defaults to :obj:`123`): Random seed for reproducibility. Used for shuffling.
.. note::
Generally not recommended to shuffle the underlying dataset. Shuffling can be performed using DataLoader or by shuffling the order of indices we attack.
Expand All @@ -108,6 +109,7 @@ def __init__(
label_names=None,
output_scale_factor=None,
shuffle=False,
random_seed=123,
):
if isinstance(name_or_dataset, datasets.Dataset):
self._dataset = name_or_dataset
Expand Down Expand Up @@ -149,7 +151,7 @@ def __init__(

self.shuffled = shuffle
if shuffle:
self._dataset.shuffle()
self.shuffle(random_seed=random_seed)

def _format_as_dict(self, example):
input_dict = collections.OrderedDict(
Expand Down Expand Up @@ -189,6 +191,6 @@ def __getitem__(self, i):
self._format_as_dict(self._dataset[j]) for j in range(i.start, i.stop)
]

def shuffle(self):
self._dataset.shuffle()
def shuffle(self, random_seed=123):
self._dataset = self._dataset.shuffle(seed=random_seed).flatten_indices()
self.shuffled = True

0 comments on commit a8e0c7b

Please sign in to comment.