Skip to content

Commit

Permalink
handle torch_compile set to auto (axolotl-ai-cloud#2172) [skip ci]
Browse files Browse the repository at this point in the history
* handle torch_compile set to auto

* update docs [skip ci]

* add tests
  • Loading branch information
winglian authored Dec 17, 2024
1 parent 10cfecf commit 3798229
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 4 deletions.
3 changes: 2 additions & 1 deletion docs/config.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,8 @@ comet_experiment_config: # Dictionary for additional configuration settings, see
output_dir: ./completed-model

# Whether to use torch.compile and which backend to use
torch_compile: # bool
# setting to `auto` will enable torch compile when torch>=2.5.1
torch_compile: # Optional[Union[Literal["auto"], bool]]
torch_compile_backend: # Optional[str]

# Training hyperparameters
Expand Down
4 changes: 2 additions & 2 deletions src/axolotl/utils/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,8 @@ def validate_config(
) = merge_input_args()

if capabilities or env_capabilities:
if (capabilities and not env_capabilities) or (
env_capabilities and not capabilities
if (capabilities and env_capabilities is None) or (
env_capabilities and capabilities is None
):
raise ValueError(
"Both capabilities and env_capabilities must be provided or not provided."
Expand Down
21 changes: 20 additions & 1 deletion src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ class Config:
special_tokens: Optional[SpecialTokensConfig] = None
tokens: Optional[List[str]] = None

torch_compile: Optional[bool] = None
torch_compile: Optional[Union[Literal["auto"], bool]] = None
torch_compile_backend: Optional[str] = None
torch_compile_mode: Optional[
Literal["default", "reduce-overhead", "max-autotune"]
Expand Down Expand Up @@ -1582,3 +1582,22 @@ def check_adopt_torch_version(cls, data):
"ADOPT optimizer is incompatible with torch version < 2.5.1"
)
return data

@model_validator(mode="before")
@classmethod
def check_torch_compile_auto(cls, data):
if data.get("torch_compile") == "auto":
env_capabilities = data.get("env_capabilities", {})
if env_capabilities.get("torch_version"):
if version.parse(
env_capabilities.get("torch_version")
) >= version.parse("2.5.1"):
LOG.info(
"torch.compile is available, setting torch_compile to True"
)
data["torch_compile"] = True
else:
data["torch_compile"] = False
else:
data["torch_compile"] = False
return data
40 changes: 40 additions & 0 deletions tests/patched/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,6 +1196,46 @@ def test_torch_version_adopt_req(self, minimal_cfg):
)


class TestTorchCompileValidation(BaseValidation):
"""
test suite for when torch_compile is set to 'auto'
"""

def test_torch_compile_auto(self, minimal_cfg):
cfg = (
DictDefault(
{
"torch_compile": "auto",
}
)
| minimal_cfg
)

env_capabilities = {"torch_version": "2.5.1"}
capabilities = {"bf16": True}
updated_cfg = validate_config(
cfg, capabilities=capabilities, env_capabilities=env_capabilities
)

assert updated_cfg.torch_compile is True

env_capabilities = {"torch_version": "2.4.1"}
capabilities = {"bf16": True}
updated_cfg = validate_config(
cfg, capabilities=capabilities, env_capabilities=env_capabilities
)

assert updated_cfg.torch_compile is False

env_capabilities = {}
capabilities = {"bf16": True}
updated_cfg = validate_config(
cfg, capabilities=capabilities, env_capabilities=env_capabilities
)

assert updated_cfg.torch_compile is False


class TestValidationCheckModelConfig(BaseValidation):
"""
Test the validation for the config when the model config is available
Expand Down

0 comments on commit 3798229

Please sign in to comment.