-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reduce nvFuser's Hot Path Host Latency #3507
Labels
Comments
This was referenced Dec 2, 2024
Another example: import torch
import thunder
from transformers.models.llama import LlamaForCausalLM, LlamaConfig
from functools import partial, wraps
from typing import Callable
from collections import OrderedDict
import inspect
LLAMA_3_2_1B_CFG = {
"architectures": ["LlamaForCausalLM"],
"attention_bias": False,
"attention_dropout": 0.0,
"bos_token_id": 128000,
"eos_token_id": 128001,
"head_dim": 64,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 8192,
"max_position_embeddings": 131072,
"mlp_bias": False,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 16,
"num_key_value_heads": 8,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": {
"factor": 32.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
},
"rope_theta": 500000.0,
"tie_word_embeddings": True,
"torch_dtype": "bfloat16",
"transformers_version": "4.45.0.dev0",
"use_cache": True,
"vocab_size": 128256,
"_commit_hash": "4e20de362430cd3b72f300e6b0f18e50e7166e08",
}
args = dict(
input_ids=torch.ones(1, 2048, dtype=torch.int64, device="cuda"),
labels=torch.ones(1, 2048, dtype=torch.int64, device="cuda"),
)
config = LlamaConfig(**LLAMA_3_2_1B_CFG)
with torch.device("cuda"):
model = LlamaForCausalLM(config).to(torch.bfloat16)
def cuda_timer(warmup_iters: int = 2, timing_iters: int = 10):
def decorator(fn: Callable) -> Callable:
@wraps(fn)
def wrapper(*args, **kwargs) -> float:
for _ in range(warmup_iters):
fn(*args, **kwargs)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start.record()
for _ in range(timing_iters):
fn(*args, **kwargs)
end.record()
torch.cuda.synchronize()
kernel_time = start.elapsed_time(end) / timing_iters
return kernel_time
return wrapper
return decorator
@cuda_timer()
def run_model(mymodel, args) :
res = mymodel(**args)
res.loss.backward()
def eager(fn):
return fn
executors = OrderedDict()
executors['torch-eager'] = eager
executors['Thunder-nvFuser'] = thunder.jit
executors['Thunder-torch.compile'] = partial(thunder.jit, executors=['apex', 'cudnn', 'sdpa', 'torchcompile'])
executors['Thunder-torch'] = partial(thunder.jit, executors=['apex', 'cudnn', 'sdpa'])
executors['torch.compile'] = torch.compile
#print(inspect.signature(model.forward, follow_wrapped=True))
for name, func in executors.items():
exec_model = func(model)
kernel_time = run_model(exec_model, args)
print(f"{name} {kernel_time:.03f} ms") DGX H100-80GB Results:
L40 Results:
|
import torch
import thunder
from transformers.models.llama import LlamaForCausalLM, LlamaConfig
from typing import Callable
from functools import partial, wraps
from collections import OrderedDict
LLAMA_3_2_1B_CFG = {
"architectures": ["LlamaForCausalLM"],
"attention_bias": False,
"attention_dropout": 0.0,
"bos_token_id": 128000,
"eos_token_id": 128001,
"head_dim": 64,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 8192,
"max_position_embeddings": 131072,
"mlp_bias": False,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 16,
"num_key_value_heads": 8,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": {
"factor": 32.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
},
"rope_theta": 500000.0,
"tie_word_embeddings": True,
"torch_dtype": "bfloat16",
"transformers_version": "4.45.0.dev0",
"use_cache": True,
"vocab_size": 128256,
"_commit_hash": "4e20de362430cd3b72f300e6b0f18e50e7166e08",
}
config = LlamaConfig(**LLAMA_3_2_1B_CFG)
config.num_hidden_layers = 1
with torch.device("cuda"):
model = LlamaForCausalLM(config).to(torch.bfloat16).requires_grad_(False).eval()
args = dict(
cache_positions=torch.arange(6, device="cuda"),
input_ids=torch.tensor([[128000, 791, 1401, 311, 2324, 374]], device="cuda"),
attention_mask=torch.ones(1, 6, dtype=torch.int64, device="cuda"),
inputs_embeds=None,
use_cache=True,
return_dict=True,
)
def cuda_timer(warmup_iters: int = 10, timing_iters: int = 40):
def decorator(fn: Callable) -> Callable:
@wraps(fn)
def wrapper(*args, **kwargs) -> float:
for _ in range(warmup_iters):
fn(*args, **kwargs)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start.record()
for _ in range(timing_iters):
fn(*args, **kwargs)
end.record()
torch.cuda.synchronize()
kernel_time = start.elapsed_time(end) / timing_iters
return kernel_time
return wrapper
return decorator
@cuda_timer()
def run_model(mymodel, args) :
res = mymodel(**args)
def eager(fn):
return fn
executors = OrderedDict()
executors['Thunder-nvFuser'] = thunder.jit
#executors['Thunder-torch.compile'] = partial(thunder.jit, executors=['apex', 'cudnn', 'sdpa', 'torchcompile'])
executors['Thunder-torch'] = partial(thunder.jit, executors=['apex', 'cudnn', 'sdpa'])
executors['torch.compile'] = torch.compile
executors['torch-eager'] = eager
#print(inspect.signature(model.forward, follow_wrapped=True))
for name, func in executors.items():
exec_model = func(model)
kernel_time = run_model(exec_model, args)
print(f"{name} {kernel_time:.03f} ms") DGX H100 Results:
If I remove
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
On a DGXH100, nvFuser's current hot path latency is around 20 to 30 us which is too expensive for inference sized kernels that can be on the order of <10us. nvFuser's cache hierachy may need to a short cut to return a list of kernel and launch parameters instead of recalculating various things. This complete picture of what nvFuser is recalculating needs to be enumerated.
There are 3 general latency cases that nvFuser cares about and we are talking about the 3rd one:
DGX H100-80GB Results:
Repro:
The text was updated successfully, but these errors were encountered: