Skip to content

Commit

Permalink
[Inductor cutlass backend] Enabled nonzero workspace and Cutlass StreamK
Browse files Browse the repository at this point in the history
Enable nonzero workspace and Cutlass StreamK for Inductor Cutlass GEMM ops.

This is a simpler rewrite of my original version of pytorch#119005 using peterbell10 's workspace allocation mechanism from pytorch#117992

Test Plan:
 - Additional unit test in test_cutlass_backend.py which specifically tests StreamK GEMM with workspace requirement
 - CI

ghstack-source-id: 24d06299f90a1e31af6b097316b76689e4944df2
Pull Request resolved: pytorch#125406
  • Loading branch information
kadeng committed May 2, 2024
1 parent a86f97f commit d89a24b
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 28 deletions.
57 changes: 55 additions & 2 deletions test/inductor/test_cutlass_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ def test_max_autotune_cutlass_backend_regular_mm(
def mm(a, b):
return a @ b

a = torch.randn(100, 10).cuda().half()
b = torch.randn(10, 100).cuda().half()
a = torch.randn(128, 16).cuda().half()
b = torch.randn(16, 128).cuda().half()

with config.patch(
{
Expand All @@ -185,6 +185,59 @@ def mm(a, b):
Y = mm(a, b)
torch.testing.assert_close(Y_compiled, Y)

@unittest.skipIf(not SM90OrLater, "need sm_90")
@unittest.skipIf(config.is_fbcode(), "fbcode requires different CUTLASS path setup")
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
def test_max_autotune_cutlass_backend_regular_mm_streamk(
self, dynamic: bool = False, max_autotune_gemm_backends: str = "CUTLASS"
):
"""
Make sure autotuning mm in sub processes work without crashes.
"""

if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
return

torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False

def mm(a, b):
return a @ b

a = torch.randn(128, 16).cuda().half()
b = torch.randn(16, 128).cuda().half()

with config.patch(
{
"max_autotune": True,
"autotune_in_subproc": True,
"max_autotune_gemm_backends": max_autotune_gemm_backends,
"cuda.cutlass_dir": _CUTLASS_DIR,
"cuda.cutlass_max_profiling_configs": 2,
"cuda.cutlass_op_allowlist_regex": "stream_k", # only stream-k GEMM Kernels
}
):
for M, K, N in (
(128, 16, 128),
(1024, 256, 1024),
(
16384,
1024,
16384,
),
(
16384,
1408,
16384,
),
):
a = torch.randn(M, K).cuda().half()
b = torch.randn(K, N).cuda().half()
Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b)
Y = mm(a, b)
# we need relaxed numerical limits due to the sheer size of the
# matmuls involved. Many small addition differences add up.
torch.testing.assert_close(Y_compiled, Y, atol=0.01, rtol=0.01)

def _test_max_autotune_cutlass_backend_epilogue_fusion(
self,
dynamic: bool = False,
Expand Down
76 changes: 56 additions & 20 deletions torch/_inductor/autotune_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ class Pong:
pass


class NonzeroWorkspaceNotSupportedError(Exception):
pass


@contextlib.contextmanager
def set_cuda_visible_device(device: Optional[int]):
"""
Expand Down Expand Up @@ -505,8 +509,12 @@ def benchmark(
if debug:
create_tensor_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
start_ts = time.time()

fn = self.make_run_fn(*input_tensors, output_tensor=output_tensor)
try:
fn = self.make_run_fn(*input_tensors, output_tensor=output_tensor)
except NonzeroWorkspaceNotSupportedError:
# Skipping all ops with nonzero workspace requirements
log.info("Skipping op due to nonzero workspace requirement")
return float("inf")

if debug:
load_elapse = time.time() - start_ts # type: ignore[possibly-undefined]
Expand Down Expand Up @@ -652,6 +660,7 @@ def __init__(
self.workspace_size: int = 0
self.workspace: Optional[torch.Tensor] = None
self.DLL: Optional[DLLWrapper] = None
self._workspace_size_updated = False
self.hash_key: str = ""
self.source_file: str = ""
self.hash_key, self.source_file = CUDACodeCache.write(self.source_code, "so")
Expand All @@ -660,15 +669,14 @@ def precompile(self):
# Prepopulate CUDACodeCache
# may happen in separate Threadpool
log.debug("Precompiling %s", self)
CUDACodeCache.load(self.source_code, "so")
CUDACodeCache.compile(self.source_code, "so")
log.debug("Done precompiling %s", self)

def make_run_fn(
self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
) -> Callable[[], None]:
self.DLL, self.hash_key, self.source_file = CUDACodeCache.load(
self.source_code, "so"
)
self.ensure_dll_loaded()
self.update_workspace_size()
args = [
c_void_p(tensor.data_ptr())
for tensor in list(input_tensors) + [output_tensor]
Expand All @@ -682,9 +690,35 @@ def make_run_fn(
args,
self.extra_args,
)
stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream)
run_method = getattr(self.DLL, self.kernel_name)
workspace_ptr = c_void_p(0)
if self.workspace_size > 0:
self.workspace = torch.zeros(
(self.workspace_size + 7) // 8,
dtype=torch.float64,
device=output_tensor.device,
)
workspace_ptr = c_void_p(self.workspace.data_ptr())

# Generate partial function.
return functools.partial(
run_method,
*args,
*self.extra_args,
None, # null workspace size ptr
workspace_ptr, # set workspace ptr,
stream_ptr,
)

def update_workspace_size(self) -> None:
if self._workspace_size_updated:
return
self.ensure_dll_loaded()
args = [c_void_p(None) for _ in range(len(self.input_tensor_meta) + 1)]
stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream)

run_method = getattr(self.DLL, self.kernel_name)
# Retrieve workspace_size and initialize workspace.
c_workspace_size = c_size_t()
run_method(
Expand All @@ -696,23 +730,25 @@ def make_run_fn(
None, # null workspace ptr
stream_ptr,
)
torch.cuda.synchronize() # shake out any CUDA errors
self.workspace_size = c_workspace_size.value
# TODO: Support non-zero workspace_size.
assert self.workspace_size == 0, (
"Things need to be fixed to support non-zero workspace_size: "
"1) max autotune cache needs to store workspace size; "
"2) memory allocation needs to allocate / deallocate workspace correctly; "
log.debug(
"update_workspace_size called: new workspace size=%d, self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", # noqa: B950
self.workspace_size,
self.kernel_name,
self.source_file,
self.hash_key,
self.DLL,
args,
self.extra_args,
)
self._workspace_size_updated = True

# Generate partial function.
return functools.partial(
run_method,
*args,
*self.extra_args,
None, # null workspace size ptr
None, # set workspace ptr, TODO: update it to a real ptr if workspace_size > 0
stream_ptr,
)
def ensure_dll_loaded(self):
if self.DLL is None:
self.DLL, self.hash_key, self.source_file = CUDACodeCache.load(
self.source_code, "so"
)

def cleanup_run_fn(self) -> None:
if self.DLL is not None:
Expand Down
8 changes: 7 additions & 1 deletion torch/_inductor/codegen/cuda/cuda_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,10 @@ def call_kernel(
call_args.append("None")

if node.get_workspace_size() > 0:
call_args.append(f"c_void_p({node.get_name()}_workspace.data_ptr())")
wrapper.generate_workspace_allocation(
node.get_workspace_size(), V.graph.scheduler.current_device, False
)
call_args.append("c_void_p(workspace.data_ptr())")
else:
call_args.append("None")

Expand All @@ -180,6 +183,8 @@ def call_kernel(
cuda=True,
triton=False,
)
if node.get_workspace_size() > 0:
wrapper.writeline(wrapper.make_free_by_names(["workspace"]))

def dtype(self, node: IRNode) -> Optional[str]:
"""
Expand Down Expand Up @@ -379,6 +384,7 @@ def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType
return {"backend": "CUDA", "op_type": "unknown"}

def output_node(self) -> TensorBox:
self.bmreq.update_workspace_size()
return TensorBox.create(
CUDATemplateBuffer(
layout=self.layout,
Expand Down
4 changes: 0 additions & 4 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1493,10 +1493,6 @@ def codegen_deferred_allocation(self, name, layout):
)

def codegen_allocation(self, buffer):
assert (
buffer.get_workspace_size() == 0
), "Only support zero workspace size for now!"

name = buffer.get_name()

if name in V.graph.removed_buffers or name in self.allocated:
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3707,7 +3707,7 @@ def __init__(
self.template = template

def get_workspace_size(self):
return self.workspace_size if self.workspace_size is not None else 0
return self.workspace_size


@dataclasses.dataclass
Expand Down

0 comments on commit d89a24b

Please sign in to comment.