Skip to content

Commit

Permalink
Update tutorials code in NeuronSDK 2.21 release (#41)
Browse files Browse the repository at this point in the history
1. Fix typos
2. Add tutorial code for Single program, multiple data tensor addition using multiple Neuron Cores
  • Loading branch information
aws-qieqingy authored Jan 7, 2025
1 parent 9919484 commit c72a66b
Show file tree
Hide file tree
Showing 7 changed files with 293 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# NKI_EXAMPLE_31_BEGIN
@nki.jit
def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, use_causal_mask=False,
mixed_percision=True):
mixed_precision=True):
"""
Fused self attention kernel for small head dimension Stable Diffusion workload,
simplified for this tutorial.
Expand All @@ -38,14 +38,14 @@ def fused_self_attn_for_SD_small_head_size(q_ref, k_ref, v_ref, use_causal_mask=
IO tensor dtypes:
- This kernel assumes all IO tensors have the same dtype
- If mixed_percision is True, then all Tensor Engine operation will be performed in
- If mixed_precision is True, then all Tensor Engine operation will be performed in
bfloat16 and accumulation will be performed in float32. Otherwise the intermediates
will be in the same type as the inputs.
"""
# Use q_ref dtype as the intermediate tensor dtype
# Assume all IO tensors have the same dtype
kernel_dtype = q_ref.dtype
pe_in_dt = nl.bfloat16 if mixed_percision else np.float32
pe_in_dt = nl.bfloat16 if mixed_precision else np.float32
assert q_ref.dtype == k_ref.dtype == v_ref.dtype

# Shape checking
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
Copyright (C) 2024, Amazon.com. All Rights Reserved
JAX implementation for SPMD tensor addition with multiple Neuron cores NKI tutorial.
"""
# NKI_EXAMPLE_50_BEGIN
import jax
import jax.numpy as jnp
# NKI_EXAMPLE_50_END

from spmd_multiple_nc_tensor_addition_nki_kernels import nki_tensor_add_nc2

# NKI_EXAMPLE_50_BEGIN
if __name__ == "__main__":

seed_a, seed_b = jax.random.split(jax.random.PRNGKey(42))
a = jax.random.uniform(seed_a, (512, 2048), dtype=jnp.bfloat16)
b = jax.random.uniform(seed_b, (512, 2048), dtype=jnp.bfloat16)

output_nki = nki_tensor_add_nc2(a, b)
print(f"output_nki={output_nki}")

output_jax = a + b
print(f"output_jax={output_jax}")

allclose = jnp.allclose(output_jax, output_nki, atol=1e-4, rtol=1e-2)
if allclose:
print("NKI and JAX match")
else:
print("NKI and JAX differ")

assert allclose
# NKI_EXAMPLE_50_END
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
Copyright (C) 2024, Amazon.com. All Rights Reserved
NKI implementation for SPMD tensor addition with multiple Neuron cores NKI tutorial.
"""
import numpy as np
import neuronxcc.nki as nki
import neuronxcc.nki.language as nl
from spmd_tensor_addition_nki_kernels import nki_tensor_add_kernel_


# NKI_EXAMPLE_48_BEGIN
def nki_tensor_add_nc2(a_input, b_input):
"""NKI kernel caller to compute element-wise addition of two input tensors using multiple Neuron cores.
This kernel caller lifts tile-size restriction, by applying the kernel on tiles of the inputs/outputs.
a_input and b_input are sharded across Neuron cores, directly utilizing Trn2 architecture capabilities
Args:
a_input: a first input tensor, of shape [N*128, M*512]
b_input: a second input tensor, of shape [N*128, M*512]
Returns:
a tensor of shape [N*128, M*512], the result of a_input + b_input
"""

# The SPMD launch grid denotes the number of kernel instances.
# In this case, we use a 2D grid where the size of each invocation is 128x512
# Since we're sharding across neuron cores on the 1st dimension we want to do our slicing at
# 128 per core * 2 cores = 256
grid_x = a_input.shape[0] // (128 * 2)
grid_y = a_input.shape[1] // 512

# In addition, we distribute the kernel to physical neuron cores around the first dimension
# of the spmd grid.
# This means:
# Physical NC [0]: kernel[n, m] where n is even
# Physical NC [1]: kernel[n, m] where n is odd
# notice, by specifying this information in the SPMD grid, we can use multiple neuron cores
# without updating the original `nki_tensor_add_kernel_` kernel.
return nki_tensor_add_kernel_[nl.spmd_dim(grid_x, nl.nc(2)), grid_y](a_input, b_input)
# NKI_EXAMPLE_48_END

if __name__ == "__main__":
a = np.random.rand(512, 2048).astype(np.float16)
b = np.random.rand(512, 2048).astype(np.float16)

output_nki = nki_tensor_add_nc2(a, b)
print(f"output_nki={output_nki}")

output_np = a + b
print(f"output_np={output_np}")

allclose = np.allclose(output_np, output_nki, atol=1e-4, rtol=1e-2)
if allclose:
print("NKI and NumPy match")
else:
print("NKI and NumPy differ")

