Skip to content

Commit

Permalink
Allow duplicate cut IDs in a CutSet (CutSet is list-like instead of d…
Browse files Browse the repository at this point in the history
…ict-like) (#1279)

* Allow duplicate cut IDs in a CutSet (CutSet is list-like instead of dict-like)

* Remove BaseIterable altogether (was renamed from ImitatesDict and is no longer needed)

* cleanup duplicate checking fn
  • Loading branch information
pzelasko authored Jan 31, 2024
1 parent e043228 commit 455b20e
Show file tree
Hide file tree
Showing 11 changed files with 150 additions and 168 deletions.
50 changes: 25 additions & 25 deletions lhotse/audio/recording_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@
Seconds,
exactly_one_not_null,
ifnone,
index_by_id_and_check,
split_manifest_lazy,
split_sequence,
)


class RecordingSet(Serializable, AlgorithmMixin):
"""
:class:`~lhotse.audio.RecordingSet` represents a collection of recordings, indexed by recording IDs.
:class:`~lhotse.audio.RecordingSet` represents a collection of recordings.
It does not contain any annotation such as the transcript or the speaker identity --
just the information needed to retrieve a recording such as its path, URL, number of channels,
and some recording metadata (duration, number of samples).
Expand Down Expand Up @@ -86,7 +85,7 @@ class RecordingSet(Serializable, AlgorithmMixin):
>>> recs_24k = recs.resample(24000)
"""

def __init__(self, recordings: Optional[Mapping[str, Recording]] = None) -> None:
def __init__(self, recordings: Optional[Iterable[Recording]] = None) -> None:
self.recordings = ifnone(recordings, {})

def __eq__(self, other: "RecordingSet") -> bool:
Expand All @@ -99,11 +98,11 @@ def data(self) -> Union[Dict[str, Recording], Iterable[Recording]]:

@property
def ids(self) -> Iterable[str]:
return self.recordings.keys()
return (r.id for r in self)

@staticmethod
def from_recordings(recordings: Iterable[Recording]) -> "RecordingSet":
return RecordingSet(recordings=index_by_id_and_check(recordings))
return RecordingSet(list(recordings))

from_items = from_recordings

