-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update tutorials code in NeuronSDK 2.21 release (#41)
1. Fix typos 2. Add tutorial code for Single program, multiple data tensor addition using multiple Neuron Cores
- Loading branch information
1 parent
9919484
commit c72a66b
Showing
7 changed files
with
293 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
34 changes: 34 additions & 0 deletions
34
src/nki_samples/tutorials/tensor_addition/spmd_multiple_nc_tensor_addition_jax.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
""" | ||
Copyright (C) 2024, Amazon.com. All Rights Reserved | ||
JAX implementation for SPMD tensor addition with multiple Neuron cores NKI tutorial. | ||
""" | ||
# NKI_EXAMPLE_50_BEGIN | ||
import jax | ||
import jax.numpy as jnp | ||
# NKI_EXAMPLE_50_END | ||
|
||
from spmd_multiple_nc_tensor_addition_nki_kernels import nki_tensor_add_nc2 | ||
|
||
# NKI_EXAMPLE_50_BEGIN | ||
if __name__ == "__main__": | ||
|
||
seed_a, seed_b = jax.random.split(jax.random.PRNGKey(42)) | ||
a = jax.random.uniform(seed_a, (512, 2048), dtype=jnp.bfloat16) | ||
b = jax.random.uniform(seed_b, (512, 2048), dtype=jnp.bfloat16) | ||
|
||
output_nki = nki_tensor_add_nc2(a, b) | ||
print(f"output_nki={output_nki}") | ||
|
||
output_jax = a + b | ||
print(f"output_jax={output_jax}") | ||
|
||
allclose = jnp.allclose(output_jax, output_nki, atol=1e-4, rtol=1e-2) | ||
if allclose: | ||
print("NKI and JAX match") | ||
else: | ||
print("NKI and JAX differ") | ||
|
||
assert allclose | ||
# NKI_EXAMPLE_50_END |
61 changes: 61 additions & 0 deletions
61
src/nki_samples/tutorials/tensor_addition/spmd_multiple_nc_tensor_addition_nki_kernels.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
""" | ||
Copyright (C) 2024, Amazon.com. All Rights Reserved | ||
NKI implementation for SPMD tensor addition with multiple Neuron cores NKI tutorial. | ||
""" | ||
import numpy as np | ||
import neuronxcc.nki as nki | ||
import neuronxcc.nki.language as nl | ||
from spmd_tensor_addition_nki_kernels import nki_tensor_add_kernel_ | ||
|
||
|
||
# NKI_EXAMPLE_48_BEGIN | ||
def nki_tensor_add_nc2(a_input, b_input): | ||
"""NKI kernel caller to compute element-wise addition of two input tensors using multiple Neuron cores. | ||
This kernel caller lifts tile-size restriction, by applying the kernel on tiles of the inputs/outputs. | ||
a_input and b_input are sharded across Neuron cores, directly utilizing Trn2 architecture capabilities | ||
Args: | ||
a_input: a first input tensor, of shape [N*128, M*512] | ||
b_input: a second input tensor, of shape [N*128, M*512] | ||
Returns: | ||
a tensor of shape [N*128, M*512], the result of a_input + b_input | ||
""" | ||
|
||
# The SPMD launch grid denotes the number of kernel instances. | ||
# In this case, we use a 2D grid where the size of each invocation is 128x512 | ||
# Since we're sharding across neuron cores on the 1st dimension we want to do our slicing at | ||
# 128 per core * 2 cores = 256 | ||
grid_x = a_input.shape[0] // (128 * 2) | ||
grid_y = a_input.shape[1] // 512 | ||
|
||
# In addition, we distribute the kernel to physical neuron cores around the first dimension | ||
# of the spmd grid. | ||
# This means: | ||
# Physical NC [0]: kernel[n, m] where n is even | ||
# Physical NC [1]: kernel[n, m] where n is odd | ||
# notice, by specifying this information in the SPMD grid, we can use multiple neuron cores | ||
# without updating the original `nki_tensor_add_kernel_` kernel. | ||
return nki_tensor_add_kernel_[nl.spmd_dim(grid_x, nl.nc(2)), grid_y](a_input, b_input) | ||
# NKI_EXAMPLE_48_END | ||
|
||
if __name__ == "__main__": | ||
a = np.random.rand(512, 2048).astype(np.float16) | ||
b = np.random.rand(512, 2048).astype(np.float16) | ||
|
||
output_nki = nki_tensor_add_nc2(a, b) | ||
print(f"output_nki={output_nki}") | ||
|
||
output_np = a + b | ||
print(f"output_np={output_np}") | ||
|
||
allclose = np.allclose(output_np, output_nki, atol=1e-4, rtol=1e-2) | ||
if allclose: | ||
print("NKI and NumPy match") | ||
else: | ||
print("NKI and NumPy differ") | ||
|
||
assert allclose |
35 changes: 35 additions & 0 deletions
35
src/nki_samples/tutorials/tensor_addition/spmd_multiple_nc_tensor_addition_torch.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
""" | ||
Copyright (C) 2024, Amazon.com. All Rights Reserved | ||
PyTorch implementation for SPMD tensor addition with multiple Neuron cores NKI tutorial. | ||
""" | ||
# NKI_EXAMPLE_49_BEGIN | ||
import torch | ||
from torch_xla.core import xla_model as xm | ||
# NKI_EXAMPLE_49_END | ||
|
||
from spmd_multiple_nc_tensor_addition_nki_kernels import nki_tensor_add_nc2 | ||
|
||
|
||
# NKI_EXAMPLE_49_BEGIN | ||
if __name__ == "__main__": | ||
device = xm.xla_device() | ||
|
||
a = torch.rand((512, 2048), dtype=torch.bfloat16).to(device=device) | ||
b = torch.rand((512, 2048), dtype=torch.bfloat16).to(device=device) | ||
|
||
output_nki = nki_tensor_add_nc2(a, b) | ||
print(f"output_nki={output_nki}") | ||
|
||
output_torch = a + b | ||
print(f"output_torch={output_torch}") | ||
|
||
allclose = torch.allclose(output_torch, output_nki, atol=1e-4, rtol=1e-2) | ||
if allclose: | ||
print("NKI and Torch match") | ||
else: | ||
print("NKI and Torch differ") | ||
|
||
assert allclose | ||
# NKI_EXAMPLE_49_END |
34 changes: 34 additions & 0 deletions
34
src/nki_samples/tutorials/tensor_addition/spmd_tensor_addition_jax.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
""" | ||
Copyright (C) 2024, Amazon.com. All Rights Reserved | ||
JAX implementation for SPMD tensor addition NKI tutorial. | ||
""" | ||
# NKI_EXAMPLE_30_BEGIN | ||
import jax | ||
import jax.numpy as jnp | ||
# NKI_EXAMPLE_30_END | ||
|
||
from spmd_tensor_addition_nki_kernels import nki_tensor_add | ||
|
||
# NKI_EXAMPLE_30_BEGIN | ||
if __name__ == "__main__": | ||
|
||
seed_a, seed_b = jax.random.split(jax.random.PRNGKey(42)) | ||
a = jax.random.uniform(seed_a, (256, 1024), dtype=jnp.bfloat16) | ||
b = jax.random.uniform(seed_b, (256, 1024), dtype=jnp.bfloat16) | ||
|
||
output_nki = nki_tensor_add(a, b) | ||
print(f"output_nki={output_nki}") | ||
|
||
output_jax = a + b | ||
print(f"output_jax={output_jax}") | ||
|
||
allclose = jnp.allclose(output_jax, output_nki, atol=1e-4, rtol=1e-2) | ||
if allclose: | ||
print("NKI and JAX match") | ||
else: | ||
print("NKI and JAX differ") | ||
|
||
assert allclose | ||
# NKI_EXAMPLE_30_END |
91 changes: 91 additions & 0 deletions
91
src/nki_samples/tutorials/tensor_addition/spmd_tensor_addition_nki_kernels.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
""" | ||
Copyright (C) 2024, Amazon.com. All Rights Reserved | ||
NKI implementation for SPMD tensor addition NKI tutorial. | ||
""" | ||
import numpy as np | ||
# NKI_EXAMPLE_27_BEGIN | ||
import neuronxcc.nki as nki | ||
import neuronxcc.nki.language as nl | ||
|
||
|
||
@nki.jit | ||
def nki_tensor_add_kernel_(a_input, b_input): | ||
"""NKI kernel to compute element-wise addition of two input tensors | ||
This kernel assumes strict input/output sizes can be uniformly tiled to [128,512] | ||
Args: | ||
a_input: a first input tensor | ||
b_input: a second input tensor | ||
Returns: | ||
c_output: an output tensor | ||
""" | ||
# Create output tensor shared between all SPMD instances as result tensor | ||
c_output = nl.ndarray(a_input.shape, dtype=a_input.dtype, buffer=nl.shared_hbm) | ||
|
||
# Calculate tile offsets based on current 'program' | ||
offset_i_x = nl.program_id(0) * 128 | ||
offset_i_y = nl.program_id(1) * 512 | ||
|
||
# Generate tensor indices to index tensors a and b | ||
ix = offset_i_x + nl.arange(128)[:, None] | ||
iy = offset_i_y + nl.arange(512)[None, :] | ||
|
||
# Load input data from device memory (HBM) to on-chip memory (SBUF) | ||
# We refer to an indexed portion of a tensor as an intermediate tensor | ||
a_tile = nl.load(a_input[ix, iy]) | ||
b_tile = nl.load(b_input[ix, iy]) | ||
|
||
# compute a + b | ||
c_tile = a_tile + b_tile | ||
|
||
# store the addition results back to device memory (c_output) | ||
nl.store(c_output[ix, iy], value=c_tile) | ||
|
||
# Transfer the ownership of `c_output` to the caller | ||
return c_output | ||
# NKI_EXAMPLE_27_END | ||
|
||
|
||
# NKI_EXAMPLE_28_BEGIN | ||
def nki_tensor_add(a_input, b_input): | ||
"""NKI kernel caller to compute element-wise addition of two input tensors | ||
This kernel caller lifts tile-size restriction, by applying the kernel on tiles of the inputs/outputs | ||
Args: | ||
a_input: a first input tensor, of shape [N*128, M*512] | ||
b_input: a second input tensor, of shape [N*128, M*512] | ||
Returns: | ||
a tensor of shape [N*128, M*512], the result of a_input + b_input | ||
""" | ||
|
||
# The SPMD launch grid denotes the number of kernel instances. | ||
# In this case, we use a 2D grid where the size of each invocation is 128x512 | ||
grid_x = a_input.shape[0] // 128 | ||
grid_y = a_input.shape[1] // 512 | ||
|
||
return nki_tensor_add_kernel_[grid_x, grid_y](a_input, b_input) | ||
# NKI_EXAMPLE_28_END | ||
|
||
if __name__ == "__main__": | ||
a = np.random.rand(256, 1024).astype(np.float16) | ||
b = np.random.rand(256, 1024).astype(np.float16) | ||
|
||
output_nki = nki_tensor_add(a, b) | ||
print(f"output_nki={output_nki}") | ||
|
||
output_np = a + b | ||
print(f"output_np={output_np}") | ||
|
||
allclose = np.allclose(output_np, output_nki, atol=1e-4, rtol=1e-2) | ||
if allclose: | ||
print("NKI and NumPy match") | ||
else: | ||
print("NKI and NumPy differ") | ||
|
||
assert allclose |
35 changes: 35 additions & 0 deletions
35
src/nki_samples/tutorials/tensor_addition/spmd_tensor_addition_torch.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
""" | ||
Copyright (C) 2024, Amazon.com. All Rights Reserved | ||
PyTorch implementation for SPMD tensor addition NKI tutorial. | ||
""" | ||
# NKI_EXAMPLE_29_BEGIN | ||
import torch | ||
from torch_xla.core import xla_model as xm | ||
# NKI_EXAMPLE_29_END | ||
|
||
from spmd_tensor_addition_nki_kernels import nki_tensor_add | ||
|
||
|
||
# NKI_EXAMPLE_29_BEGIN | ||
if __name__ == "__main__": | ||
device = xm.xla_device() | ||
|
||
a = torch.rand((256, 1024), dtype=torch.bfloat16).to(device=device) | ||
b = torch.rand((256, 1024), dtype=torch.bfloat16).to(device=device) | ||
|
||
output_nki = nki_tensor_add(a, b) | ||
print(f"output_nki={output_nki}") | ||
|
||
output_torch = a + b | ||
print(f"output_torch={output_torch}") | ||
|
||
allclose = torch.allclose(output_torch, output_nki, atol=1e-4, rtol=1e-2) | ||
if allclose: | ||
print("NKI and Torch match") | ||
else: | ||
print("NKI and Torch differ") | ||
|
||
assert allclose | ||
# NKI_EXAMPLE_29_END |