Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dist.accelerate component #838

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
accelerate
aiobotocore
ax-platform[mysql]==0.2.3
black==23.3.0
Expand All @@ -20,7 +21,7 @@ protobuf==3.20.3
pyre-extensions
pyre-check
pytest
pytorch-lightning==1.5.10
pytorch-lightning==2.2.0
torch-model-archiver>=0.4.2
torch>=1.10.0
torchmetrics<0.11.0
Expand Down
115 changes: 115 additions & 0 deletions torchx/components/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,3 +350,118 @@ def parse_nnodes(j: str) -> Tuple[int, int, int, str]:
f"Invalid format for -j, usage example: 1:2x4 or 1x4 or 4. Given: {j}"
)
return int(min_nnodes), int(max_nnodes), int(nproc_per_node), nnodes_rep

def accelerate(
*script_args: str,
script: Optional[str] = None,
image: str = torchx.IMAGE,
name: str = "/",
h: Optional[str] = None,
cpu: int = 2,
gpu: int = 0,
memMB: int = 1024,
j: str = "1x2",
env: Optional[Dict[str, str]] = None,
max_retries: int = 0,
main_process_port: int = 29500,
accelerate_args: Optional[List[str]] = None,
mounts: Optional[List[str]] = None,
debug: bool = False,
) -> specs.AppDef:
"""
A component that uses HuggingFace accelerate to launch the job.

Args:
script_args: arguments to the main module
script: script or binary to run within the image
image: image (e.g. docker)
name: job name override in the following format: ``{experimentname}/{runname}`` or ``{experimentname}/`` or ``/{runname}`` or ``{runname}``.
Uses the script or module name if ``{runname}`` not specified.
cpu: number of cpus per replica
gpu: number of gpus per replica
memMB: cpu memory in MB per replica
h: a registered named resource (if specified takes precedence over cpu, gpu, memMB)
j: [{min_nnodes}:]{nnodes}x{nproc_per_node}, for gpu hosts, nproc_per_node must not exceed num gpus
env: environment varibles to be passed to the run (e.g. ENV1=v1,ENV2=v2,ENV3=v3)
max_retries: the number of scheduler retries allowed
main_process_port: the port on rank0's host to use for coordinating the workers.
Only takes effect when running multi-node. When running single node, this parameter
is ignored and a random free port is chosen.
mounts: mounts to mount into the worker environment/container (ex. type=<bind/volume>,src=/host,dst=/job[,readonly]).
See scheduler documentation for more info.
debug: whether to run with preset debug flags enabled
"""

# nnodes: number of nodes or minimum nodes for elastic launch
# max_nnodes: maximum number of nodes for elastic launch
# nproc_per_node: number of processes on each node
min_nnodes, max_nnodes, nproc_per_node, nnodes_rep = parse_nnodes(j)

assert min_nnodes == max_nnodes, "accelerate component doesn't support elasticity"

if max_nnodes == 1:
# using port 0 makes elastic chose a free random port which is ok
# for single-node jobs since all workers run under a single agent
# When nnodes is 0 and max_nnodes is 1, it's stil a single node job
# but pending until the resources become available
main_process_ip = "localhost"
else:
# for multi-node, rely on the rank0_env environment variable set by
# the schedulers (see scheduler implementation for the actual env var this maps to)
# some schedulers (e.g. aws batch) make the rank0's ip-addr available on all BUT on rank0
# so default to "localhost" if the env var is not set or is empty
# rdzv_endpoint bash resolves to something to the effect of
# ${TORCHX_RANK0_HOST:=localhost}:29500
# use $$ in the prefix to escape the '$' literal (rather than a string Template substitution argument)
main_process_ip = _noquote(f"$${{{macros.rank0_env}:=localhost}}")

argname = StructuredNameArgument.parse_from(
name=name,
m=None,
script=script,
)

if env is None:
env = {}

if debug:
env.update(_TORCH_DEBUG_FLAGS)

env["TORCHX_TRACKING_EXPERIMENT_NAME"] = argname.experiment_name

cmd = [
"accelerate",
"launch",
f"--main_process_ip",
main_process_ip,
f"--main_process_port={main_process_port}",
f"--num_machines={max_nnodes}",
f"--num_processes={nproc_per_node*max_nnodes}",
f"--machine_rank={macros.replica_id}",
f"--max_restarts={max_retries}",
]
if accelerate_args is not None:
cmd += accelerate_args
cmd += [script]
cmd += script_args