Expand Down Expand Up @@ -254,24 +253,24 @@ def load_audio(
offset_seconds: float = 0.0,
duration_seconds: Optional[float] = None,
) -> np.ndarray:
return self.recordings[recording_id].load_audio(
return self[recording_id].load_audio(
channels=channels, offset=offset_seconds, duration=duration_seconds
)

def with_path_prefix(self, path: Pathlike) -> "RecordingSet":
return RecordingSet.from_recordings(r.with_path_prefix(path) for r in self)

def num_channels(self, recording_id: str) -> int:
return self.recordings[recording_id].num_channels
return self[recording_id].num_channels

def sampling_rate(self, recording_id: str) -> int:
return self.recordings[recording_id].sampling_rate
return self[recording_id].sampling_rate

def num_samples(self, recording_id: str) -> int:
return self.recordings[recording_id].num_samples
return self[recording_id].num_samples

def duration(self, recording_id: str) -> Seconds:
return self.recordings[recording_id].duration
return self[recording_id].duration

def perturb_speed(self, factor: float, affix_id: bool = True) -> "RecordingSet":
"""
Expand Down Expand Up @@ -368,24 +367,25 @@ def resample(self, sampling_rate: int) -> "RecordingSet":
def __repr__(self) -> str:
return f"RecordingSet(len={len(self)})"

def __contains__(self, item: Union[str, Recording]) -> bool:
if isinstance(item, str):
return item in self.recordings
def __getitem__(self, index_or_id: Union[int, str]) -> Recording:
try:
return self.recordings[index_or_id] # int passed, eager manifest, fast
except TypeError:
# either lazy manifest or str passed, both are slow
if self.is_lazy:
return next(item for idx, item in enumerate(self) if idx == index_or_id)
else:
# string id passed, support just for backward compatibility, not recommended
return next(item for item in self if item.id == index_or_id)

def __contains__(self, other: Union[str, Recording]) -> bool:
if isinstance(other, str):
return any(other == item.id for item in self)
else:
return item.id in self.recordings

def __getitem__(self, recording_id_or_index: Union[int, str]) -> Recording:
if isinstance(recording_id_or_index, str):
return self.recordings[recording_id_or_index]
# ~100x faster than list(dict.values())[index] for 100k elements
return next(
val
for idx, val in enumerate(self.recordings.values())
if idx == recording_id_or_index
)
return any(other.id == item.id for item in self)

def __iter__(self) -> Iterable[Recording]:
return iter(self.recordings.values())
yield from self.recordings

def __len__(self) -> int:
return len(self.recordings)
103 changes: 56 additions & 47 deletions lhotse/cut/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
Iterable,
List,
Literal,
Mapping,
Optional,
Sequence,
Set,
Expand Down Expand Up @@ -45,7 +44,7 @@
from lhotse.features.io import FeaturesWriter, LilcomChunkyWriter
from lhotse.lazy import (
AlgorithmMixin,
ImitatesDict,
Dillable,
LazyFlattener,
LazyIteratorChain,
LazyManifestIterator,
Expand All @@ -62,10 +61,10 @@
Seconds,
compute_num_frames,
compute_num_samples,
deprecated,
exactly_one_not_null,
fastcopy,
ifnone,
index_by_id_and_check,
split_manifest_lazy,
split_sequence,
uuid4,
Expand All @@ -76,10 +75,15 @@

class CutSet(Serializable, AlgorithmMixin):
"""
:class:`~lhotse.cut.CutSet` represents a collection of cuts, indexed by cut IDs.
:class:`~lhotse.cut.CutSet` represents a collection of cuts.
CutSet ties together all types of data -- audio, features and supervisions, and is suitable to represent
training/dev/test sets.
CutSet can be either "lazy" (acts as an iterable) which is best for representing full datasets,
or "eager" (acts as a list), which is best for representing individual mini-batches (and sometimes test/dev datasets).
Almost all operations are available for both modes, but some of them are more efficient depending on the mode
(e.g. indexing an "eager" manifest is O(1)).
.. note::
:class:`~lhotse.cut.CutSet` is the basic building block of PyTorch-style Datasets for speech/audio processing tasks.
Expand Down Expand Up @@ -242,34 +246,32 @@ class CutSet(Serializable, AlgorithmMixin):
- :class:`~lhotse.cut.Cut`
"""

def __init__(
self, cuts: Optional[Union[Mapping[str, Cut], ImitatesDict]] = None
) -> None:
self.cuts = ifnone(cuts, {})
def __init__(self, cuts: Optional[Iterable[Cut]] = None) -> None:
self.cuts = ifnone(cuts, [])

def __eq__(self, other: "CutSet") -> bool:
return self.cuts == other.cuts

@property
def data(self) -> Union[Dict[str, Cut], Iterable[Cut]]:
def data(self) -> Iterable[Cut]:
"""Alias property for ``self.cuts``"""
return self.cuts

@property
def mixed_cuts(self) -> Dict[str, MixedCut]:
return {id_: cut for id_, cut in self.cuts.items() if isinstance(cut, MixedCut)}
def mixed_cuts(self) -> "CutSet":
return CutSet.from_cuts(cut for cut in self.cuts if isinstance(cut, MixedCut))

@property
def simple_cuts(self) -> Dict[str, MonoCut]:
return {id_: cut for id_, cut in self.cuts.items() if isinstance(cut, MonoCut)}
def simple_cuts(self) -> "CutSet":
return CutSet.from_cuts(cut for cut in self.cuts if isinstance(cut, MonoCut))

@property
def multi_cuts(self) -> Dict[str, MultiCut]:
return {id_: cut for id_, cut in self.cuts.items() if isinstance(cut, MultiCut)}
def multi_cuts(self) -> "CutSet":
return CutSet.from_cuts(cut for cut in self.cuts if isinstance(cut, MultiCut))

@property
def ids(self) -> Iterable[str]:
return self.cuts.keys()
return (c.id for c in self.cuts)

@property
def speakers(self) -> FrozenSet[str]:
Expand Down Expand Up @@ -307,7 +309,8 @@ def from_files(

@staticmethod
def from_cuts(cuts: Iterable[Cut]) -> "CutSet":
return CutSet(cuts=index_by_id_and_check(cuts))
"""Left for backward compatibility, where it implicitly created an "eager" CutSet."""
return CutSet(list(cuts))

from_items = from_cuts

Expand Down Expand Up @@ -827,7 +830,7 @@ def split(
:return: A list of :class:`~lhotse.CutSet` pieces.
"""
return [
CutSet.from_cuts(subset)
CutSet(subset)
for subset in split_sequence(
self,
num_splits=num_splits,
Expand Down Expand Up @@ -925,14 +928,14 @@ def subset(
cut_ids = list(cut_ids) # Remember the original order
id_set = frozenset(cut_ids) # Make a set for quick lookup
# Iteration makes it possible to subset lazy manifests
cuts = CutSet.from_cuts(cut for cut in self if cut.id in id_set)
cuts = CutSet([cut for cut in self if cut.id in id_set])
if len(cuts) < len(cut_ids):
logging.warning(
f"In CutSet.subset(cut_ids=...): expected {len(cut_ids)} cuts but got {len(cuts)} "
f"instead ({len(cut_ids) - len(cuts)} cut IDs were not in the CutSet)."
)
# Restore the requested cut_ids order.
return CutSet.from_cuts(cuts[cid] for cid in cut_ids)
return cuts.sort_like(cut_ids)

def filter_supervisions(
self, predicate: Callable[[SupervisionSegment], bool]
Expand Down Expand Up @@ -1142,7 +1145,7 @@ def trim_to_unsupervised_segments(self) -> "CutSet":
)
for span in segments:
cuts.append(cut.truncate(offset=span.start, duration=span.duration))
return CutSet.from_cuts(cuts)
return CutSet(cuts)

def trim_to_supervision_groups(
self,
Expand Down Expand Up @@ -1245,26 +1248,31 @@ def sort_by_recording_id(self, ascending: bool = True) -> "CutSet":
This is advantageous before caling `save_audios()` on a `trim_to_supervision()`
processed `CutSet`, also make sure that `set_caching_enabled(True)` was called.
"""
return CutSet.from_cuts(
return CutSet(
sorted(self, key=(lambda cut: cut.recording.id), reverse=not ascending)
)

def sort_by_duration(self, ascending: bool = False) -> "CutSet":
"""
Sort the CutSet according to cuts duration and return the result. Descending by default.
"""
return CutSet.from_cuts(
return CutSet(
sorted(self, key=(lambda cut: cut.duration), reverse=not ascending)
)

def sort_like(self, other: "CutSet") -> "CutSet":
def sort_like(self, other: Union["CutSet", Sequence[str]]) -> "CutSet":
"""
Sort the CutSet according to the order of cut IDs in ``other`` and return the result.
"""
other_ids = list(other.ids if isinstance(other, CutSet) else other)
assert set(self.ids) == set(
other.ids
other_ids
), "sort_like() expects both CutSet's to have identical cut IDs."
return CutSet.from_cuts(self[cid] for cid in other.ids)
index_map: Dict[str, int] = {v: index for index, v in enumerate(other_ids)}
ans: List[Cut] = [None] * len(other_ids)
for cut in self:
ans[index_map[cut.id]] = cut
return CutSet(ans)

def index_supervisions(
self, index_mixed_tracks: bool = False, keep_ids: Optional[Set[str]] = None
Expand Down Expand Up @@ -1397,7 +1405,7 @@ def compute_offset():
preserve_id=preserve_id,
)
)
return CutSet.from_cuts(truncated_cuts)
return CutSet(truncated_cuts)

def extend_by(
self,
Expand Down Expand Up @@ -1513,13 +1521,11 @@ def sample(self, n_cuts: int = 1) -> Union[Cut, "CutSet"]:
When ``n_cuts`` is 1, will return a single cut instance; otherwise will return a ``CutSet``.
"""
assert n_cuts > 0
# TODO: We might want to make this more efficient in the future
# by holding a cached list of cut ids as a member of CutSet...
cut_indices = random.sample(range(len(self)), min(n_cuts, len(self)))
cuts = [self[idx] for idx in cut_indices]
if n_cuts == 1:
return cuts[0]
return CutSet.from_cuts(cuts)
return CutSet(cuts)

def resample(self, sampling_rate: int, affix_id: bool = False) -> "CutSet":
"""
Expand Down Expand Up @@ -2194,7 +2200,7 @@ def file_storage_path(cut: Cut, storage_path: Pathlike) -> Path:
progress = partial(
tqdm, desc="Storing audio recordings", total=len(self)
)
return CutSet.from_cuts(
return CutSet(
progress(
cut.save_audio(
storage_path=file_storage_path(cut, storage_path),
Expand All @@ -2204,7 +2210,7 @@ def file_storage_path(cut: Cut, storage_path: Pathlike) -> Path:
)
for cut in self
)
)
).to_eager()

# Parallel execution: prepare the CutSet splits
cut_sets = self.split(num_jobs, shuffle=shuffle_on_split)
Expand Down Expand Up @@ -2495,25 +2501,28 @@ def __repr__(self) -> str:
len_val = "<unknown>"
return f"CutSet(len={len_val}) [underlying data type: {type(self.data)}]"

def __contains__(self, item: Union[str, Cut]) -> bool:
if isinstance(item, str):
return item in self.cuts
def __contains__(self, other: Union[str, Cut]) -> bool:
if isinstance(other, str):
return any(other == item.id for item in self)
else:
return item.id in self.cuts

def __getitem__(self, cut_id_or_index: Union[int, str]) -> "Cut":
if isinstance(cut_id_or_index, str):
return self.cuts[cut_id_or_index]
# ~100x faster than list(dict.values())[index] for 100k elements
return next(
val for idx, val in enumerate(self.cuts.values()) if idx == cut_id_or_index
)
return any(other.id == item.id for item in self)

def __getitem__(self, index_or_id: Union[int, str]) -> Cut:
try:
return self.cuts[index_or_id] # int passed, eager manifest, fast
except TypeError:
# either lazy manifest or str passed, both are slow
if self.is_lazy:
return next(item for idx, item in enumerate(self) if idx == index_or_id)
else:
# string id passed, support just for backward compatibility, not recommended
return next(item for item in self if item.id == index_or_id)

def __len__(self) -> int:
return len(self.cuts)

def __iter__(self) -> Iterable[Cut]:
return iter(self.cuts.values())
yield from self.cuts


def mix(
Expand Down Expand Up @@ -2993,7 +3002,7 @@ def create_cut_set_eager(
else [],
)
)
cuts = CutSet.from_cuts(cuts)
cuts = CutSet(cuts)
if output_path is not None:
cuts.to_file(output_path)
return cuts
Expand Down Expand Up @@ -3391,7 +3400,7 @@ def _export_to_shar_single(
return writer.output_paths


class LazyCutMixer(ImitatesDict):
class LazyCutMixer(Dillable):
"""
Iterate over cuts from ``cuts`` CutSet while mixing randomly sampled ``mix_in_cuts`` into them.
A typical application would be data augmentation with noise, music, babble, etc.
Expand Down
Loading

0 comments on commit 455b20e

Please sign in to comment.