Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

k2SSL: a Faster and Better Framework for Self-Supervised Speech Representation Learning #1500

Merged
merged 18 commits into from
Apr 4, 2024

Conversation

yfyeung
Copy link
Collaborator

@yfyeung yfyeung commented Feb 18, 2024

In this PR, we decoupled HuBERT from fairseq, making it independent from the fairseq library while maintaining full equivalence with the original pre-training logic (model architecture, data normalization, masking strategy, loss computation...). We conducted comparisons on the outputs of some layers to ensure this equivalence. Additionally, we support the checkpoints from fairseq (hubert_base_ls960, hubert_large_ll60k, hubert_xtralarge_ll60k).
Then, we optimized the pre-train loss, significantly reducing peak memory usage and even slightly enhancing performance. Unfortunately, this improvement rendered the original HuBERT's half-precision unstable. We adopted ScaledAdam as the optimizer and Eden as the scheduler and replaced the Transformer encoder with the Zipformer encoder. This approach further reduced peak memory usage and enhanced performance, maintaining stability in half-precision.

@yfyeung yfyeung requested a review from csukuangfj February 18, 2024 05:19
@kobenaxie
Copy link
Contributor

Hi @yfyeung ,

  • How to get k-means file to train the zipformer based HuBERT pretrain model ?
  • Can we use fbank as the model input like w2vbert

@yfyeung
Copy link
Collaborator Author

yfyeung commented Feb 21, 2024

  • How to get k-means file to train the zipformer based HuBERT pretrain model ?

For LibriSpeech, we directly use the k-means labels from hubert_base_ls960.

  • Can we use fbank as the model input like w2vbert

Yes, you can replace the ConvFeatureExtractionModel with the Conv2dSubsampling.

@kafan1986
Copy link

@yfyeung What are approximate increase in WER and training time and inference if this K2SSL is used with say Hubert base?

@danpovey
Copy link
Collaborator

danpovey commented Apr 4, 2024

Guys, I just noticed this, it seems like a great contribution.
I'd rather not have these things wait so long... let me merge it now and if we have any changes we want, we can do them later on.

@danpovey danpovey merged commit 87843e9 into k2-fsa:master Apr 4, 2024
143 checks passed
@yfyeung yfyeung deleted the k2ssl branch April 5, 2024 04:42
@yfyeung yfyeung restored the k2ssl branch April 5, 2024 04:43
@teowenshen
Copy link
Contributor

teowenshen commented Apr 12, 2024

Hi there @yfyeung , first of all thank you for creating this SSL recipe!

I tried running your zipformer/ codes, but my model diverged at epoch 33 and pretraining ended with a Grad scale is small error.

Throughout pretraining before the divergence, I noticed my grad scale tended to fluctuate between 0.125 and 2.

Did you face the same issues?

EDIT: I was also wondering if you tried toggling the loss reduction to mean instead of sum. Maybe that will stabilise training?

My commands. I adapted the batch size to my setup, maintaining the same accum_grad * max_duration * world_size.

# pretraining
python zipformer/pretrain.py \
    --world-size 4 \
    --use-fp16 1 \
    --num-epochs 50 \
    --manifest-dir data/raw \
    --max-duration 350 \
    --accum-grad 2 \
    --exp-dir zipformer/exp2/pretrain

As per your explanation, I used the same 500 k-means labels from simple_kmeans.

@yfyeung
Copy link
Collaborator Author

yfyeung commented Apr 12, 2024

Hi there @yfyeung , first of all thank you for creating this SSL recipe!

I tried running your zipformer/ codes, but my model diverged at epoch 33 and pretraining ended with a Grad scale is small error.

Throughout pretraining before the divergence, I noticed my grad scale tended to fluctuate between 0.125 and 2.

Did you face the same issues?

EDIT: I was also wondering if you tried toggling the loss reduction to mean instead of sum. Maybe that will stabilise training?

My commands. I adapted the batch size to my setup, maintaining the same accum_grad * max_duration * world_size.

# pretraining
python zipformer/pretrain.py \
    --world-size 4 \
    --use-fp16 1 \
    --num-epochs 50 \
    --manifest-dir data/raw \
    --max-duration 350 \
    --accum-grad 2 \
    --exp-dir zipformer/exp2/pretrain

As per your explanation, I used the same 500 k-means labels from simple_kmeans.

Hi, hope this message finds you well.

My training command is as follows:

./zipformer/pretrain.py \
  --world-size 8 \
  --num-epochs 291 \
  --start-epoch 1 \
  --use-fp16 1 \
  --exp-dir zipformer/exp_pretrain \
  --full-libri 1 \
  --max-duration 600 \
  --accum-grad 1 \
  --do-normalize 0 \
  --mask-prob 0.8 \
  --dropout-input 0.1 \
  --dropout-features 0.1 \
  --feature-grad-mult 0.1 \
  --untie-final-proj 1 \
  --num-encoder-layers 2,2,3,4,3,2 \
  --feedforward-dim 512,768,1024,1536,1024,768 \
  --encoder-dim 192,256,448,768,448,192 \
  --encoder-unmasked-dim 192,192,256,256,256,192 \
  --base-lr 0.045

EDIT: I was also wondering if you tried toggling the loss reduction to mean instead of sum. Maybe that will stabilise training?

Regarding your question about toggling the loss reduction to mean instead of sum to stabilize training: the mean reduction is typically used for multi-GPU simulations to ensure uniform scaling, while sum reduction is preferred for larger batch sizes as it helps stabilize the gradient estimate. It’s not a good way to optimize for both large batch sizes and multi-GPU setups simultaneously.

Fine-tuning command is:

./zipformer/finetune.py \
  --world-size 8 \
  --num-epochs 222 \
  --start-epoch 1 \
  --use-fp16 1 \
  --exp-dir zipformer/exp_finetune \
  --pretrained-dir zipformer/exp_pretrain/epoch-291.pt \
  --full-libri 0 \
  --max-duration 600 \
  --accum-grad 1 \
  --do-normalize 0 \
  --mask-prob 0.65 \
  --mask-channel-prob 0.5 \
  --mask-channel-length 64 \
  --feature-grad-mult 0.0 \
  --num-encoder-layers 2,2,3,4,3,2 \
  --feedforward-dim 512,768,1024,1536,1024,768 \
  --encoder-dim 192,256,448,768,448,192 \
  --encoder-unmasked-dim 192,192,256,256,256,192 \
  --base-lr 0.002

Decoding uses greedy search to identify the top K candidates based on two key parameters: --epoch and --avg:

for ((epoch=100; epoch<=222; epoch+=1)); do
  for ((avg=1; avg<=$epoch-1; avg+=1)); do
    ./zipformer/decode.py \
        --epoch $epoch \
        --avg $avg \
        --exp-dir ./zipformer/exp_finetune \
        --do-normalize 0 \
        --max-duration 1000 \
        --decoding-method greedy_search \
        --num-encoder-layers 2,2,3,4,3,2 \
        --feedforward-dim 512,768,1024,1536,1024,768 \
        --encoder-dim 192,256,448,768,448,192 \
        --encoder-unmasked-dim 192,192,256,256,256,192
  done
done

Then use modified beam search on these top K candidates:

epoch=
avg=
./zipformer/decode.py \
      --epoch $epoch \
      --avg $avg \
      --exp-dir ./zipformer/exp_finetune \
      --do-normalize 0 \
      --max-duration 1000 \
      --decoding-method modified_beam_search \
      --beam-size 8 \
      --num-encoder-layers 2,2,3,4,3,2 \
      --feedforward-dim 512,768,1024,1536,1024,768 \
      --encoder-dim 192,256,448,768,448,192 \
      --encoder-unmasked-dim 192,192,256,256,256,192

@teowenshen
Copy link
Contributor

I see! Thanks for the explanation!

Meanwhile, can you share your finetuning and decoding commands as well?

@yfyeung
Copy link
Collaborator Author

yfyeung commented Apr 12, 2024

I see! Thanks for the explanation!

Meanwhile, can you share your finetuning and decoding commands as well?

Sure, I updated my comment. You can perform pruning in the process of searching the decoding space.

@danpovey
Copy link
Collaborator

@teowenshen is there any chance you can run with from your --start-epoch=33 with the --inf-check=True option, assuming pretrain.py supports these options like train.py; and show us the log? If the options are not there we should add them. I want to see where the inf grad is coming from, maybe we can fix it with more info.

@danpovey
Copy link
Collaborator

Also, @yfyeung we normally have a README.md and/or RESULTS.md that show typical sequences of training and testing commands, and associated results. Is there any chance of adding those?
Is a link to a paper going to come later?

@teowenshen
Copy link
Contributor

I want to see where the inf grad is coming from, maybe we can fix it with more info.

Yes, please find the logs for epoch 33 as attached.

librispeech_SSL_zipformer_pretrain_ep33_infcheck.txt

I couldn't run --print-diagnostics 1 due to this error:

Error getting eigenvalues, trying another method.
Error getting eigenvalues, trying another method.
Error getting eigenvalues, trying another method.
Error getting eigenvalues, trying another method.
/workspace/icefall/icefall/diagnostics.py:255: UserWarning: ComplexHalf support is experimental and many operators don't support it yet. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392067780/work/aten/src/ATen/EmptyTensor.cpp:31.)
  eigs, _ = torch.linalg.eig(stats)
/workspace/icefall/icefall/diagnostics.py:255: UserWarning: ComplexHalf support is experimental and many operators don't support it yet. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392067780/work/aten/src/ATen/EmptyTensor.cpp:31.)
  eigs, _ = torch.linalg.eig(stats)
/workspace/icefall/icefall/diagnostics.py:255: UserWarning: ComplexHalf support is experimental and many operators don't support it yet. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392067780/work/aten/src/ATen/EmptyTensor.cpp:31.)
  eigs, _ = torch.linalg.eig(stats)
/workspace/icefall/icefall/diagnostics.py:255: UserWarning: ComplexHalf support is experimental and many operators don't support it yet. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392067780/work/aten/src/ATen/EmptyTensor.cpp:31.)
  eigs, _ = torch.linalg.eig(stats)
Traceback (most recent call last):
  File "/mnt/host/icefall-k2ssl/egs/librispeech/SSL/zipformer/pretrain.py", line 1380, in <module>
    main()
  File "/mnt/host/icefall-k2ssl/egs/librispeech/SSL/zipformer/pretrain.py", line 1371, in main
    mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
  File "/opt/conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 246, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
  File "/opt/conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 202, in start_processes
    while not context.join():
  File "/opt/conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 163, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/workspace/icefall/icefall/diagnostics.py", line 248, in print_diagnostics
    eigs, _ = torch.linalg.eigh(stats)
RuntimeError: "linalg_eigh_cuda" not implemented for 'Half'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 74, in _wrap
    fn(i, *args)
  File "/mnt/host/icefall-k2ssl/egs/librispeech/SSL/zipformer/pretrain.py", line 1276, in run
    diagnostic.print_diagnostics()
  File "/workspace/icefall/icefall/diagnostics.py", line 517, in print_diagnostics
    self.diagnostics[k].print_diagnostics()
  File "/workspace/icefall/icefall/diagnostics.py", line 255, in print_diagnostics
    eigs, _ = torch.linalg.eig(stats)
RuntimeError: torch.linalg.eig: input tensor should not contain infs or NaNs.

@danpovey
Copy link
Collaborator

danpovey commented Apr 13, 2024 via email

@danpovey
Copy link
Collaborator

