Skip to content

Commit

Permalink
fix(hosted_vllm/transformation.py): return fake api key, if none give… (
Browse files Browse the repository at this point in the history
#7301)

* fix(hosted_vllm/transformation.py): return fake api key, if none give. Prevents httpx error

Fixes #7291

* test: fix test

* fix(main.py): add hosted_vllm/ support for embeddings endpoint

Closes #7290

* docs(vllm.md): add docs on vllm embeddings usage

* fix(__init__.py): fix sambanova model test

* fix(base_llm_unit_tests.py): skip pydantic obj test if model takes >5s to respond
  • Loading branch information
krrishdholakia authored Dec 19, 2024
1 parent 246e3ba commit 6a45ee1
Show file tree
Hide file tree
Showing 9 changed files with 189 additions and 6 deletions.
66 changes: 63 additions & 3 deletions docs/my-website/docs/providers/vllm.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,17 @@ import TabItem from '@theme/TabItem';

LiteLLM supports all models on VLLM.

| Property | Details |
|-------|-------|
| Description | vLLM is a fast and easy-to-use library for LLM inference and serving. [Docs](https://docs.vllm.ai/en/latest/index.html) |
| Provider Route on LiteLLM | `hosted_vllm/` (for OpenAI compatible server), `vllm/` (for vLLM sdk usage) |
| Provider Doc | [vLLM ↗](https://docs.vllm.ai/en/latest/index.html) |
| Supported Endpoints | `/chat/completions`, `/embeddings`, `/completions` |


# Quick Start

## Usage - litellm.completion (calling vLLM endpoint)
## Usage - litellm.completion (calling OpenAI compatible endpoint)
vLLM Provides an OpenAI compatible endpoints - here's how to call it with LiteLLM

In order to use litellm to call a hosted vllm server add the following to your completion call
Expand All @@ -29,7 +37,7 @@ print(response)
```


## Usage - LiteLLM Proxy Server (calling vLLM endpoint)
## Usage - LiteLLM Proxy Server (calling OpenAI compatible endpoint)

Here's how to call an OpenAI-Compatible Endpoint with the LiteLLM Proxy Server

Expand Down Expand Up @@ -97,7 +105,59 @@ Here's how to call an OpenAI-Compatible Endpoint with the LiteLLM Proxy Server
</Tabs>


## Extras - for `vllm pip package`
## Embeddings

<Tabs>
<TabItem value="sdk" label="SDK">

```python
from litellm import embedding
import os

os.environ["HOSTED_VLLM_API_BASE"] = "http://localhost:8000"


embedding = embedding(model="hosted_vllm/facebook/opt-125m", input=["Hello world"])

print(embedding)
```

</TabItem>
<TabItem value="proxy" label="PROXY">

1. Setup config.yaml

```yaml
model_list:
- model_name: my-model
litellm_params:
model: hosted_vllm/facebook/opt-125m # add hosted_vllm/ prefix to route as OpenAI provider
api_base: https://hosted-vllm-api.co # add api base for OpenAI compatible provider
```
2. Start the proxy
```bash
$ litellm --config /path/to/config.yaml

# RUNNING on http://0.0.0.0:4000
```

3. Test it!

```bash
curl -L -X POST 'http://0.0.0.0:4000/embeddings' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{"input": ["hello world"], "model": "my-model"}'
```

[See OpenAI SDK/Langchain/etc. examples](../proxy/user_keys.md#embeddings)

</TabItem>
</Tabs>

## (Deprecated) for `vllm pip package`
### Using - `litellm.completion`

```
Expand Down
5 changes: 5 additions & 0 deletions litellm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ def identify(event_details):
anyscale_models: List = []
cerebras_models: List = []
galadriel_models: List = []
sambanova_models: List = []


def add_known_models():
Expand Down Expand Up @@ -578,6 +579,8 @@ def add_known_models():
cerebras_models.append(key)
elif value.get("litellm_provider") == "galadriel":
galadriel_models.append(key)
elif value.get("litellm_provider") == "sambanova_models":
sambanova_models.append(key)


add_known_models()
Expand Down Expand Up @@ -841,6 +844,7 @@ def add_known_models():
+ anyscale_models
+ cerebras_models
+ galadriel_models
+ sambanova_models
)


Expand Down Expand Up @@ -891,6 +895,7 @@ def add_known_models():
"anyscale": anyscale_models,
"cerebras": cerebras_models,
"galadriel": galadriel_models,
"sambanova": sambanova_models,
}

# mapping for those models which have larger equivalents
Expand Down
2 changes: 1 addition & 1 deletion litellm/llms/hosted_vllm/chat/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,6 @@ def _get_openai_compatible_provider_info(
) -> Tuple[Optional[str], Optional[str]]:
api_base = api_base or get_secret_str("HOSTED_VLLM_API_BASE") # type: ignore
dynamic_api_key = (
api_key or get_secret_str("HOSTED_VLLM_API_KEY") or ""
api_key or get_secret_str("HOSTED_VLLM_API_KEY") or "fake-api-key"
) # vllm does not require an api key
return api_base, dynamic_api_key
5 changes: 5 additions & 0 deletions litellm/llms/hosted_vllm/embedding/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
No transformation is required for hosted_vllm embedding.

VLLM is a superset of OpenAI's `embedding` endpoint.

To pass provider-specific parameters, see [this](https://docs.litellm.ai/docs/completion/provider_specific_params)
6 changes: 5 additions & 1 deletion litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3362,7 +3362,11 @@ def embedding( # noqa: PLR0915
client=client,
aembedding=aembedding,
)
elif custom_llm_provider == "openai_like" or custom_llm_provider == "jina_ai":
elif (
custom_llm_provider == "openai_like"
or custom_llm_provider == "jina_ai"
or custom_llm_provider == "hosted_vllm"
):
api_base = (
api_base or litellm.api_base or get_secret_str("OPENAI_LIKE_API_BASE")
)
Expand Down
70 changes: 70 additions & 0 deletions litellm/model_prices_and_context_window_backup.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,76 @@
"supports_prompt_caching": true,
"supports_response_schema": true
},
"sambanova/Meta-Llama-3.1-8B-Instruct": {
"max_tokens": 16000,
"max_input_tokens": 16000,
"max_output_tokens": 16000,
"input_cost_per_token": 0.0000001,
"output_cost_per_token": 0.0000002,
"litellm_provider": "sambanova",
"supports_function_calling": true,
"mode": "chat"
},
"sambanova/Meta-Llama-3.1-70B-Instruct": {
"max_tokens": 128000,
"max_input_tokens": 128000,
"max_output_tokens": 128000,
"input_cost_per_token": 0.0000006,
"output_cost_per_token": 0.0000012,
"litellm_provider": "sambanova",
"supports_function_calling": true,
"mode": "chat"
},
"sambanova/Meta-Llama-3.1-405B-Instruct": {
"max_tokens": 16000,
"max_input_tokens": 16000,
"max_output_tokens": 16000,
"input_cost_per_token": 0.000005,
"output_cost_per_token": 0.000010,
"litellm_provider": "sambanova",
"supports_function_calling": true,
"mode": "chat"
},
"sambanova/Meta-Llama-3.2-1B-Instruct": {
"max_tokens": 16000,
"max_input_tokens": 16000,
"max_output_tokens": 16000,
"input_cost_per_token": 0.0000004,
"output_cost_per_token": 0.0000008,
"litellm_provider": "sambanova",
"supports_function_calling": true,
"mode": "chat"
},
"sambanova/Meta-Llama-3.2-3B-Instruct": {
"max_tokens": 4000,
"max_input_tokens": 4000,
"max_output_tokens": 4000,
"input_cost_per_token": 0.0000008,
"output_cost_per_token": 0.0000016,
"litellm_provider": "sambanova",
"supports_function_calling": true,
"mode": "chat"
},
"sambanova/Qwen2.5-Coder-32B-Instruct": {
"max_tokens": 8000,
"max_input_tokens": 8000,
"max_output_tokens": 8000,
"input_cost_per_token": 0.0000015,
"output_cost_per_token": 0.000003,
"litellm_provider": "sambanova",
"supports_function_calling": true,
"mode": "chat"
},
"sambanova/Qwen2.5-72B-Instruct": {
"max_tokens": 8000,
"max_input_tokens": 8000,
"max_output_tokens": 8000,
"input_cost_per_token": 0.000002,
"output_cost_per_token": 0.000004,
"litellm_provider": "sambanova",
"supports_function_calling": true,
"mode": "chat"
},
"gpt-4": {
"max_tokens": 4096,
"max_input_tokens": 8192,
Expand Down
3 changes: 3 additions & 0 deletions tests/llm_translation/base_llm_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,13 +219,16 @@ class TestModel(BaseModel):
},
],
response_format=TestModel,
timeout=5,
)
assert res is not None

print(res.choices[0].message)

assert res.choices[0].message.content is not None
assert res.choices[0].message.tool_calls is None
except litellm.Timeout:
pytest.skip("Model took too long to respond")
except litellm.InternalServerError:
pytest.skip("Model is overloaded")

Expand Down
22 changes: 22 additions & 0 deletions tests/local_testing/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,28 @@ async def test_hf_embedddings_with_optional_params(sync_mode):
assert json_data["parameters"]["top_k"] == 10


def test_hosted_vllm_embedding(monkeypatch):
monkeypatch.setenv("HOSTED_VLLM_API_BASE", "http://localhost:8000")
from litellm.llms.custom_httpx.http_handler import HTTPHandler

client = HTTPHandler()
with patch.object(client, "post") as mock_post:
try:
embedding(
model="hosted_vllm/jina-embeddings-v3",
input=["Hello world"],
client=client,
)
except Exception as e:
print(e)

mock_post.assert_called_once()

json_data = json.loads(mock_post.call_args.kwargs["data"])
assert json_data["input"] == ["Hello world"]
assert json_data["model"] == "jina-embeddings-v3"


@pytest.mark.parametrize(
"model",
[
Expand Down
16 changes: 15 additions & 1 deletion tests/local_testing/test_get_llm_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,20 @@ def test_default_api_base():
assert other_provider.value not in api_base.replace("/openai", "")


def test_hosted_vllm_default_api_key():
from litellm.litellm_core_utils.get_llm_provider_logic import (
_get_openai_compatible_provider_info,
)

_, _, dynamic_api_key, _ = _get_openai_compatible_provider_info(
model="hosted_vllm/llama-3.1-70b-instruct",
api_base=None,
api_key=None,
dynamic_api_key=None,
)
assert dynamic_api_key == "fake-api-key"


def test_get_llm_provider_jina_ai():
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
model="jina_ai/jina-embeddings-v3",
Expand All @@ -168,7 +182,7 @@ def test_get_llm_provider_hosted_vllm():
)
assert custom_llm_provider == "hosted_vllm"
assert model == "llama-3.1-70b-instruct"
assert dynamic_api_key == ""
assert dynamic_api_key == "fake-api-key"


def test_get_llm_provider_watson_text():
Expand Down

0 comments on commit 6a45ee1

Please sign in to comment.