Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG]Convergence Issue: Training BERT for Embedding with Zero2 and 3 as compared to Torchrun #6911

Open
dawnik17 opened this issue Dec 24, 2024 · 12 comments
Labels
bug Something isn't working training

Comments

@dawnik17
Copy link

dawnik17 commented Dec 24, 2024

Describe the bug
There is a convergence issue when using Zero2/3 as compared to running it with torch ddp. I'm attaching my deepspeed config and training screenshots below. I'm training bert for embedding task.

I use reentrant True while training.

Expected behavior
The training curves (loss and gradient) should be similar.

Deepspeed Config

{
  "bf16": {
    "enabled": false
  },
  "fp16": {
    "enabled": false
  },
  "optimizer": {
    "type": "AdamW",
    "params": {
      "lr": "auto",
      "betas": "auto",
      "eps": "auto",
      "weight_decay": "auto"
    }
  },
  "scheduler": {
    "type": "WarmupCosineLR",
    "params": {
      "warmup_min_ratio": 0.0001,
      "warmup_num_steps": "auto",
      "cos_min_ratio": 0.0001,
      "total_num_steps": "auto",
      "warmup_type": "linear"
    }
  },
  "zero_optimization": {
    "stage": 3,
    "overlap_comm": true,
    "contiguous_gradients": true,
    "sub_group_size": 1e9,
    "reduce_bucket_size": 1e6,
    "stage3_prefetch_bucket_size": 0.94e6,
    "stage3_param_persistence_threshold": 1e4,
    "stage3_max_live_parameters": 1e9,
    "stage3_max_reuse_distance": 1e9,
    "stage3_gather_16bit_weights_on_model_save": true
  },
  "gradient_accumulation_steps": "auto",
  "gradient_clipping": "auto",
  "steps_per_print": 2000,
  "train_batch_size": "auto",
  "train_micro_batch_size_per_gpu": "auto",
  "wall_clock_breakdown": false
}

Training Curves
The green curve is of Zero3 and the purple curve is of torch ddp.

Image Image
@dawnik17 dawnik17 added bug Something isn't working training labels Dec 24, 2024
@tjruwase
Copy link
Contributor

@dawnik17, thanks for reporting this issue. Can you please provide more details that enables us to reproduce the problem?

@dawnik17
Copy link
Author

dawnik17 commented Dec 24, 2024

Hi @tjruwase, I am using BERT mini (L=4, dim=256) from here.
To train it on embedding task I'm using siglip loss from here.
I'm training on 5 A100 GPUs with per device batch size 7680 and max learning rate of 5e-4.

An update from my side, I started another training with a revised deepspeed config after going through other github issues. Below is the part I've changed from the previous config (mainly I've made "overlap_comm" as False).

  "zero_optimization": {
    "stage": 3,
    "overlap_comm": false,
    "contiguous_gradients": true,
    "sub_group_size": 1e9,
    "reduce_bucket_size": 1e6,
    "stage3_prefetch_bucket_size": 5e8,
    "stage3_param_persistence_threshold": 1e5,
    "stage3_max_live_parameters": 1e9,
    "stage3_max_reuse_distance": 1e9,
    "stage3_gather_16bit_weights_on_model_save": false
  }

I'm not sure if this would solve the issue. The training curve looks identical so far, but can't really tell because the deviation starts happening post 1000 steps. I'm attaching the screenshot of the training so far. Please let me know what you think and what else I can try. Thanks! :)

Image

@tjruwase
Copy link
Contributor

@dawnik17, thanks for the update. We have seen reports of potential bugs in overlap_comm.

The new loss curve looks promising. Hopefully that works and unblocks you.

What about the grad_norm curve? That seemed to show the error quicker.

@dawnik17
Copy link
Author

dawnik17 commented Dec 24, 2024

@tjruwase Unfortunately, the grad_norm curve has started to show deviation. Even the training curve has started to diverge now. The divergence for both starts at the same number of steps (expected).

Image Image

@tjruwase
Copy link
Contributor

@dawnik17, can you share the command line for your run?

@dawnik17
Copy link
Author

dawnik17 commented Dec 24, 2024

@tjruwase I'm running things with deepspeed command like so:

deepspeed run.py --experiment_name v1 >> run.log

And, I'm reading the hyperparameters from a yaml file.

@tjruwase
Copy link
Contributor

@dawnik17. I am unfamiliar with this bert codebase, so I will need very specific instructions. In particular:

  1. I am unable to find run.py after cloning the repo.
  2. I also noticed that there you referred to a different repo for the siglip loss. How are you combining these two repos?
  3. Finally, I am don't know how you are reading hyperparameters from yaml file.

Thanks!

@dawnik17
Copy link
Author

dawnik17 commented Dec 24, 2024

@tjruwase Let me share the codebase with you in a couple hours

@dawnik17
Copy link
Author

@tjruwase I have added all the relevant files here - https://github.com/dawnik17/debug_deepspeed/tree/main
I have made a few quick edits to the code to flatten the directory structure which could lead to a few errors, but it should give you the exact context.

@tjruwase
Copy link
Contributor

  1. What is correct setting for this?
  2. I notice both fp16 and bf16 are disabled. Can you confirm that this is fp32 training for both ddp and deepspeed?
  3. Can you try setting zero_stage=0 in ds_config?

@dawnik17
Copy link
Author

dawnik17 commented Dec 24, 2024

@tjruwase
For 1. It's the training dataset path. The training dataset has positive (query, passage) pairs.
For 2. Yes, training in both ddp and deepspeed is using fp32.
For 3. Sure, I'll try zero_stage=0 and keep you posted on how it progresses.

Thanks! :)

@dawnik17
Copy link
Author

dawnik17 commented Dec 24, 2024

@tjruwase An observation: though the loss curves deviate, there is an uncanny similarity in the loss curves. (Signal pattern in both the curves is exactly the same)

  • The purple is with torch ddp and the dark grey is with zero3.
  • Maybe the issue is just with the grad_norm not increasing.
Image

PS: an update on your 3rd point, zero_stage = 0 follows the same loss curve as zero3. And, I get the following after every step with zero3 Invalidate trace cache @ step 119 and module 0: cache has only 119 modules

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests

2 participants