The error was unusual, it was an infinity in the forward-pass. This is because you used the wav2vec2 frontend and it doesn't have any balancers or similar code to stop large values appearing. ScaledAdam can make large values appear faster than Adam would, although even with Adam they'll appear eventually unless steps are taken to stop it.

   x = conv(x)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/container.py", line 215, in forward
    input = module(input)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1581, in _call_impl
    hook_result = hook(self, args, result)
  File "/workspace/icefall/icefall/hooks.py", line 41, in forward_hook
    raise ValueError(
ValueError: The sum of module.feature_extractor.conv_layers.2.0.output is not finite: tensor([[[  -6.5234,   -6.5078,   -6.6094,  ...,   -6.5820,   -6.5469,
            -6.5469],
         [  -0.7900,   -0.6479,   -0.5444,  ...,   -0.9287,   -0.9971,
            -0.9380],
         [  -7.3672,   -8.1250,   -8.5938,  ...,   -7.8672,   -7.9023,
            -7.8047],

Anyway, this PR
#1593
should fix the issue without causing any model incompatibility. I haven't tested it though.

@yfyeung
Copy link
Collaborator Author

yfyeung commented Apr 13, 2024

Also, @yfyeung we normally have a README.md and/or RESULTS.md that show typical sequences of training and testing commands, and associated results. Is there any chance of adding those?
Is a link to a paper going to come later?

Sure, I will add those after the anonymity period ends, including the model checkpoint/tensorboard/pre-training logs/fine-tuning logs/decoding logs, and RESULTS.md. And if things go well, also a link to the paper.

@sanjuktasr
Copy link

I want to see where the inf grad is coming from, maybe we can fix it with more info.

Yes, please find the logs for epoch 33 as attached.

librispeech_SSL_zipformer_pretrain_ep33_infcheck.txt

I couldn't run --print-diagnostics 1 due to this error:

Error getting eigenvalues, trying another method.
Error getting eigenvalues, trying another method.
Error getting eigenvalues, trying another method.
Error getting eigenvalues, trying another method.
/workspace/icefall/icefall/diagnostics.py:255: UserWarning: ComplexHalf support is experimental and many operators don't support it yet. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392067780/work/aten/src/ATen/EmptyTensor.cpp:31.)
  eigs, _ = torch.linalg.eig(stats)
/workspace/icefall/icefall/diagnostics.py:255: UserWarning: ComplexHalf support is experimental and many operators don't support it yet. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392067780/work/aten/src/ATen/EmptyTensor.cpp:31.)
  eigs, _ = torch.linalg.eig(stats)
/workspace/icefall/icefall/diagnostics.py:255: UserWarning: ComplexHalf support is experimental and many operators don't support it yet. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392067780/work/aten/src/ATen/EmptyTensor.cpp:31.)
  eigs, _ = torch.linalg.eig(stats)
/workspace/icefall/icefall/diagnostics.py:255: UserWarning: ComplexHalf support is experimental and many operators don't support it yet. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392067780/work/aten/src/ATen/EmptyTensor.cpp:31.)
  eigs, _ = torch.linalg.eig(stats)
Traceback (most recent call last):
  File "/mnt/host/icefall-k2ssl/egs/librispeech/SSL/zipformer/pretrain.py", line 1380, in <module>
    main()
  File "/mnt/host/icefall-k2ssl/egs/librispeech/SSL/zipformer/pretrain.py", line 1371, in main
    mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
  File "/opt/conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 246, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
  File "/opt/conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 202, in start_processes
    while not context.join():
  File "/opt/conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 163, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/workspace/icefall/icefall/diagnostics.py", line 248, in print_diagnostics
    eigs, _ = torch.linalg.eigh(stats)
RuntimeError: "linalg_eigh_cuda" not implemented for 'Half'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 74, in _wrap
    fn(i, *args)
  File "/mnt/host/icefall-k2ssl/egs/librispeech/SSL/zipformer/pretrain.py", line 1276, in run
    diagnostic.print_diagnostics()
  File "/workspace/icefall/icefall/diagnostics.py", line 517, in print_diagnostics
    self.diagnostics[k].print_diagnostics()
  File "/workspace/icefall/icefall/diagnostics.py", line 255, in print_diagnostics
    eigs, _ = torch.linalg.eig(stats)
RuntimeError: torch.linalg.eig: input tensor should not contain infs or NaNs.

How did you prepare the input data i.e. the manifest dir for zipformer pretrain?

@yfyeung
Copy link
Collaborator Author

yfyeung commented Jul 26, 2024

I want to see where the inf grad is coming from, maybe we can fix it with more info.

Yes, please find the logs for epoch 33 as attached.
librispeech_SSL_zipformer_pretrain_ep33_infcheck.txt
I couldn't run --print-diagnostics 1 due to this error:

Error getting eigenvalues, trying another method.
Error getting eigenvalues, trying another method.
Error getting eigenvalues, trying another method.
Error getting eigenvalues, trying another method.
/workspace/icefall/icefall/diagnostics.py:255: UserWarning: ComplexHalf support is experimental and many operators don't support it yet. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392067780/work/aten/src/ATen/EmptyTensor.cpp:31.)
  eigs, _ = torch.linalg.eig(stats)
/workspace/icefall/icefall/diagnostics.py:255: UserWarning: ComplexHalf support is experimental and many operators don't support it yet. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392067780/work/aten/src/ATen/EmptyTensor.cpp:31.)
  eigs, _ = torch.linalg.eig(stats)
/workspace/icefall/icefall/diagnostics.py:255: UserWarning: ComplexHalf support is experimental and many operators don't support it yet. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392067780/work/aten/src/ATen/EmptyTensor.cpp:31.)
  eigs, _ = torch.linalg.eig(stats)
/workspace/icefall/icefall/diagnostics.py:255: UserWarning: ComplexHalf support is experimental and many operators don't support it yet. (Triggered internally at /opt/conda/conda-bld/pytorch_1695392067780/work/aten/src/ATen/EmptyTensor.cpp:31.)
  eigs, _ = torch.linalg.eig(stats)
Traceback (most recent call last):
  File "/mnt/host/icefall-k2ssl/egs/librispeech/SSL/zipformer/pretrain.py", line 1380, in <module>
    main()
  File "/mnt/host/icefall-k2ssl/egs/librispeech/SSL/zipformer/pretrain.py", line 1371, in main
    mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
  File "/opt/conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 246, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
  File "/opt/conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 202, in start_processes
    while not context.join():
  File "/opt/conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 163, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/workspace/icefall/icefall/diagnostics.py", line 248, in print_diagnostics
    eigs, _ = torch.linalg.eigh(stats)
RuntimeError: "linalg_eigh_cuda" not implemented for 'Half'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 74, in _wrap
    fn(i, *args)
  File "/mnt/host/icefall-k2ssl/egs/librispeech/SSL/zipformer/pretrain.py", line 1276, in run
    diagnostic.print_diagnostics()
  File "/workspace/icefall/icefall/diagnostics.py", line 517, in print_diagnostics
    self.diagnostics[k].print_diagnostics()
  File "/workspace/icefall/icefall/diagnostics.py", line 255, in print_diagnostics
    eigs, _ = torch.linalg.eig(stats)
RuntimeError: torch.linalg.eig: input tensor should not contain infs or NaNs.

How did you prepare the input data i.e. the manifest dir for zipformer pretrain?

I add the kmeans into the custom field of CutSet. Maybe we will release results and more utils next month.

@sanjuktasr
Copy link

#1705
Hi @yfyeung how to prepare the data manifest dir for pretrain input format?

@yfyeung
Copy link
Collaborator Author

yfyeung commented Jul 26, 2024

#1705 Hi @yfyeung how to prepare the data manifest dir for pretrain input format?

Hi, check out the code in dataset https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/SSL/hubert/dataset.py#L80-L85 for the specific format. It only has one more field compared with the wav CutSet.

@sanjuktasr
Copy link

Hi @yfyeung

I add the kmeans into the custom field of CutSet. Maybe we will release results and more utils next month.

How did you obtain the kmeans? Is there any codebase available?

@yfyeung
Copy link
Collaborator Author

yfyeung commented Jul 26, 2024

Hi @yfyeung

I add the kmeans into the custom field of CutSet. Maybe we will release results and more utils next month.

How did you obtain the kmeans? Is there any codebase available?

Same way in fairseq, simple kmeans.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants