If training's got you in a stew, take a REST and speed right through.
🎉 2024-3-14: REST is accepted to NAACL 2024!
REST is a retrieval-based speculative decoding method designed to boost generation speed of LLMs. Instead of relying on a draft language model like speculative decoding, REST utilizes a datastore to retrieve and employ draft tokens. Moreover, REST differs from blockwise parallel decoding and Medusa in that it doesn't require extra training steps. It functions as a plug-and-play solution capable of accelerating any pre-existing language model.
- Introduction
- Contents
- Installation
- Build datastores
- Inference
- Citation
- Other Models and Datastore
- Acknowledgements
conda create -n rest python=3.9
conda activate rest
pip3 install -r requirements.txt # pay attention to Pytorch CUDA version
pip3 install DraftRetriever/wheels/draftretriever-0.1.0-cp39-cp39-manylinux_2_34_x86_64.whl
Build a chat datastore using data from ShareGPT within 10 minutes (requires 465MB disk storage)
cd datastore
python3 get_datastore_chat.py --model-path lmsys/vicuna-7b-v1.5 # get datastore_chat_small.idx in this folder
Build a Python code generation datastore from The Stack within 20 minutes (requires 924MB disk storage)
cd datastore
python3 get_datastore_code.py --model-path codellama/CodeLlama-7b-instruct-hf # get datastore_stack_small.idx in this folder
(optionally) Build a chat datastore using data from UltraChat (requires 12GB disk storage)
cd datastore
python3 get_datastore_chat.py --model-path lmsys/vicuna-7b-v1.5 --large-datastore True # get datastore_chat_large.idx in this folder
(optionally) Build a Python code generation datastore from The Stack (requires 27GB disk storage)
cd datastore
python3 get_datastore_code.py --model-path codellama/CodeLlama-7b-instruct-hf --large-datastore True # get datastore_stack_large.idx in this folder
cd llm_judge
RAYON_NUM_THREADS=6 CUDA_VISIBLE_DEVICES=0 python3 gen_model_answer_rest.py --model-path lmsys/vicuna-7b-v1.5 --model-id vicuna-7b-v1.5 --datastore-path ../datastore/datastore_chat_small.idx
cd human_eval
RAYON_NUM_THREADS=6 CUDA_VISIBLE_DEVICES=0 python3 rest_test.py --model-path codellama/CodeLlama-7b-instruct-hf --datastore-path ../datastore/datastore_stack_small.idx
RAYON_NUM_THREADS=6 CUDA_VISIBLE_DEVICES=0 python3 -m rest.inference.cli --datastore-path datastore/datastore_chat_small.idx --base-model lmsys/vicuna-7b-v1.5
Note that the RAYON_NUM_THREADS environment variable control the maximum number of threads for retrieval. You can adjust it based on your machine.
In the examples above, we default to use Vicuna and CodeLlama. But actually you can use any LLaMA-based models you like by simply changing the "--model-path" argument. You can also build the datastore from any data you like. If you want to use architectures other than LLaMA, you can also modify the file model/modeling_llama_kv.py to match the corresponding model.
Note: For models with a vocab size larger than 65535 (range of u16), you may change this line in writer from self.index_file.write_u16::<LittleEndian>(item as u16)?;
to self.index_file.write_u32::<LittleEndian>(item as u32)?;
Besides, change this line in Reader from let int = LittleEndian::read_u16(&data_u8[i..i+2]) as i32;
to let int = LittleEndian::read_u32(&data_u8[i..i+4]) as i32;
@misc{he2023rest,
title={REST: Retrieval-Based Speculative Decoding},
author={Zhenyu He and Zexuan Zhong and Tianle Cai and Jason D Lee and Di He},
year={2023},
eprint={2311.08252},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
The codebase is from Medusa and influenced by remarkable projects from the LLM community, including FastChat, TinyChat, vllm and many others.