Skip to content

Commit

Permalink
Add new sampler: weighted sampler (#1344)
Browse files Browse the repository at this point in the history
* add file

* add a weighted data source to enable sampling based on per-sample weight; do not allow duplicated sample within the same epoch

* add a weighted sampler; do not allow lazy mode; do not allow duplicated cut in the same batch

* modify init file accordingly

* add more documentations

* use numpy for sampling; pre-compute the indexes in __iter__ to save time

* add more documentation

* minor changes to the arguments

* remove unused file

* add test

* add more docs

* fix isort

* inherit from SimpleCutSampler; remove duplicated code

* minor fix

* Add changes requested in code review

---------

Co-authored-by: Piotr Żelasko <[email protected]>
  • Loading branch information
marcoyang1998 and pzelasko authored Jun 5, 2024
1 parent cf6cde8 commit 4d57d53
Show file tree
Hide file tree
Showing 4 changed files with 274 additions and 1 deletion.
2 changes: 2 additions & 0 deletions lhotse/dataset/sampling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .simple import SimpleCutSampler
from .stateless import StatelessSampler
from .utils import find_pessimistic_batches, report_padding_ratio_estimate
from .weighted_simple import WeightedSimpleCutSampler
from .zip import ZipSampler

__all__ = [
Expand All @@ -25,6 +26,7 @@
"DynamicBucketingSampler",
"RoundRobinSampler",
"SimpleCutSampler",
"WeightedSimpleCutSampler",
"StatelessSampler",
"ZipSampler",
"find_pessimistic_batches",
Expand Down
77 changes: 76 additions & 1 deletion lhotse/dataset/sampling/data_source.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import random
from collections import deque
from typing import Optional
from typing import List, Optional

import numpy as np

from lhotse import CutSet
from lhotse.cut import Cut
Expand Down Expand Up @@ -98,3 +100,76 @@ def __next__(self) -> Cut:

def __len__(self) -> int:
return len(self._shuffled_items)


class WeightedDataSource(DataSource):
"""
An iterator wrapper over CutSet that helps with the sampling process:
it allows for deterministic re-shuffling of elements and "returning"
sampled elements to be yielded again.
Every cut has a sampling weight. At the beginning of each epoch, we
pre-compute the indexes by sampling from multi-nomial distribution without
replacement. The data source will be exhausted if the number of drawn cuts
exceed num_samples
"""

def __init__(self, items: CutSet, weights: List, num_samples: int):
"""The constructor of the weighted data source
Args:
items (CutSet): The cutset itself
weights (List): A list of values representing the weight of each cut. All values must be positive
num_samples (int): The number of samples to be drawn. Must smaller than the total number of cuts
"""
super().__init__(items=items)
assert len(items) == len(weights), "The length should match"
assert num_samples < len(
weights
), "The number of samples to be drawn should not exceed the dataset size"

# normalize the weight
weights = np.array(weights)
weights = weights / weights.sum()

self.weights = weights
self.num_samples = num_samples
self.sampled_indexes = None

def reset(self) -> None:
"""Reset the iterable state of DataSource."""
self._iter = None
self.sampled_indexes = None
self._reusable.clear()
self._remaining_duration = self._total_duration
self.remaining_cuts = self._total_cuts

def fast_forward(self, steps: int) -> None:
"""Advance the data source by ``steps`` amount of steps."""
assert steps >= 0
iter(self)
for i in range(steps):
next(self.sampled_indexes)

def __iter__(self) -> "WeightedDataSource":
self.reset()
self._iter = iter(self._shuffled_items)
self.sampled_indexes = np.random.choice(
len(self.weights),
self.num_samples,
p=self.weights,
replace=False,
)
self.sampled_indexes = iter(self.sampled_indexes)
return self

def __next__(self) -> Cut:
if self._reusable:
next_cut = self._reusable.popleft()
else:
next_cut = self._orig_items[next(self.sampled_indexes)]

if not self.is_lazy:
self._remaining_duration -= next_cut.duration
self.remaining_cuts -= 1
return next_cut
147 changes: 147 additions & 0 deletions lhotse/dataset/sampling/weighted_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import warnings
from typing import Any, Dict, List, Optional

from lhotse import CutSet, Seconds
from lhotse.dataset.sampling.base import TimeConstraint
from lhotse.dataset.sampling.data_source import WeightedDataSource
from lhotse.dataset.sampling.simple import SimpleCutSampler


class WeightedSimpleCutSampler(SimpleCutSampler):
"""
Samples cuts from a CutSet, where the sampling prob is given by a list.
To enable global sampling, cuts must be in eager mode.
When performing sampling, it avoids having duplicated cuts in the same batch.
The sampler terminates if the number of sampled cuts reach :attr:`num_samples`
When one of :attr:`max_frames`, :attr:`max_samples`, or :attr:`max_duration` is specified,
the batch size is dynamic.
Example usage:
>>> dataset = K2SpeechRecognitionDataset(cuts)
>>> weights = get_weights(cuts)
>>> sampler = WeightedSimpleCutSampler(cuts, weights, num_samples=100, max_duration=200.0)
>>> loader = DataLoader(dataset, sampler=sampler, batch_size=None)
>>> for epoch in range(start_epoch, n_epochs):
... sampler.set_epoch(epoch)
... train(loader)
"""

def __init__(
self,
cuts: CutSet,
cuts_weight: List,
num_samples: int,
max_duration: Seconds = None,
max_cuts: Optional[int] = None,
shuffle: bool = False,
drop_last: bool = False,
world_size: Optional[int] = None,
rank: Optional[int] = None,
seed: int = 0,
):
"""
WeightedSimpleCutSampler's constructor
:param cuts: the ``CutSet`` to sample data from.
:param cuts_weight: the weight of each cut for sampling.
:param num_samples: the number of samples to be drawn.
:param max_duration: The maximum total recording duration from ``cuts``.
:param max_cuts: The maximum number of cuts sampled to form a mini-batch.
By default, this constraint is off.
:param shuffle: When ``True``, the cuts will be shuffled at the start of iteration.
Convenient when mini-batch loop is inside an outer epoch-level loop, e.g.:
`for epoch in range(10): for batch in dataset: ...` as every epoch will see a
different cuts order.
:param drop_last: When ``True``, the last batch is dropped if it's incomplete.
:param world_size: Total number of distributed nodes. We will try to infer it by default.
:param rank: Index of distributed node. We will try to infer it by default.
:param seed: Random seed used to consistently shuffle the dataset across different processes.
"""
super().__init__(
cuts=cuts,
drop_last=drop_last,
shuffle=shuffle,
world_size=world_size,
rank=rank,
max_duration=max_duration,
max_cuts=max_cuts,
seed=seed,
)
assert not cuts.is_lazy, "This sampler does not support lazy mode!"
self.data_source = WeightedDataSource(
cuts, weights=cuts_weight, num_samples=num_samples
)

self.weights = cuts_weight
self.num_samples = num_samples

def state_dict(self) -> Dict[str, Any]:
"""
Return the current state of the sampler in a state_dict.
Together with ``load_state_dict()``, this can be used to restore the
training loop's state to the one stored in the state_dict.
"""
state_dict = super().state_dict()
state_dict.update(
{
"time_constraint": self.time_constraint.state_dict(),
"weights": self.weights,
"num_samples": self.num_samples,
}
)
return state_dict

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""
Restore the state of the sampler that is described in a state_dict.
This will result in the sampler yielding batches from where the previous training left it off.
.. caution::
The samplers are expected to be initialized with the same CutSets,
but this is not explicitly checked anywhere.
.. caution::
The input ``state_dict`` is being mutated: we remove each consumed key, and expect
it to be empty at the end of loading. If you don't want this behavior, pass a copy
inside of this function (e.g., using ``import deepcopy``).
.. note::
For implementers of sub-classes of CutSampler: the flag ``self._just_restored_state`` has to be
handled in ``__iter__`` to make it avoid resetting the just-restored state (only once).
"""
time_constraint = TimeConstraint(**state_dict.pop("time_constraint"))
if self.time_constraint != time_constraint:
warnings.warn(
"SimpleCutSampler.load_state_dict(): Inconsistent time_constraint:\n"
f"expected {self.time_constraint}\n"
f"received {time_constraint}\n"
f"We will overwrite the settings with the received state_dict."
)
self.time_constraint = time_constraint

super().load_state_dict(state_dict)

# Restore the data source's state
self.data_source.fast_forward(self.diagnostics.current_epoch_stats.total_cuts)

self.weights = state_dict.pop("weights")
self.num_samples = state_dict.pop("num_samples")

def __iter__(self) -> "WeightedSimpleCutSampler":
"""
Prepare the dataset for iterating over a new epoch. Will shuffle the data if requested.
"""
# Restored state with load_state_dict()? Skip resetting only this once.
if self._just_restored_state:
return self
# Why reset the current epoch?
# Either we are iterating the epoch for the first time and it's a no-op,
# or we are iterating the same epoch again, in which case setting more steps
# than are actually available per epoch would have broken the checkpoint restoration.
self.diagnostics.reset_current_epoch()
# Reset the state to the beginning of the epoch.
iter(self.data_source)
return self
49 changes: 49 additions & 0 deletions test/dataset/sampling/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
BucketingSampler,
CutPairsSampler,
SimpleCutSampler,
WeightedSimpleCutSampler,
ZipSampler,
)
from lhotse.dataset.sampling.base import SamplingDiagnostics, TimeConstraint
Expand Down Expand Up @@ -1024,6 +1025,54 @@ def test_cut_pairs_sampler_lazy_shuffle(sampler_cls):
assert [c.id for c in sampled_src_cuts] != [c.id for c in lazy_cuts]


def test_weighted_sampler_num_samples():
cut_set = DummyManifest(CutSet, begin_id=0, end_id=100)
weight = [random.random() for i in range(100)]
num_samples = 32

sampler = WeightedSimpleCutSampler(
cut_set,
weight,
num_samples=num_samples,
max_duration=10.0,
drop_last=True,
)

sampled_cuts = []
num_cuts = 0
for batch in sampler:
sampled_cuts.extend(batch)
num_cuts += len(batch)

assert num_cuts <= num_samples


def test_weighted_sampler_across_epochs():
cut_set = DummyManifest(CutSet, begin_id=0, end_id=100)
weight = [random.random() for i in range(100)]
num_samples = 32

sampler = WeightedSimpleCutSampler(
cut_set,
weight,
num_samples=num_samples,
max_duration=10.0,
drop_last=True,
)

# 1st epoch
sampler.set_epoch(1)
batch = next(iter(sampler))
cut_ids1 = [c.id for c in batch]

# 2st epoch
sampler.set_epoch(2)
batch = next(iter(sampler))
cut_ids2 = [c.id for c in batch]

assert set(cut_ids1) != set(cut_ids2)


@pytest.mark.parametrize("datasize", [10, 1000, 20000])
@pytest.mark.parametrize("bufsize", [100, 1000, 10000])
def test_streaming_shuffle(datasize, bufsize):
Expand Down

0 comments on commit 4d57d53

Please sign in to comment.