-
Notifications
You must be signed in to change notification settings - Fork 217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for Video Features, for example How2Sign #1359
Comments
Hi @kerolos, thanks for opening this discussion! I can help you get your video recipe set up.
There is one AV recipe currently for GRID AV corpus: https://github.com/lhotse-speech/lhotse/blob/master/lhotse/recipes/grid.py
Once you create a recording, you can load the video, process it with some module, and save + attach as a custom field to the cut. For example: video_recording = Recording.from_file("/path/to/-fZc293MpJk_0-1-rgb_front.mp4") # lhotse will auto-construct video recording manifest
video_cut = video_recording.to_cut()
video_frames = video_cut.load_video() # video frames is a uint8 np.array with shape (T, C, H, W) [or some other permutation, I don't remember off the top of my head]
video_features = compute_some_features(video_frames) # video_features is np.array with arbitrary shape
# Option 1 -> save to some storage directly
# temporal_dim indicates which dimension in video_features shape corresponds to time; set accordingly.
with NumpyHdf5Writer("video_features.h5") as writer:
video_cut.video_features = writer.store_array(video_cut.id, video_features, frame_shift=video_recording.video.fps, temporal_dim=0)
# Option 2 -> holds data in memory, write to some storage later (useful if you're going to use Lhotse Shar format):
video_cut = video_cut.attach_tensor("video_features", video_features, frame_shift=video_recording.video.fps, temporal_dim=0) If you save the final
I would use one of numpy format writers in lhotse (e.g. That said, video features would likely require better compression for very large datasets, which is something we can explore later.
You can access the fps via Final comment, |
In the beginning, I tried to use to load the video mp4 format (this mp4 does not have an audio form sign language dataset) : recording = Recording.from_file("test_rgb_front_clips/raw_videos/_fZbAxSSbX4_0-5-rgb_front.mp4") I got this error :
|
I would like to use the "Lhotse SHAR format" to save SHAR files from manifests jsonl. I have two options:
from lhotse import RecordingSet, SupervisionSet, CutSet
from lhotse.shar import SharWriter
output_dir = "./data-shar"
recordings_manifest = src_dir / 'recordings.jsonl'
supervisions_manifest = src_dir / 'supervisions.jsonl'
recordings = RecordingSet.from_jsonl(recordings_manifest)
supervisions = SupervisionSet.from_jsonl(supervisions_manifest)
cuts = CutSet.from_manifests(recordings, supervisions).trim_to_supervisions()
try:
shards = cuts.to_shar(output_dir, fields={"recording": "mp4"}, shard_size=15)
except AssertionError as e:
print(f"Error: {e}") Error:
from lhotse import load_manifest_lazy, Recording
from lhotse.shar import ArrayTarWriter
import cv2
import logging
from tqdm import tqdm
import mediapipe as mp
output_dir = "./data-shar"
recordings_manifest_path = src_dir / 'recordings.jsonl'
supervisions_manifest_path = src_dir / 'supervisions.jsonl'
recordings_manifest = load_manifest_lazy(recordings_manifest_path)
supervisions_manifest = load_manifest_lazy(supervisions_manifest_path)
tar_path = output_dir / "video_features.%06d.tar"
with ArrayTarWriter(tar_path, shard_size=15) as writer, tqdm(total=len(recordings_manifest)) as pbar, mp.solutions.holistic.Holistic(
static_image_mode=False, model_complexity=0, min_detection_confidence=0.5, min_tracking_confidence=0.5) as holistic:
for recording in recordings_manifest:
try:
video_recording = Recording.from_dict(recording.to_dict())
video_cut = video_recording.to_cut()
video_frames = video_cut.load_video()
video_path = video_recording.sources[0].source # Get the video path from sources
logging.info(f"Loading video frames from {video_path}")
# Get FPS using OpenCV
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
cap.release()
video_features = extract_features_from_video(video_path, holistic)
if video_features is None:
logging.error(f"Failed to load video frames for recording ID: {video_recording.id}, video path: {video_path}")
continue
# Attach features to video_cut
video_cut = video_cut.attach_tensor("video_features", video_features, frame_shift=float(1.0 / fps), temporal_dim=0)
# Store the features using ArrayTarWriter
writer.write(video_cut.id, video_features, video_cut.video_features)
except Exception as e:
logging.error(f"Error processing recording ID {recording.id}: {e}")
pbar.update(1) I can save features video_features.000000.tar (inside this folder for each video has two files -fZc293MpJk_0-1-rgb_front.json , and -fZc293MpJk_0-1-rgb_front.npy) Hint: i have not compressed with "lilcom" in ArrayTarWriter and also not saved
I also want to be able to use the from_shar function and later training DataLoader with Lhotse Shar: cuts_nodata = CutSet.from_shar(fields={"cuts": shards["cuts"]})
or
cuts = CutSet.from_shar(
fields={
"cuts": shards["cuts"],
"recording": shards["recording"],
}, In this tutorial (examples: 04-lhotse-shar.ipynb) Implementation note: the use of IterableDataset: How the code be modified in the way to read the existed features from shads "feature_shards" in this DynamicBucketingSampler not extracting a new one from shards recording in (Implementation note: the use of IterableDataset session ) ? Thanks in advance @pzelasko |
Video loading features depend on you having a recent version of pytorch, torchaudio, and compatible ffmpeg version to load videos. Based on the call stack I think maybe you don't have this backend available. Try updating your torch/torchaudio and setting the env var
Try using |
As for your other question:
Yeah we'll need to add mp4 support for AudioTarWriter. I don't have the bandwidth for this right now but I can give help you get started. First we'll need to add
You have two options. There is a high-level utility
I think what you want is, after executing the suggestions before, this: cuts = CutSet.from_shar(
fields={
"cuts": shards["cuts"],
"video_features": shards["video_features"],
},
) |
Thanks a lot @pzelasko, I really appreciate your help and support: from lhotse import RecordingSet, SupervisionSet, CutSet
from lhotse.shar import SharWriter
import torch
print(torch.cuda.is_available())
import torchaudio
print(torchaudio.get_audio_backend())
#torchaudio.set_audio_backend("ffmpeg")
#torchaudio.set_audio_backend("sox_io") # or "soundfile"
recordings_manifest = "./recordings.jsonl'
supervisions_manifest = "./supervisions.jsonl'
# Load recordings and supervisions
recordings = RecordingSet.from_jsonl(recordings_manifest)
supervisions = SupervisionSet.from_jsonl(supervisions_manifest)
# Create CutSet and trim to supervisions
cuts = CutSet.from_manifests(recordings, supervisions).trim_to_supervisions()
# Write shards
shards = cuts.to_shar(output_dir, fields={"recording": "mp4"}, shard_size=15)
print("Shards created:", shards) I got this error:
Then i was able to create the required files for shar manually: and i checked them with respect to the files created from this tutorial for speech dataset (https://github.com/lhotse-speech/lhotse/blob/master/examples/04-lhotse-shar.ipynb). the json cut file looks like that : {"id": "-fZc293MpJk_0-1-rgb_front", "start": 0.0, "duration": 6.53, "channel": 0, "supervisions": [{"id": "-fZc293MpJk_0-1-rgb_front", "recording_id": "-fZc293MpJk_0-1-rgb_front", "start": 0.0, "duration": 6.53, "channel": 0, "text": "hi", "language": "English", "speaker": "-fZc293MpJk"}], "recording": {"id": "-fZc293MpJk_0-1-rgb_front", "sources": [{"type": "shar", "channels": [0], "source": ""}], "sampling_rate": 24, "num_samples": 17, "duration": 6.53, "channel_ids": [0]}, "type": "MonoCut"}
{"id": "-fZc293MpJk_2-1-rgb_front", "start": 0.0, "duration": 13.03, "channel": 0, "supervisions": [{"id": "-fZc293MpJk_2-1-rgb_front", "recording_id": "-fZc293MpJk_2-1-rgb_front", "start": 0.0, "duration": 13.03, "channel": 0, "text": "the aileron is the control surface in the wing that is controlled by lateral movement right and left of the stick", "language": "English", "speaker": "-fZc293MpJk"}], "recording": {"id": "-fZc293MpJk_2-1-rgb_front", "sources": [{"type": "shar", "channels": [0], "source": ""}], "sampling_rate": 24, "num_samples": 412, "duration": 13.03, "channel_ids": [0]}, "type": "MonoCut"}
a) modify the DataModule python script to read shar data: class SignLanguageDataModule:
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
train_path = "data/test_V2/shar_out"
cuts_video_feat_train = CutSet.from_shar(
fields={
"cuts": shards["cuts"],
"video_features": shards["video_features"],
},
shuffle_shards=True,
stateful_shuffle=True,
seed="randomized",
).repeat()
features_array = cuts_video_feat_train[0].load_video_features()
print("Features first array shape:", features_array.shape)
print("Features first array:", features_array)
logging.info(f"train_cuts size: {len(cuts_video_feat_train)}")
return cuts_video_feat_train
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
logging.info(
f"Using cut concatenation with duration factor "
f"{self.args.duration_factor} and gap {self.args.gap}."
)
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
input_strategy=eval(self.args.input_strategy)(),
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
shuffle=True,
max_duration=10.0,
num_buckets=10,
rank=0,
world_size=1,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=10.0,
shuffle=self.args.shuffle,
)
logging.info(f"train_sampler created: {train_sampler}")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_iter_dataset = IterableDatasetWrapper(
dataset=train,
sampler=train_sampler,
)
train_dl = DataLoader(
train_iter_dataset,
batch_size=None,
num_workers=self.args.num_workers,
worker_init_fn=make_worker_init_fn(seed=0),
)
logging.info(f"train_dl created: {train_dl}")
return train_dl train_sign.py: signData = SignLanguageDataModule(args)
train_cuts = signData.train_cuts()
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
# saved in the middle of an epoch
sampler_state_dict = checkpoints["sampler"]
else:
sampler_state_dict = None
train_dl = signData.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
params=params,
)
def scan_pessimistic_batches_for_oom(
model: Union[nn.Module, DDP],
train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
params: AttributeDict,
):
from lhotse.dataset import find_pessimistic_batches
logging.info(
"Sanity check -- see if any of the batches in epoch 1 would cause OOM."
)
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
)
loss.backward()
optimizer.zero_grad()
except Exception as e:
if "CUDA out of memory" in str(e):
logging.error(
"Your GPU ran out of memory with the current "
"max_duration setting. We recommend decreasing "
"max_duration and trying again.\n"
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
display_and_save_batch(batch, params=params, sp=sp)
raise
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
)
def compute_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
sp: spm.SentencePieceProcessor,
batch: dict,
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute loss given the model and its inputs.
Args:
params:
Parameters for training. See :func:`get_params`.
model:
The model for training. It is an instance of Zipformer in our case.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
is_training:
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
feature = feature.to(device)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
batch_idx_train = params.batch_idx_train
warm_step = params.warm_step
texts = batch["supervisions"]["text"]
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
)
loss = 0.0
if params.use_transducer:
s = params.simple_loss_scale
# take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step.
simple_loss_scale = (
s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
)
pruned_loss_scale = (
1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step)
)
loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
if params.use_ctc:
loss += params.ctc_loss_scale * ctc_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
if params.use_transducer:
info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item()
if params.use_ctc:
info["ctc_loss"] = ctc_loss.detach().cpu().item()
return loss, info Thanks in advance |
In the dataset class, instead of calling video_features, video_features_lens = lhotse.dataset.collation.collate_custom_field(cuts, "video_features")
batch["inputs"] = video_features You might need to work out some details but hopefully this can get you started. You can also safely remove the |
Extend Lhotse to support video features for tasks such as sign language recognition (e.g., How2Sign) and human activity recognition. This enhancement will be useful for the Icefall platform.
Details
With the recent support for video in PR #1151, I am interested in developing a new recipe to handle video data and extract features using tools like MediaPipe.
Objectives
Recipe Addition:
lhotse/recipes
directory.Feature Extraction:
Implementation Steps
Create Manifest Files:
Recordings manifest (
recordings.jsonl
):Supervisions manifest (
supervisions.jsonl
):Feature Extraction Script:
compute_features_sign_language.py
:Questions
and load them later for training ?
4- i have also Frames per second, it is not always fixed it in between 24 fps to 50 fps , how can i deal with that ?
I would appreciate any guidance or support on implementing this feature and utilizing it within the Icefall platform @pzelasko .
Thank you!
The text was updated successfully, but these errors were encountered: