Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support vllm quantization #7297

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions litellm/llms/vllm/completion/handler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import time # type: ignore
from typing import Callable
from typing import Callable, Optional
import litellm

import httpx

Expand All @@ -24,17 +25,31 @@ def __init__(self, status_code, message):


# check if vllm is installed
def validate_environment(model: str):
def validate_environment(model: str, vllm_params: dict):
global llm
try:
from vllm import LLM, SamplingParams # type: ignore

if llm is None:
llm = LLM(model=model)
llm = LLM(model=model, **vllm_params)
return llm, SamplingParams
except Exception as e:
raise VLLMError(status_code=0, message=str(e))

# extract vllm params from optional params
def handle_vllm_params(optional_params: Optional[dict]):
vllm_params = litellm.VLLMConfig.get_config()
if optional_params is None:
optional_params = {}

for k, v in optional_params.items():
if k in vllm_params:
vllm_params[k] = v

optional_params = {k: v for k, v in optional_params.items() if k not in vllm_params}

return vllm_params, optional_params


def completion(
model: str,
Expand All @@ -49,8 +64,9 @@ def completion(
logger_fn=None,
):
global llm
vllm_params, optional_params = handle_vllm_params(optional_params)
try:
llm, SamplingParams = validate_environment(model=model)
llm, SamplingParams = validate_environment(model=model, vllm_params=vllm_params)
except Exception as e:
raise VLLMError(status_code=0, message=str(e))
sampling_params = SamplingParams(**optional_params)
Expand Down Expand Up @@ -138,8 +154,9 @@ def batch_completions(
]
)
"""
vllm_params, optional_params = handle_vllm_params(optional_params)
try:
llm, SamplingParams = validate_environment(model=model)
llm, SamplingParams = validate_environment(model=model, vllm_params=vllm_params)
except Exception as e:
error_str = str(e)
raise VLLMError(status_code=0, message=error_str)
Expand Down
80 changes: 78 additions & 2 deletions litellm/llms/vllm/completion/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,88 @@
NOT RECOMMENDED FOR PRODUCTION USE. Use `hosted_vllm/` instead.
"""

from typing import Optional, Dict, Any, Union
import types

from ...hosted_vllm.chat.transformation import HostedVLLMChatConfig


class VLLMConfig(HostedVLLMChatConfig):
"""
VLLM SDK supports the same OpenAI params as hosted_vllm.
"""

pass
model: str
tokenizer: Optional[str] = None
tokenizer_mode: str = "auto"
skip_tokenizer_init: bool = False
trust_remote_code: bool = False
allowed_local_media_path: str = ""
tensor_parallel_size: int = 1
dtype: str = "auto"
quantization: Optional[str] = None
load_format: str = "auto"
revision: Optional[str] = None
tokenizer_revision: Optional[str] = None
seed: int = 0
gpu_memory_utilization: float = 0.9
swap_space: float = 4
cpu_offload_gb: float = 0
enforce_eager: Optional[bool] = None
max_seq_len_to_capture: int = 8192
disable_custom_all_reduce: bool = False
disable_async_output_proc: bool = False
hf_overrides: Optional[Any] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None
task: str = "auto"
override_pooler_config: Optional[Any] = None
compilation_config: Optional[Union[int, Dict[str, Any]]] = None

def __init__(
self,
tokenizer: Optional[str] = None,
tokenizer_mode: str = "auto",
skip_tokenizer_init: bool = False,
trust_remote_code: bool = False,
allowed_local_media_path: str = "",
tensor_parallel_size: int = 1,
dtype: str = "auto",
quantization: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: float = 4,
cpu_offload_gb: float = 0,
enforce_eager: Optional[bool] = None,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = False,
hf_overrides: Optional[Any] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
task: str = "auto",
override_pooler_config: Optional[Any] = None,
compilation_config: Optional[Union[int, Dict[str, Any]]] = None,
):
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self":
setattr(self.__class__, key, value)

@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not k.startswith("_abc")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
}
60 changes: 60 additions & 0 deletions tests/llm_translation/test_vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import pytest
from unittest.mock import MagicMock, patch

import litellm

def test_vllm():
litellm.set_verbose = True

with patch("litellm.llms.vllm.completion.handler.validate_environment") as mock_client:
mock_client.return_value = MagicMock(), MagicMock()
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the capital of France?"},
]

response = litellm.completion(
model="vllm/facebook/opt-125m",
messages=messages
)

# Verify the request was made
mock_client.assert_called_once()

# Check the request body
request_body = mock_client.call_args.kwargs

assert request_body["model"] == "facebook/opt-125m"
assert request_body["vllm_params"] is not None
assert request_body["vllm_params"]["quantization"] is None


def test_vllm_quantized():
litellm.set_verbose = True

with patch("litellm.llms.vllm.completion.handler.validate_environment") as mock_client:
mock_client.return_value = MagicMock(), MagicMock()
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the capital of France?"},
]

response = litellm.completion(
model="vllm/facebook/opt-125m",
messages=messages,
dtype="auto",
quantization="bitsandbytes",
load_format="bitsandbytes"
)

# Verify the request was made
mock_client.assert_called_once()

# Check the request body
request_body = mock_client.call_args.kwargs

assert request_body["model"] == "facebook/opt-125m"
assert request_body["vllm_params"] is not None
assert request_body["vllm_params"]["quantization"] == "bitsandbytes"
assert request_body["vllm_params"]["dtype"] == "auto"
assert request_body["vllm_params"]["load_format"] == "bitsandbytes"