Skip to content

Commit

Permalink
Add support for custom tpu topology flag
Browse files Browse the repository at this point in the history
  • Loading branch information
Obliviour committed Jan 17, 2024
1 parent 4909767 commit 3cf2991
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 3 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,10 @@ workload.

# More advanced facts:

* Workload create accepts a --env-file flag to allow specifying the container's
* `xpk cluster create` accepts a `--tpu-topology` flag to allow for custom tpu topologies.
See https://cloud.google.com/kubernetes-engine/docs/concepts/tpus#topology for more details.

* `xpk workload create` accepts a `--env-file`` flag to allow specifying the container's
environment from a file. Usage is the same as Docker's
[--env-file flag](https://docs.docker.com/engine/reference/commandline/run/#env)

Expand All @@ -383,7 +386,7 @@ environment from a file. Usage is the same as Docker's
MY_ENV_VAR=hello
```

* Workload create accepts a --debug-dump-gcs flag which is a path to GCS bucket.
* `xpk workload create` accepts a --debug-dump-gcs flag which is a path to GCS bucket.
Passing this flag sets the XLA_FLAGS='--xla_dump_to=/tmp/xla_dump/' and uploads
hlo dumps to the specified GCS bucket for each worker.

Expand Down
57 changes: 56 additions & 1 deletion xpk.py
Original file line number Diff line number Diff line change
Expand Up @@ -1318,7 +1318,11 @@ def run_gke_node_pool_create_command(args, system) -> int:
f' {args.custom_tpu_nodepool_arguments}'
)
if system.accelerator_type == AcceleratorType['TPU']:
command += (f' --tpu-topology={system.topology}')
tpu_topology, return_code = get_tpu_topology(system, args)
if return_code > 0:
xpk_print('Parsing tpu_topology failed!')
return return_code
command += (f' --tpu-topology={tpu_topology}')
elif system.accelerator_type == AcceleratorType['GPU']:
command += f' --accelerator type={system.gke_accelerator},count={str(system.chips_per_vm)}'
task = f'NodepoolCreate-{node_pool_name}'
Expand Down Expand Up @@ -1556,6 +1560,35 @@ def default_subcommand_function(_args) -> int: # args is unused, so pylint: dis
return 0


def get_tpu_topology(system: SystemCharacteristics, args) -> tuple[str, int]:
"""Function around parsing tpu-topology argument and obtaining the tpu-topology.
Args:
system: SystemCharacteristics for the device type.
args: User provided arguments for running commands.
Returns:
A tuple of:
str: topology type to use
int: 0 if successful and 1 otherwise.
"""
if system.accelerator_type != AcceleratorType['TPU']:
xpk_print(
'tpu_topology argument is only supported when the AcceleratorType is TPU.'
f' The AcceleratorType you are using is: {system.accelerator_type}'
)
return None, 1
tpu_topology = system.topology
if args.tpu_topology is not None:
tpu_topology = args.tpu_topology
xpk_print(
f'Using custom tpu topology of {args.tpu_topology} for {system.device_type}'
' in node pool creation.'
)

return tpu_topology, 0


def cluster_create(args) -> int:
"""Function around cluster creation.
Expand Down Expand Up @@ -2518,6 +2551,18 @@ def directory_path_type(value):
return value


def tpu_topology_type(value, pat=re.compile(r'^[\d]+(x[\d.*]+){1,2}$')):
match = pat.fullmatch(value)
if not match:
raise argparse.ArgumentTypeError(
f'Custom TPU Topology must match the pattern `{pat.pattern}` such as 1x2x3'
f' or 10x10. TPU Topology set through `--tpu-topology` is currently {value}.'
' See https://cloud.google.com/kubernetes-engine/docs/concepts/tpus#topology'
' for more details.'
)
return value


#### "cluster" command parser. ####
cluster_parser = xpk_subcommands.add_parser(
'cluster',
Expand Down Expand Up @@ -2604,6 +2649,16 @@ def directory_path_type(value):
)

### Optional Arguments
cluster_create_optional_arguments.add_argument(
'--tpu-topology',
type=tpu_topology_type,
default=None,
help=(
'The slice topology to create the TPU slice with. This only supports TPUs.'
'By default, tpu node pool creation will use the tpu-topology defined in'
' the SystemCharacteristics within xpk code.'
)
)
cluster_create_optional_arguments.add_argument(
'--host-maintenance-interval',
type=str,
Expand Down

0 comments on commit 3cf2991

Please sign in to comment.