return specs.AppDef(
name=argname.run_name,
roles=[
specs.Role(
name=get_role_name(script, None),
image=image,
min_replicas=min_nnodes,
entrypoint="bash",
num_replicas=int(max_nnodes),
resource=specs.resource(cpu=cpu, gpu=gpu, memMB=memMB, h=h),
args=["-c", _args_join(cmd)],
env=env,
port_map={
"accelerate": main_process_port,
},
max_retries=max_retries,
mounts=specs.parse_mounts(mounts) if mounts else [],
)
],
)
16 changes: 16 additions & 0 deletions torchx/components/integration_tests/component_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,22 @@ def get_app_def(self) -> AppDef:
)


class AccelerateComponentProvider(ComponentProvider):
def get_app_def(self) -> AppDef:
return dist_components.accelerate(
script="torchx/examples/apps/compute_world_size/main.py",
name="accelerate-compute",
image=self._image,
cpu=1,
j="2x2",
max_retries=3,
main_process_port=19501,
env={
"LOGLEVEL": "INFO",
},
)


class ServeComponentProvider(ComponentProvider):
# TODO(aivanou): Remove dryrun and test e2e serve component+app
def get_app_def(self) -> AppDef:
Expand Down
39 changes: 38 additions & 1 deletion torchx/components/test/dist_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
# LICENSE file in the root directory of this source tree.

from torchx.components.component_test_base import ComponentTestCase
from torchx.components.dist import _TORCH_DEBUG_FLAGS, ddp, parse_nnodes, spmd
from torchx.components.dist import (
_TORCH_DEBUG_FLAGS, ddp, parse_nnodes, spmd,
accelerate,
)


class DDPTest(ComponentTestCase):
Expand Down Expand Up @@ -148,3 +151,37 @@ def test_spmd_call_by_module_or_script_with_run_name(self) -> None:
"default-experiment",
appdef.roles[0].env["TORCHX_TRACKING_EXPERIMENT_NAME"],
)

class AccelerateTest(ComponentTestCase):
def test_ddp(self) -> None:
import torchx.components.dist as dist

self.validate(dist, "accelerate")

def test_basic(self) -> None:
app = accelerate(
"--script_arg",
script="foo.py",
j="2x2",
accelerate_args=["--accelerate_arg"],
env={"a": "b"}
)
self.assertEqual(len(app.roles), 1)
role = app.roles[0]
self.assertEqual(role.num_replicas, 2)
args = " ".join(role.args)
self.assertIn("--script_arg", args)
self.assertIn("--accelerate_arg", args)
self.assertIn("a", role.env)

def test_mounts(self) -> None:
app = accelerate(
script="foo.py", mounts=["type=bind", "src=/dst", "dst=/dst", "readonly"]
)
self.assertEqual(len(app.roles[0].mounts), 1)

def test_debug(self) -> None:
app = accelerate(script="foo.py", debug=True)
env = app.roles[0].env
for k, v in _TORCH_DEBUG_FLAGS.items():
self.assertEqual(env[k], v)
4 changes: 2 additions & 2 deletions torchx/examples/apps/lightning/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def __init__(
m.fc.out_features = 200
self.model: ResNet = m

self.train_acc = Accuracy()
self.val_acc = Accuracy()
self.train_acc = Accuracy(task="multiclass", num_classes=1000)
self.val_acc = Accuracy(task="multiclass", num_classes=1000)

# pyre-fixme[14]
def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down
8 changes: 4 additions & 4 deletions torchx/examples/apps/lightning/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,18 @@
import time
from typing import Dict

from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.profiler.base import BaseProfiler
from pytorch_lightning.loggers.logger import Logger
from pytorch_lightning.profilers import Profiler


class SimpleLoggingProfiler(BaseProfiler):
class SimpleLoggingProfiler(Profiler):
"""
This profiler records the duration of actions (in seconds) and reports the
mean duration of each action to the specified logger. Reported metrics are
in the format `duration_<event>`.
"""

def __init__(self, logger: LightningLoggerBase) -> None:
def __init__(self, logger: Logger) -> None:
super().__init__()

self.current_actions: Dict[str, float] = {}
Expand Down