Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce nvFuser's Hot Path Host Latency #3507

Open
kevinstephano opened this issue Dec 2, 2024 · 4 comments
Open

Reduce nvFuser's Hot Path Host Latency #3507

kevinstephano opened this issue Dec 2, 2024 · 4 comments

Comments

@kevinstephano
Copy link
Collaborator

On a DGXH100, nvFuser's current hot path latency is around 20 to 30 us which is too expensive for inference sized kernels that can be on the order of <10us. nvFuser's cache hierachy may need to a short cut to return a list of kernel and launch parameters instead of recalculating various things. This complete picture of what nvFuser is recalculating needs to be enumerated.

There are 3 general latency cases that nvFuser cares about and we are talking about the 3rd one:

  1. Cold startup time with NVRTC kernel compilation
  2. A dynamic shape requires determining if a suitable kernel already exists and calculating new launch parameters
  3. ***Hot path caching of kernels and launch parameters in the steady state of execution.

DGX H100-80GB Results:

Execution Type Wall Clock Time (ms) CPU Overhead (ms) Kernel Time (ms) Kernels Overhead / Kernel (us)
Thunder-nvFuser 1.281 0.984 0.297 32 30.7
Thunder-torch 1.232 0.814 0.413 77 10.5
torch.compile 0.455 0.163 0.292 24 6.8
torch-eager 1.014 0.616 0.398 65 9.4

Repro:

import torch                                                                                                                                                                                                  [76/137]
from transformers.models.llama import LlamaForCausalLM, LlamaConfig
from typing import Callable
from functools import partial, wraps

LLAMA_3_2_1B_CFG = {
    "architectures": ["LlamaForCausalLM"],
    "attention_bias": False,
    "attention_dropout": 0.0,
    "bos_token_id": 128000,
    "eos_token_id": 128001,
    "head_dim": 64,
    "hidden_act": "silu",
    "hidden_size": 2048,
    "initializer_range": 0.02,
    "intermediate_size": 8192,
    "max_position_embeddings": 131072,
    "mlp_bias": False,
    "model_type": "llama",
    "num_attention_heads": 32,
    "num_hidden_layers": 16,
    "num_key_value_heads": 8,
    "pretraining_tp": 1,
    "rms_norm_eps": 1e-05,
    "rope_scaling": {
        "factor": 32.0,
        "high_freq_factor": 4.0,
        "low_freq_factor": 1.0,
        "original_max_position_embeddings": 8192,
        "rope_type": "llama3",
    },
    "rope_theta": 500000.0,
    "tie_word_embeddings": True,
    "torch_dtype": "bfloat16",
    "transformers_version": "4.45.0.dev0",
    "use_cache": True,
    "vocab_size": 128256,
    "_commit_hash": "4e20de362430cd3b72f300e6b0f18e50e7166e08",
}


config = LlamaConfig(**LLAMA_3_2_1B_CFG)
config.num_hidden_layers = 1

with torch.device("cuda"):
    model = LlamaForCausalLM(config).to(torch.bfloat16).requires_grad_(False).eval()

args = dict(
    cache_positions=torch.arange(6, device="cuda"),
    input_ids=torch.tensor([[128000, 791, 1401, 311, 2324, 374]], device="cuda"),
    attention_mask=torch.ones(1, 6, dtype=torch.int64, device="cuda"),
    inputs_embeds=None,
    use_cache=True,
    return_dict=True,
)

def cuda_timer(warmup_iters: int = 2, timing_iters: int = 10):
    def decorator(fn: Callable) -> Callable:
        @wraps(fn)
        def wrapper(*args, **kwargs) -> float:
            for _ in range(warmup_iters):
                fn(*args, **kwargs)

            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)

            torch.cuda.synchronize()
            start.record()
            for _ in range(timing_iters):
                fn(*args, **kwargs)
            end.record()
            torch.cuda.synchronize()

            kernel_time = start.elapsed_time(end) / timing_iters
            return kernel_time
        return wrapper
    return decorator

import thunder
# from thunder.dev_utils.nvtx_profile_transform import NvtxProfileTransform
# jm = thunder.jit(model, transformers=[NvtxProfileTransform()])
jm = thunder.jit(model)

@cuda_timer()
def run_model():
    jm(**args)

kernel_time = run_model()

print(f"Thunder-nvFuser {kernel_time:.03f} ms")

tc = torch.compile(model)

@cuda_timer()
def run_tc_model():
    tc(**args)

tc_kernel_time = run_tc_model()

print(f"torch.compile {tc_kernel_time:.03f} ms")

thunder_eager = thunder.jit(model,
                    executors=["sdpa", "torch"],
                    # transformers=[NvtxProfileTransform()]),
                    ) #, transforms=(CUDAGraphTransform(),))

@cuda_timer()
def run_te_model():
    thunder_eager(**args)

kernel_time = run_te_model()

print(f"Thunder-eager {kernel_time:.03f} ms")

@cuda_timer()
def run_eager_model():
    model(**args)

kernel_time = run_eager_model()

print(f"torch-eager {kernel_time:.03f} ms")
@kevinstephano
Copy link
Collaborator Author

kevinstephano commented Dec 2, 2024

Another example:

import torch
import thunder
from transformers.models.llama import LlamaForCausalLM, LlamaConfig
from functools import partial, wraps
from typing import Callable
from collections import OrderedDict
import inspect

LLAMA_3_2_1B_CFG = {
    "architectures": ["LlamaForCausalLM"],
    "attention_bias": False,
    "attention_dropout": 0.0,
    "bos_token_id": 128000,
    "eos_token_id": 128001,
    "head_dim": 64,
    "hidden_act": "silu",
    "hidden_size": 2048,
    "initializer_range": 0.02,
    "intermediate_size": 8192,
    "max_position_embeddings": 131072,
    "mlp_bias": False,
    "model_type": "llama",
    "num_attention_heads": 32,
    "num_hidden_layers": 16,
    "num_key_value_heads": 8,
    "pretraining_tp": 1,
    "rms_norm_eps": 1e-05,
    "rope_scaling": {
        "factor": 32.0,
        "high_freq_factor": 4.0,
        "low_freq_factor": 1.0,
        "original_max_position_embeddings": 8192,
        "rope_type": "llama3",
    },
    "rope_theta": 500000.0,
    "tie_word_embeddings": True,
    "torch_dtype": "bfloat16",
    "transformers_version": "4.45.0.dev0",
    "use_cache": True,
    "vocab_size": 128256,
    "_commit_hash": "4e20de362430cd3b72f300e6b0f18e50e7166e08",
}

args = dict(
    input_ids=torch.ones(1, 2048, dtype=torch.int64, device="cuda"),
    labels=torch.ones(1, 2048, dtype=torch.int64, device="cuda"),
)

config = LlamaConfig(**LLAMA_3_2_1B_CFG)

with torch.device("cuda"):
    model = LlamaForCausalLM(config).to(torch.bfloat16)

def cuda_timer(warmup_iters: int = 2, timing_iters: int = 10):
    def decorator(fn: Callable) -> Callable:
        @wraps(fn)
        def wrapper(*args, **kwargs) -> float:
            for _ in range(warmup_iters):
                fn(*args, **kwargs)

            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)

            torch.cuda.synchronize()
            start.record()
            for _ in range(timing_iters):
                fn(*args, **kwargs)
            end.record()
            torch.cuda.synchronize()

            kernel_time = start.elapsed_time(end) / timing_iters
            return kernel_time
        return wrapper
    return decorator

@cuda_timer()
def run_model(mymodel, args) :
    res = mymodel(**args)
    res.loss.backward()

def eager(fn):
    return fn

executors = OrderedDict()
executors['torch-eager'] = eager
executors['Thunder-nvFuser'] = thunder.jit
executors['Thunder-torch.compile'] = partial(thunder.jit, executors=['apex', 'cudnn', 'sdpa', 'torchcompile'])
executors['Thunder-torch'] = partial(thunder.jit, executors=['apex', 'cudnn', 'sdpa'])
executors['torch.compile'] = torch.compile

#print(inspect.signature(model.forward, follow_wrapped=True))

for name, func in executors.items():
    exec_model = func(model)
    kernel_time = run_model(exec_model, args)
    print(f"{name} {kernel_time:.03f} ms")

DGX H100-80GB Results:

Execution Type Wall Clock Time (ms)
torch-eager 53.775
Thunder-nvFuser 37.977
Thunder-torch.compile 38.979
Thunder-torch 96.625
torch.compile 37.963

L40 Results:

Execution Type Wall Clock Time (ms)
torch-eager 202.513
Thunder-nvFuser 188.717
Thunder-torch.compile 195.734
Thunder-torch 295.482
torch.compile 186.761

@kevinstephano
Copy link
Collaborator Author

kevinstephano commented Dec 2, 2024

import torch
import thunder
from functools import partial, wraps
from typing import Callable
from collections import OrderedDict

class MyFusion(torch.nn.Module):
    def __init__(self):
        super(MyFusion, self).__init__()

    def forward(self, input1, input2):
        arg1 = input1

        for _ in range(1):
            out = arg1 + input2
            arg1 = out

        return out

inputs = [
        torch.randn(8, 8, device='cuda'),
        torch.randn(8, 8, device='cuda'),
        ]

model = MyFusion().cuda()

def cuda_timer(warmup_iters: int = 2, timing_iters: int = 10):
    def decorator(fn: Callable) -> Callable:
        @wraps(fn)
        def wrapper(*args, **kwargs) -> float:
            for _ in range(warmup_iters):
                fn(*args, **kwargs)

            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)

            torch.cuda.synchronize()
            start.record()
            for _ in range(timing_iters):
                fn(*args, **kwargs)
            end.record()
            torch.cuda.synchronize()

            kernel_time = start.elapsed_time(end) / timing_iters
            return kernel_time
        return wrapper
    return decorator

@cuda_timer()
def run_model(mymodel, args) :
    res = mymodel(*args)

def eager(fn):
    return fn

executors = OrderedDict()
executors['Thunder-nvFuser'] = thunder.jit
executors['Thunder-torch.compile'] = partial(thunder.jit, executors=['apex', 'cudnn', 'sdpa', 'torchcompile'])
executors['Thunder-torch'] = partial(thunder.jit, executors=['apex', 'cudnn', 'sdpa'])
executors['torch.compile'] = torch.compile
executors['torch-eager'] = eager

#print(inspect.signature(model.forward, follow_wrapped=True))

for name, func in executors.items():
    exec_model = func(model)
    kernel_time = run_model(exec_model, inputs)
    print(f"{name} {kernel_time:.03f} ms")

DGX H100-80GB Results:

Execution Type Wall Clock Time (ms)
Thunder-nvFuser 0.072
Thunder-torch.compile 0.075
Thunder-torch 0.042
torch.compile 0.026
torch-eager 0.007

It's unknown whether the Thunder-nvFuser executor and the Thunder-torch.compile executor paths have equivalent latencies despite having similar latency numbers. A guess is that there is ~10us are recomputed launch parameters in nvFuser that are more expensive than torch.compile as going through some caching and having to allocate outputs should also be required by both. It is known that torch.compile uses pool allocations for outputs and that would show up in fusions that have more outputs.
Image

Some possible changes:

  1. Remove recomputeArgs when an ExecutorEntry was previously calculated. (~6 us)
  2. Move Output Allocations to a pool allocation.
  3. Look at cache lookups based on inputs.

Todo:

  • Look closer at segmentation
  • Look at multiple output cases for how expensive the output allocation appears.

Below is the change with recomputeArgs removed. We see a ~6us decrease. The non-API time goes down from ~20 us to ~14 us.
Image

@kevinstephano
Copy link
Collaborator Author

kevinstephano commented Dec 3, 2024

Breaking down a step of Thunder-nvFuser Host Overhead from the above simple example:

Contributor Time (us)
Thunder 45
nvFuser 25
Cuda API 4.5

Image

@kevinstephano
Copy link
Collaborator Author

kevinstephano commented Dec 3, 2024

import torch
import thunder
from transformers.models.llama import LlamaForCausalLM, LlamaConfig
from typing import Callable
from functools import partial, wraps
from collections import OrderedDict

LLAMA_3_2_1B_CFG = {
    "architectures": ["LlamaForCausalLM"],
    "attention_bias": False,
    "attention_dropout": 0.0,
    "bos_token_id": 128000,
    "eos_token_id": 128001,
    "head_dim": 64,
    "hidden_act": "silu",
    "hidden_size": 2048,
    "initializer_range": 0.02,
    "intermediate_size": 8192,
    "max_position_embeddings": 131072,
    "mlp_bias": False,
    "model_type": "llama",
    "num_attention_heads": 32,
    "num_hidden_layers": 16,
    "num_key_value_heads": 8,
    "pretraining_tp": 1,
    "rms_norm_eps": 1e-05,
    "rope_scaling": {
        "factor": 32.0,
        "high_freq_factor": 4.0,
        "low_freq_factor": 1.0,
        "original_max_position_embeddings": 8192,
        "rope_type": "llama3",
    },
    "rope_theta": 500000.0,
    "tie_word_embeddings": True,
    "torch_dtype": "bfloat16",
    "transformers_version": "4.45.0.dev0",
    "use_cache": True,
    "vocab_size": 128256,
    "_commit_hash": "4e20de362430cd3b72f300e6b0f18e50e7166e08",
}


config = LlamaConfig(**LLAMA_3_2_1B_CFG)
config.num_hidden_layers = 1

with torch.device("cuda"):
    model = LlamaForCausalLM(config).to(torch.bfloat16).requires_grad_(False).eval()

args = dict(
    cache_positions=torch.arange(6, device="cuda"),
    input_ids=torch.tensor([[128000, 791, 1401, 311, 2324, 374]], device="cuda"),
    attention_mask=torch.ones(1, 6, dtype=torch.int64, device="cuda"),
    inputs_embeds=None,
    use_cache=True,
    return_dict=True,
)

def cuda_timer(warmup_iters: int = 10, timing_iters: int = 40):
    def decorator(fn: Callable) -> Callable:
        @wraps(fn)
        def wrapper(*args, **kwargs) -> float:
            for _ in range(warmup_iters):
                fn(*args, **kwargs)

            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)

            torch.cuda.synchronize()
            start.record()
            for _ in range(timing_iters):
                fn(*args, **kwargs)
            end.record()
            torch.cuda.synchronize()

            kernel_time = start.elapsed_time(end) / timing_iters
            return kernel_time
        return wrapper
    return decorator

@cuda_timer()
def run_model(mymodel, args) :
    res = mymodel(**args)

def eager(fn):
    return fn

executors = OrderedDict()
executors['Thunder-nvFuser'] = thunder.jit
#executors['Thunder-torch.compile'] = partial(thunder.jit, executors=['apex', 'cudnn', 'sdpa', 'torchcompile'])
executors['Thunder-torch'] = partial(thunder.jit, executors=['apex', 'cudnn', 'sdpa'])
executors['torch.compile'] = torch.compile
executors['torch-eager'] = eager

#print(inspect.signature(model.forward, follow_wrapped=True))

for name, func in executors.items():
    exec_model = func(model)
    kernel_time = run_model(exec_model, args)
    print(f"{name} {kernel_time:.03f} ms")

DGX H100 Results:

Execution Type Wall Clock Time (ms)
Thunder-nvFuser 1.082
Thunder-torch.compile N/A
Thunder-torch 0.965
torch.compile 0.313
torch-eager 0.834

If I remove recomputeArgs from nvFuser, the picture changes:

Execution Type Wall Clock Time (ms)
Thunder-nvFuser 0.926
Thunder-torch.compile N/A
Thunder-torch 0.965
torch.compile 0.313
torch-eager 0.834

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant