Skip to content

Commit

Permalink
dist.accelerate component
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Mar 1, 2024
1 parent 42de5b5 commit db744a1
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 8 deletions.
3 changes: 2 additions & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
accelerate
aiobotocore
ax-platform[mysql]==0.2.3
black==23.3.0
Expand All @@ -20,7 +21,7 @@ protobuf==3.20.3
pyre-extensions
pyre-check
pytest
pytorch-lightning==1.5.10
pytorch-lightning==2.2.0
torch-model-archiver>=0.4.2
torch>=1.10.0
torchmetrics<0.11.0
Expand Down
115 changes: 115 additions & 0 deletions torchx/components/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,3 +350,118 @@ def parse_nnodes(j: str) -> Tuple[int, int, int, str]:
f"Invalid format for -j, usage example: 1:2x4 or 1x4 or 4. Given: {j}"
)
return int(min_nnodes), int(max_nnodes), int(nproc_per_node), nnodes_rep

def accelerate(
*script_args: str,
script: Optional[str] = None,
image: str = torchx.IMAGE,
name: str = "/",
h: Optional[str] = None,
cpu: int = 2,
gpu: int = 0,
memMB: int = 1024,
j: str = "1x2",
env: Optional[Dict[str, str]] = None,
max_retries: int = 0,
main_process_port: int = 29500,
accelerate_args: Optional[List[str]] = None,
mounts: Optional[List[str]] = None,
debug: bool = False,
) -> specs.AppDef:
"""
A component that uses HuggingFace accelerate to launch the job.
Args:
script_args: arguments to the main module
script: script or binary to run within the image
image: image (e.g. docker)
name: job name override in the following format: ``{experimentname}/{runname}`` or ``{experimentname}/`` or ``/{runname}`` or ``{runname}``.
Uses the script or module name if ``{runname}`` not specified.
cpu: number of cpus per replica
gpu: number of gpus per replica
memMB: cpu memory in MB per replica
h: a registered named resource (if specified takes precedence over cpu, gpu, memMB)
j: [{min_nnodes}:]{nnodes}x{nproc_per_node}, for gpu hosts, nproc_per_node must not exceed num gpus
env: environment varibles to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
max_retries: the number of scheduler retries allowed
main_process_port: the port on rank0's host to use for coordinating the workers.
Only takes effect when running multi-node. When running single node, this parameter
is ignored and a random free port is chosen.
mounts: mounts to mount into the worker environment/container (ex. type=<bind/volume>,src=/host,dst=/job[,readonly]).
See scheduler documentation for more info.
debug: whether to run with preset debug flags enabled
"""

# nnodes: number of nodes or minimum nodes for elastic launch
# max_nnodes: maximum number of nodes for elastic launch
# nproc_per_node: number of processes on each node
min_nnodes, max_nnodes, nproc_per_node, nnodes_rep = parse_nnodes(j)

assert min_nnodes == max_nnodes, "accelerate component doesn't support elasticity"

if max_nnodes == 1:
# using port 0 makes elastic chose a free random port which is ok
# for single-node jobs since all workers run under a single agent
# When nnodes is 0 and max_nnodes is 1, it's stil a single node job
# but pending until the resources become available
main_process_ip = "localhost"
else:
# for multi-node, rely on the rank0_env environment variable set by
# the schedulers (see scheduler implementation for the actual env var this maps to)
# some schedulers (e.g. aws batch) make the rank0's ip-addr available on all BUT on rank0
# so default to "localhost" if the env var is not set or is empty
# rdzv_endpoint bash resolves to something to the effect of
# ${TORCHX_RANK0_HOST:=localhost}:29500
# use $$ in the prefix to escape the '$' literal (rather than a string Template substitution argument)
main_process_ip = _noquote(f"$${{{macros.rank0_env}:=localhost}}")

argname = StructuredNameArgument.parse_from(
name=name,
m=None,
script=script,
)

if env is None:
env = {}

if debug:
env.update(_TORCH_DEBUG_FLAGS)

env["TORCHX_TRACKING_EXPERIMENT_NAME"] = argname.experiment_name

cmd = [
"accelerate",
"launch",
f"--main_process_ip",
main_process_ip,
f"--main_process_port={main_process_port}",
f"--num_machines={max_nnodes}",
f"--num_processes={nproc_per_node*max_nnodes}",
f"--machine_rank={macros.replica_id}",
f"--max_restarts={max_retries}",
]
if accelerate_args is not None:
cmd += accelerate_args
cmd += [script]
cmd += script_args

