Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepSpeed with ZeRO3 strategy cannot build 'fused_adam' #6892

Open
LeonardoZini opened this issue Dec 18, 2024 · 3 comments
Open

DeepSpeed with ZeRO3 strategy cannot build 'fused_adam' #6892

LeonardoZini opened this issue Dec 18, 2024 · 3 comments
Labels
bug Something isn't working training

Comments

@LeonardoZini
Copy link

LeonardoZini commented Dec 18, 2024

Describe the bug
I am using Deepspeed with the huggingface trainer to fine-tune an llm. While with ZeRO2 strategy I don't have any problem I need to shard also the parameters since i'm working on long-context sequences.
When using ZeRO3 the trainer at the beginning of the training, raise me an excpetion RuntimeError: Error building extension 'fused_adam'

I installed DeepSpeed with the command TORCH_CUDA_ARCH_LIST="8.0;8.6;9.0" DS_BUILD_OPS=1 pip install deepspeed --global-option="build_ext" (i tried also the 0.15.4 version)
I tried also TORCH_CUDA_ARCH_LIST="8.0;8.6;9.0" DS_BUILD_FUSED_ADAM=1 pip install deepspeed --global-option="build_ext".

What changes is that if i specify in the deepspeed json config file the "offload_optimizer" and "offload_param" it doesn't throws me any error, but i lose any reference to the model parameters (the weights are void tensors).

I am using a SLURM scheduler, and one thing i noticed is that the ds_report output are different. Outisde the SLURM fused_adam seems installed, while inside SLURM no.

pip env

accelerate==0.34.1
aiohappyeyeballs==2.4.4
aiohttp==3.11.10
aiosignal==1.3.2
annotated-types==0.7.0
attrs==24.3.0
bitsandbytes==0.45.0
certifi==2024.12.14
charset-normalizer==3.4.0
click==8.1.7
contourpy==1.3.1
cycler==0.12.1
datasets==3.2.0
deepspeed==0.16.1
dill==0.3.8
docker-pycreds==0.4.0
einops==0.8.0
filelock==3.13.1
flash-attn==2.7.2.post1
fonttools==4.55.3
frozenlist==1.5.0
fsspec==2024.2.0
gitdb==4.0.11
GitPython==3.1.43
hjson==3.1.0
huggingface-hub==0.27.0
idna==3.10
Jinja2==3.1.3
kiwisolver==1.4.7
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.10.0
mdurl==0.1.2
mpmath==1.3.0
msgpack==1.1.0
multidict==6.1.0
multiprocess==0.70.16
networkx==3.2.1
ninja==1.11.1.3
numpy==1.26.3
nvidia-cublas-cu11==11.11.3.6
nvidia-cuda-cupti-cu11==11.8.87
nvidia-cuda-nvrtc-cu11==11.8.89
nvidia-cuda-runtime-cu11==11.8.89
nvidia-cudnn-cu11==9.1.0.70
nvidia-cufft-cu11==10.9.0.58
nvidia-curand-cu11==10.3.0.86
nvidia-cusolver-cu11==11.4.1.48
nvidia-cusparse-cu11==11.7.5.86
nvidia-ml-py==12.560.30
nvidia-nccl-cu11==2.20.5
nvidia-nvtx-cu11==11.8.86
packaging==24.2
pandas==2.2.3
peft==0.14.0
pillow==10.2.0
platformdirs==4.3.6
propcache==0.2.1
protobuf==5.29.1
psutil==6.1.0
py-cpuinfo==9.0.0
pyarrow==18.1.0
pydantic==2.10.3
pydantic_core==2.27.1
Pygments==2.18.0
pyparsing==3.2.0
python-dateutil==2.9.0.post0
pytz==2024.2
PyYAML==6.0.2
regex==2024.11.6
requests==2.32.3
rich==13.9.4
safetensors==0.4.5
sentry-sdk==2.19.2
setproctitle==1.3.4
six==1.17.0
smmap==5.0.1
sympy==1.13.1
tokenizers==0.21.0
torch==2.4.1+cu118
torchaudio==2.4.1+cu118
torchvision==0.19.1+cu118
tqdm==4.67.1
transformers==4.47.0
triton==3.0.0
trl==0.13.0
typing_extensions==4.12.2
tzdata==2024.2
urllib3==2.2.3
wandb==0.19.1
xxhash==3.5.0
yarl==1.18.3

ds_report output
output

[2024-12-18 16:33:42,017] [WARNING] [real_accelerator.py:174:get_accelerator] Setting accelerator to CPU. If you have GPU or other accelerator, we were unable to detect it.
[2024-12-18 16:33:42,031] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to cpu (auto detect)
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
deepspeed_not_implemented  [NO] ....... [OKAY]
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-devel package with yum
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
deepspeed_ccl_comm ..... [NO] ....... [OKAY]
deepspeed_shm_comm ..... [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
fused_adam ............. [YES] ...... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/leonardo/home/userexternal/lzini000/svgen2/lib/python3.11/site-packages/torch']
torch version .................... 2.4.1+cu118
deepspeed install path ........... ['/leonardo/home/userexternal/lzini000/svgen2/lib/python3.11/site-packages/deepspeed']
deepspeed info ................... 0.16.1, unknown, unknown
deepspeed wheel compiled w. ...... torch 2.4 
shared memory (/dev/shm) size .... 251.49 GB

ds_report output in SLURM
output

[2024-12-18 15:48:04,705] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to cuda (auto detect)
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. �[92m[OKAY]�[0m
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
�[93m [WARNING] �[0m async_io requires the dev libaio .so object and headers but these were not found.
�[93m [WARNING] �[0m async_io: please install the libaio-devel package with yum
�[93m [WARNING] �[0m If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... �[93m[NO]�[0m ....... �[93m[NO]�[0m
fused_adam ............. �[93m[NO]�[0m ....... �[92m[OKAY]�[0m
cpu_adam ............... �[93m[NO]�[0m ....... �[92m[OKAY]�[0m
cpu_adagrad ............ �[93m[NO]�[0m ....... �[92m[OKAY]�[0m
cpu_lion ............... �[93m[NO]�[0m ....... �[92m[OKAY]�[0m
�[93m [WARNING] �[0m Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... �[93m[NO]�[0m ....... �[93m[NO]�[0m
fp_quantizer ........... �[93m[NO]�[0m ....... �[92m[OKAY]�[0m
fused_lamb ............. �[93m[NO]�[0m ....... �[92m[OKAY]�[0m
fused_lion ............. �[93m[NO]�[0m ....... �[92m[OKAY]�[0m
�[93m [WARNING] �[0m gds requires the dev libaio .so object and headers but these were not found.
�[93m [WARNING] �[0m gds: please install the libaio-devel package with yum
�[93m [WARNING] �[0m If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
gds .................... �[93m[NO]�[0m ....... �[93m[NO]�[0m
transformer_inference .. �[93m[NO]�[0m ....... �[92m[OKAY]�[0m
inference_core_ops ..... �[93m[NO]�[0m ....... �[92m[OKAY]�[0m
cutlass_ops ............ �[93m[NO]�[0m ....... �[92m[OKAY]�[0m
quantizer .............. �[93m[NO]�[0m ....... �[92m[OKAY]�[0m
ragged_device_ops ...... �[93m[NO]�[0m ....... �[92m[OKAY]�[0m
ragged_ops ............. �[93m[NO]�[0m ....... �[92m[OKAY]�[0m
random_ltd ............. �[93m[NO]�[0m ....... �[92m[OKAY]�[0m
�[93m [WARNING] �[0m sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.4
�[93m [WARNING] �[0m using untested triton version (3.0.0), only 1.0.0 is known to be compatible
sparse_attn ............ �[93m[NO]�[0m ....... �[93m[NO]�[0m
spatial_inference ...... �[93m[NO]�[0m ....... �[92m[OKAY]�[0m
transformer ............ [NO] ....... �[92m[OKAY]�[0m
stochastic_transformer . [NO] ....... �[92m[OKAY]�[0m
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/leonardo/home/userexternal/lzini000/svgen2/lib/python3.11/site-packages/torch']
torch version .................... 2.4.1+cu118
deepspeed install path ........... ['/leonardo/home/userexternal/lzini000/svgen2/lib/python3.11/site-packages/deepspeed']
deepspeed info ................... 0.16.1, unknown, unknown
torch cuda version ............... 11.8
torch hip version ................ None
nvcc version ..................... 11.8
deepspeed wheel compiled w. ...... torch 2.4, cuda 11.8
shared memory (/dev/shm) size .... 251.43 GB

output log

Loading trainer...
Number of model parameter: 4568002560
Start training...
trainable params: 27,262,976 || all params: 8,057,663,488 || trainable%: 0.3383
trainable params: 1,078,075,392 || all params: 8,057,663,488 || trainable%: 13.3795
[2024-12-18 17:33:04,522] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-12-18 17:33:05,517] [INFO] [comm.py:652:init_distributed] cdb=None
Loading trainer...
Start training...
[1/3] g++ -MMD -MF cpu_adam.o.d -DTORCH_EXTENSION_NAME=cpu_adam -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -I/leonardo/home/userexternal/lzini000/svgen2/lib/python3.11/site-packages/deepspeed/ops/csrc/includes -isystem /leonardo/home/userexternal/lzini000/svgen2/lib/python3.11/site-packages/torch/include -isystem /leonardo/home/userexternal/lzini000/svgen2/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /leonardo/home/userexternal/lzini000/svgen2/lib/python3.11/site-packages/torch/include/TH -isystem /leonardo/home/userexternal/lzini000/svgen2/lib/python3.11/site-packages/torch/include/THC -isystem /leonardo/home/userexternal/lzini000/.pyenv/versions/3.11.9/include/python3.11 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 -O3 -std=c++17 -g -Wno-reorder -L/leonardo/prod/opt/compilers/cuda/11.8/none/lib64 -lcudart -lcublas -g -march=native -fopenmp -D__AVX512__ -D__ENABLE_CUDA__ -DBF16_AVAILABLE -c /leonardo/home/userexternal/lzini000/svgen2/lib/python3.11/site-packages/deepspeed/ops/csrc/adam/cpu_adam.cpp -o cpu_adam.o 
[2/3] g++ -MMD -MF cpu_adam_impl.o.d -DTORCH_EXTENSION_NAME=cpu_adam -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -I/leonardo/home/userexternal/lzini000/svgen2/lib/python3.11/site-packages/deepspeed/ops/csrc/includes -isystem /leonardo/home/userexternal/lzini000/svgen2/lib/python3.11/site-packages/torch/include -isystem /leonardo/home/userexternal/lzini000/svgen2/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /leonardo/home/userexternal/lzini000/svgen2/lib/python3.11/site-packages/torch/include/TH -isystem /leonardo/home/userexternal/lzini000/svgen2/lib/python3.11/site-packages/torch/include/THC -isystem /leonardo/home/userexternal/lzini000/.pyenv/versions/3.11.9/include/python3.11 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 -O3 -std=c++17 -g -Wno-reorder -L/leonardo/prod/opt/compilers/cuda/11.8/none/lib64 -lcudart -lcublas -g -march=native -fopenmp -D__AVX512__ -D__ENABLE_CUDA__ -DBF16_AVAILABLE -c /leonardo/home/userexternal/lzini000/svgen2/lib/python3.11/site-packages/deepspeed/ops/csrc/adam/cpu_adam_impl.cpp -o cpu_adam_impl.o 
[3/3] g++ cpu_adam.o cpu_adam_impl.o -shared -lcurand -L/leonardo/prod/opt/compilers/cuda/11.8/none/lib64 -L/leonardo/home/userexternal/lzini000/svgen2/lib/python3.11/site-packages/torch/lib -lc10 -ltorch_cpu -ltorch -ltorch_python -o cpu_adam.so
Time to load cpu_adam op: 30.90044140815735 seconds
Time to load cpu_adam op: 30.946683883666992 seconds
Parameter Offload: Total persistent parameters: 266240 in 65 params
{'train_runtime': 72.9273, 'train_samples_per_second': 0.11, 'train_steps_per_second': 0.055, 'train_loss': 0.5185592025518417, 'epoch': 1.0}
Training ended
Number of model parameter: 266240

Ass enlighten in this logs, the number of parameter goes from 4568002560 before the training loop, to 266240 after the training loop (the voice Parameter Offload makes me thinking..).

System info :

  • OS: Linux kernel 4.18.0-425.19.2.el8_7.x86_64
  • GPU count and types: cluster of NVIDIA A100 (nodes of 4 gpus)
  • Python version 3.1..9

Launcher context
I am launching with torchrun,
srun torchrun --nnodes=1 --nproc-per-node=2 --rdzv-endpoint=$MASTER_ADDR:$MASTER_PORT --rdzv-id=$SLURM_JOB_NAME --rdzv-backend="c10d" --max_restarts=$MAX_RESTARTS trainer.py

@LeonardoZini LeonardoZini added bug Something isn't working training labels Dec 18, 2024
@LeonardoZini LeonardoZini changed the title [BUG] DeepSpeed with ZeRO3 strategy cannot build 'fused_adam' DeepSpeed with ZeRO3 strategy cannot build 'fused_adam' Dec 18, 2024
@tjruwase
Copy link
Contributor

@LeonardoZini, can you please share the log showing the fused_adam build error message?

Ass enlighten in this logs, the number of parameter goes from 4568002560 before the training loop, to 266240 after the training loop (the voice Parameter Offload makes me thinking..).

With zero stage 3 model sharding, special handling is required to access parameters. See following links for more details.

  1. Using deepspeed.zero.GatheredParameters context manager
  2. https://deepspeed.readthedocs.io/en/stable/zero3.html#debugging
  3. https://deepspeed.readthedocs.io/en/stable/zero3.html#modifying-partitioned-states

@LeonardoZini
Copy link
Author

The log are this one

Building extension module fused_adam...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
[rank1]: Traceback (most recent call last):
[rank1]:   File "[..]/svgen2/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2105, in _run_ninja_build
[rank1]:     subprocess.run(
[rank1]:   File "[..]/.pyenv/versions/3.11.9/lib/python3.11/subprocess.py", line 571, in run
[rank1]:     raise CalledProcessError(retcode, process.args,
[rank1]: subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.

[rank1]: The above exception was the direct cause of the following exception:

[rank1]: Traceback (most recent call last):
[rank1]:   File "[..]/svgen/Parser/trainer.py", line 163, in <module>
[rank1]:     trainer.train() 
[rank1]:     ^^^^^^^^^^^^^^^
[rank1]:   File "[..]/svgen2/lib/python3.11/site-packages/transformers/trainer.py", line 2164, in train
[rank1]:     return inner_training_loop(
[rank1]:            ^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "[..]/svgen2/lib/python3.11/site-packages/transformers/trainer.py", line 2325, in _inner_training_loop
[rank1]:     model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
[rank1]:                                                ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "[..]/svgen2/lib/python3.11/site-packages/accelerate/accelerator.py", line 1318, in prepare
[rank1]:     result = self._prepare_deepspeed(*args)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "[..]/svgen2/lib/python3.11/site-packages/accelerate/accelerator.py", line 1815, in _prepare_deepspeed
[rank1]:     engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs)
[rank1]:                                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "[..]/svgen2/lib/python3.11/site-packages/deepspeed/__init__.py", line 193, in initialize
[rank1]:     engine = DeepSpeedEngine(args=args,
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "[..]/svgen2/lib/python3.11/site-packages/deepspeed/runtime/engine.py", line 313, in __init__
[rank1]:     self._configure_optimizer(optimizer, model_parameters)
[rank1]:   File "[..]/svgen2/lib/python3.11/site-packages/deepspeed/runtime/engine.py", line 1276, in _configure_optimizer
[rank1]:     basic_optimizer = self._configure_basic_optimizer(model_parameters)
[rank1]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "[..]/svgen2/lib/python3.11/site-packages/deepspeed/runtime/engine.py", line 1353, in _configure_basic_optimizer
[rank1]:     optimizer = FusedAdam(
[rank1]:                 ^^^^^^^^^^
[rank1]:   File "/[..]/svgen2/lib/python3.11/site-packages/deepspeed/ops/adam/fused_adam.py", line 94, in __init__
[rank1]:     fused_adam_cuda = FusedAdamBuilder().load()
[rank1]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "[..]/svgen2/lib/python3.11/site-packages/deepspeed/ops/op_builder/builder.py", line 531, in load
[rank1]:     return self.jit_load(verbose)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "[..]/svgen2/lib/python3.11/site-packages/deepspeed/ops/op_builder/builder.py", line 578, in jit_load
[rank1]:     op_module = load(name=self.name,
[rank1]:                 ^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "[..]/svgen2/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 1312, in load
[rank1]:     return _jit_compile(
[rank1]:            ^^^^^^^^^^^^^
[rank1]:   File "[..]/svgen2/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 1722, in _jit_compile
[rank1]:     _write_ninja_file_and_build_library(
[rank1]:   File "[..]/svgen2/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 1834, in _write_ninja_file_and_build_library
[rank1]:     _run_ninja_build(
[rank1]:   File "[..]/svgen2/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2121, in _run_ninja_build
[rank1]:     raise RuntimeError(message) from e
[rank1]: RuntimeError: Error building extension 'fused_adam'
Loading extension module fused_adam...
[rank0]: Traceback (most recent call last):
[rank0]:   File "[..]/svgen/Parser/trainer.py", line 163, in <module>
[rank0]:     trainer.train() 
[rank0]:     ^^^^^^^^^^^^^^^
[rank0]:   File "[..]/svgen2/lib/python3.11/site-packages/transformers/trainer.py", line 2164, in train
[rank0]:     return inner_training_loop(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "[..]/svgen2/lib/python3.11/site-packages/transformers/trainer.py", line 2325, in _inner_training_loop
[rank0]:     model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
[rank0]:                                                ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "[..]/svgen2/lib/python3.11/site-packages/accelerate/accelerator.py", line 1318, in prepare
[rank0]:     result = self._prepare_deepspeed(*args)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "[..]/svgen2/lib/python3.11/site-packages/accelerate/accelerator.py", line 1815, in _prepare_deepspeed
[rank0]:     engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs)
[rank0]:                                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "[..]/svgen2/lib/python3.11/site-packages/deepspeed/__init__.py", line 193, in initialize
[rank0]:     engine = DeepSpeedEngine(args=args,
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "[..]/svgen2/lib/python3.11/site-packages/deepspeed/runtime/engine.py", line 313, in __init__
[rank0]:     self._configure_optimizer(optimizer, model_parameters)
[rank0]:   File "[..]/svgen2/lib/python3.11/site-packages/deepspeed/runtime/engine.py", line 1276, in _configure_optimizer
[rank0]:     basic_optimizer = self._configure_basic_optimizer(model_parameters)
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "[..]/svgen2/lib/python3.11/site-packages/deepspeed/runtime/engine.py", line 1353, in _configure_basic_optimizer
[rank0]:     optimizer = FusedAdam(
[rank0]:                 ^^^^^^^^^^
[rank0]:   File "[..]/svgen2/lib/python3.11/site-packages/deepspeed/ops/adam/fused_adam.py", line 94, in __init__
[rank0]:     fused_adam_cuda = FusedAdamBuilder().load()
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "[..]/svgen2/lib/python3.11/site-packages/deepspeed/ops/op_builder/builder.py", line 531, in load
[rank0]:     return self.jit_load(verbose)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "[..]/svgen2/lib/python3.11/site-packages/deepspeed/ops/op_builder/builder.py", line 578, in jit_load
[rank0]:     op_module = load(name=self.name,
[rank0]:                 ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "[..]/svgen2/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 1312, in load
[rank0]:     return _jit_compile(
[rank0]:            ^^^^^^^^^^^^^
[rank0]:   File "[..]/svgen2/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 1747, in _jit_compile
[rank0]:     return _import_module_from_library(name, build_directory, is_python_module)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "[..]/svgen2/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2141, in _import_module_from_library
[rank0]:     module = importlib.util.module_from_spec(spec)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "<frozen importlib._bootstrap>", line 573, in module_from_spec
[rank0]:   File "<frozen importlib._bootstrap_external>", line 1233, in create_module
[rank0]:   File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
[rank0]: ImportError: [..]/.cache/torch_extensions/py311_cu118/fused_adam/fused_adam.so: cannot open shared object file: No such file or directory
[rank0]:[W1218 18:06:25.780776374 ProcessGroupNCCL.cpp:1168] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present,  but this warning has only been added since PyTorch 2.4 (function operator())

and this

[1/2] /leonardo/prod/opt/compilers/cuda/11.8/none/bin/nvcc --generate-dependencies-with-compile --dependency-output multi_tensor_adam.cuda.o.d -ccbin gcc -DTORCH_EXTENSION_NAME=fused_adam -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -I[..]/svgen2/lib/python3.11/site-packages/deepspeed/ops/csrc/includes -I[..]/svgen2/lib/python3.11/site-packages/deepspeed/ops/csrc/adam -isystem [..]/svgen2/lib/python3.11/site-packages/torch/include -isystem [..]/svgen2/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem [..]/svgen2/lib/python3.11/site-packages/torch/include/TH -isystem [..]/svgen2/lib/python3.11/site-packages/torch/include/THC -isystem /leonardo/prod/opt/compilers/cuda/11.8/none/include -isystem [..]/.pyenv/versions/3.11.9/include/python3.11 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_80,code=sm_80 --compiler-options '-fPIC' -O3 -DVERSION_GE_1_1 -DVERSION_GE_1_3 -DVERSION_GE_1_5 -lineinfo --use_fast_math -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_80,code=compute_80 -DBF16_AVAILABLE -U__CUDA_NO_BFLOAT16_OPERATORS__ -U__CUDA_NO_BFLOAT162_OPERATORS__ -U__CUDA_NO_BFLOAT16_CONVERSIONS__ -std=c++17 -c [..]/svgen2/lib/python3.11/site-packages/deepspeed/ops/csrc/adam/multi_tensor_adam.cu -o multi_tensor_adam.cuda.o 
FAILED: multi_tensor_adam.cuda.o 
/leonardo/prod/opt/compilers/cuda/11.8/none/bin/nvcc --generate-dependencies-with-compile --dependency-output multi_tensor_adam.cuda.o.d -ccbin gcc -DTORCH_EXTENSION_NAME=fused_adam -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -I[..]/svgen2/lib/python3.11/site-packages/deepspeed/ops/csrc/includes -I[..]/svgen2/lib/python3.11/site-packages/deepspeed/ops/csrc/adam -isystem /leonardo/home/userexternal/lzini000/svgen2/lib/python3.11/site-packages/torch/include -isystem [..]/svgen2/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem [..]/svgen2/lib/python3.11/site-packages/torch/include/TH -isystem [..]/svgen2/lib/python3.11/site-packages/torch/include/THC -isystem /leonardo/prod/opt/compilers/cuda/11.8/none/include -isystem [..]/.pyenv/versions/3.11.9/include/python3.11 -D_GLIBCXX_USE_CXX11_ABI=0 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_80,code=sm_80 --compiler-options '-fPIC' -O3 -DVERSION_GE_1_1 -DVERSION_GE_1_3 -DVERSION_GE_1_5 -lineinfo --use_fast_math -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_80,code=compute_80 -DBF16_AVAILABLE -U__CUDA_NO_BFLOAT16_OPERATORS__ -U__CUDA_NO_BFLOAT162_OPERATORS__ -U__CUDA_NO_BFLOAT16_CONVERSIONS__ -std=c++17 -c [..]/svgen2/lib/python3.11/site-packages/deepspeed/ops/csrc/adam/multi_tensor_adam.cu -o multi_tensor_adam.cuda.o 
/leonardo/prod/spack/5.2/install/0.21/linux-rhel8-icelake/gcc-8.5.0/gcc-12.2.0-gmhym3kmbzqlpwkzhgab2xsoygsdwxcl/lib/gcc/x86_64-pc-linux-gnu/12.2.0/../../../../include/c++/12.2.0/bits/locale_facets_nonio.tcc: In member function '_InIter std::time_get<_CharT, _InIter>::get(iter_type, iter_type, std::ios_base&, std::ios_base::iostate&, tm*, const char_type*, const char_type*) const':
/leonardo/prod/spack/5.2/install/0.21/linux-rhel8-icelake/gcc-8.5.0/gcc-12.2.0-gmhym3kmbzqlpwkzhgab2xsoygsdwxcl/lib/gcc/x86_64-pc-linux-gnu/12.2.0/../../../../include/c++/12.2.0/bits/locale_facets_nonio.tcc:1477:77: error: invalid type argument of unary '*' (have 'int')
 1477 |       if ((void*)(this->*(&time_get::do_get)) == (void*)(&time_get::do_get))
      |                                                                             ^   
/leonardo/prod/spack/5.2/install/0.21/linux-rhel8-icelake/gcc-8.5.0/gcc-12.2.0-gmhym3kmbzqlpwkzhgab2xsoygsdwxcl/lib/gcc/x86_64-pc-linux-gnu/12.2.0/../../../../include/c++/12.2.0/bits/stl_map.h: In member function 'std::pair<typename std::_Rb_tree<_Key, std::pair<const _Key, _Val>, std::_Select1st<std::pair<const _Key, _Val> >, _Compare, typename __gnu_cxx::__alloc_traits<_Allocator>::rebind<std::pair<const _Key, _Val> >::other>::iterator, bool> std::map<_Key, _Tp, _Compare, _Alloc>::emplace(_Args&& ...)':
/leonardo/prod/spack/5.2/install/0.21/linux-rhel8-icelake/gcc-8.5.0/gcc-12.2.0-gmhym3kmbzqlpwkzhgab2xsoygsdwxcl/lib/gcc/x86_64-pc-linux-gnu/12.2.0/../../../../include/c++/12.2.0/bits/stl_map.h:593:29: error: parameter packs not expanded with '...':
  593 |                 if constexpr (__usable_key<decltype(__a)>)
      |                             ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~                             
/leonardo/prod/spack/5.2/install/0.21/linux-rhel8-icelake/gcc-8.5.0/gcc-12.2.0-gmhym3kmbzqlpwkzhgab2xsoygsdwxcl/lib/gcc/x86_64-pc-linux-gnu/12.2.0/../../../../include/c++/12.2.0/bits/stl_map.h:593:29: note:         '_Args'
ninja: build stopped: subcommand failed.

Thank you for the references!

@tjruwase
Copy link
Contributor

@LeonardoZini, I noticed a compiler error in your build log
Image

Can you try the following to see if it would reproduce the compile-time error?

>>> import torch, deepspeed
>>> from deepspeed.ops.adam.fused_adam import FusedAdam
>>> x = FusedAdam([torch.empty(100)])
``

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests

2 participants