assert allclose
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""
Copyright (C) 2024, Amazon.com. All Rights Reserved
PyTorch implementation for SPMD tensor addition with multiple Neuron cores NKI tutorial.
"""
# NKI_EXAMPLE_49_BEGIN
import torch
from torch_xla.core import xla_model as xm
# NKI_EXAMPLE_49_END

from spmd_multiple_nc_tensor_addition_nki_kernels import nki_tensor_add_nc2


# NKI_EXAMPLE_49_BEGIN
if __name__ == "__main__":
device = xm.xla_device()

a = torch.rand((512, 2048), dtype=torch.bfloat16).to(device=device)
b = torch.rand((512, 2048), dtype=torch.bfloat16).to(device=device)

output_nki = nki_tensor_add_nc2(a, b)
print(f"output_nki={output_nki}")

output_torch = a + b
print(f"output_torch={output_torch}")

allclose = torch.allclose(output_torch, output_nki, atol=1e-4, rtol=1e-2)
if allclose:
print("NKI and Torch match")
else:
print("NKI and Torch differ")

assert allclose
# NKI_EXAMPLE_49_END
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
Copyright (C) 2024, Amazon.com. All Rights Reserved
JAX implementation for SPMD tensor addition NKI tutorial.
"""
# NKI_EXAMPLE_30_BEGIN
import jax
import jax.numpy as jnp
# NKI_EXAMPLE_30_END

from spmd_tensor_addition_nki_kernels import nki_tensor_add

# NKI_EXAMPLE_30_BEGIN
if __name__ == "__main__":

seed_a, seed_b = jax.random.split(jax.random.PRNGKey(42))
a = jax.random.uniform(seed_a, (256, 1024), dtype=jnp.bfloat16)
b = jax.random.uniform(seed_b, (256, 1024), dtype=jnp.bfloat16)

output_nki = nki_tensor_add(a, b)
print(f"output_nki={output_nki}")

output_jax = a + b
print(f"output_jax={output_jax}")

allclose = jnp.allclose(output_jax, output_nki, atol=1e-4, rtol=1e-2)
if allclose:
print("NKI and JAX match")
else:
print("NKI and JAX differ")

assert allclose
# NKI_EXAMPLE_30_END
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""
Copyright (C) 2024, Amazon.com. All Rights Reserved
NKI implementation for SPMD tensor addition NKI tutorial.
"""
import numpy as np
# NKI_EXAMPLE_27_BEGIN
import neuronxcc.nki as nki
import neuronxcc.nki.language as nl


@nki.jit
def nki_tensor_add_kernel_(a_input, b_input):
"""NKI kernel to compute element-wise addition of two input tensors
This kernel assumes strict input/output sizes can be uniformly tiled to [128,512]
Args:
a_input: a first input tensor
b_input: a second input tensor
Returns:
c_output: an output tensor
"""
# Create output tensor shared between all SPMD instances as result tensor
c_output = nl.ndarray(a_input.shape, dtype=a_input.dtype, buffer=nl.shared_hbm)

# Calculate tile offsets based on current 'program'
offset_i_x = nl.program_id(0) * 128
offset_i_y = nl.program_id(1) * 512

# Generate tensor indices to index tensors a and b
ix = offset_i_x + nl.arange(128)[:, None]
iy = offset_i_y + nl.arange(512)[None, :]

# Load input data from device memory (HBM) to on-chip memory (SBUF)
# We refer to an indexed portion of a tensor as an intermediate tensor
a_tile = nl.load(a_input[ix, iy])
b_tile = nl.load(b_input[ix, iy])

# compute a + b
c_tile = a_tile + b_tile

# store the addition results back to device memory (c_output)
nl.store(c_output[ix, iy], value=c_tile)

# Transfer the ownership of `c_output` to the caller
return c_output
# NKI_EXAMPLE_27_END


# NKI_EXAMPLE_28_BEGIN
def nki_tensor_add(a_input, b_input):
"""NKI kernel caller to compute element-wise addition of two input tensors
This kernel caller lifts tile-size restriction, by applying the kernel on tiles of the inputs/outputs
Args:
a_input: a first input tensor, of shape [N*128, M*512]
b_input: a second input tensor, of shape [N*128, M*512]
Returns:
a tensor of shape [N*128, M*512], the result of a_input + b_input
"""

# The SPMD launch grid denotes the number of kernel instances.
# In this case, we use a 2D grid where the size of each invocation is 128x512
grid_x = a_input.shape[0] // 128
grid_y = a_input.shape[1] // 512

return nki_tensor_add_kernel_[grid_x, grid_y](a_input, b_input)
# NKI_EXAMPLE_28_END

if __name__ == "__main__":
a = np.random.rand(256, 1024).astype(np.float16)
b = np.random.rand(256, 1024).astype(np.float16)

output_nki = nki_tensor_add(a, b)
print(f"output_nki={output_nki}")

output_np = a + b
print(f"output_np={output_np}")

allclose = np.allclose(output_np, output_nki, atol=1e-4, rtol=1e-2)
if allclose:
print("NKI and NumPy match")
else:
print("NKI and NumPy differ")

assert allclose
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""
Copyright (C) 2024, Amazon.com. All Rights Reserved
PyTorch implementation for SPMD tensor addition NKI tutorial.
"""
# NKI_EXAMPLE_29_BEGIN
import torch
from torch_xla.core import xla_model as xm
# NKI_EXAMPLE_29_END

from spmd_tensor_addition_nki_kernels import nki_tensor_add


# NKI_EXAMPLE_29_BEGIN
if __name__ == "__main__":
device = xm.xla_device()

a = torch.rand((256, 1024), dtype=torch.bfloat16).to(device=device)
b = torch.rand((256, 1024), dtype=torch.bfloat16).to(device=device)

output_nki = nki_tensor_add(a, b)
print(f"output_nki={output_nki}")

output_torch = a + b
print(f"output_torch={output_torch}")

allclose = torch.allclose(output_torch, output_nki, atol=1e-4, rtol=1e-2)
if allclose:
print("NKI and Torch match")
else:
print("NKI and Torch differ")

assert allclose
# NKI_EXAMPLE_29_END

0 comments on commit c72a66b

Please sign in to comment.