return specs.AppDef(
name=argname.run_name,
roles=[
specs.Role(
name=get_role_name(script, None),
image=image,
min_replicas=min_nnodes,
entrypoint="bash",
num_replicas=int(max_nnodes),
resource=specs.resource(cpu=cpu, gpu=gpu, memMB=memMB, h=h),
args=["-c", _args_join(cmd)],
env=env,
port_map={
"accelerate": main_process_port,
},
max_retries=max_retries,
mounts=specs.parse_mounts(mounts) if mounts else [],
)
],
)
16 changes: 16 additions & 0 deletions torchx/components/integration_tests/component_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,22 @@ def get_app_def(self) -> AppDef:
)


class AccelerateComponentProvider(ComponentProvider):
def get_app_def(self) -> AppDef:
return dist_components.accelerate(
script="torchx/examples/apps/compute_world_size/main.py",
name="accelerate-compute",
image=self._image,
cpu=1,
j="2x2",
max_retries=3,
main_process_port=19501,
env={
"LOGLEVEL": "INFO",
},
)


class ServeComponentProvider(ComponentProvider):
# TODO(aivanou): Remove dryrun and test e2e serve component+app
def get_app_def(self) -> AppDef:
Expand Down
39 changes: 38 additions & 1 deletion torchx/components/test/dist_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
# LICENSE file in the root directory of this source tree.

from torchx.components.component_test_base import ComponentTestCase
from torchx.components.dist import _TORCH_DEBUG_FLAGS, ddp, parse_nnodes, spmd
from torchx.components.dist import (
_TORCH_DEBUG_FLAGS, ddp, parse_nnodes, spmd,
accelerate,
)


class DDPTest(ComponentTestCase):
Expand Down Expand Up @@ -148,3 +151,37 @@ def test_spmd_call_by_module_or_script_with_run_name(self) -> None:
"default-experiment",
appdef.roles[0].env["TORCHX_TRACKING_EXPERIMENT_NAME"],
)

class AccelerateTest(ComponentTestCase):
def test_ddp(self) -> None:
import torchx.components.dist as dist

self.validate(dist, "accelerate")

def test_basic(self) -> None:
app = accelerate(
"--script_arg",
script="foo.py",
j="2x2",
accelerate_args=["--accelerate_arg"],
env={"a": "b"}
)
self.assertEqual(len(app.roles), 1)
role = app.roles[0]
self.assertEqual(role.num_replicas, 2)
args = " ".join(role.args)
self.assertIn("--script_arg", args)
self.assertIn("--accelerate_arg", args)
self.assertIn("a", role.env)

def test_mounts(self) -> None:
app = accelerate(
script="foo.py", mounts=["type=bind", "src=/dst", "dst=/dst", "readonly"]
)
self.assertEqual(len(app.roles[0].mounts), 1)

def test_debug(self) -> None:
app = accelerate(script="foo.py", debug=True)
env = app.roles[0].env
for k, v in _TORCH_DEBUG_FLAGS.items():
self.assertEqual(env[k], v)
4 changes: 2 additions & 2 deletions torchx/examples/apps/lightning/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def __init__(
m.fc.out_features = 200
self.model: ResNet = m

self.train_acc = Accuracy()
self.val_acc = Accuracy()
self.train_acc = Accuracy(task="multiclass", num_classes=1000)
self.val_acc = Accuracy(task="multiclass", num_classes=1000)

# pyre-fixme[14]
def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down
8 changes: 4 additions & 4 deletions torchx/examples/apps/lightning/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,18 @@
import time
from typing import Dict

from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.profiler.base import BaseProfiler
from pytorch_lightning.loggers.logger import Logger
from pytorch_lightning.profilers import Profiler


class SimpleLoggingProfiler(BaseProfiler):
class SimpleLoggingProfiler(Profiler):
"""
This profiler records the duration of actions (in seconds) and reports the
mean duration of each action to the specified logger. Reported metrics are
in the format `duration_<event>`.
"""

def __init__(self, logger: LightningLoggerBase) -> None:
def __init__(self, logger: Logger) -> None:
super().__init__()

self.current_actions: Dict[str, float] = {}
Expand Down

0 comments on commit db744a1

Please sign in to comment.