Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zipformer Training Issues:Gradient too small and cuda out of memory issue #1751

Closed
bsshruthi22 opened this issue Sep 10, 2024 · 14 comments
Closed

Comments

@bsshruthi22
Copy link

Hello All,
we are training a zip former model for about 3400 hours of Tamil data.
We were facing this issue:

RuntimeError:
grad_scale is too small, exiting: 1.4901161193847656e-08

We have NVIDIA A6000 50GB GPU.
So as per the suggestion in the terminal,we reduced the max duration from 1000,600,500,400,300,150
and Learning rate changed from 0.045 to 0.04,0.035,0.02.

Finally we used max duration is 150 and learning rate is 0.02 and trained for 4 epochs.
Then we faced cuda out of memory issue as below.
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 6.86 GiB (GPU 0; 47.54 GiB total capacity; 42.23 GiB already allocated; 2.80 GiB free; 43.04 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Kindly suggest if any parameter changes are required so that the training can be continued.

@csukuangfj
Copy link
Collaborator

Could you use
https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/local/display_manifest_statistics.py#L48
to get the statistics of your data?

Sample output is given in

## train-clean-100
Cuts count: 85617
Total duration (hours): 303.8
Speech duration (hours): 303.8 (100.0%)
***
Duration statistics (seconds):
mean 12.8
std 3.8
min 1.3
0.1% 1.9
0.5% 2.2
1% 2.5
5% 4.2
10% 6.4
25% 11.4
50% 13.8
75% 15.3
90% 16.7
95% 17.3
99% 18.1
99.5% 18.4
99.9% 18.8
max 27.2

@csukuangfj
Copy link
Collaborator

By the way, have you enabled

train_cuts = train_cuts.filter(remove_short_and_long_utt)

and if yes, what are the thresholds you are using?

@epicyclism
Copy link

I think I had the CUDA problem you report, reading the error and PyTorch documentation suggested adding

export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

before starting to train. This worked for me.

@bsshruthi22
Copy link
Author

Could you use https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/local/display_manifest_statistics.py#L48 to get the statistics of your data?

Sample output is given in

## train-clean-100
Cuts count: 85617
Total duration (hours): 303.8
Speech duration (hours): 303.8 (100.0%)
***
Duration statistics (seconds):
mean 12.8
std 3.8
min 1.3
0.1% 1.9
0.5% 2.2
1% 2.5
5% 4.2
10% 6.4
25% 11.4
50% 13.8
75% 15.3
90% 16.7
95% 17.3
99% 18.1
99.5% 18.4
99.9% 18.8
max 27.2

Attached the manifest statistics file
Manifests_statistics.docx

@bsshruthi22
Copy link
Author

I think I had the CUDA problem you report, reading the error and PyTorch documentation suggested adding

export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

before starting to train. This worked for me.

This command was not available in the version of pytorch which I was using.So I used export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512

@bsshruthi22
Copy link
Author

By the way, have you enabled

train_cuts = train_cuts.filter(remove_short_and_long_utt)

and if yes, what are the thresholds you are using?

I have not explicitly set any thresholds. It must be taking the default values.Will check where it is being set in lhotse.

@csukuangfj
Copy link
Collaborator

I have not explicitly set any thresholds. It must be taking the default values.Will check where it is being set in lhotse.

Default values are fine for your dataset as long as you have used train_cuts.filter in your code; otherwise, some very long cut may lead to OOM.

@bsshruthi22
Copy link
Author

bsshruthi22 commented Sep 13, 2024

After resuming from 4th epoch,I again faced this error. I had set PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 6.86 GiB (GPU 0; 47.54 GiB total capacity; 43.98 GiB already allocated; 914.00 MiB free; 44.95 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Please advise whether we can train in NVIDIA A6000 50 GB GPU ?
Last year we had trained about 245 hours of data in Quadro RTX 4000 8GB GPU for 30 epochs.There was no issue.

@csukuangfj
Copy link
Collaborator

Please advise whether we can train in NVIDIA A6000 50 GB GPU ?

Yes, you can. We are using 32GB V100.


Which sampler are you using? Are you using

train_sampler = DynamicBucketingSampler(

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 6.86 GiB (GPU 0; 47.54 GiB total capacity; 43.98 GiB already allocated; 914.00 MiB free; 44.95 GiB reserved in total by PyTorch)

Have you changed train.py or any other files? If yes, could you post the git diff?

@bsshruthi22
Copy link
Author

I have not changed sampler ,I am using all default settings.i have only changed the sampling rate to 8k and train, dev ,test folder names

@bsshruthi22
Copy link
Author

Please advise whether we can train in NVIDIA A6000 50 GB GPU ?

Yes, you can. We are using 32GB V100.

Which sampler are you using? Are you using

train_sampler = DynamicBucketingSampler(

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 6.86 GiB (GPU 0; 47.54 GiB total capacity; 43.98 GiB already allocated; 914.00 MiB free; 44.95 GiB reserved in total by PyTorch)

Have you changed train.py or any other files? If yes, could you post the git diff?
I have not changed sampler ,I am using all default settings.i have only changed the sampling rate to 8k and train, dev ,test folder names.
Any thing you can advise on this.

@lingjzhu
Copy link

Please advise whether we can train in NVIDIA A6000 50 GB GPU ?

Yes, you can. We are using 32GB V100.
Which sampler are you using? Are you using

train_sampler = DynamicBucketingSampler(

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 6.86 GiB (GPU 0; 47.54 GiB total capacity; 43.98 GiB already allocated; 914.00 MiB free; 44.95 GiB reserved in total by PyTorch)

Have you changed train.py or any other files? If yes, could you post the git diff?
I have not changed sampler ,I am using all default settings.i have only changed the sampling rate to 8k and train, dev ,test folder names.
Any thing you can advise on this.

If you have removed all outlier audios, then you can try running torch.cuda.empty_cache() every few hundred iterations to release unused memory. Works for me for a 300M model on a 48G GPU.

@JinZr JinZr closed this as completed Nov 5, 2024
@bsshruthi22
Copy link
Author

Please advise whether we can train in NVIDIA A6000 50 GB GPU ?

Yes, you can. We are using 32GB V100.
Which sampler are you using? Are you using

train_sampler = DynamicBucketingSampler(

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 6.86 GiB (GPU 0; 47.54 GiB total capacity; 43.98 GiB already allocated; 914.00 MiB free; 44.95 GiB reserved in total by PyTorch)

Have you changed train.py or any other files? If yes, could you post the git diff?
I have not changed sampler ,I am using all default settings.i have only changed the sampling rate to 8k and train, dev ,test folder names.
Any thing you can advise on this.

If you have removed all outlier audios, then you can try running torch.cuda.empty_cache() every few hundred iterations to release unused memory. Works for me for a 300M model on a 48G GPU.

I tried with emptying cache as suggested,it generated a few check points and before the completion of a epoch,ran into the same issue. So thought of reducing the data to around 1000hrs and checking, as I am not able to deduce on solving the error

@danpovey
Copy link
Collaborator

danpovey commented Nov 5, 2024

It should dump the offending batch if your script is fairly up to date. You could load the .pt file with torch.load(), and see the characteristics of the batch. E.g. might be very short or very long utterances.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants