-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding NXD and NKI example #26
Open
EmilyWebber
wants to merge
3
commits into
aws-neuron:main
Choose a base branch
from
EmilyWebber:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 2 commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# TinyLLama inference with NeuronX Distributed and Neuron Kernel Interface | ||
In this example you can test [TinyLlama](https://huggingface.co/TinyLlama) from Hugging Face on AWS Trainium. This example was built on a trn1.2xlarge machine using this AMI: Deep Learning AMI Neuron (Ubuntu 22.04) 20240927. | ||
|
||
This example pulls largely from the Llama2 inference example from NeuronX Distributed available [here](https://github.com/aws-neuron/neuronx-distributed/tree/main/examples/inference/llama2). However, it adds support for 1/ TinyLlama and 2/ Neuron Kernel Interface (NKI). | ||
|
||
### Setup | ||
To run this example, first clone the repository with `git clone https://github.com/aws-neuron/nki-samples.git`. | ||
|
||
Next, `cd` into `nki_samples/nki_university/nki_and_nxd_llama_inference`. | ||
|
||
Then install the requirements with `pip install -r requirements.txt`. | ||
|
||
### Download the model | ||
|
||
You'll need to download the TinyLlama model from Hugging Face. You can do this through the `transformers` SDK like this. | ||
|
||
``` | ||
# Load model directly | ||
from transformers import AutoTokenizer, AutoModelForCausalLM | ||
|
||
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama_v1.1") | ||
model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama_v1.1") | ||
``` | ||
|
||
After that, save the model into a local directory. This needs to be the same directory you set at the top of `run_llama.py`. | ||
|
||
``` | ||
model.save_pretrained('/home/ubuntu/models/Tiny-Llama') | ||
tokenizer.save_pretrained('/home/ubuntu/models/Tiny-Llama') | ||
``` | ||
|
||
### Test the script | ||
Once you've installed all the packages and downloaded your model, you should be ready to test the script. This is done with `python run_llama.py`. | ||
|
||
This script will take at least 30 minutes to complete because it does the following: 1/ compile your model 2/ load to Neuron device 3/ test on Neuron 4/ compare accuracy 5/ run benchmark suite. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: either |
||
|
||
### Write your NKI kernel | ||
Your NKI kernels can operate like normal Python functions inside of this project, such as within `llama2/neuron_modeling_llama.py`. Your script already has a sample kernel, `nki_tensor_add_`, which simply takes the addition of the hidden and residual states during the forward pass. This is available in `llama2/neuron_modeling_llama.py`. This kernel has been tested and confirmed for both accuracy and performance. | ||
|
||
|
Empty file.
106 changes: 106 additions & 0 deletions
106
nki_university/nki_and_nxd_llama_inference/llama2/llama2_runner.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import torch | ||
from llama2.neuron_modeling_llama import ( | ||
NeuronLlamaConfig, | ||
NeuronLlamaForCausalLM, | ||
NeuronLlamaModel, | ||
) | ||
from runner import InferenceRunner | ||
from transformers import AutoTokenizer | ||
|
||
from neuronx_distributed.parallel_layers.checkpointing import _invoke_preshard_hook | ||
from neuronx_distributed.quantization.quantization_config import QuantizationType | ||
from neuronx_distributed.quantization.quantization_utils import ( | ||
quantize_pytorch_model_per_channel_symmetric, | ||
quantize_pytorch_model_per_tensor_symmetric, | ||
) | ||
|
||
|
||
class LlamaRunner(InferenceRunner): | ||
def load_hf_model(self): | ||
return NeuronLlamaForCausalLM.load_hf_model(self.model_path) | ||
|
||
def load_neuron_model_on_cpu(self, max_prompt_length, sequence_length, batch_size, **kwargs): | ||
self.config = self.get_config_for_nxd( | ||
batch_size, | ||
1, | ||
max_prompt_length=max_prompt_length, | ||
sequence_length=sequence_length, | ||
enable_bucketing=False, | ||
**kwargs) | ||
self.config.torch_dtype = torch.float32 | ||
|
||
neuron_model = NeuronLlamaModel(self.config) | ||
|
||
state_dict = NeuronLlamaForCausalLM.get_state_dict(self.model_path, config=self.config) | ||
_invoke_preshard_hook(neuron_model, state_dict) | ||
|
||
neuron_model.load_state_dict(state_dict, strict=False) | ||
|
||
if self.config.torch_dtype == torch.bfloat16: | ||
neuron_model.bfloat16() | ||
|
||
model = NeuronLlamaForCausalLM(None, self.config) | ||
model.context_encoding_model.model = neuron_model | ||
model.token_generation_model.model = neuron_model | ||
return model | ||
|
||
def generate_quantized_hf_checkpoints_on_cpu(self, max_prompt_length, sequence_length, batch_size, **kwargs): | ||
config = self.get_config_for_nxd(batch_size, 1, max_prompt_length, sequence_length, **kwargs) | ||
config.torch_dtype = torch.float32 | ||
|
||
quantized_state_dict = NeuronLlamaForCausalLM.generate_quantized_state_dict( | ||
model_path=self.model_path, config=config | ||
) | ||
return quantized_state_dict | ||
|
||
def load_quantized_neuron_model_on_cpu(self, max_prompt_length, sequence_length, batch_size, **kwargs): | ||
model = self.load_neuron_model_on_cpu(max_prompt_length, sequence_length, batch_size, **kwargs) | ||
|
||
quantization_type = QuantizationType(kwargs.get("quantization_type", "per_tensor_symmetric")) | ||
if quantization_type == QuantizationType.PER_TENSOR_SYMMETRIC: | ||
return quantize_pytorch_model_per_tensor_symmetric(model, inplace=True) | ||
elif quantization_type == QuantizationType.PER_CHANNEL_SYMMETRIC: | ||
return quantize_pytorch_model_per_channel_symmetric(model, inplace=True) | ||
else: | ||
raise RuntimeError(f"quantization_type: {quantization_type} not supported") | ||
|
||
def load_neuron_model(self, traced_model_path): | ||
config = NeuronLlamaConfig.from_pretrained(traced_model_path) | ||
model = NeuronLlamaForCausalLM.from_pretrained("", config) | ||
self.config = config | ||
|
||
model.load(traced_model_path) | ||
if config.torch_dtype == torch.bfloat16: | ||
model.bfloat16() | ||
|
||
return model | ||
|
||
def load_tokenizer(self, padding_side=None): | ||
tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path) | ||
if not hasattr(self.config, 'pad_token_id') or self.config.pad_token_id is None: | ||
# Use eos_token as pad_token which works for both llama2 and llama3 | ||
tokenizer.pad_token = tokenizer.eos_token | ||
else: | ||
tokenizer.pad_token_id = self.config.pad_token_id | ||
tokenizer.padding_side = padding_side if padding_side else self.get_padding_side() | ||
return tokenizer | ||
|
||
def get_config_cls(self): | ||
return NeuronLlamaConfig | ||
|
||
def get_model_cls(self): | ||
return NeuronLlamaForCausalLM | ||
|
||
def get_padding_side(self): | ||
return "right" | ||
|
||
def get_default_hf_generation_config_kwargs(self): | ||
config = super().get_default_hf_generation_config_kwargs() | ||
# set to eos_token_id as that's done in load_tokenizer | ||
config['pad_token_id'] = self.generation_config.eos_token_id | ||
|
||
return config | ||
|
||
|
||
if __name__ == "__main__": | ||
LlamaRunner.cmd_execute() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe “instance” rather than “machine”