diff --git a/python/packages/autogen-ext/pyproject.toml b/python/packages/autogen-ext/pyproject.toml index d79fbc8f1b62..d96089fe85f1 100644 --- a/python/packages/autogen-ext/pyproject.toml +++ b/python/packages/autogen-ext/pyproject.toml @@ -28,7 +28,7 @@ langchain = ["langchain_core~= 0.3.3"] azure = ["azure-core", "azure-identity"] docker = ["docker~=7.0"] openai = ["openai>=1.3"] -chromadb = ["chromadb~=0.4.15"] +chromadb = ["chromadb~=0.5.15", "sentence-transformers"] [tool.hatch.build.targets.wheel] packages = ["src/autogen_ext"] diff --git a/python/packages/autogen-ext/src/autogen_ext/storage/__init__.py b/python/packages/autogen-ext/src/autogen_ext/storage/__init__.py index e69de29bb2d1..523de3f21c4e 100644 --- a/python/packages/autogen-ext/src/autogen_ext/storage/__init__.py +++ b/python/packages/autogen-ext/src/autogen_ext/storage/__init__.py @@ -0,0 +1,4 @@ +from ._chromadb import ChromaVectorDB +from ._factory import VectorDBFactory + +__all__ = ["ChromaVectorDB", "VectorDBFactory"] diff --git a/python/packages/autogen-ext/src/autogen_ext/storage/_base.py b/python/packages/autogen-ext/src/autogen_ext/storage/_base.py index f47732aff97e..2e73c61e5155 100644 --- a/python/packages/autogen-ext/src/autogen_ext/storage/_base.py +++ b/python/packages/autogen-ext/src/autogen_ext/storage/_base.py @@ -1,4 +1,3 @@ -from pydantic import BaseModel from typing import ( Any, Callable, @@ -12,19 +11,23 @@ runtime_checkable, ) +from pydantic import BaseModel Metadata = Union[Mapping[str, Any], None] Vector = Union[Sequence[float], Sequence[int]] ItemID = Union[str, int] - class Document(BaseModel): """Define Document according to autogen 0.4 specifications.""" + id: ItemID - content: str - metadata: Optional[Metadata] - embedding: Optional[Vector] + content: Optional[str] = None + metadata: Optional[Metadata] = None + embedding: Optional[Vector] = None + + model_config = {"arbitrary_types_allowed": True} + """QueryResults is the response from the vector database for a query/queries. A query is a list containing one string while queries is a list containing multiple strings. @@ -33,10 +36,178 @@ class Document(BaseModel): QueryResults = List[List[Tuple[Document, float]]] +@runtime_checkable +class AsyncVectorDB(Protocol): + """ + Abstract class for async vector database. A vector database is responsible for storing and retrieving documents. + + Attributes: + active_collection: Any | The active collection in the vector database. Make get_collection faster. Default is None. + type: str | The type of the vector database, chroma, pgvector, etc. Default is "". + + Methods: + create_collection: Callable[[str, bool, bool], Awaitable[Any]] | Create a collection in the vector database. + get_collection: Callable[[str], Awaitable[Any]] | Get the collection from the vector database. + delete_collection: Callable[[str], Awaitable[Any]] | Delete the collection from the vector database. + insert_docs: Callable[[List[Document], str, bool], Awaitable[None]] | Insert documents into the collection of the vector database. + update_docs: Callable[[List[Document], str], Awaitable[None]] | Update documents in the collection of the vector database. + delete_docs: Callable[[List[ItemID], str], Awaitable[None]] | Delete documents from the collection of the vector database. + retrieve_docs: Callable[[List[str], str, int, float], Awaitable[QueryResults]] | Retrieve documents from the collection of the vector database based on the queries. + get_docs_by_ids: Callable[[List[ItemID], str], Awaitable[List[Document]]] | Retrieve documents from the collection of the vector database based on the ids. + """ + + active_collection: Any = None + type: str = "" + embedding_function: Optional[Callable[[List[str]], List[List[float]]]] = ( + None # embeddings = embedding_function(sentences) + ) + + async def create_collection(self, collection_name: str, overwrite: bool = False, get_or_create: bool = True) -> Any: + """ + Create a collection in the vector database. + Case 1. if the collection does not exist, create the collection. + Case 2. the collection exists, if overwrite is True, it will overwrite the collection. + Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection, + otherwise it raise a ValueError. + + Args: + collection_name: str | The name of the collection. + overwrite: bool | Whether to overwrite the collection if it exists. Default is False. + get_or_create: bool | Whether to get the collection if it exists. Default is True. + + Returns: + Any | The collection object. + """ + ... + + async def get_collection(self, collection_name: Optional[str] = None) -> Any: + """ + Get the collection from the vector database. + + Args: + collection_name: Optional[str] | The name of the collection. Default is None. + If None, return the current active collection. + + Returns: + Any | The collection object. + """ + ... + + async def delete_collection(self, collection_name: str) -> Any: + """ + Delete the collection from the vector database. + + Args: + collection_name: str | The name of the collection. + + Returns: + Any + """ + ... + + async def insert_docs( + self, + docs: Sequence[Document], + collection_name: Optional[str] = None, + upsert: bool = False, + **kwargs: Any, + ) -> None: + """ + Insert documents into the collection of the vector database. + + Args: + docs: List[Document] | A list of documents. Each document is a Pydantic Document model. + collection_name: Optional[str] | The name of the collection. Default is None. + upsert: bool | Whether to update the document if it exists. Default is False. + kwargs: Dict[str, Any] | Additional keyword arguments. + + Returns: + None + """ + ... + + async def update_docs(self, docs: Sequence[Document], collection_name: Optional[str] = None, **kwargs: Any) -> None: + """ + Update documents in the collection of the vector database. + + Args: + docs: List[Document] | A list of documents. + collection_name: Optional[str] | The name of the collection. Default is None. + kwargs: Dict[str, Any] | Additional keyword arguments. + + Returns: + None + """ + ... + + async def delete_docs(self, ids: Sequence[ItemID], collection_name: Optional[str] = None, **kwargs: Any) -> None: + """ + Delete documents from the collection of the vector database. + + Args: + ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`. + collection_name: Optional[str] | The name of the collection. Default is None. + kwargs: Dict[str, Any] | Additional keyword arguments. + + Returns: + None + """ + ... + + async def retrieve_docs( + self, + queries: Sequence[str], + collection_name: Optional[str] = None, + n_results: int = 10, + distance_threshold: float = -1, + **kwargs: Any, + ) -> QueryResults: + """ + Retrieve documents from the collection of the vector database based on the queries. + + Args: + queries: List[str] | A list of queries. Each query is a string. + collection_name: Optional[str] | The name of the collection. Default is None. + n_results: int | The number of relevant documents to return. Default is 10. + distance_threshold: float | The threshold for the distance score, only distance smaller than it will be + returned. Don't filter with it if < 0. Default is -1. + kwargs: Dict[str, Any] | Additional keyword arguments. + + Returns: + QueryResults | The query results. Each query result is a list of list of tuples containing the document and + the distance. + """ + ... + + async def get_docs_by_ids( + self, + ids: Optional[Sequence[ItemID]] = None, + collection_name: Optional[str] = None, + include: Optional[Sequence[str]] = None, + **kwargs: Any, + ) -> List[Document]: + """ + Retrieve documents from the collection of the vector database based on the ids. + + Args: + ids: Optional[List[ItemID]] | A list of document ids. If None, will return all the documents. Default is None. + collection_name: Optional[str] | The name of the collection. Default is None. + include: Optional[List[str]] | The fields to include. Default is None. + If None, will include ["metadatas", "documents"], ids will always be included. This may differ + depending on the implementation. + kwargs: Dict[str, Any] | Additional keyword arguments. + + Returns: + List[Document] | The results. + """ + ... + + @runtime_checkable class VectorDB(Protocol): """ - Abstract class for vector database. A vector database is responsible for storing and retrieving documents. + Abstract class for synchronous vector database. A vector database is responsible for storing and retrieving documents. + For async support, use AsyncVectorDB instead. Attributes: active_collection: Any | The active collection in the vector database. Make get_collection faster. Default is None. @@ -77,13 +248,13 @@ def create_collection(self, collection_name: str, overwrite: bool = False, get_o """ ... - def get_collection(self, collection_name: str = None) -> Any: + def get_collection(self, collection_name: Optional[str] = None) -> Any: """ Get the collection from the vector database. Args: - collection_name: str | The name of the collection. Default is None. If None, return the - current active collection. + collection_name: Optional[str] | The name of the collection. Default is None. + If None, return the current active collection. Returns: Any | The collection object. @@ -102,43 +273,49 @@ def delete_collection(self, collection_name: str) -> Any: """ ... - def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False, **kwargs) -> None: + def insert_docs( + self, + docs: Sequence[Document], + collection_name: Optional[str] = None, + upsert: bool = False, + **kwargs: Any, + ) -> None: """ Insert documents into the collection of the vector database. Args: - docs: List[Document] | A list of documents. Each document is a TypedDict `Document`. - collection_name: str | The name of the collection. Default is None. + docs: List[Document] | A list of documents. Each document is a Pydantic Document model. + collection_name: Optional[str] | The name of the collection. Default is None. upsert: bool | Whether to update the document if it exists. Default is False. - kwargs: Dict | Additional keyword arguments. + kwargs: Dict[str, Any] | Additional keyword arguments. Returns: None """ ... - def update_docs(self, docs: List[Document], collection_name: str = None, **kwargs) -> None: + def update_docs(self, docs: Sequence[Document], collection_name: Optional[str] = None, **kwargs: Any) -> None: """ Update documents in the collection of the vector database. Args: docs: List[Document] | A list of documents. - collection_name: str | The name of the collection. Default is None. - kwargs: Dict | Additional keyword arguments. + collection_name: Optional[str] | The name of the collection. Default is None. + kwargs: Dict[str, Any] | Additional keyword arguments. Returns: None """ ... - def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs) -> None: + def delete_docs(self, ids: Sequence[ItemID], collection_name: Optional[str] = None, **kwargs: Any) -> None: """ Delete documents from the collection of the vector database. Args: ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`. - collection_name: str | The name of the collection. Default is None. - kwargs: Dict | Additional keyword arguments. + collection_name: Optional[str] | The name of the collection. Default is None. + kwargs: Dict[str, Any] | Additional keyword arguments. Returns: None @@ -147,22 +324,22 @@ def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs) def retrieve_docs( self, - queries: List[str], - collection_name: str = None, + queries: Sequence[str], + collection_name: Optional[str] = None, n_results: int = 10, distance_threshold: float = -1, - **kwargs, + **kwargs: Any, ) -> QueryResults: """ Retrieve documents from the collection of the vector database based on the queries. Args: queries: List[str] | A list of queries. Each query is a string. - collection_name: str | The name of the collection. Default is None. + collection_name: Optional[str] | The name of the collection. Default is None. n_results: int | The number of relevant documents to return. Default is 10. distance_threshold: float | The threshold for the distance score, only distance smaller than it will be returned. Don't filter with it if < 0. Default is -1. - kwargs: Dict | Additional keyword arguments. + kwargs: Dict[str, Any] | Additional keyword arguments. Returns: QueryResults | The query results. Each query result is a list of list of tuples containing the document and @@ -171,21 +348,24 @@ def retrieve_docs( ... def get_docs_by_ids( - self, ids: List[ItemID] = None, collection_name: str = None, include=None, **kwargs + self, + ids: Optional[Sequence[ItemID]] = None, + collection_name: Optional[str] = None, + include: Optional[List[str]] = None, + **kwargs: Any, ) -> List[Document]: """ Retrieve documents from the collection of the vector database based on the ids. Args: - ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None. - collection_name: str | The name of the collection. Default is None. - include: List[str] | The fields to include. Default is None. + ids: Optional[List[ItemID]] | A list of document ids. If None, will return all the documents. Default is None. + collection_name: Optional[str] | The name of the collection. Default is None. + include: Optional[List[str]] | The fields to include. Default is None. If None, will include ["metadatas", "documents"], ids will always be included. This may differ depending on the implementation. - kwargs: dict | Additional keyword arguments. + kwargs: Dict[str, Any] | Additional keyword arguments. Returns: List[Document] | The results. """ ... - diff --git a/python/packages/autogen-ext/src/autogen_ext/storage/_chromadb.py b/python/packages/autogen-ext/src/autogen_ext/storage/_chromadb.py index dca623bfd18e..618728859175 100644 --- a/python/packages/autogen-ext/src/autogen_ext/storage/_chromadb.py +++ b/python/packages/autogen-ext/src/autogen_ext/storage/_chromadb.py @@ -1,16 +1,21 @@ +# python\packages\autogen-ext\src\autogen_ext\storage\_chromadb.py + +import logging import os -from typing import Callable, List, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Union + +from autogen_core.application.logging import TRACE_LOGGER_NAME if TYPE_CHECKING: - import chromadb.utils.embedding_functions as ef + from chromadb.api import AsyncClientAPI, Client from chromadb.api.models.Collection import Collection - from chromadb.errors import ChromaError + from chromadb.config import Settings -from ._base import Document, ItemID, QueryResults, VectorDB -from ._utils import chroma_results_to_query_results, filter_results_by_distance, get_logger +from ._base import AsyncVectorDB, Document, ItemID, QueryResults, VectorDB +from ._utils import chroma_results_to_query_results, filter_results_by_distance -CHROMADB_MAX_BATCH_SIZE = os.environ.get("CHROMADB_MAX_BATCH_SIZE", 40000) -logger = get_logger(__name__) +CHROMADB_MAX_BATCH_SIZE = int(os.environ.get("CHROMADB_MAX_BATCH_SIZE", 40000)) +logger = logging.getLogger(f"{TRACE_LOGGER_NAME}.{__name__}") class ChromaVectorDB(VectorDB): @@ -22,10 +27,19 @@ class ChromaVectorDB(VectorDB): This class requires the :code:`chromadb` extra for the :code:`autogen-ext` package. """ - ChromaError = Exception # Default to Exception if chromadb is not installed + ChromaError = Exception # Default to Exception if chromadb is not installed def __init__( - self, *, client=None, path: str = "tmp/db", embedding_function: Callable = None, metadata: dict = None, **kwargs + self, + *, + client: Optional["Client"] = None, + path: Optional[str] = None, + embedding_function: Optional[Callable[[List[str]], List[List[float]]]] = None, + metadata: Optional[Dict[str, Any]] = None, + client_type: str = "persistent", + host: str = "localhost", + port: int = 8000, + **kwargs: Any, ) -> None: """ Initialize the vector database. @@ -33,14 +47,13 @@ def __init__( Args: client: chromadb.Client | The client object of the vector database. Default is None. If provided, it will use the client object directly and ignore other arguments. - path: str | The path to the vector database. Default is `tmp/db`. The default was `None` for version <=0.2.24. + path: Optional[str] | The path to the vector database. Required if client_type is 'persistent'. embedding_function: Callable | The embedding function used to generate the vector representation of the documents. Default is None, SentenceTransformerEmbeddingFunction("all-MiniLM-L6-v2") will be used. - metadata: dict | The metadata of the vector database. Default is None. If None, it will use this - setting: {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32}. For more details of - the metadata, please refer to [distances](https://github.com/nmslib/hnswlib#supported-distances), - [hnsw](https://github.com/chroma-core/chroma/blob/566bc80f6c8ee29f7d99b6322654f32183c368c4/chromadb/segment/impl/vector/local_hnsw.py#L184), - and [ALGO_PARAMS](https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md). + metadata: dict | The metadata of the vector database. Default is None. + client_type: str | The type of client to use. Can be 'persistent' or 'http'. Default is 'persistent'. + host: str | The host of the HTTP server. Default is 'localhost'. + port: int | The port of the HTTP server. Default is 8000. kwargs: dict | Additional keyword arguments. Returns: @@ -49,41 +62,45 @@ def __init__( try: import chromadb - if chromadb.__version__ < "0.4.15": - raise ImportError("Please upgrade chromadb to version 0.4.15 or later.") + if chromadb.__version__ < "0.5.0": + raise ImportError("Please upgrade chromadb to version 0.5.0 or later.") import chromadb.utils.embedding_functions as ef from chromadb.errors import ChromaError + ChromaVectorDB.ChromaError = ChromaError # Set the class attribute except ImportError as e: raise RuntimeError( - "Missing dependecies for ChromaVectorDB. Please ensure the autogen-ext package was installed with the 'chromadb' extra." + "Missing dependencies for ChromaVectorDB. Please ensure the autogen-ext package was installed with the 'chromadb' extra." ) from e - self.client = client - self.path = path + self.client: "Client" = client self.embedding_function = ( ef.SentenceTransformerEmbeddingFunction("all-MiniLM-L6-v2") if embedding_function is None else embedding_function ) - self.metadata = metadata if metadata else {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 32} + self.metadata = metadata if metadata else {} + self.type = "chroma" if not self.client: - if self.path is not None: - self.client = chromadb.PersistentClient(path=self.path, **kwargs) + if client_type == "persistent": + if path is None: + raise ValueError("Persistent client requires a 'path' to save the database.") + self.client = chromadb.PersistentClient(path=path, **kwargs) + elif client_type == "http": + self.client = chromadb.HttpClient(host=host, port=port, **kwargs) else: - self.client = chromadb.Client(**kwargs) - self.active_collection = None - self.type = "chroma" + raise ValueError(f"Invalid client_type: {client_type}") + self.active_collection: Optional["Collection"] = None def create_collection( self, collection_name: str, overwrite: bool = False, get_or_create: bool = True - ) -> Collection: + ) -> "Collection": """ Create a collection in the vector database. Case 1. if the collection does not exist, create the collection. Case 2. the collection exists, if overwrite is True, it will overwrite the collection. Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection, - otherwise it raise a ValueError. + otherwise it raises a ValueError. Args: collection_name: str | The name of the collection. @@ -97,22 +114,22 @@ def create_collection( if self.active_collection and self.active_collection.name == collection_name: collection = self.active_collection else: - collection = self.client.get_collection(collection_name, embedding_function=self.embedding_function) - except (ValueError, ChromaError): + collection = self.client.get_collection( + name=collection_name, embedding_function=self.embedding_function + ) + except (ValueError, ChromaVectorDB.ChromaError): collection = None if collection is None: return self.client.create_collection( - collection_name, + name=collection_name, embedding_function=self.embedding_function, - get_or_create=get_or_create, metadata=self.metadata, ) elif overwrite: - self.client.delete_collection(collection_name) + self.client.delete_collection(name=collection_name) return self.client.create_collection( - collection_name, + name=collection_name, embedding_function=self.embedding_function, - get_or_create=get_or_create, metadata=self.metadata, ) elif get_or_create: @@ -120,13 +137,13 @@ def create_collection( else: raise ValueError(f"Collection {collection_name} already exists.") - def get_collection(self, collection_name: str = None) -> Collection: + def get_collection(self, collection_name: Optional[str] = None) -> "Collection": """ Get the collection from the vector database. Args: - collection_name: str | The name of the collection. Default is None. If None, return the - current active collection. + collection_name: Optional[str] | The name of the collection. Default is None. + If None, return the current active collection. Returns: Collection | The collection object. @@ -141,7 +158,7 @@ def get_collection(self, collection_name: str = None) -> Collection: else: if not (self.active_collection and self.active_collection.name == collection_name): self.active_collection = self.client.get_collection( - collection_name, embedding_function=self.embedding_function + name=collection_name, embedding_function=self.embedding_function ) return self.active_collection @@ -155,108 +172,121 @@ def delete_collection(self, collection_name: str) -> None: Returns: None """ - self.client.delete_collection(collection_name) + self.client.delete_collection(name=collection_name) if self.active_collection and self.active_collection.name == collection_name: self.active_collection = None def _batch_insert( - self, collection: Collection, embeddings=None, ids=None, metadatas=None, documents=None, upsert=False + self, + collection: "Collection", + embeddings: Optional[List[Any]] = None, + ids: Optional[List[str]] = None, + metadatas: Optional[List[Dict[str, Any]]] = None, + documents: Optional[List[str]] = None, + upsert: bool = False, ) -> None: - batch_size = int(CHROMADB_MAX_BATCH_SIZE) - for i in range(0, len(documents), min(batch_size, len(documents))): - end_idx = i + min(batch_size, len(documents) - i) + batch_size = CHROMADB_MAX_BATCH_SIZE + for i in range(0, len(ids or []), batch_size): + end_idx = i + batch_size collection_kwargs = { - "documents": documents[i:end_idx], - "ids": ids[i:end_idx], + "documents": documents[i:end_idx] if documents else None, + "ids": ids[i:end_idx] if ids else None, "metadatas": metadatas[i:end_idx] if metadatas else None, "embeddings": embeddings[i:end_idx] if embeddings else None, } if upsert: - collection.upsert(**collection_kwargs) + collection.upsert(**collection_kwargs) # type: ignore else: - collection.add(**collection_kwargs) + collection.add(**collection_kwargs) # type: ignore - def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False) -> None: + def insert_docs( + self, + docs: List[Document], + collection_name: Optional[str] = None, + upsert: bool = False, + **kwargs: Any, + ) -> None: """ Insert documents into the collection of the vector database. Args: - docs: List[Document] | A list of documents. Each document is a TypedDict `Document`. - collection_name: str | The name of the collection. Default is None. + docs: List[Document] | A list of documents. Each document is a Pydantic Document model. + collection_name: Optional[str] | The name of the collection. Default is None. upsert: bool | Whether to update the document if it exists. Default is False. - kwargs: Dict | Additional keyword arguments. + kwargs: Dict[str, Any] | Additional keyword arguments. Returns: None """ if not docs: return - if docs[0].get("content") is None: - raise ValueError("The document content is required.") - if docs[0].get("id") is None: + if docs[0].content is None and docs[0].embedding is None: + raise ValueError("Either document content or embedding is required.") + if docs[0].id is None: raise ValueError("The document id is required.") - documents = [doc.get("content") for doc in docs] - ids = [doc.get("id") for doc in docs] + documents = [doc.content for doc in docs] if docs[0].content else None + ids = [str(doc.id) for doc in docs] collection = self.get_collection(collection_name) - if docs[0].get("embedding") is None: - logger.info( - "No content embedding is provided. Will use the VectorDB's embedding function to generate the content embedding." - ) - embeddings = None - else: - embeddings = [doc.get("embedding") for doc in docs] - if docs[0].get("metadata") is None: - metadatas = None - else: - metadatas = [doc.get("metadata") for doc in docs] - self._batch_insert(collection, embeddings, ids, metadatas, documents, upsert) + embeddings = [doc.embedding for doc in docs] if docs[0].embedding else None + if not embeddings and not documents: + raise ValueError("Either documents or embeddings must be provided.") + metadatas = [doc.metadata for doc in docs] if docs[0].metadata else None + self._batch_insert( + collection, + embeddings=embeddings, + ids=ids, + metadatas=metadatas, + documents=documents, + upsert=upsert, + ) - def update_docs(self, docs: List[Document], collection_name: str = None) -> None: + def update_docs(self, docs: Sequence[Document], collection_name: Optional[str] = None, **kwargs: Any) -> None: """ Update documents in the collection of the vector database. Args: docs: List[Document] | A list of documents. - collection_name: str | The name of the collection. Default is None. + collection_name: Optional[str] | The name of the collection. Default is None. + kwargs: Dict[str, Any] | Additional keyword arguments. Returns: None """ - self.insert_docs(docs, collection_name, upsert=True) + self.insert_docs(docs, collection_name=collection_name, upsert=True, **kwargs) - def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs) -> None: + def delete_docs(self, ids: Sequence[ItemID], collection_name: Optional[str] = None, **kwargs: Any) -> None: """ Delete documents from the collection of the vector database. Args: ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`. - collection_name: str | The name of the collection. Default is None. - kwargs: Dict | Additional keyword arguments. + collection_name: Optional[str] | The name of the collection. Default is None. + kwargs: Dict[str, Any] | Additional keyword arguments. Returns: None """ collection = self.get_collection(collection_name) - collection.delete(ids, **kwargs) + collection.delete(ids=ids) def retrieve_docs( self, queries: List[str], - collection_name: str = None, + collection_name: Optional[str] = None, n_results: int = 10, distance_threshold: float = -1, - **kwargs, + **kwargs: Any, ) -> QueryResults: """ Retrieve documents from the collection of the vector database based on the queries. Args: queries: List[str] | A list of queries. Each query is a string. - collection_name: str | The name of the collection. Default is None. + collection_name: Optional[str] | The name of the collection. Default is None. n_results: int | The number of relevant documents to return. Default is 10. distance_threshold: float | The threshold for the distance score, only distance smaller than it will be returned. Don't filter with it if < 0. Default is -1. - kwargs: Dict | Additional keyword arguments. + kwargs: Dict[str, Any] | Additional keyword arguments. Returns: QueryResults | The query results. Each query result is a list of list of tuples containing the document and @@ -268,7 +298,6 @@ def retrieve_docs( results = collection.query( query_texts=queries, n_results=n_results, - **kwargs, ) results["contents"] = results.pop("documents") results = chroma_results_to_query_results(results) @@ -276,7 +305,7 @@ def retrieve_docs( return results @staticmethod - def _chroma_get_results_to_list_documents(data_dict) -> List[Document]: + def _chroma_get_results_to_list_documents(data_dict: Dict[str, Any]) -> List[Document]: """Converts a dictionary with list values to a list of Document. Args: @@ -284,51 +313,376 @@ def _chroma_get_results_to_list_documents(data_dict) -> List[Document]: Returns: List[Document] | The list of Document. + """ + results: List[Document] = [] + keys = [key for key in data_dict if data_dict[key] is not None] + + for i in range(len(data_dict[keys[0]])): + doc_dict = {} + for key in data_dict.keys(): + if data_dict[key] is not None and len(data_dict[key]) > i: + doc_dict[key[:-1]] = data_dict[key][i] + results.append(Document(**doc_dict)) # type: ignore + return results + + def get_docs_by_ids( + self, + ids: Optional[List[ItemID]] = None, + collection_name: Optional[str] = None, + include: Optional[List[str]] = None, + **kwargs: Any, + ) -> List[Document]: + """ + Retrieve documents from the collection of the vector database based on the ids. + + Args: + ids: Optional[List[ItemID]] | A list of document ids. If None, will return all the documents. Default is None. + collection_name: Optional[str] | The name of the collection. Default is None. + include: Optional[List[str]] | The fields to include. Default is None. + If None, will include ["metadatas", "documents"]. IDs are always included. + kwargs: Dict[str, Any] | Additional keyword arguments. + + Returns: + List[Document] | The results. + """ + collection = self.get_collection(collection_name) + if include is None: + include = ["metadatas", "documents"] + results = collection.get(ids=ids, include=include) + results = self._chroma_get_results_to_list_documents(results) + return results + + +class AsyncChromaVectorDB(AsyncVectorDB): + """ + An asynchronous vector database that uses ChromaDB as the backend. + + .. note:: + + This class requires the :code:`chromadb` extra for the :code:`autogen-ext` package. + """ + + ChromaError = Exception # Default to Exception if chromadb is not installed + + def __init__( + self, + *, + client: "AsyncClientAPI", + embedding_function: Optional[Callable[[List[str]], List[List[float]]]] = None, + host: str = "localhost", + port: int = 8000, + ssl: bool = False, + headers: Optional[Dict[str, str]] = None, + settings: Optional["Settings"] = None, + tenant: str = "default_tenant", + database: str = "default_database", + **kwargs: Any, + ) -> None: + """ + Initialize the async vector database. + + Args: + client: chromadb.AsyncClientAPI | The client object of the vector database. Default is None. + If provided, it will use the client object directly and ignore other arguments. + embedding_function: Callable | The embedding function used to generate the vector representation + of the documents. Default is None. Must be provided for async client. + host: str | The host of the HTTP server. Default is 'localhost'. + port: int | The port of the HTTP server. Default is 8000. + ssl: bool | Whether to use SSL to connect to the Chroma server. Defaults to False. + headers: Optional[Dict[str, str]] | A dictionary of headers to send to the Chroma server. Defaults to None. + settings: Optional[Settings] | A dictionary of settings to communicate with the chroma server. + tenant: str | The tenant to use for this client. Defaults to "default_tenant". + database: str | The database to use for this client. Defaults to "default_database". + kwargs: dict | Additional keyword arguments. + + Returns: + None + """ + try: + import chromadb + + if chromadb.__version__ < "0.5.0": + raise ImportError("Please upgrade chromadb to version 0.5.0 or later.") + from chromadb.errors import ChromaError + + AsyncChromaVectorDB.ChromaError = ChromaError # Set the class attribute + except ImportError as e: + raise RuntimeError( + "Missing dependencies for AsyncChromaVectorDB. Please ensure the autogen-ext package was installed with the 'chromadb' extra." + ) from e + + self.client: "AsyncClientAPI" = client + self.embedding_function = embedding_function + if self.embedding_function is None: + raise ValueError("An embedding function must be provided for AsyncChromaVectorDB.") + self.type = "chroma" + if not self.client: + self.client = chromadb.AsyncHttpClient( + host=host, + port=port, + ssl=ssl, + headers=headers, + settings=settings, + tenant=tenant, + database=database, + **kwargs, + ) + self.active_collection: Optional["Collection"] = None + + async def create_collection(self, collection_name: str, overwrite: bool = False, get_or_create: bool = True) -> Any: + """ + Create a collection in the vector database. + Case 1. if the collection does not exist, create the collection. + Case 2. the collection exists, if overwrite is True, it will overwrite the collection. + Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection, + otherwise it raises a ValueError. + + Args: + collection_name: str | The name of the collection. + overwrite: bool | Whether to overwrite the collection if it exists. Default is False. + get_or_create: bool | Whether to get the collection if it exists. Default is True. + + Returns: + Any | The collection object. + """ + try: + if self.active_collection and self.active_collection.name == collection_name: + collection = self.active_collection + else: + collection = await self.client.get_collection( + name=collection_name, embedding_function=self.embedding_function + ) + except (ValueError, AsyncChromaVectorDB.ChromaError): + collection = None + if collection is None: + return await self.client.create_collection( + name=collection_name, + embedding_function=self.embedding_function, + metadata={}, + ) + elif overwrite: + await self.client.delete_collection(name=collection_name) + return await self.client.create_collection( + name=collection_name, + embedding_function=self.embedding_function, + metadata={}, + ) + elif get_or_create: + return collection + else: + raise ValueError(f"Collection {collection_name} already exists.") + + async def get_collection(self, collection_name: Optional[str] = None) -> Any: + """ + Get the collection from the vector database. + + Args: + collection_name: Optional[str] | The name of the collection. Default is None. + If None, return the current active collection. + + Returns: + Any | The collection object. + """ + if collection_name is None: + if self.active_collection is None: + raise ValueError("No collection is specified.") + else: + logger.info( + f"No collection is specified. Using current active collection {self.active_collection.name}." + ) + else: + if not (self.active_collection and self.active_collection.name == collection_name): + self.active_collection = await self.client.get_collection( + name=collection_name, embedding_function=self.embedding_function + ) + return self.active_collection - Example: - data_dict = { - "key1s": [1, 2, 3], - "key2s": ["a", "b", "c"], - "key3s": None, - "key4s": ["x", "y", "z"], + async def delete_collection(self, collection_name: str) -> Any: + """ + Delete the collection from the vector database. + + Args: + collection_name: str | The name of the collection. + + Returns: + Any + """ + await self.client.delete_collection(name=collection_name) + if self.active_collection and self.active_collection.name == collection_name: + self.active_collection = None + + async def _batch_insert( + self, + collection: Any, + embeddings: Optional[List[Any]] = None, + ids: Optional[List[str]] = None, + metadatas: Optional[List[Dict[str, Any]]] = None, + documents: Optional[List[str]] = None, + upsert: bool = False, + ) -> None: + batch_size = CHROMADB_MAX_BATCH_SIZE + for i in range(0, len(ids or []), batch_size): + end_idx = i + batch_size + collection_kwargs = { + "documents": documents[i:end_idx] if documents else None, + "ids": ids[i:end_idx] if ids else None, + "metadatas": metadatas[i:end_idx] if metadatas else None, + "embeddings": embeddings[i:end_idx] if embeddings else None, } + if upsert: + await collection.upsert(**collection_kwargs) + else: + await collection.add(**collection_kwargs) - results = [ - {"key1": 1, "key2": "a", "key4": "x"}, - {"key1": 2, "key2": "b", "key4": "y"}, - {"key1": 3, "key2": "c", "key4": "z"}, - ] + async def insert_docs( + self, + docs: List[Document], + collection_name: Optional[str] = None, + upsert: bool = False, + **kwargs: Any, + ) -> None: """ + Insert documents into the collection of the vector database. + Args: + docs: List[Document] | A list of documents. Each document is a Pydantic Document model. + collection_name: Optional[str] | The name of the collection. Default is None. + upsert: bool | Whether to update the document if it exists. Default is False. + kwargs: Dict[str, Any] | Additional keyword arguments. + + Returns: + None + """ + if not docs: + return + if docs[0].content is None and docs[0].embedding is None: + raise ValueError("Either document content or embedding is required.") + if docs[0].id is None: + raise ValueError("The document id is required.") + documents = [doc.content for doc in docs] if docs[0].content else None + ids = [str(doc.id) for doc in docs] + collection = await self.get_collection(collection_name) + embeddings = [doc.embedding for doc in docs] if docs[0].embedding else None + if not embeddings and not documents: + raise ValueError("Either documents or embeddings must be provided.") + metadatas = [doc.metadata for doc in docs] if docs[0].metadata else None + await self._batch_insert( + collection, + embeddings=embeddings, + ids=ids, + metadatas=metadatas, + documents=documents, + upsert=upsert, + ) + + async def update_docs(self, docs: List[Document], collection_name: Optional[str] = None, **kwargs: Any) -> None: + """ + Update documents in the collection of the vector database. + + Args: + docs: List[Document] | A list of documents. + collection_name: Optional[str] | The name of the collection. Default is None. + kwargs: Dict[str, Any] | Additional keyword arguments. + + Returns: + None + """ + await self.insert_docs(docs, collection_name=collection_name, upsert=True, **kwargs) + + async def delete_docs(self, ids: List[ItemID], collection_name: Optional[str] = None, **kwargs: Any) -> None: + """ + Delete documents from the collection of the vector database. + + Args: + ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`. + collection_name: Optional[str] | The name of the collection. Default is None. + kwargs: Dict[str, Any] | Additional keyword arguments. + + Returns: + None + """ + collection = await self.get_collection(collection_name) + await collection.delete(ids=ids) + + async def retrieve_docs( + self, + queries: List[str], + collection_name: Optional[str] = None, + n_results: int = 10, + distance_threshold: float = -1, + **kwargs: Any, + ) -> QueryResults: + """ + Retrieve documents from the collection of the vector database based on the queries. + + Args: + queries: List[str] | A list of queries. Each query is a string. + collection_name: Optional[str] | The name of the collection. Default is None. + n_results: int | The number of relevant documents to return. Default is 10. + distance_threshold: float | The threshold for the distance score, only distance smaller than it will be + returned. Don't filter with it if < 0. Default is -1. + kwargs: Dict[str, Any] | Additional keyword arguments. + + Returns: + QueryResults | The query results. Each query result is a list of list of tuples containing the document and + the distance. + """ + collection = await self.get_collection(collection_name) + if isinstance(queries, str): + queries = [queries] + results = await collection.query( + query_texts=queries, + n_results=n_results, + ) + results["contents"] = results.pop("documents") + results = chroma_results_to_query_results(results) + results = filter_results_by_distance(results, distance_threshold) + return results + + @staticmethod + def _chroma_get_results_to_list_documents(data_dict: Dict[str, Any]) -> List[Document]: + """Converts a dictionary with list values to a list of Document. + + Args: + data_dict: A dictionary where keys map to lists or None. + + Returns: + List[Document] | The list of Document. + """ results = [] keys = [key for key in data_dict if data_dict[key] is not None] for i in range(len(data_dict[keys[0]])): - sub_dict = {} + doc_dict = {} for key in data_dict.keys(): if data_dict[key] is not None and len(data_dict[key]) > i: - sub_dict[key[:-1]] = data_dict[key][i] - results.append(sub_dict) + doc_dict[key[:-1]] = data_dict[key][i] + results.append(Document(**doc_dict)) # type: ignore return results - def get_docs_by_ids( - self, ids: List[ItemID] = None, collection_name: str = None, include=None, **kwargs + async def get_docs_by_ids( + self, + ids: Optional[List[ItemID]] = None, + collection_name: Optional[str] = None, + include: Optional[List[str]] = None, + **kwargs: Any, ) -> List[Document]: """ Retrieve documents from the collection of the vector database based on the ids. Args: - ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None. - collection_name: str | The name of the collection. Default is None. - include: List[str] | The fields to include. Default is None. - If None, will include ["metadatas", "documents"], ids will always be included. - kwargs: dict | Additional keyword arguments. + ids: Optional[List[ItemID]] | A list of document ids. If None, will return all the documents. Default is None. + collection_name: Optional[str] | The name of the collection. Default is None. + include: Optional[List[str]] | The fields to include. Default is None. + If None, will include ["metadatas", "documents"]. IDs are always included. + kwargs: Dict[str, Any] | Additional keyword arguments. Returns: List[Document] | The results. """ - collection = self.get_collection(collection_name) - include = include if include else ["metadatas", "documents"] - results = collection.get(ids, include=include, **kwargs) + collection = await self.get_collection(collection_name) + if include is None: + include = ["metadatas", "documents"] + results = await collection.get(ids=ids, include=include) results = self._chroma_get_results_to_list_documents(results) - return results \ No newline at end of file + return results diff --git a/python/packages/autogen-ext/src/autogen_ext/storage/_factory.py b/python/packages/autogen-ext/src/autogen_ext/storage/_factory.py index e69de29bb2d1..ba9f2d047494 100644 --- a/python/packages/autogen-ext/src/autogen_ext/storage/_factory.py +++ b/python/packages/autogen-ext/src/autogen_ext/storage/_factory.py @@ -0,0 +1,33 @@ +from typing import Literal + +from ._base import VectorDB + + +class VectorDBFactory: + """ + Factory class for creating vector databases. + """ + + PREDEFINED_VECTOR_DB = ["chromadb"] + + @staticmethod + def create_vector_db(db_type: Literal["chromadb"], **kwargs) -> VectorDB: + """ + Create a vector database. + + Args: + db_type: Literal["chroma", "chromadb"] | The type of the vector database. + kwargs: Dict | The keyword arguments for initializing the vector database. + + Returns: + VectorDB | The vector database. + """ + if db_type.lower() == "chromadb": + from ._chromadb import ChromaVectorDB + + return ChromaVectorDB(**kwargs) + + else: + raise ValueError( + f"Unsupported vector database type: {db_type}. Valid types are {VectorDBFactory.PREDEFINED_VECTOR_DB}." + ) diff --git a/python/packages/autogen-ext/src/autogen_ext/storage/_utils.py b/python/packages/autogen-ext/src/autogen_ext/storage/_utils.py index e96076c7d84a..fe4c08e9edf6 100644 --- a/python/packages/autogen-ext/src/autogen_ext/storage/_utils.py +++ b/python/packages/autogen-ext/src/autogen_ext/storage/_utils.py @@ -1,44 +1,6 @@ -import logging -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Tuple -from termcolor import colored - -from .base import QueryResults - - -class ColoredLogger(logging.Logger): - def __init__(self, name, level=logging.NOTSET): - super().__init__(name, level) - - def debug(self, msg, *args, color=None, **kwargs): - super().debug(colored(msg, color), *args, **kwargs) - - def info(self, msg, *args, color=None, **kwargs): - super().info(colored(msg, color), *args, **kwargs) - - def warning(self, msg, *args, color="yellow", **kwargs): - super().warning(colored(msg, color), *args, **kwargs) - - def error(self, msg, *args, color="light_red", **kwargs): - super().error(colored(msg, color), *args, **kwargs) - - def critical(self, msg, *args, color="red", **kwargs): - super().critical(colored(msg, color), *args, **kwargs) - - def fatal(self, msg, *args, color="red", **kwargs): - super().fatal(colored(msg, color), *args, **kwargs) - - -def get_logger(name: str, level: int = logging.INFO) -> ColoredLogger: - logger = ColoredLogger(name, level) - console_handler = logging.StreamHandler() - logger.addHandler(console_handler) - formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") - logger.handlers[0].setFormatter(formatter) - return logger - - -logger = get_logger(__name__) +from ._base import QueryResults def filter_results_by_distance(results: QueryResults, distance_threshold: float = -1) -> QueryResults: @@ -58,17 +20,19 @@ def filter_results_by_distance(results: QueryResults, distance_threshold: float return results -def chroma_results_to_query_results(data_dict: Dict[str, List[List[Any]]], special_key="distances") -> QueryResults: +def chroma_results_to_query_results( + data_dict: Dict[str, Optional[List[List[Any]]]], special_key: str = "distances" +) -> List[List[Tuple[Dict[str, Any], float]]]: """Converts a dictionary with list-of-list values to a list of tuples. Args: data_dict: A dictionary where keys map to lists of lists or None. - special_key: The key in the dictionary containing the special values - for each tuple. + special_key: str | The key in the dictionary containing the special values + for each tuple. Returns: - A list of tuples, where each tuple contains a sub-dictionary with - some keys from the original dictionary and the value from the + List[List[Tuple[Dict[str, Any], float]]] | A list of tuples, where each tuple contains + a sub-dictionary with some keys from the original dictionary and the value from the special_key. Example: @@ -99,22 +63,31 @@ def chroma_results_to_query_results(data_dict: Dict[str, List[List[Any]]], speci ] """ - keys = [ + if not data_dict or special_key not in data_dict or not data_dict[special_key]: + return [] + + keys: List[str] = [ key for key in data_dict - if key != special_key and data_dict[key] is not None and isinstance(data_dict[key][0], list) + if key != special_key + and data_dict[key] is not None + and isinstance(data_dict[key], list) + and len(data_dict[key]) > 0 + and isinstance(data_dict[key][0], list) ] - result = [] + result: List[List[Tuple[Dict[str, Any], float]]] = [] data_special_key = data_dict[special_key] + assert data_special_key is not None + for i in range(len(data_special_key)): - sub_result = [] + sub_result: List[Tuple[Dict[str, Any], float]] = [] for j, distance in enumerate(data_special_key[i]): - sub_dict = {} + sub_dict: Dict[str, Any] = {} for key in keys: - if len(data_dict[key]) > i: - sub_dict[key[:-1]] = data_dict[key][i][j] # remove 's' in the end from key + if len(data_dict[key]) > i and len(data_dict[key][i]) > j: + sub_dict[key[:-1]] = data_dict[key][i][j] # remove 's' at the end from key sub_result.append((sub_dict, distance)) result.append(sub_result) - return result \ No newline at end of file + return result diff --git a/python/packages/autogen-ext/tests/storage/test_chroma_db.py b/python/packages/autogen-ext/tests/storage/test_chroma_db.py new file mode 100644 index 000000000000..579f1db6e596 --- /dev/null +++ b/python/packages/autogen-ext/tests/storage/test_chroma_db.py @@ -0,0 +1,79 @@ +import pytest +from autogen_ext.storage import ChromaVectorDB +from chromadb.errors import ChromaError + + +# @pytest.mark.skipif(skip, reason="dependency is not installed") +def test_chromadb(): + # test create collection + db = ChromaVectorDB(path=".db") + collection_name = "test_collection" + collection = db.create_collection(collection_name, overwrite=True, get_or_create=True) + assert collection.name == collection_name + + # test_delete_collection + db.delete_collection(collection_name) + pytest.raises((ValueError, ChromaError), db.get_collection, collection_name) + + # test more create collection + collection = db.create_collection(collection_name, overwrite=False, get_or_create=False) + assert collection.name == collection_name + pytest.raises( + (ValueError, ChromaError), db.create_collection, collection_name, overwrite=False, get_or_create=False + ) + collection = db.create_collection(collection_name, overwrite=True, get_or_create=False) + assert collection.name == collection_name + collection = db.create_collection(collection_name, overwrite=False, get_or_create=True) + assert collection.name == collection_name + + # test_get_collection + collection = db.get_collection(collection_name) + assert collection.name == collection_name + + # test_insert_docs + docs = [{"content": "doc1", "id": "1"}, {"content": "doc2", "id": "2"}, {"content": "doc3", "id": "3"}] + db.insert_docs(docs, collection_name, upsert=False) + res = db.get_collection(collection_name).get(["1", "2"]) + assert res["documents"] == ["doc1", "doc2"] + + # test_update_docs + docs = [{"content": "doc11", "id": "1"}, {"content": "doc2", "id": "2"}, {"content": "doc3", "id": "3"}] + db.update_docs(docs, collection_name) + res = db.get_collection(collection_name).get(["1", "2"]) + assert res["documents"] == ["doc11", "doc2"] + + # test_delete_docs + ids = ["1"] + collection_name = "test_collection" + db.delete_docs(ids, collection_name) + res = db.get_collection(collection_name).get(ids) + assert res["documents"] == [] + + # test_retrieve_docs + queries = ["doc2", "doc3"] + collection_name = "test_collection" + res = db.retrieve_docs(queries, collection_name) + assert [[r[0].id for r in rr] for rr in res] == [["2", "3"], ["3", "2"]] + res = db.retrieve_docs(queries, collection_name, distance_threshold=0.1) + assert [[r[0].id for r in rr] for rr in res] == [["2"], ["3"]] + + # test_get_docs_by_ids + res = db.get_docs_by_ids(["1", "2"], collection_name) + assert [r.id for r in res] == ["2"] # "1" has been deleted + res = db.get_docs_by_ids(collection_name=collection_name) + assert [r.id for r in res] == ["2", "3"] + + # test _chroma_get_results_to_list_documents + data_dict = { + "key1s": [1, 2, 3], + "key2s": ["a", "b", "c"], + "key3s": None, + "key4s": ["x", "y", "z"], + } + + results = [ + {"key1": 1, "key2": "a", "key4": "x"}, + {"key1": 2, "key2": "b", "key4": "y"}, + {"key1": 3, "key2": "c", "key4": "z"}, + ] + assert db._chroma_get_results_to_list_documents(data_dict) == results # type: ignore diff --git a/python/uv.lock b/python/uv.lock index 0da8b9d65ad9..e9d6707ae580 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -238,7 +238,7 @@ wheels = [ [[package]] name = "anthropic" -version = "0.36.2" +version = "0.37.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -250,9 +250,9 @@ dependencies = [ { name = "tokenizers" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/2b/ee/53cadd3262cbd3f1e6feee9640bd74e59bc33a053f214631eb2fe85eac2e/anthropic-0.36.2.tar.gz", hash = "sha256:d5a3fa56d1c82a51944f9dc7b0dc72048deb89f8df5ebfd09e2d1b59c62eb8eb", size = 928435 } +sdist = { url = "https://files.pythonhosted.org/packages/bb/7c/4b4cc70a82b18ecbd69b13c4707281850bd9575b6c1fc74b06df231b17ca/anthropic-0.37.1.tar.gz", hash = "sha256:99f688265795daa7ba9256ee68eaf2f05d53cd99d7417f4a0c2dc292c106d00a", size = 931431 } wheels = [ - { url = "https://files.pythonhosted.org/packages/cf/46/e564db65755947c66b5c23fae53886db41aea28b4132a50311c92c739f5f/anthropic-0.36.2-py3-none-any.whl", hash = "sha256:308ddc6c538de03c081552e456bc0b387b6f7c7d1dea0c20122cc11c7cdbaf6a", size = 939627 }, + { url = "https://files.pythonhosted.org/packages/4e/40/bbb252b77f7a0345aa8c759bab8280d97eab5a9acf4df49fa2251f4a3a58/anthropic-0.37.1-py3-none-any.whl", hash = "sha256:8f550f88906823752e2abf99fbe491fbc8d40bce4cb26b9663abdf7be990d721", size = 945950 }, ] [[package]] @@ -515,6 +515,10 @@ docker = [ { name = "docker" }, chromadb = [ { name = "chromadb" }, + { name = "sentence-transformers" }, +] +docker = [ + { name = "docker" }, ] docker-code-executor = [ { name = "docker" }, @@ -536,12 +540,13 @@ requires-dist = [ { name = "azure-core", marker = "extra == 'azure-code-executor'" }, { name = "chromadb", marker = "extra == 'chromadb'", specifier = "~=0.4.15" }, { name = "azure-identity", marker = "extra == 'azure'" }, + { name = "chromadb", marker = "extra == 'chromadb'", specifier = "~=0.5.15" }, { name = "docker", marker = "extra == 'docker'", specifier = "~=7.0" }, - { name = "chromadb", marker = "extra == 'chromadb'", specifier = "~=0.4.15" }, { name = "docker", marker = "extra == 'docker-code-executor'", specifier = "~=7.0" }, { name = "langchain-core", marker = "extra == 'langchain'", specifier = "~=0.3.3" }, { name = "langchain-core", marker = "extra == 'langchain-tools'", specifier = "~=0.3.3" }, { name = "openai", marker = "extra == 'openai'", specifier = ">=1.3" }, + { name = "sentence-transformers", marker = "extra == 'chromadb'" }, ] [[package]] @@ -956,28 +961,32 @@ sdist = { url = "https://files.pythonhosted.org/packages/74/16/53b895bb4fccede8e [[package]] name = "chroma-hnswlib" -version = "0.7.3" +version = "0.7.6" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c0/59/1224cbae62c7b84c84088cdf6c106b9b2b893783c000d22c442a1672bc75/chroma-hnswlib-0.7.3.tar.gz", hash = "sha256:b6137bedde49fffda6af93b0297fe00429fc61e5a072b1ed9377f909ed95a932", size = 31876 } +sdist = { url = "https://files.pythonhosted.org/packages/73/09/10d57569e399ce9cbc5eee2134996581c957f63a9addfa6ca657daf006b8/chroma_hnswlib-0.7.6.tar.gz", hash = "sha256:4dce282543039681160259d29fcde6151cc9106c6461e0485f57cdccd83059b7", size = 32256 } wheels = [ - { url = "https://files.pythonhosted.org/packages/1a/36/d1069ffa520efcf93f6d81b527e3c7311e12363742fdc786cbdaea3ab02e/chroma_hnswlib-0.7.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:59d6a7c6f863c67aeb23e79a64001d537060b6995c3eca9a06e349ff7b0998ca", size = 219588 }, - { url = "https://files.pythonhosted.org/packages/c3/e8/263d331f5ce29367f6f8854cd7fa1f54fce72ab4f92ab957525ef9165a9c/chroma_hnswlib-0.7.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d71a3f4f232f537b6152947006bd32bc1629a8686df22fd97777b70f416c127a", size = 197094 }, - { url = "https://files.pythonhosted.org/packages/a9/72/a9b61ae00d490c26359a8e10f3974c0d38065b894e6a2573ec6a7597f8e3/chroma_hnswlib-0.7.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c92dc1ebe062188e53970ba13f6b07e0ae32e64c9770eb7f7ffa83f149d4210", size = 2315620 }, - { url = "https://files.pythonhosted.org/packages/2f/48/f7609a3cb15a24c5d8ec18911ce10ac94144e9a89584f0a86bf9871b024c/chroma_hnswlib-0.7.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49da700a6656fed8753f68d44b8cc8ae46efc99fc8a22a6d970dc1697f49b403", size = 2350956 }, - { url = "https://files.pythonhosted.org/packages/cc/3d/ca311b8f79744db3f4faad8fd9140af80d34c94829d3ed1726c98cf4a611/chroma_hnswlib-0.7.3-cp310-cp310-win_amd64.whl", hash = "sha256:108bc4c293d819b56476d8f7865803cb03afd6ca128a2a04d678fffc139af029", size = 150598 }, - { url = "https://files.pythonhosted.org/packages/94/3f/844393b0d2ea1072b7704d6eff5c595e05ae8b831b96340cdb76b2fe995c/chroma_hnswlib-0.7.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:11e7ca93fb8192214ac2b9c0943641ac0daf8f9d4591bb7b73be808a83835667", size = 221219 }, - { url = "https://files.pythonhosted.org/packages/11/7a/673ccb9bb2faf9cf655d9040e970c02a96645966e06837fde7d10edf242a/chroma_hnswlib-0.7.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6f552e4d23edc06cdeb553cdc757d2fe190cdeb10d43093d6a3319f8d4bf1c6b", size = 198652 }, - { url = "https://files.pythonhosted.org/packages/ba/f4/c81a40da5473d5d80fc9d0c5bd5b1cb64e530a6ea941c69f195fe81c488c/chroma_hnswlib-0.7.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f96f4d5699e486eb1fb95849fe35ab79ab0901265805be7e60f4eaa83ce263ec", size = 2332260 }, - { url = "https://files.pythonhosted.org/packages/48/0e/068b658a547d6090b969014146321e28dae1411da54b76d081e51a2af22b/chroma_hnswlib-0.7.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:368e57fe9ebae05ee5844840fa588028a023d1182b0cfdb1d13f607c9ea05756", size = 2367211 }, - { url = "https://files.pythonhosted.org/packages/d2/32/a91850c7aa8a34f61838913155103808fe90da6f1ea4302731b59e9ba6f2/chroma_hnswlib-0.7.3-cp311-cp311-win_amd64.whl", hash = "sha256:b7dca27b8896b494456db0fd705b689ac6b73af78e186eb6a42fea2de4f71c6f", size = 151574 }, + { url = "https://files.pythonhosted.org/packages/a8/74/b9dde05ea8685d2f8c4681b517e61c7887e974f6272bb24ebc8f2105875b/chroma_hnswlib-0.7.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f35192fbbeadc8c0633f0a69c3d3e9f1a4eab3a46b65458bbcbcabdd9e895c36", size = 195821 }, + { url = "https://files.pythonhosted.org/packages/fd/58/101bfa6bc41bc6cc55fbb5103c75462a7bf882e1704256eb4934df85b6a8/chroma_hnswlib-0.7.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6f007b608c96362b8f0c8b6b2ac94f67f83fcbabd857c378ae82007ec92f4d82", size = 183854 }, + { url = "https://files.pythonhosted.org/packages/17/ff/95d49bb5ce134f10d6aa08d5f3bec624eaff945f0b17d8c3fce888b9a54a/chroma_hnswlib-0.7.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:456fd88fa0d14e6b385358515aef69fc89b3c2191706fd9aee62087b62aad09c", size = 2358774 }, + { url = "https://files.pythonhosted.org/packages/3a/6d/27826180a54df80dbba8a4f338b022ba21c0c8af96fd08ff8510626dee8f/chroma_hnswlib-0.7.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5dfaae825499c2beaa3b75a12d7ec713b64226df72a5c4097203e3ed532680da", size = 2392739 }, + { url = "https://files.pythonhosted.org/packages/d6/63/ee3e8b7a8f931918755faacf783093b61f32f59042769d9db615999c3de0/chroma_hnswlib-0.7.6-cp310-cp310-win_amd64.whl", hash = "sha256:2487201982241fb1581be26524145092c95902cb09fc2646ccfbc407de3328ec", size = 150955 }, + { url = "https://files.pythonhosted.org/packages/f5/af/d15fdfed2a204c0f9467ad35084fbac894c755820b203e62f5dcba2d41f1/chroma_hnswlib-0.7.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:81181d54a2b1e4727369486a631f977ffc53c5533d26e3d366dda243fb0998ca", size = 196911 }, + { url = "https://files.pythonhosted.org/packages/0d/19/aa6f2139f1ff7ad23a690ebf2a511b2594ab359915d7979f76f3213e46c4/chroma_hnswlib-0.7.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4b4ab4e11f1083dd0a11ee4f0e0b183ca9f0f2ed63ededba1935b13ce2b3606f", size = 185000 }, + { url = "https://files.pythonhosted.org/packages/79/b1/1b269c750e985ec7d40b9bbe7d66d0a890e420525187786718e7f6b07913/chroma_hnswlib-0.7.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:53db45cd9173d95b4b0bdccb4dbff4c54a42b51420599c32267f3abbeb795170", size = 2377289 }, + { url = "https://files.pythonhosted.org/packages/c7/2d/d5663e134436e5933bc63516a20b5edc08b4c1b1588b9680908a5f1afd04/chroma_hnswlib-0.7.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c093f07a010b499c00a15bc9376036ee4800d335360570b14f7fe92badcdcf9", size = 2411755 }, + { url = "https://files.pythonhosted.org/packages/3e/79/1bce519cf186112d6d5ce2985392a89528c6e1e9332d680bf752694a4cdf/chroma_hnswlib-0.7.6-cp311-cp311-win_amd64.whl", hash = "sha256:0540b0ac96e47d0aa39e88ea4714358ae05d64bbe6bf33c52f316c664190a6a3", size = 151888 }, + { url = "https://files.pythonhosted.org/packages/93/ac/782b8d72de1c57b64fdf5cb94711540db99a92768d93d973174c62d45eb8/chroma_hnswlib-0.7.6-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e87e9b616c281bfbe748d01705817c71211613c3b063021f7ed5e47173556cb7", size = 197804 }, + { url = "https://files.pythonhosted.org/packages/32/4e/fd9ce0764228e9a98f6ff46af05e92804090b5557035968c5b4198bc7af9/chroma_hnswlib-0.7.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ec5ca25bc7b66d2ecbf14502b5729cde25f70945d22f2aaf523c2d747ea68912", size = 185421 }, + { url = "https://files.pythonhosted.org/packages/d9/3d/b59a8dedebd82545d873235ef2d06f95be244dfece7ee4a1a6044f080b18/chroma_hnswlib-0.7.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:305ae491de9d5f3c51e8bd52d84fdf2545a4a2bc7af49765cda286b7bb30b1d4", size = 2389672 }, + { url = "https://files.pythonhosted.org/packages/74/1e/80a033ea4466338824974a34f418e7b034a7748bf906f56466f5caa434b0/chroma_hnswlib-0.7.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:822ede968d25a2c88823ca078a58f92c9b5c4142e38c7c8b4c48178894a0a3c5", size = 2436986 }, ] [[package]] name = "chromadb" -version = "0.4.24" +version = "0.5.15" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "bcrypt" }, @@ -985,6 +994,7 @@ dependencies = [ { name = "chroma-hnswlib" }, { name = "fastapi" }, { name = "grpcio" }, + { name = "httpx" }, { name = "importlib-resources" }, { name = "kubernetes" }, { name = "mmh3" }, @@ -997,11 +1007,10 @@ dependencies = [ { name = "orjson" }, { name = "overrides" }, { name = "posthog" }, - { name = "pulsar-client" }, { name = "pydantic" }, { name = "pypika" }, { name = "pyyaml" }, - { name = "requests" }, + { name = "rich" }, { name = "tenacity" }, { name = "tokenizers" }, { name = "tqdm" }, @@ -1009,9 +1018,9 @@ dependencies = [ { name = "typing-extensions" }, { name = "uvicorn", extra = ["standard"] }, ] -sdist = { url = "https://files.pythonhosted.org/packages/47/6b/a5465827d8017b658d18ad1e63d2dc31109dec717c6bd068e82485186f4b/chromadb-0.4.24.tar.gz", hash = "sha256:a5c80b4e4ad9b236ed2d4899a5b9e8002b489293f2881cb2cadab5b199ee1c72", size = 13667084 } +sdist = { url = "https://files.pythonhosted.org/packages/34/ae/1ec964744b2e8d26db386617c63bd18ff6fdacba854b699b2d07cc8811f5/chromadb-0.5.15.tar.gz", hash = "sha256:9314a1904418dafbc4d7ed47d88b8c9d0cf51f5ca6e9377e668367ef3c46ee75", size = 33609544 } wheels = [ - { url = "https://files.pythonhosted.org/packages/cc/63/b7d76109331318423f9cfb89bd89c99e19f5d0b47a5105439a629224d297/chromadb-0.4.24-py3-none-any.whl", hash = "sha256:3a08e237a4ad28b5d176685bd22429a03717fe09d35022fb230d516108da01da", size = 525452 }, + { url = "https://files.pythonhosted.org/packages/89/43/7295465181c22b0e84162c6b647859f691c69bde5bd8a8f30b320ccc2e3c/chromadb-0.5.15-py3-none-any.whl", hash = "sha256:df8ccc3a36798e14d6e173261aabcdb88021d8ad7550ab2a6acbd79f5ab5ef4f", size = 607020 }, ] [[package]] @@ -1315,16 +1324,16 @@ wheels = [ [[package]] name = "fastapi" -version = "0.115.2" +version = "0.115.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pydantic" }, { name = "starlette" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/22/fa/19e3c7c9b31ac291987c82e959f36f88840bea183fa3dc3bb654669f19c1/fastapi-0.115.2.tar.gz", hash = "sha256:3995739e0b09fa12f984bce8fa9ae197b35d433750d3d312422d846e283697ee", size = 299968 } +sdist = { url = "https://files.pythonhosted.org/packages/a9/ce/b64ce344d7b13c0768dc5b131a69d52f57202eb85839408a7637ca0dd7e2/fastapi-0.115.3.tar.gz", hash = "sha256:c091c6a35599c036d676fa24bd4a6e19fa30058d93d950216cdc672881f6f7db", size = 300453 } wheels = [ - { url = "https://files.pythonhosted.org/packages/c9/14/bbe7776356ef01f830f8085ca3ac2aea59c73727b6ffaa757abeb7d2900b/fastapi-0.115.2-py3-none-any.whl", hash = "sha256:61704c71286579cc5a598763905928f24ee98bfcc07aabe84cfefb98812bbc86", size = 94650 }, + { url = "https://files.pythonhosted.org/packages/57/95/4c5b79e7ca1f7b372d16a32cad7c9cc6c3c899200bed8f45739f4415cfae/fastapi-0.115.3-py3-none-any.whl", hash = "sha256:8035e8f9a2b0aa89cea03b6c77721178ed5358e1aea4cd8570d9466895c0638c", size = 94647 }, ] [[package]] @@ -2365,7 +2374,7 @@ wheels = [ [[package]] name = "langsmith" -version = "0.1.136" +version = "0.1.137" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "httpx" }, @@ -2374,9 +2383,9 @@ dependencies = [ { name = "requests" }, { name = "requests-toolbelt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/1c/fe/7de2ef64464819bfae186831e82be1c24d009a90fda0130acd6269ecc48e/langsmith-0.1.136.tar.gz", hash = "sha256:5c0de01a313db70dd9a85845c0f416a69b5b653b3e98ba413d7d41e8851315b1", size = 287752 } +sdist = { url = "https://files.pythonhosted.org/packages/95/b0/b6c112e5080765ad31272b92f16478d2d38c54727e00cc8bbc9a66bbaa44/langsmith-0.1.137.tar.gz", hash = "sha256:56cdfcc6c74cb20a3f437d5bd144feb5bf93f54c5a2918d1e568cbd084a372d4", size = 287888 } wheels = [ - { url = "https://files.pythonhosted.org/packages/20/cc/e559a369fd3811534563e568ee7077fdc1698d86d99afe28d5167e692787/langsmith-0.1.136-py3-none-any.whl", hash = "sha256:cad2215eb7a754ee259878e19c558f4f8d3795aa1b699f087d4500e640f80d0a", size = 296724 }, + { url = "https://files.pythonhosted.org/packages/71/fd/7713b0e737f4e171112e44134790823ccec4aabe31f07d6e836fcbeb3b8a/langsmith-0.1.137-py3-none-any.whl", hash = "sha256:4256d5c61133749890f7b5c88321dbb133ce0f440c621ea28e76513285859b81", size = 296895 }, ] [[package]] @@ -2574,28 +2583,28 @@ wheels = [ [[package]] name = "llama-index-llms-openai" -version = "0.2.15" +version = "0.2.16" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "llama-index-core" }, { name = "openai" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c9/77/f1dfa05ad53d31cf0025716c909878cba7e2d035135150c2ec611030a9bd/llama_index_llms_openai-0.2.15.tar.gz", hash = "sha256:f13655535e8966f5ccf0214c7360e86ef8fc718678557ef248d7fe13f6fde8d0", size = 13429 } +sdist = { url = "https://files.pythonhosted.org/packages/ba/e7/46f16e0f3ad25f49a050f1421a20b738ec312a5003bd07d749095eedb235/llama_index_llms_openai-0.2.16.tar.gz", hash = "sha256:7c666dd27056c278a079ff45d53f1fbfc8ed363764aa7baeee2e03df47f9072a", size = 13437 } wheels = [ - { url = "https://files.pythonhosted.org/packages/06/70/952e4a291cc510186baff14655beb9fedae4541b827c69b5e001bb77269c/llama_index_llms_openai-0.2.15-py3-none-any.whl", hash = "sha256:a906669397c4c0c3ee55b241dcc22bf0129b3391a8d6ae681a2579affbc5ed48", size = 13614 }, + { url = "https://files.pythonhosted.org/packages/3b/49/bae3a019eba473a0b9bf21ad911786f86941e86dd0dac3c3e909352eaf54/llama_index_llms_openai-0.2.16-py3-none-any.whl", hash = "sha256:413466acbb894bd81f8dab2037f595e92392d869eec6d8274a16d43123cac8b6", size = 13623 }, ] [[package]] name = "llama-index-multi-modal-llms-openai" -version = "0.2.2" +version = "0.2.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "llama-index-core" }, { name = "llama-index-llms-openai" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/54/ee/c1712e870853d881a1168e592f9d915fc7e757710d4fdcdda9e9e8bd4ac3/llama_index_multi_modal_llms_openai-0.2.2.tar.gz", hash = "sha256:c7205cfd9a23e2201db527ca3f8fa5ef4fb260ab6c9b15e79163630a916ee159", size = 5178 } +sdist = { url = "https://files.pythonhosted.org/packages/03/26/298362f1c9531c637b46466847d8aad967aac3b8561c8a0dc859921f6feb/llama_index_multi_modal_llms_openai-0.2.3.tar.gz", hash = "sha256:8eb9b7f1ff3956ef0979e21bc83e6a885e40987b7199f195e46525d06e3ae402", size = 5098 } wheels = [ - { url = "https://files.pythonhosted.org/packages/12/be/971b8a51813e3613e1e9b1df57036796797df899b2a42400f5b042d7b2b6/llama_index_multi_modal_llms_openai-0.2.2-py3-none-any.whl", hash = "sha256:81813c66c133aab0554b3bee60fe9673e84403dcc57c9fa95fb8be2d7c4c4cee", size = 5869 }, + { url = "https://files.pythonhosted.org/packages/c6/e2/3e2b639880baf5fd5ca0f88abd68719d2ed7af4d5076698cb5aff612505c/llama_index_multi_modal_llms_openai-0.2.3-py3-none-any.whl", hash = "sha256:96b36beb2c3fca4faca80c59ecf7c6c6629ecdb96c288ef89777b592ec43f872", size = 5886 }, ] [[package]] @@ -2704,15 +2713,15 @@ wheels = [ [[package]] name = "llama-parse" -version = "0.5.10" +version = "0.5.11" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, { name = "llama-index-core" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/fa/f9/395badd98d2ff53ad34dc898cc723b3098c3e543317dd2e027a4bf8342ad/llama_parse-0.5.10.tar.gz", hash = "sha256:1b301fe5b4a806225cfaaf4b7899d20c340a523cd3e60cbcf1f38356cf4fffa4", size = 13507 } +sdist = { url = "https://files.pythonhosted.org/packages/ec/ad/66f8cb49f5a0a60b01035dd1ee8b9bdc7c34946888b0a794374dd32c7eeb/llama_parse-0.5.11.tar.gz", hash = "sha256:4ba5c7bc8be12f63c8edf7fdf6b0bb814ab6f91658a7e2d7c00d99f5d428dd1f", size = 13569 } wheels = [ - { url = "https://files.pythonhosted.org/packages/c1/53/c9ab9e4eb4c9116852a9e5005664c6e94adf941c693667677738743ecde1/llama_parse-0.5.10-py3-none-any.whl", hash = "sha256:cd6225553c8761e6eadb086eb174df0364a426aad4722fa42118d57ee3d5eea0", size = 12983 }, + { url = "https://files.pythonhosted.org/packages/fe/43/bd4ebdb9adb030e14e22c2a47148a348ed9c4b7f04f8a4da672a97d56ced/llama_parse-0.5.11-py3-none-any.whl", hash = "sha256:c81ea82aa7543a288352aa9cd46476b148e0b8d32b0b57df8a0c192059c21e8d", size = 13051 }, ] [[package]] @@ -3460,6 +3469,126 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/16/2e/86f24451c2d530c88daf997cb8d6ac622c1d40d19f5a031ed68a4b73a374/numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818", size = 15517754 }, ] +[[package]] +name = "nvidia-cublas-cu12" +version = "12.4.5.8" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7f/7f/7fbae15a3982dc9595e49ce0f19332423b260045d0a6afe93cdbe2f1f624/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0f8aa1706812e00b9f19dfe0cdb3999b092ccb8ca168c0db5b8ea712456fd9b3", size = 363333771 }, + { url = "https://files.pythonhosted.org/packages/ae/71/1c91302526c45ab494c23f61c7a84aa568b8c1f9d196efa5993957faf906/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b", size = 363438805 }, +] + +[[package]] +name = "nvidia-cuda-cupti-cu12" +version = "12.4.127" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/93/b5/9fb3d00386d3361b03874246190dfec7b206fd74e6e287b26a8fcb359d95/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:79279b35cf6f91da114182a5ce1864997fd52294a87a16179ce275773799458a", size = 12354556 }, + { url = "https://files.pythonhosted.org/packages/67/42/f4f60238e8194a3106d06a058d494b18e006c10bb2b915655bd9f6ea4cb1/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb", size = 13813957 }, +] + +[[package]] +name = "nvidia-cuda-nvrtc-cu12" +version = "12.4.127" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/77/aa/083b01c427e963ad0b314040565ea396f914349914c298556484f799e61b/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0eedf14185e04b76aa05b1fea04133e59f465b6f960c0cbf4e37c3cb6b0ea198", size = 24133372 }, + { url = "https://files.pythonhosted.org/packages/2c/14/91ae57cd4db3f9ef7aa99f4019cfa8d54cb4caa7e00975df6467e9725a9f/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a178759ebb095827bd30ef56598ec182b85547f1508941a3d560eb7ea1fbf338", size = 24640306 }, +] + +[[package]] +name = "nvidia-cuda-runtime-cu12" +version = "12.4.127" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a1/aa/b656d755f474e2084971e9a297def515938d56b466ab39624012070cb773/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:961fe0e2e716a2a1d967aab7caee97512f71767f852f67432d572e36cb3a11f3", size = 894177 }, + { url = "https://files.pythonhosted.org/packages/ea/27/1795d86fe88ef397885f2e580ac37628ed058a92ed2c39dc8eac3adf0619/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5", size = 883737 }, +] + +[[package]] +name = "nvidia-cudnn-cu12" +version = "9.1.0.70" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cublas-cu12" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 }, +] + +[[package]] +name = "nvidia-cufft-cu12" +version = "11.2.1.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-nvjitlink-cu12" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/7a/8a/0e728f749baca3fbeffad762738276e5df60851958be7783af121a7221e7/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5dad8008fc7f92f5ddfa2101430917ce2ffacd86824914c82e28990ad7f00399", size = 211422548 }, + { url = "https://files.pythonhosted.org/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9", size = 211459117 }, +] + +[[package]] +name = "nvidia-curand-cu12" +version = "10.3.5.147" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/80/9c/a79180e4d70995fdf030c6946991d0171555c6edf95c265c6b2bf7011112/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1f173f09e3e3c76ab084aba0de819c49e56614feae5c12f69883f4ae9bb5fad9", size = 56314811 }, + { url = "https://files.pythonhosted.org/packages/8a/6d/44ad094874c6f1b9c654f8ed939590bdc408349f137f9b98a3a23ccec411/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a88f583d4e0bb643c49743469964103aa59f7f708d862c3ddb0fc07f851e3b8b", size = 56305206 }, +] + +[[package]] +name = "nvidia-cusolver-cu12" +version = "11.6.1.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cublas-cu12" }, + { name = "nvidia-cusparse-cu12" }, + { name = "nvidia-nvjitlink-cu12" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/46/6b/a5c33cf16af09166845345275c34ad2190944bcc6026797a39f8e0a282e0/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d338f155f174f90724bbde3758b7ac375a70ce8e706d70b018dd3375545fc84e", size = 127634111 }, + { url = "https://files.pythonhosted.org/packages/3a/e1/5b9089a4b2a4790dfdea8b3a006052cfecff58139d5a4e34cb1a51df8d6f/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260", size = 127936057 }, +] + +[[package]] +name = "nvidia-cusparse-cu12" +version = "12.3.1.170" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-nvjitlink-cu12" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/96/a9/c0d2f83a53d40a4a41be14cea6a0bf9e668ffcf8b004bd65633f433050c0/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9d32f62896231ebe0480efd8a7f702e143c98cfaa0e8a76df3386c1ba2b54df3", size = 207381987 }, + { url = "https://files.pythonhosted.org/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1", size = 207454763 }, +] + +[[package]] +name = "nvidia-nccl-cu12" +version = "2.21.5" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/df/99/12cd266d6233f47d00daf3a72739872bdc10267d0383508b0b9c84a18bb6/nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0", size = 188654414 }, +] + +[[package]] +name = "nvidia-nvjitlink-cu12" +version = "12.4.127" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/02/45/239d52c05074898a80a900f49b1615d81c07fceadd5ad6c4f86a987c0bc4/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83", size = 20552510 }, + { url = "https://files.pythonhosted.org/packages/ff/ff/847841bacfbefc97a00036e0fce5a0f086b640756dc38caea5e1bb002655/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57", size = 21066810 }, +] + +[[package]] +name = "nvidia-nvtx-cu12" +version = "12.4.127" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/06/39/471f581edbb7804b39e8063d92fc8305bdc7a80ae5c07dbe6ea5c50d14a5/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7959ad635db13edf4fc65c06a6e9f9e55fc2f92596db928d169c0bb031e88ef3", size = 100417 }, + { url = "https://files.pythonhosted.org/packages/87/20/199b8713428322a2f22b722c62b8cc278cc53dffa9705d744484b5035ee9/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a", size = 99144 }, +] + [[package]] name = "oauthlib" version = "3.2.2" @@ -3501,7 +3630,7 @@ wheels = [ [[package]] name = "openai" -version = "1.52.0" +version = "1.52.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -3513,9 +3642,9 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/12/00/0983e56ca4535394a34f3ce25429ce6710878f2f8d7931973d04364ca922/openai-1.52.0.tar.gz", hash = "sha256:95c65a5f77559641ab8f3e4c3a050804f7b51d278870e2ec1f7444080bfe565a", size = 309426 } +sdist = { url = "https://files.pythonhosted.org/packages/80/ac/54c76352d493866637756b7c0ecec44f0b5bafb8fe753d98472cf6cfe4ce/openai-1.52.1.tar.gz", hash = "sha256:383b96c7e937cbec23cad5bf5718085381e4313ca33c5c5896b54f8e1b19d144", size = 310069 } wheels = [ - { url = "https://files.pythonhosted.org/packages/39/1e/9dc3ccee95d0e16e54e353d3c355bb7cc506d56a2dbb0a07bc739cc48eac/openai-1.52.0-py3-none-any.whl", hash = "sha256:0c249f20920183b0a2ca4f7dba7b0452df3ecd0fa7985eb1d91ad884bc3ced9c", size = 386947 }, + { url = "https://files.pythonhosted.org/packages/ad/31/28a83e124e9f9dd04c83b5aeb6f8b1770f45addde4dd3d34d9a9091590ad/openai-1.52.1-py3-none-any.whl", hash = "sha256:f23e83df5ba04ee0e82c8562571e8cb596cd88f9a84ab783e6c6259e5ffbfb4a", size = 386945 }, ] [[package]] @@ -3669,46 +3798,47 @@ wheels = [ [[package]] name = "orjson" -version = "3.10.9" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ae/9f/645b533f73bd49adaa4cab46921276c8ad8b3aff44959f2e717ac7533e92/orjson-3.10.9.tar.gz", hash = "sha256:c378074e0c46035dc66e57006993233ec66bf8487d501bab41649b4b7289ed4d", size = 5399823 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/61/da/b8d2ba56a677cb4d3da96c60978d460edbfc6335d1be5a0aafac3d5340da/orjson-3.10.9-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:a377186a11b48c55969e34f0aa414c2826a234f212d6f2b312ba512e3cdb2c6f", size = 270561 }, - { url = "https://files.pythonhosted.org/packages/e1/60/06bc5723e995b844614b419af2ba63c2dab8cf313095ac5ffa2f9416f58f/orjson-3.10.9-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0bf37bf0ca538065c34efe1803378b2dadd7e05b06610a086c2857f15ee59e12", size = 153319 }, - { url = "https://files.pythonhosted.org/packages/5d/a2/ee32afde6f6574f0175e3a3e0d1aaf1fc0039018bc03bbd4a7c1d3d35115/orjson-3.10.9-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7d9d83a91168aa48309acba804e393b7d9216b66f15e38f339b9fbb00db8986d", size = 168580 }, - { url = "https://files.pythonhosted.org/packages/dc/1f/b21d8f6d57bc9e345ad199db7ba8de6a243c367661c8ee89021560747209/orjson-3.10.9-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e0014038a17a1fe273da0a5489787677ef5a64566ab383ad6d929e44ed5683f4", size = 155823 }, - { url = "https://files.pythonhosted.org/packages/02/3c/963c6cdf7cdd531c2a25cd694d7d70150348cb7707c457e3784303654fd1/orjson-3.10.9-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d6ae1b1733e4528e45675ed09a732b6ac37d716bce2facaf467f84ce774adecd", size = 166398 }, - { url = "https://files.pythonhosted.org/packages/fd/4a/eaaa6f41d80ff41d5187b4bd9d035485479091f603c483bb3a1dcadb1c11/orjson-3.10.9-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe91c2259c4a859356b6db1c6e649b40577492f66d483da8b8af6da0f87c00e3", size = 144509 }, - { url = "https://files.pythonhosted.org/packages/09/4d/6509ba86a73b323bfa7291d69484f10d4ab821c4842f8832da8663129514/orjson-3.10.9-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a04f912c32463386ba117591c99a3d9e40b3b69bed9c5123d89dff06f0f5a4b0", size = 172192 }, - { url = "https://files.pythonhosted.org/packages/1c/98/d79008d516878528502f6a8bb15b54a306b75a385ffe021456c133d184f4/orjson-3.10.9-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ae82ca347829ca47431767b079f96bb977f592189250ccdede676339a80c8982", size = 170176 }, - { url = "https://files.pythonhosted.org/packages/ab/07/821b5e9647885e429e7950eab6e4425068d2e4e0537aae076682a748c117/orjson-3.10.9-cp310-none-win32.whl", hash = "sha256:fd5083906825d7f5d23089425ce5424d783d6294020bcabb8518a3e1f97833e5", size = 145103 }, - { url = "https://files.pythonhosted.org/packages/7e/4c/fbe27448ae7c6c9f40602ed12068486a26ed43d03d426bc01af6cc5c83e3/orjson-3.10.9-cp310-none-win_amd64.whl", hash = "sha256:e9ff9521b5be0340c8e686bcfe2619777fd7583f71e7b494601cc91ad3919d2e", size = 139386 }, - { url = "https://files.pythonhosted.org/packages/0c/5b/b19a919c4402c11afb584d5ce2764218e8790378f4ddb5d3c7a2631fb331/orjson-3.10.9-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:f3bd9df47385b8fabb3b2ee1e83f9960b8accc1905be971a1c257f16c32b491e", size = 270558 }, - { url = "https://files.pythonhosted.org/packages/79/36/82461a50f02fd2f1e4da1d5e4ca2f0faa4960e0088390f79a15e152b218d/orjson-3.10.9-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a4948961b6bce1e2086b2cf0b56cc454cdab589d40c7f85be71fb5a5556c51d3", size = 153320 }, - { url = "https://files.pythonhosted.org/packages/1c/ac/a449f890218aaef5463e82c2370a5d6054254291fa69934c2e38ace3d738/orjson-3.10.9-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0a9fc7a6cf2b229ddc323e136df13b3fb4466c50d84ed600cd0898223dd2fea3", size = 168580 }, - { url = "https://files.pythonhosted.org/packages/1e/8e/4677be5d2ef754baa5988798f2a7c052bb02eb4ebdfc5d596c49b49fd3ce/orjson-3.10.9-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2314846e1029a2d2b899140f350eaaf3a73281df43ba84ac44d94ca861b5b269", size = 155824 }, - { url = "https://files.pythonhosted.org/packages/fb/51/c0d12b4708fd61cfa253ab48fade4c1c1efc06d3a2115b5d8ca7d69f04e9/orjson-3.10.9-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f52d993504827503411df2d60e60acf52885561458d6273f99ecd172f31c4352", size = 166397 }, - { url = "https://files.pythonhosted.org/packages/4c/f4/2490eee8b31756e45c10c8b774c7ab3a670f71ae7be5675dba2c275911d1/orjson-3.10.9-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e29bbf08d907756c145a3a3a1f7ce2f11f15e3edbd3342842589d6030981b76f", size = 144508 }, - { url = "https://files.pythonhosted.org/packages/d4/d5/f86ed324c3fb98874d069128a4677320ccede2268b81ce76c99b43f378d7/orjson-3.10.9-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:7ae82992c00b480c3cc7dac6739324554be8c5d8e858a90044928506a3333ef4", size = 172195 }, - { url = "https://files.pythonhosted.org/packages/4c/99/6193781ab43ad4a9c030031dc053697fe33bc0a9426023ccc992d2cf5b03/orjson-3.10.9-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:6fdf8d32b6d94019dc15163542d345e9ce4c4661f56b318608aa3088a1a3a23b", size = 170182 }, - { url = "https://files.pythonhosted.org/packages/fa/5f/e9d892834f716bab10730d55ceb3c19e4401d832a0fda9ffec8c702cd4f8/orjson-3.10.9-cp311-none-win32.whl", hash = "sha256:01f5fef452b4d7615f2e94153479370a4b59e0c964efb32dd902978f807a45cd", size = 145106 }, - { url = "https://files.pythonhosted.org/packages/46/b7/47b7eca49540e2f0674739a41f9b851f4fb5a342448722cbd27f7aeb5406/orjson-3.10.9-cp311-none-win_amd64.whl", hash = "sha256:95361c4197c7ce9afdf56255de6f4e2474c39d16a277cce31d1b99a2520486d8", size = 139383 }, - { url = "https://files.pythonhosted.org/packages/f2/5f/0db41961aa7215d12568bc1fe1f0001723c9717056a74a55ce5a0d55f40b/orjson-3.10.9-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:43ad5560db54331c007dc38be5ba7706cb72974a29ae8227019d89305d750a6f", size = 270645 }, - { url = "https://files.pythonhosted.org/packages/ef/ba/dd11ae9bdc9930c616500d2b0fc49777c515e369046bb539e547d7b8b239/orjson-3.10.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1471c3274b1a4a9b8f4b9ed6effaea9ad885796373797515c44b365b375c256d", size = 153254 }, - { url = "https://files.pythonhosted.org/packages/b5/f5/1132a62fcc7c108d1e7b0977f8768ff5f01615d646764c3bffec5f393de1/orjson-3.10.9-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:41d8cac575acd15918903d74cfaabb5dbe57b357b93341332f647d1013928dcc", size = 168517 }, - { url = "https://files.pythonhosted.org/packages/ab/4c/29829607902192ad7f8a2195c373e7cb0896f5fbdef492460e4c57324dab/orjson-3.10.9-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2920c8754f1aedc98bd357ec172af18ce48f5f1017a92244c85fe41d16d3c6e0", size = 156063 }, - { url = "https://files.pythonhosted.org/packages/a9/79/fa7d2c65e10e9a3bf0314ce623ffcf42b2333f823a2c67f5a17664e8eed9/orjson-3.10.9-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c7fa3ff6a0d9d15a0d0d2254cca16cd919156a18423654ce5574591392fe9914", size = 166577 }, - { url = "https://files.pythonhosted.org/packages/bb/f0/1d89c199aca0b00a35a5bd55f892093f25acc8c5a0334096d77a91c4d6a2/orjson-3.10.9-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e1e91b90c0c26bd79593967c1adef421bcff88c9e723d49c93bb7ad8af80bc6b", size = 144833 }, - { url = "https://files.pythonhosted.org/packages/40/49/29a0e095a1b417c4a815c21a7f1768324e9d34eba8461b6633e222ef4c32/orjson-3.10.9-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f11949024f785ace1a516db32fa6255f6227226b2c988abf66f5aee61d43d8f7", size = 172077 }, - { url = "https://files.pythonhosted.org/packages/5e/75/1ba1c28c329c90b38aefc1a91b411bb7d4288c30d22dc268306926c89287/orjson-3.10.9-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:060e020d85d0ec145bc1b536b1fd9c10a0519c91991ead9724d6f759ebe26b9a", size = 170476 }, - { url = "https://files.pythonhosted.org/packages/31/ca/29a9c943f18d11e81cd3997140aeae0a6919fc03037705b584e0337b3131/orjson-3.10.9-cp312-none-win32.whl", hash = "sha256:71f73439999fe662843da3607cdf6e75b1551c330f487e5801d463d969091c63", size = 145133 }, - { url = "https://files.pythonhosted.org/packages/f2/0e/5d1f9aa51361f592a5a2e13ac4405eb9c7068a6f3d79b369cca97fbf6536/orjson-3.10.9-cp312-none-win_amd64.whl", hash = "sha256:12e2efe81356b8448f1cd130f8d75d3718de583112d71f2e2f8baa81bd835bb9", size = 139465 }, - { url = "https://files.pythonhosted.org/packages/73/c1/9a181520a9967cb5c2dc40ca9a067e7ea49154d886a7aadbd1088efc0c0c/orjson-3.10.9-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:0ab6e3ad10e964392f0e838751bcce2ef9c8fa8be7deddffff83088e5791566d", size = 270592 }, - { url = "https://files.pythonhosted.org/packages/c0/c5/575448b3e22062880cc9d1490d329b75897407b6bab5a1e50bcc5848925d/orjson-3.10.9-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:68ef65223baab00f469c8698f771ab3e6ccf6af2a987e77de5b566b4ec651150", size = 144805 }, - { url = "https://files.pythonhosted.org/packages/98/77/7c503ac20e044e9bd3ee8c9a8e6cc5cd2900138a5f52f873e951f819d211/orjson-3.10.9-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6f130848205fea90a2cb9fa2b11cafff9a9f31f4efad225800bc8b9e4a702f24", size = 171992 }, - { url = "https://files.pythonhosted.org/packages/fc/76/3bd20cae3a2a16c83137446dda1dcd4b16ef74626f625777844e09e3c7a0/orjson-3.10.9-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:2ea7a98f3295ed8adb6730a5788cc78dafea28300d19932a1d2143457f7db802", size = 170460 }, - { url = "https://files.pythonhosted.org/packages/6a/62/6a8f2a64be762f40acac779be37a3edef7ea0bd4809d314e8b4d7ccc4b35/orjson-3.10.9-cp313-none-win32.whl", hash = "sha256:bdce39f96149a74fddeb2674c54f1da5e57724d32952eb6df2ac719b66d453cc", size = 145089 }, - { url = "https://files.pythonhosted.org/packages/1a/55/3b41acf8c5f96e8fbde86fb56c7d80fa359de0020937618f558d81785ae9/orjson-3.10.9-cp313-none-win_amd64.whl", hash = "sha256:d11383701d4b58e795039b662ada46987744293d57bfa2719e7379b8d67bc796", size = 139223 }, +version = "3.10.10" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/80/44/d36e86b33fc84f224b5f2cdf525adf3b8f9f475753e721c402b1ddef731e/orjson-3.10.10.tar.gz", hash = "sha256:37949383c4df7b4337ce82ee35b6d7471e55195efa7dcb45ab8226ceadb0fe3b", size = 5404170 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/98/c7/07ca73c32d49550490572235e5000aa0d75e333997cbb3a221890ef0fa04/orjson-3.10.10-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:b788a579b113acf1c57e0a68e558be71d5d09aa67f62ca1f68e01117e550a998", size = 270718 }, + { url = "https://files.pythonhosted.org/packages/4d/6e/eaefdfe4b11fd64b38f6663c71a3c9063054c8c643a52555c5b6d4350446/orjson-3.10.10-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:804b18e2b88022c8905bb79bd2cbe59c0cd014b9328f43da8d3b28441995cda4", size = 153292 }, + { url = "https://files.pythonhosted.org/packages/cf/87/94474cbf63306f196a0a85a2f3ea6cea261328b4141a260b7ec5e7145bc5/orjson-3.10.10-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9972572a1d042ec9ee421b6da69f7cc823da5962237563fa548ab17f152f0b9b", size = 168625 }, + { url = "https://files.pythonhosted.org/packages/0a/67/1a6bd763282bc89fcc0269e3a44a8ecbb236a1e4b6f5a6320301726b36a1/orjson-3.10.10-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dc6993ab1c2ae7dd0711161e303f1db69062955ac2668181bfdf2dd410e65258", size = 155845 }, + { url = "https://files.pythonhosted.org/packages/ae/28/bb2dd7a988159896be9fa59ef4c991dca8cca9af85ebdc27164234929008/orjson-3.10.10-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d78e4cacced5781b01d9bc0f0cd8b70b906a0e109825cb41c1b03f9c41e4ce86", size = 166406 }, + { url = "https://files.pythonhosted.org/packages/e3/88/42199849c791b4b5b92fcace0e8ef178d5ae1ea9865dfd4d5809e650d652/orjson-3.10.10-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e6eb2598df518281ba0cbc30d24c5b06124ccf7e19169e883c14e0831217a0bc", size = 144518 }, + { url = "https://files.pythonhosted.org/packages/c7/77/e684fe4ed34e73149bc0e7320b91a459386693279cd62efab6e82da072a3/orjson-3.10.10-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:23776265c5215ec532de6238a52707048401a568f0fa0d938008e92a147fe2c7", size = 172184 }, + { url = "https://files.pythonhosted.org/packages/fa/b2/9dc2ed13121b27b9f99acba077c821ad2c0deff9feeb617efef4699fad35/orjson-3.10.10-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8cc2a654c08755cef90b468ff17c102e2def0edd62898b2486767204a7f5cc9c", size = 170148 }, + { url = "https://files.pythonhosted.org/packages/86/0a/b06967f9374856f491f297a914c588eae97ef9490a77ec0e146a2e4bfe7f/orjson-3.10.10-cp310-none-win32.whl", hash = "sha256:081b3fc6a86d72efeb67c13d0ea7c030017bd95f9868b1e329a376edc456153b", size = 145116 }, + { url = "https://files.pythonhosted.org/packages/1f/c7/1aecf5e320828261ece0683e472ee77c520f4e6c52c468486862e2257962/orjson-3.10.10-cp310-none-win_amd64.whl", hash = "sha256:ff38c5fb749347768a603be1fb8a31856458af839f31f064c5aa74aca5be9efe", size = 139307 }, + { url = "https://files.pythonhosted.org/packages/79/bc/2a0eb0029729f1e466d5a595261446e5c5b6ed9213759ee56b6202f99417/orjson-3.10.10-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:879e99486c0fbb256266c7c6a67ff84f46035e4f8749ac6317cc83dacd7f993a", size = 270717 }, + { url = "https://files.pythonhosted.org/packages/3d/2b/5af226f183ce264bf64f15afe58647b09263dc1bde06aaadae6bbeca17f1/orjson-3.10.10-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:019481fa9ea5ff13b5d5d95e6fd5ab25ded0810c80b150c2c7b1cc8660b662a7", size = 153294 }, + { url = "https://files.pythonhosted.org/packages/1d/95/d6a68ab51ed76e3794669dabb51bf7fa6ec2f4745f66e4af4518aeab4b73/orjson-3.10.10-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0dd57eff09894938b4c86d4b871a479260f9e156fa7f12f8cad4b39ea8028bb5", size = 168628 }, + { url = "https://files.pythonhosted.org/packages/c0/c9/1bbe5262f5e9df3e1aeec44ca8cc86846c7afb2746fa76bf668a7d0979e9/orjson-3.10.10-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dbde6d70cd95ab4d11ea8ac5e738e30764e510fc54d777336eec09bb93b8576c", size = 155845 }, + { url = "https://files.pythonhosted.org/packages/bf/22/e17b14ff74646e6c080dccb2859686a820bc6468f6b62ea3fe29a8bd3b05/orjson-3.10.10-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3b2625cb37b8fb42e2147404e5ff7ef08712099197a9cd38895006d7053e69d6", size = 166406 }, + { url = "https://files.pythonhosted.org/packages/8a/1e/b3abbe352f648f96a418acd1e602b1c77ffcc60cf801a57033da990b2c49/orjson-3.10.10-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dbf3c20c6a7db69df58672a0d5815647ecf78c8e62a4d9bd284e8621c1fe5ccb", size = 144518 }, + { url = "https://files.pythonhosted.org/packages/0e/5e/28f521ee0950d279489db1522e7a2460d0596df7c5ca452e242ff1509cfe/orjson-3.10.10-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:75c38f5647e02d423807d252ce4528bf6a95bd776af999cb1fb48867ed01d1f6", size = 172187 }, + { url = "https://files.pythonhosted.org/packages/04/b4/538bf6f42eb0fd5a485abbe61e488d401a23fd6d6a758daefcf7811b6807/orjson-3.10.10-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:23458d31fa50ec18e0ec4b0b4343730928296b11111df5f547c75913714116b2", size = 170152 }, + { url = "https://files.pythonhosted.org/packages/94/5c/a1a326a58452f9261972ad326ae3bb46d7945681239b7062a1b85d8811e2/orjson-3.10.10-cp311-none-win32.whl", hash = "sha256:2787cd9dedc591c989f3facd7e3e86508eafdc9536a26ec277699c0aa63c685b", size = 145116 }, + { url = "https://files.pythonhosted.org/packages/df/12/a02965df75f5a247091306d6cf40a77d20bf6c0490d0a5cb8719551ee815/orjson-3.10.10-cp311-none-win_amd64.whl", hash = "sha256:6514449d2c202a75183f807bc755167713297c69f1db57a89a1ef4a0170ee269", size = 139307 }, + { url = "https://files.pythonhosted.org/packages/21/c6/f1d2ec3ffe9d6a23a62af0477cd11dd2926762e0186a1fad8658a4f48117/orjson-3.10.10-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:8564f48f3620861f5ef1e080ce7cd122ee89d7d6dacf25fcae675ff63b4d6e05", size = 270801 }, + { url = "https://files.pythonhosted.org/packages/52/01/eba0226efaa4d4be8e44d9685750428503a3803648878fa5607100a74f81/orjson-3.10.10-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5bf161a32b479034098c5b81f2608f09167ad2fa1c06abd4e527ea6bf4837a9", size = 153221 }, + { url = "https://files.pythonhosted.org/packages/da/4b/a705f9d3ae4786955ee0ac840b20960add357e612f1b0a54883d1811fe1a/orjson-3.10.10-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:68b65c93617bcafa7f04b74ae8bc2cc214bd5cb45168a953256ff83015c6747d", size = 168590 }, + { url = "https://files.pythonhosted.org/packages/de/6c/eb405252e7d9ae9905a12bad582cfe37ef8ef18fdfee941549cb5834c7b2/orjson-3.10.10-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e8e28406f97fc2ea0c6150f4c1b6e8261453318930b334abc419214c82314f85", size = 156052 }, + { url = "https://files.pythonhosted.org/packages/9f/e7/65a0461574078a38f204575153524876350f0865162faa6e6e300ecaa199/orjson-3.10.10-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e4d0d9fe174cc7a5bdce2e6c378bcdb4c49b2bf522a8f996aa586020e1b96cee", size = 166562 }, + { url = "https://files.pythonhosted.org/packages/dd/99/85780be173e7014428859ba0211e6f2a8f8038ea6ebabe344b42d5daa277/orjson-3.10.10-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3be81c42f1242cbed03cbb3973501fcaa2675a0af638f8be494eaf37143d999", size = 144892 }, + { url = "https://files.pythonhosted.org/packages/ed/c0/c7c42a2daeb262da417f70064746b700786ee0811b9a5821d9d37543b29d/orjson-3.10.10-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:65f9886d3bae65be026219c0a5f32dbbe91a9e6272f56d092ab22561ad0ea33b", size = 172093 }, + { url = "https://files.pythonhosted.org/packages/ad/9b/be8b3d3aec42aa47f6058482ace0d2ca3023477a46643d766e96281d5d31/orjson-3.10.10-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:730ed5350147db7beb23ddaf072f490329e90a1d059711d364b49fe352ec987b", size = 170424 }, + { url = "https://files.pythonhosted.org/packages/1b/15/a4cc61e23c39b9dec4620cb95817c83c84078be1771d602f6d03f0e5c696/orjson-3.10.10-cp312-none-win32.whl", hash = "sha256:a8f4bf5f1c85bea2170800020d53a8877812892697f9c2de73d576c9307a8a5f", size = 145132 }, + { url = "https://files.pythonhosted.org/packages/9f/8a/ce7c28e4ea337f6d95261345d7c61322f8561c52f57b263a3ad7025984f4/orjson-3.10.10-cp312-none-win_amd64.whl", hash = "sha256:384cd13579a1b4cd689d218e329f459eb9ddc504fa48c5a83ef4889db7fd7a4f", size = 139389 }, + { url = "https://files.pythonhosted.org/packages/0c/69/f1c4382cd44bdaf10006c4e82cb85d2bcae735369f84031e203c4e5d87de/orjson-3.10.10-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:44bffae68c291f94ff5a9b4149fe9d1bdd4cd0ff0fb575bcea8351d48db629a1", size = 270695 }, + { url = "https://files.pythonhosted.org/packages/61/29/aeb5153271d4953872b06ed239eb54993a5f344353727c42d3aabb2046f6/orjson-3.10.10-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e27b4c6437315df3024f0835887127dac2a0a3ff643500ec27088d2588fa5ae1", size = 141632 }, + { url = "https://files.pythonhosted.org/packages/bc/a2/c8ac38d8fb461a9b717c766fbe1f7d3acf9bde2f12488eb13194960782e4/orjson-3.10.10-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bca84df16d6b49325a4084fd8b2fe2229cb415e15c46c529f868c3387bb1339d", size = 144854 }, + { url = "https://files.pythonhosted.org/packages/79/51/e7698fdb28bdec633888cc667edc29fd5376fce9ade0a5b3e22f5ebe0343/orjson-3.10.10-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:c14ce70e8f39bd71f9f80423801b5d10bf93d1dceffdecd04df0f64d2c69bc01", size = 172023 }, + { url = "https://files.pythonhosted.org/packages/02/2d/0d99c20878658c7e33b90e6a4bb75cf2924d6ff29c2365262cff3c26589a/orjson-3.10.10-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:24ac62336da9bda1bd93c0491eff0613003b48d3cb5d01470842e7b52a40d5b4", size = 170429 }, + { url = "https://files.pythonhosted.org/packages/cd/45/6a4a446f4fb29bb4703c3537d5c6a2bf7fed768cb4d7b7dce9d71b72fc93/orjson-3.10.10-cp313-none-win32.whl", hash = "sha256:eb0a42831372ec2b05acc9ee45af77bcaccbd91257345f93780a8e654efc75db", size = 145099 }, + { url = "https://files.pythonhosted.org/packages/72/6e/4631fe219a4203aa111e9bb763ad2e2e0cdd1a03805029e4da124d96863f/orjson-3.10.10-cp313-none-win_amd64.whl", hash = "sha256:f0c4f37f8bf3f1075c6cc8dd8a9f843689a4b618628f8812d0a71e6968b95ffd", size = 139176 }, ] [[package]] @@ -4167,34 +4297,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/22/a6/858897256d0deac81a172289110f31629fc4cee19b6f01283303e18c8db3/ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35", size = 13993 }, ] -[[package]] -name = "pulsar-client" -version = "3.5.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "certifi" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/e0/aa/eb3b04be87b961324e49748f3a715a12127d45d76258150bfa61b2a002d8/pulsar_client-3.5.0-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:c18552edb2f785de85280fe624bc507467152bff810fc81d7660fa2dfa861f38", size = 10953552 }, - { url = "https://files.pythonhosted.org/packages/cc/20/d59bf89ccdda45edd89f5b54bd1e93605ebe5ad3cb73f4f4f5e8eca8f9e6/pulsar_client-3.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18d438e456c146f01be41ef146f649dedc8f7bc714d9eaef94cff2e34099812b", size = 5190714 }, - { url = "https://files.pythonhosted.org/packages/1a/02/ca7e96b97d564d0375b8e3de65f95ac86c8502c40f6ff750e9d145709d9a/pulsar_client-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18a26a0719841103c7a89eb1492c4a8fedf89adaa386375baecbb4fa2707e88f", size = 5429820 }, - { url = "https://files.pythonhosted.org/packages/47/f3/682670cdc951b830cd3d8d1287521997327254e59508772664aaa656e246/pulsar_client-3.5.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:ab0e1605dc5f44a126163fd06cd0a768494ad05123f6e0de89a2c71d6e2d2319", size = 5710427 }, - { url = "https://files.pythonhosted.org/packages/bc/00/119cd039286dfc1c91a5580963e9ba79204cd4717b16b7a6fdc57d1c1673/pulsar_client-3.5.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:cdef720891b97656fdce3bf5913ea7729b2156b84ba64314f432c1e72c6117fa", size = 5916490 }, - { url = "https://files.pythonhosted.org/packages/0a/cc/d606b483dbb263cbaf7fc7c3d2ec4032628cf3324266cf9a4ccdb2a73076/pulsar_client-3.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:a42544e38773191fe550644a90e8050579476bb2dcf17ac69a4aed62a6cb70e7", size = 3305387 }, - { url = "https://files.pythonhosted.org/packages/0d/2e/aec6886a6d67f09230476182399b7fad694fbcbbaf004ce914725d4eddd9/pulsar_client-3.5.0-cp311-cp311-macosx_10_15_universal2.whl", hash = "sha256:fd94432ea5d398ea78f8f2e09a217ec5058d26330c137a22690478c031e116da", size = 10954116 }, - { url = "https://files.pythonhosted.org/packages/43/06/b98df9300f60e5fad3396f843dd633c31176a495a2d60ba111c99511658a/pulsar_client-3.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d6252ae462e07ece4071213fdd9c76eab82ca522a749f2dc678037d4cbacd40b", size = 5189618 }, - { url = "https://files.pythonhosted.org/packages/72/05/c9aef7da7802a03c0b65ffe8f00a24289ff992f99ed5d5d1fd0ed63d9cf6/pulsar_client-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:03b4d440b2d74323784328b082872ee2f206c440b5d224d7941eb3c083ec06c6", size = 5429329 }, - { url = "https://files.pythonhosted.org/packages/06/96/9acfe6f1d827cdd53b8460b04c63b4081333ef64a49a2f425419f1eb6b6b/pulsar_client-3.5.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f60af840b8d64a2fac5a0c1ce6ae0ddffec5f42267c6ded2c5e74bad8345f2a1", size = 5710106 }, - { url = "https://files.pythonhosted.org/packages/e1/7b/877a06eff5c9ac828cdb75e378ee29b0adac9328da9ee173eaf7076d8c56/pulsar_client-3.5.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:2277a447c3b7f6571cb1eb9fc5c25da3fdd43d0b2fb91cf52054adfadc7d6842", size = 5916541 }, - { url = "https://files.pythonhosted.org/packages/fb/62/ed1da1ef72c95ba6a830e43995550ed0a1d26c223fb4b036ac6cd028c2ed/pulsar_client-3.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:f20f3e9dd50db2a37059abccad42078b7a4754b8bc1d3ae6502e71c1ad2209f0", size = 3305485 }, - { url = "https://files.pythonhosted.org/packages/81/19/4b145766df706aa5e09f60bbf5f87b934e6ac950fddd18f4acd520c465b9/pulsar_client-3.5.0-cp312-cp312-macosx_10_15_universal2.whl", hash = "sha256:d61f663d85308e12f44033ba95af88730f581a7e8da44f7a5c080a3aaea4878d", size = 10967548 }, - { url = "https://files.pythonhosted.org/packages/bf/bd/9bc05ee861b46884554a4c61f96edb9602de131dd07982c27920e554ab5b/pulsar_client-3.5.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2a1ba0be25b6f747bcb28102b7d906ec1de48dc9f1a2d9eacdcc6f44ab2c9e17", size = 5189598 }, - { url = "https://files.pythonhosted.org/packages/76/00/379bedfa6f1c810553996a4cb0984fa2e2c89afc5953df0936e1c9636003/pulsar_client-3.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a181e3e60ac39df72ccb3c415d7aeac61ad0286497a6e02739a560d5af28393a", size = 5430145 }, - { url = "https://files.pythonhosted.org/packages/88/c8/8a37d75aa9132a69a28061c9e5f4b516328a1968b58bbae018f431c6d3d4/pulsar_client-3.5.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3c72895ff7f51347e4f78b0375b2213fa70dd4790bbb78177b4002846f1fd290", size = 5708960 }, - { url = "https://files.pythonhosted.org/packages/6e/9a/abd98661e3f7ae3a8e1d3fb0fc7eba1a30005391ebd575ab06a66021256c/pulsar_client-3.5.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:547dba1b185a17eba915e51d0a3aca27c80747b6187e5cd7a71a3ca33921decc", size = 5915227 }, - { url = "https://files.pythonhosted.org/packages/a2/51/db376181d05716de595515fac736e3d06e96d3345ba0e31c0a90c352eae1/pulsar_client-3.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:443b786eed96bc86d2297a6a42e79f39d1abf217ec603e0bd303f3488c0234af", size = 3306515 }, -] - [[package]] name = "pure-eval" version = "0.2.3" @@ -4894,16 +4996,16 @@ wheels = [ [[package]] name = "rich" -version = "13.9.2" +version = "13.9.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "markdown-it-py" }, { name = "pygments" }, { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/aa/9e/1784d15b057b0075e5136445aaea92d23955aad2c93eaede673718a40d95/rich-13.9.2.tar.gz", hash = "sha256:51a2c62057461aaf7152b4d611168f93a9fc73068f8ded2790f29fe2b5366d0c", size = 222843 } +sdist = { url = "https://files.pythonhosted.org/packages/d9/e9/cf9ef5245d835065e6673781dbd4b8911d352fb770d56cf0879cf11b7ee1/rich-13.9.3.tar.gz", hash = "sha256:bc1e01b899537598cf02579d2b9f4a415104d3fc439313a7a2c165d76557a08e", size = 222889 } wheels = [ - { url = "https://files.pythonhosted.org/packages/67/91/5474b84e505a6ccc295b2d322d90ff6aa0746745717839ee0c5fb4fdcceb/rich-13.9.2-py3-none-any.whl", hash = "sha256:8c82a3d3f8dcfe9e734771313e606b39d8247bb6b826e196f4914b333b743cf1", size = 242117 }, + { url = "https://files.pythonhosted.org/packages/9a/e2/10e9819cf4a20bd8ea2f5dabafc2e6bf4a78d6a0965daeb60a4b34d1c11f/rich-13.9.3-py3-none-any.whl", hash = "sha256:9836f5096eb2172c9e77df411c1b009bace4193d6a481d534fea75ebba758283", size = 242157 }, ] [[package]] @@ -5014,6 +5116,101 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fe/f1/3db1590be946c14d86ac0cc8422e5808500903592b7ca09a097e425b1dba/ruff-0.4.8-py3-none-win_arm64.whl", hash = "sha256:14019a06dbe29b608f6b7cbcec300e3170a8d86efaddb7b23405cb7f7dcaf780", size = 7944828 }, ] +[[package]] +name = "safetensors" +version = "0.4.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cb/46/a1c56ed856c6ac3b1a8b37abe5be0cac53219367af1331e721b04d122577/safetensors-0.4.5.tar.gz", hash = "sha256:d73de19682deabb02524b3d5d1f8b3aaba94c72f1bbfc7911b9b9d5d391c0310", size = 65702 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/38/10/0798ec2c8704c2d172620d8a3725bed92cdd75516357b1a3e64d4229ea4e/safetensors-0.4.5-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:a63eaccd22243c67e4f2b1c3e258b257effc4acd78f3b9d397edc8cf8f1298a7", size = 392312 }, + { url = "https://files.pythonhosted.org/packages/2b/9e/9648d8dbb485c40a4a0212b7537626ae440b48156cc74601ca0b7a7615e0/safetensors-0.4.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:23fc9b4ec7b602915cbb4ec1a7c1ad96d2743c322f20ab709e2c35d1b66dad27", size = 381858 }, + { url = "https://files.pythonhosted.org/packages/8b/67/49556aeacc00df353767ed31d68b492fecf38c3f664c52692e4d92aa0032/safetensors-0.4.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6885016f34bef80ea1085b7e99b3c1f92cb1be78a49839203060f67b40aee761", size = 441382 }, + { url = "https://files.pythonhosted.org/packages/5d/ce/e9f4869a37bb11229e6cdb4e73a6ef23b4f360eee9dca5f7e40982779704/safetensors-0.4.5-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:133620f443450429322f238fda74d512c4008621227fccf2f8cf4a76206fea7c", size = 439001 }, + { url = "https://files.pythonhosted.org/packages/a0/27/aee8cf031b89c34caf83194ec6b7f2eed28d053fff8b6da6d00c85c56035/safetensors-0.4.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4fb3e0609ec12d2a77e882f07cced530b8262027f64b75d399f1504ffec0ba56", size = 478026 }, + { url = "https://files.pythonhosted.org/packages/da/33/1d9fc4805c623636e7d460f28eec92ebd1856f7a552df8eb78398a1ef4de/safetensors-0.4.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d0f1dd769f064adc33831f5e97ad07babbd728427f98e3e1db6902e369122737", size = 495545 }, + { url = "https://files.pythonhosted.org/packages/b9/df/6f766b56690709d22e83836e4067a1109a7d84ea152a6deb5692743a2805/safetensors-0.4.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c6d156bdb26732feada84f9388a9f135528c1ef5b05fae153da365ad4319c4c5", size = 435016 }, + { url = "https://files.pythonhosted.org/packages/90/fa/7bc3f18086201b1e55a42c88b822ae197d0158e12c54cd45c887305f1b7e/safetensors-0.4.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9e347d77e2c77eb7624400ccd09bed69d35c0332f417ce8c048d404a096c593b", size = 456273 }, + { url = "https://files.pythonhosted.org/packages/3e/59/2ae50150d37a65c1c5f01aec74dc737707b8bbecdc76307e5a1a12c8a376/safetensors-0.4.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9f556eea3aec1d3d955403159fe2123ddd68e880f83954ee9b4a3f2e15e716b6", size = 619669 }, + { url = "https://files.pythonhosted.org/packages/fe/43/10f0bb597aef62c9c154152e265057089f3c729bdd980e6c32c3ec2407a4/safetensors-0.4.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9483f42be3b6bc8ff77dd67302de8ae411c4db39f7224dec66b0eb95822e4163", size = 605212 }, + { url = "https://files.pythonhosted.org/packages/7c/75/ede6887ea0ceaba55730988bfc7668dc147a8758f907fa6db26fbb681b8e/safetensors-0.4.5-cp310-none-win32.whl", hash = "sha256:7389129c03fadd1ccc37fd1ebbc773f2b031483b04700923c3511d2a939252cc", size = 272652 }, + { url = "https://files.pythonhosted.org/packages/ba/f0/919c72a9eef843781e652d0650f2819039943e69b69d5af2d0451a23edc3/safetensors-0.4.5-cp310-none-win_amd64.whl", hash = "sha256:e98ef5524f8b6620c8cdef97220c0b6a5c1cef69852fcd2f174bb96c2bb316b1", size = 285879 }, + { url = "https://files.pythonhosted.org/packages/9a/a5/25bcf75e373412daf1fd88045ab3aa8140a0d804ef0e70712c4f2c5b94d8/safetensors-0.4.5-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:21f848d7aebd5954f92538552d6d75f7c1b4500f51664078b5b49720d180e47c", size = 392256 }, + { url = "https://files.pythonhosted.org/packages/08/8c/ece3bf8756506a890bd980eca02f47f9d98dfbf5ce16eda1368f53560f67/safetensors-0.4.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bb07000b19d41e35eecef9a454f31a8b4718a185293f0d0b1c4b61d6e4487971", size = 381490 }, + { url = "https://files.pythonhosted.org/packages/39/83/c4a7ce01d626e46ea2b45887f2e59b16441408031e2ce2f9fe01860c6946/safetensors-0.4.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09dedf7c2fda934ee68143202acff6e9e8eb0ddeeb4cfc24182bef999efa9f42", size = 441093 }, + { url = "https://files.pythonhosted.org/packages/47/26/cc52de647e71bd9a0b0d78ead0d31d9c462b35550a817aa9e0cab51d6db4/safetensors-0.4.5-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:59b77e4b7a708988d84f26de3ebead61ef1659c73dcbc9946c18f3b1786d2688", size = 438960 }, + { url = "https://files.pythonhosted.org/packages/06/78/332538546775ee97e749867df2d58f2282d9c48a1681e4891eed8b94ec94/safetensors-0.4.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5d3bc83e14d67adc2e9387e511097f254bd1b43c3020440e708858c684cbac68", size = 478031 }, + { url = "https://files.pythonhosted.org/packages/d9/03/a3c8663f1ddda54e624ecf43fce651659b49e8e1603c52c3e464b442acfa/safetensors-0.4.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:39371fc551c1072976073ab258c3119395294cf49cdc1f8476794627de3130df", size = 494754 }, + { url = "https://files.pythonhosted.org/packages/e6/ee/69e498a892f208bd1da4104d4b9be887f8611bf4942144718b6738482250/safetensors-0.4.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a6c19feda32b931cae0acd42748a670bdf56bee6476a046af20181ad3fee4090", size = 435013 }, + { url = "https://files.pythonhosted.org/packages/a2/61/f0cfce984515b86d1260f556ba3b782158e2855e6a318446ac2613786fa9/safetensors-0.4.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a659467495de201e2f282063808a41170448c78bada1e62707b07a27b05e6943", size = 455984 }, + { url = "https://files.pythonhosted.org/packages/e7/a9/3e3b48fcaade3eb4e347d39ebf0bd44291db21a3e4507854b42a7cb910ac/safetensors-0.4.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:bad5e4b2476949bcd638a89f71b6916fa9a5cae5c1ae7eede337aca2100435c0", size = 619513 }, + { url = "https://files.pythonhosted.org/packages/80/23/2a7a1be24258c0e44c1d356896fd63dc0545a98d2d0184925fa09cd3ec76/safetensors-0.4.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a3a315a6d0054bc6889a17f5668a73f94f7fe55121ff59e0a199e3519c08565f", size = 604841 }, + { url = "https://files.pythonhosted.org/packages/b4/5c/34d082ff1fffffd8545fb22cbae3285ab4236f1f0cfc64b7e58261c2363b/safetensors-0.4.5-cp311-none-win32.whl", hash = "sha256:a01e232e6d3d5cf8b1667bc3b657a77bdab73f0743c26c1d3c5dd7ce86bd3a92", size = 272602 }, + { url = "https://files.pythonhosted.org/packages/6d/41/948c96c8a7e9fef57c2e051f1871c108a6dbbc6d285598bdb1d89b98617c/safetensors-0.4.5-cp311-none-win_amd64.whl", hash = "sha256:cbd39cae1ad3e3ef6f63a6f07296b080c951f24cec60188378e43d3713000c04", size = 285973 }, + { url = "https://files.pythonhosted.org/packages/bf/ac/5a63082f931e99200db95fd46fb6734f050bb6e96bf02521904c6518b7aa/safetensors-0.4.5-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:473300314e026bd1043cef391bb16a8689453363381561b8a3e443870937cc1e", size = 392015 }, + { url = "https://files.pythonhosted.org/packages/73/95/ab32aa6e9bdc832ff87784cdf9da26192b93de3ef82b8d1ada8f345c5044/safetensors-0.4.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:801183a0f76dc647f51a2d9141ad341f9665602a7899a693207a82fb102cc53e", size = 381774 }, + { url = "https://files.pythonhosted.org/packages/d6/6c/7e04b7626809fc63f3698f4c50e43aff2864b40089aa4506c918a75b8eed/safetensors-0.4.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1524b54246e422ad6fb6aea1ac71edeeb77666efa67230e1faf6999df9b2e27f", size = 441134 }, + { url = "https://files.pythonhosted.org/packages/58/2b/ffe7c86a277e6c1595fbdf415cfe2903f253f574a5405e93fda8baaa582c/safetensors-0.4.5-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b3139098e3e8b2ad7afbca96d30ad29157b50c90861084e69fcb80dec7430461", size = 438467 }, + { url = "https://files.pythonhosted.org/packages/67/9c/f271bd804e08c7fda954d17b70ff281228a88077337a9e70feace4f4cc93/safetensors-0.4.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:65573dc35be9059770808e276b017256fa30058802c29e1038eb1c00028502ea", size = 476566 }, + { url = "https://files.pythonhosted.org/packages/4c/ad/4cf76a3e430a8a26108407fa6cb93e6f80d996a5cb75d9540c8fe3862990/safetensors-0.4.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fd33da8e9407559f8779c82a0448e2133737f922d71f884da27184549416bfed", size = 492253 }, + { url = "https://files.pythonhosted.org/packages/d9/40/a6f75ea449a9647423ec8b6f72c16998d35aa4b43cb38536ac060c5c7bf5/safetensors-0.4.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3685ce7ed036f916316b567152482b7e959dc754fcc4a8342333d222e05f407c", size = 434769 }, + { url = "https://files.pythonhosted.org/packages/52/47/d4b49b1231abf3131f7bb0bc60ebb94b27ee33e0a1f9569da05f8ac65dee/safetensors-0.4.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:dde2bf390d25f67908278d6f5d59e46211ef98e44108727084d4637ee70ab4f1", size = 457166 }, + { url = "https://files.pythonhosted.org/packages/c3/cd/006468b03b0fa42ff82d795d47c4193e99001e96c3f08bd62ef1b5cab586/safetensors-0.4.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7469d70d3de970b1698d47c11ebbf296a308702cbaae7fcb993944751cf985f4", size = 619280 }, + { url = "https://files.pythonhosted.org/packages/22/4d/b6208d918e83daa84b424c0ac3191ae61b44b3191613a3a5a7b38f94b8ad/safetensors-0.4.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3a6ba28118636a130ccbb968bc33d4684c48678695dba2590169d5ab03a45646", size = 605390 }, + { url = "https://files.pythonhosted.org/packages/e8/20/bf0e01825dc01ed75538021a98b9a046e60ead63c6c6700764c821a8c873/safetensors-0.4.5-cp312-none-win32.whl", hash = "sha256:c859c7ed90b0047f58ee27751c8e56951452ed36a67afee1b0a87847d065eec6", size = 273250 }, + { url = "https://files.pythonhosted.org/packages/f1/5f/ab6b6cec85b40789801f35b7d2fb579ae242d8193929974a106d5ff5c835/safetensors-0.4.5-cp312-none-win_amd64.whl", hash = "sha256:b5a8810ad6a6f933fff6c276eae92c1da217b39b4d8b1bc1c0b8af2d270dc532", size = 286307 }, + { url = "https://files.pythonhosted.org/packages/90/61/0e27b1403e311cba0be20026bee4ee822d90eda7dad372179e7f18bb99f3/safetensors-0.4.5-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:25e5f8e2e92a74f05b4ca55686234c32aac19927903792b30ee6d7bd5653d54e", size = 392062 }, + { url = "https://files.pythonhosted.org/packages/b1/9f/cc31fafc9f5d79da10a83a820ca37f069bab0717895ad8cbcacf629dd1c5/safetensors-0.4.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:81efb124b58af39fcd684254c645e35692fea81c51627259cdf6d67ff4458916", size = 382517 }, + { url = "https://files.pythonhosted.org/packages/a4/c7/4fda8a0ebb96662550433378f4a74c677fa5fc4d0a43a7ec287d1df254a9/safetensors-0.4.5-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:585f1703a518b437f5103aa9cf70e9bd437cb78eea9c51024329e4fb8a3e3679", size = 441378 }, + { url = "https://files.pythonhosted.org/packages/14/31/9abb431f6209de9c80dab83e1112ebd769f1e32e7ab7ab228a02424a4693/safetensors-0.4.5-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4b99fbf72e3faf0b2f5f16e5e3458b93b7d0a83984fe8d5364c60aa169f2da89", size = 438831 }, + { url = "https://files.pythonhosted.org/packages/37/37/99bfb195578a808b8d045159ee9264f8da58d017ac0701853dcacda14d4e/safetensors-0.4.5-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b17b299ca9966ca983ecda1c0791a3f07f9ca6ab5ded8ef3d283fff45f6bcd5f", size = 477112 }, + { url = "https://files.pythonhosted.org/packages/7d/05/fac3ef107e60d2a78532bed171a91669d4bb259e1236f5ea8c67a6976c75/safetensors-0.4.5-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:76ded72f69209c9780fdb23ea89e56d35c54ae6abcdec67ccb22af8e696e449a", size = 493373 }, + { url = "https://files.pythonhosted.org/packages/cf/7a/825800ee8c68214b4fd3506d5e19209338c69b41e01c6e14dd13969cc8b9/safetensors-0.4.5-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2783956926303dcfeb1de91a4d1204cd4089ab441e622e7caee0642281109db3", size = 435422 }, + { url = "https://files.pythonhosted.org/packages/5e/6c/7a3233c08bde558d6c33a41219119866cb596139a4673cc6c24024710ffd/safetensors-0.4.5-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d94581aab8c6b204def4d7320f07534d6ee34cd4855688004a4354e63b639a35", size = 457382 }, + { url = "https://files.pythonhosted.org/packages/a0/58/0b7bcba3788ff503990cf9278d611b56c029400612ba93e772c987b5aa03/safetensors-0.4.5-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:67e1e7cb8678bb1b37ac48ec0df04faf689e2f4e9e81e566b5c63d9f23748523", size = 619301 }, + { url = "https://files.pythonhosted.org/packages/82/cc/9c2cf58611daf1c83ce5d37f9de66353e23fcda36008b13fd3409a760aa3/safetensors-0.4.5-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:dbd280b07e6054ea68b0cb4b16ad9703e7d63cd6890f577cb98acc5354780142", size = 605580 }, + { url = "https://files.pythonhosted.org/packages/cf/ff/037ae4c0ee32db496669365e66079b6329906c6814722b159aa700e67208/safetensors-0.4.5-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:fdadf66b5a22ceb645d5435a0be7a0292ce59648ca1d46b352f13cff3ea80410", size = 392951 }, + { url = "https://files.pythonhosted.org/packages/f1/d6/6621e16b35bf83ae099eaab07338f04991a26c9aa43879d05f19f35e149c/safetensors-0.4.5-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d42ffd4c2259f31832cb17ff866c111684c87bd930892a1ba53fed28370c918c", size = 383417 }, + { url = "https://files.pythonhosted.org/packages/ae/88/3068e1bb16f5e9f9068901de3cf7b3db270b9bfe6e7d51d4b55c1da0425d/safetensors-0.4.5-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dd8a1f6d2063a92cd04145c7fd9e31a1c7d85fbec20113a14b487563fdbc0597", size = 442311 }, + { url = "https://files.pythonhosted.org/packages/f7/15/a2bb77ebbaa76b61ec2e9f731fe4db7f9473fd855d881957c51b3a168892/safetensors-0.4.5-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:951d2fcf1817f4fb0ef0b48f6696688a4e852a95922a042b3f96aaa67eedc920", size = 436678 }, + { url = "https://files.pythonhosted.org/packages/ec/79/9608c4546cdbfe3860dd7aa59e3562c9289113398b1a0bd89b68ce0a9d41/safetensors-0.4.5-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6ac85d9a8c1af0e3132371d9f2d134695a06a96993c2e2f0bbe25debb9e3f67a", size = 457316 }, + { url = "https://files.pythonhosted.org/packages/0f/23/b17b483f2857835962ad33e38014efd4911791187e177bc23b057d35bee8/safetensors-0.4.5-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:e3cec4a29eb7fe8da0b1c7988bc3828183080439dd559f720414450de076fcab", size = 620565 }, + { url = "https://files.pythonhosted.org/packages/19/46/5d11dc300feaad285c2f1bd784ff3f689f5e0ab6be49aaf568f3a77019eb/safetensors-0.4.5-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:21742b391b859e67b26c0b2ac37f52c9c0944a879a25ad2f9f9f3cd61e7fda8f", size = 606660 }, +] + +[[package]] +name = "scikit-learn" +version = "1.5.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "joblib" }, + { name = "numpy" }, + { name = "scipy" }, + { name = "threadpoolctl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/37/59/44985a2bdc95c74e34fef3d10cb5d93ce13b0e2a7baefffe1b53853b502d/scikit_learn-1.5.2.tar.gz", hash = "sha256:b4237ed7b3fdd0a4882792e68ef2545d5baa50aca3bb45aa7df468138ad8f94d", size = 7001680 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/98/89/be41419b4bec629a4691183a5eb1796f91252a13a5ffa243fd958cad7e91/scikit_learn-1.5.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:299406827fb9a4f862626d0fe6c122f5f87f8910b86fe5daa4c32dcd742139b6", size = 12106070 }, + { url = "https://files.pythonhosted.org/packages/bf/e0/3b6d777d375f3b685f433c93384cdb724fb078e1dc8f8ff0950467e56c30/scikit_learn-1.5.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:2d4cad1119c77930b235579ad0dc25e65c917e756fe80cab96aa3b9428bd3fb0", size = 10971758 }, + { url = "https://files.pythonhosted.org/packages/7b/31/eb7dd56c371640753953277de11356c46a3149bfeebb3d7dcd90b993715a/scikit_learn-1.5.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c412ccc2ad9bf3755915e3908e677b367ebc8d010acbb3f182814524f2e5540", size = 12500080 }, + { url = "https://files.pythonhosted.org/packages/4c/1e/a7c7357e704459c7d56a18df4a0bf08669442d1f8878cc0864beccd6306a/scikit_learn-1.5.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a686885a4b3818d9e62904d91b57fa757fc2bed3e465c8b177be652f4dd37c8", size = 13347241 }, + { url = "https://files.pythonhosted.org/packages/48/76/154ebda6794faf0b0f3ccb1b5cd9a19f0a63cb9e1f3d2c61b6114002677b/scikit_learn-1.5.2-cp310-cp310-win_amd64.whl", hash = "sha256:c15b1ca23d7c5f33cc2cb0a0d6aaacf893792271cddff0edbd6a40e8319bc113", size = 11000477 }, + { url = "https://files.pythonhosted.org/packages/ff/91/609961972f694cb9520c4c3d201e377a26583e1eb83bc5a334c893729214/scikit_learn-1.5.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:03b6158efa3faaf1feea3faa884c840ebd61b6484167c711548fce208ea09445", size = 12088580 }, + { url = "https://files.pythonhosted.org/packages/cd/7a/19fe32c810c5ceddafcfda16276d98df299c8649e24e84d4f00df4a91e01/scikit_learn-1.5.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:1ff45e26928d3b4eb767a8f14a9a6efbf1cbff7c05d1fb0f95f211a89fd4f5de", size = 10975994 }, + { url = "https://files.pythonhosted.org/packages/4c/75/62e49f8a62bf3c60b0e64d0fce540578ee4f0e752765beb2e1dc7c6d6098/scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f763897fe92d0e903aa4847b0aec0e68cadfff77e8a0687cabd946c89d17e675", size = 12465782 }, + { url = "https://files.pythonhosted.org/packages/49/21/3723de321531c9745e40f1badafd821e029d346155b6c79704e0b7197552/scikit_learn-1.5.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8b0ccd4a902836493e026c03256e8b206656f91fbcc4fde28c57a5b752561f1", size = 13322034 }, + { url = "https://files.pythonhosted.org/packages/17/1c/ccdd103cfcc9435a18819856fbbe0c20b8fa60bfc3343580de4be13f0668/scikit_learn-1.5.2-cp311-cp311-win_amd64.whl", hash = "sha256:6c16d84a0d45e4894832b3c4d0bf73050939e21b99b01b6fd59cbb0cf39163b6", size = 11015224 }, + { url = "https://files.pythonhosted.org/packages/a4/db/b485c1ac54ff3bd9e7e6b39d3cc6609c4c76a65f52ab0a7b22b6c3ab0e9d/scikit_learn-1.5.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f932a02c3f4956dfb981391ab24bda1dbd90fe3d628e4b42caef3e041c67707a", size = 12110344 }, + { url = "https://files.pythonhosted.org/packages/54/1a/7deb52fa23aebb855431ad659b3c6a2e1709ece582cb3a63d66905e735fe/scikit_learn-1.5.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:3b923d119d65b7bd555c73be5423bf06c0105678ce7e1f558cb4b40b0a5502b1", size = 11033502 }, + { url = "https://files.pythonhosted.org/packages/a1/32/4a7a205b14c11225609b75b28402c196e4396ac754dab6a81971b811781c/scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd", size = 12085794 }, + { url = "https://files.pythonhosted.org/packages/c6/29/044048c5e911373827c0e1d3051321b9183b2a4f8d4e2f11c08fcff83f13/scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6", size = 12945797 }, + { url = "https://files.pythonhosted.org/packages/aa/ce/c0b912f2f31aeb1b756a6ba56bcd84dd1f8a148470526a48515a3f4d48cd/scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1", size = 10985467 }, + { url = "https://files.pythonhosted.org/packages/a4/50/8891028437858cc510e13578fe7046574a60c2aaaa92b02d64aac5b1b412/scikit_learn-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9a702e2de732bbb20d3bad29ebd77fc05a6b427dc49964300340e4c9328b3f5", size = 12025584 }, + { url = "https://files.pythonhosted.org/packages/d2/79/17feef8a1c14149436083bec0e61d7befb4812e272d5b20f9d79ea3e9ab1/scikit_learn-1.5.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:b0768ad641981f5d3a198430a1d31c3e044ed2e8a6f22166b4d546a5116d7908", size = 10959795 }, + { url = "https://files.pythonhosted.org/packages/b1/c8/f08313f9e2e656bd0905930ae8bf99a573ea21c34666a813b749c338202f/scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:178ddd0a5cb0044464fc1bfc4cca5b1833bfc7bb022d70b05db8530da4bb3dd3", size = 12077302 }, + { url = "https://files.pythonhosted.org/packages/a7/48/fbfb4dc72bed0fe31fe045fb30e924909ad03f717c36694351612973b1a9/scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7284ade780084d94505632241bf78c44ab3b6f1e8ccab3d2af58e0e950f9c12", size = 13002811 }, + { url = "https://files.pythonhosted.org/packages/a5/e7/0c869f9e60d225a77af90d2aefa7a4a4c0e745b149325d1450f0f0ce5399/scikit_learn-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:b7b0f9a0b1040830d38c39b91b3a44e1b643f4b36e36567b80b7c6bd2202a27f", size = 10951354 }, +] + [[package]] name = "scipy" version = "1.14.1" @@ -5074,6 +5271,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/aa/85/fa44f23dd5d5066a72f7c4304cce4b5ff9a6e7fd92431a48b2c63fbf63ec/selenium-4.25.0-py3-none-any.whl", hash = "sha256:3798d2d12b4a570bc5790163ba57fef10b2afee958bf1d80f2a3cf07c4141f33", size = 9693127 }, ] +[[package]] +name = "sentence-transformers" +version = "3.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "huggingface-hub" }, + { name = "pillow" }, + { name = "scikit-learn" }, + { name = "scipy" }, + { name = "torch" }, + { name = "tqdm" }, + { name = "transformers" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/de/61/708b20dedf26c460b416beb0acd5474c190dbca13e93b40858e99f17ac46/sentence_transformers-3.2.1.tar.gz", hash = "sha256:9fc38e620e5e1beba31d538a451778c9ccdbad77119d90f59f5bce49c4148e79", size = 202527 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/45/18/1ec591befcbdb2c97192a40fbe7c43a8b8a8b3c89b1fa101d3eeed4d79a4/sentence_transformers-3.2.1-py3-none-any.whl", hash = "sha256:c507e069eea33d15f1f2c72f74d7ea93abef298152cc235ab5af5e3a7584f738", size = 255758 }, +] + [[package]] name = "setuptools" version = "75.2.0" @@ -5415,14 +5630,14 @@ wheels = [ [[package]] name = "starlette" -version = "0.40.0" +version = "0.41.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/4b/cb/244daf0d7be4508099ad5bca3cdfe8b8b5538acd719c5f397f614e569fff/starlette-0.40.0.tar.gz", hash = "sha256:1a3139688fb298ce5e2d661d37046a66ad996ce94be4d4983be019a23a04ea35", size = 2573611 } +sdist = { url = "https://files.pythonhosted.org/packages/78/53/c3a36690a923706e7ac841f649c64f5108889ab1ec44218dac45771f252a/starlette-0.41.0.tar.gz", hash = "sha256:39cbd8768b107d68bfe1ff1672b38a2c38b49777de46d2a592841d58e3bf7c2a", size = 2573755 } wheels = [ - { url = "https://files.pythonhosted.org/packages/0a/0f/64baf7a06492e8c12f5c4b49db286787a7255195df496fc21f5fd9eecffa/starlette-0.40.0-py3-none-any.whl", hash = "sha256:c494a22fae73805376ea6bf88439783ecfba9aac88a43911b48c653437e784c4", size = 73303 }, + { url = "https://files.pythonhosted.org/packages/35/c6/a4443bfabf5629129512ca0e07866c4c3c094079ba4e9b2551006927253c/starlette-0.41.0-py3-none-any.whl", hash = "sha256:a0193a3c413ebc9c78bff1c3546a45bb8c8bcb4a84cae8747d650a65bd37210a", size = 73216 }, ] [[package]] @@ -5436,14 +5651,14 @@ wheels = [ [[package]] name = "sympy" -version = "1.13.3" +version = "1.13.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "mpmath" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/11/8a/5a7fd6284fa8caac23a26c9ddf9c30485a48169344b4bd3b0f02fef1890f/sympy-1.13.3.tar.gz", hash = "sha256:b27fd2c6530e0ab39e275fc9b683895367e51d5da91baa8d3d64db2565fec4d9", size = 7533196 } +sdist = { url = "https://files.pythonhosted.org/packages/ca/99/5a5b6f19ff9f083671ddf7b9632028436167cd3d33e11015754e41b249a4/sympy-1.13.1.tar.gz", hash = "sha256:9cebf7e04ff162015ce31c9c6c9144daa34a93bd082f54fd8f12deca4f47515f", size = 7533040 } wheels = [ - { url = "https://files.pythonhosted.org/packages/99/ff/c87e0622b1dadea79d2fb0b25ade9ed98954c9033722eb707053d310d4f3/sympy-1.13.3-py3-none-any.whl", hash = "sha256:54612cf55a62755ee71824ce692986f23c88ffa77207b30c1368eda4a7060f73", size = 6189483 }, + { url = "https://files.pythonhosted.org/packages/b2/fe/81695a1aa331a842b582453b605175f419fe8540355886031328089d840a/sympy-1.13.1-py3-none-any.whl", hash = "sha256:db36cdc64bf61b9b24578b6f7bab1ecdd2452cf008f34faa33776680c26d66f8", size = 6189177 }, ] [[package]] @@ -5558,6 +5773,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/07/a9/01d35770fde8d889e1fe28b726188cf28801e57afd369c614cd2bc100ee4/textual_serve-1.1.1-py3-none-any.whl", hash = "sha256:568782f1c0e60e3f7039d9121e1cb5c2f4ca1aaf6d6bd7aeb833d5763a534cb2", size = 445034 }, ] +[[package]] +name = "threadpoolctl" +version = "3.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bd/55/b5148dcbf72f5cde221f8bfe3b6a540da7aa1842f6b491ad979a6c8b84af/threadpoolctl-3.5.0.tar.gz", hash = "sha256:082433502dd922bf738de0d8bcc4fdcbf0979ff44c42bd40f5af8a282f6fa107", size = 41936 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4b/2c/ffbf7a134b9ab11a67b0cf0726453cedd9c5043a4fe7a35d1cefa9a1bcfb/threadpoolctl-3.5.0-py3-none-any.whl", hash = "sha256:56c1e26c150397e58c4926da8eeee87533b1e32bef131bd4bf6a2f45f3185467", size = 18414 }, +] + [[package]] name = "tiktoken" version = "0.8.0" @@ -5687,6 +5911,48 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cf/db/ce8eda256fa131af12e0a76d481711abe4681b6923c27efb9a255c9e4594/tomli-2.0.2-py3-none-any.whl", hash = "sha256:2ebe24485c53d303f690b0ec092806a085f07af5a5aa1464f3931eec36caaa38", size = 13237 }, ] +[[package]] +name = "torch" +version = "2.5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "filelock" }, + { name = "fsspec" }, + { name = "jinja2" }, + { name = "networkx" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "setuptools", marker = "python_full_version >= '3.12'" }, + { name = "sympy" }, + { name = "triton", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "typing-extensions" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/f1/82/adc3a77b9fbbcb79d398d565d39dc0e09f43fff088599d15da81e6cfaaec/torch-2.5.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:7f179373a047b947dec448243f4e6598a1c960fa3bb978a9a7eecd529fbc363f", size = 906443143 }, + { url = "https://files.pythonhosted.org/packages/64/b0/0d2056c8d379a3f7f0c9fa9adece180f64fd6c339e2007a4fffbea7ecaa0/torch-2.5.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:15fbc95e38d330e5b0ef1593b7bc0a19f30e5bdad76895a5cffa1a6a044235e9", size = 91839507 }, + { url = "https://files.pythonhosted.org/packages/60/41/073193dd2566012eaeae44d6c5e55ba6a9b1d5687a251f12e1804a9e2968/torch-2.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:f499212f1cffea5d587e5f06144630ed9aa9c399bba12ec8905798d833bd1404", size = 203108822 }, + { url = "https://files.pythonhosted.org/packages/93/d4/6e7bda4e52c37a78b5066e407baff2426fd4543356ead3419383a0bf4011/torch-2.5.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:c54db1fade17287aabbeed685d8e8ab3a56fea9dd8d46e71ced2da367f09a49f", size = 64283014 }, + { url = "https://files.pythonhosted.org/packages/75/9f/cde8b71ccca65d68a3733c5c9decef9adefcfaa692f8ab03afbb5de09daa/torch-2.5.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:499a68a756d3b30d10f7e0f6214dc3767b130b797265db3b1c02e9094e2a07be", size = 906478039 }, + { url = "https://files.pythonhosted.org/packages/58/27/5bacfb6600209bf7e77ba115656cf7aca5b6ab1e0dc95551eefac2d6e7ec/torch-2.5.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:9f3df8138a1126a851440b7d5a4869bfb7c9cc43563d64fd9d96d0465b581024", size = 91843630 }, + { url = "https://files.pythonhosted.org/packages/78/18/7a2e56e2dc45a433dea9e1bf46a65e234294c9c470ccb4d4b53025f57b23/torch-2.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:b81da3bdb58c9de29d0e1361e52f12fcf10a89673f17a11a5c6c7da1cb1a8376", size = 203117099 }, + { url = "https://files.pythonhosted.org/packages/47/1b/3dfcc84b383f7b27a41de3251753db077b1e23d3f89a3b294cdd2d86fb7b/torch-2.5.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:ba135923295d564355326dc409b6b7f5bd6edc80f764cdaef1fb0a1b23ff2f9c", size = 64288133 }, + { url = "https://files.pythonhosted.org/packages/ac/72/d610029ef5cdde3f3aa216e8e75c233b1a91b34af0fc47392b3aa928563a/torch-2.5.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:2dd40c885a05ef7fe29356cca81be1435a893096ceb984441d6e2c27aff8c6f4", size = 906389657 }, + { url = "https://files.pythonhosted.org/packages/22/c2/d1759641eafdf59cb3a339909e96c842fc0c3579681bb7422acaf4a2c179/torch-2.5.0-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:bc52d603d87fe1da24439c0d5fdbbb14e0ae4874451d53f0120ffb1f6c192727", size = 91823361 }, + { url = "https://files.pythonhosted.org/packages/2b/e3/0f2698930d944087c3ef585b71a1a72aa51929877c1ccf35d625bec9bd78/torch-2.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:ea718746469246cc63b3353afd75698a288344adb55e29b7f814a5d3c0a7c78d", size = 203064894 }, + { url = "https://files.pythonhosted.org/packages/56/88/f1ddffd642cf71777dca43621b170d50f13175cdd0b4179e04d6e025b5fb/torch-2.5.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:6de1fd253e27e7f01f05cd7c37929ae521ca23ca4620cfc7c485299941679112", size = 64261171 }, + { url = "https://files.pythonhosted.org/packages/b4/b1/f06261814df00eee07ac8cf697a6f5d79231d9894c996d5985243343518a/torch-2.5.0-cp313-cp313-manylinux1_x86_64.whl", hash = "sha256:83dcf518685db20912b71fc49cbddcc8849438cdb0e9dcc919b02a849e2cd9e8", size = 906416128 }, +] + [[package]] name = "tornado" version = "6.4.1" @@ -5726,6 +5992,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/00/c0/8f5d070730d7836adc9c9b6408dec68c6ced86b304a9b26a14df072a6e8c/traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f", size = 85359 }, ] +[[package]] +name = "transformers" +version = "4.45.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "filelock" }, + { name = "huggingface-hub" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "pyyaml" }, + { name = "regex" }, + { name = "requests" }, + { name = "safetensors" }, + { name = "tokenizers" }, + { name = "tqdm" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4b/4c/3862b2dd6cdf83b187897bd351da0f7fb74d0df642b03c6f5d06353a3ca0/transformers-4.45.2.tar.gz", hash = "sha256:72bc390f6b203892561f05f86bbfaa0e234aab8e927a83e62b9d92ea7e3ae101", size = 8478357 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f9/9d/030cc1b3e88172967e22ee1d012e0d5e0384eb70d2a098d1669d549aea29/transformers-4.45.2-py3-none-any.whl", hash = "sha256:c551b33660cfc815bae1f9f097ecfd1e65be623f13c6ee0dda372bd881460210", size = 9881312 }, +] + [[package]] name = "trio" version = "0.27.0" @@ -5758,6 +6045,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/48/be/a9ae5f50cad5b6f85bd2574c2c923730098530096e170c1ce7452394d7aa/trio_websocket-0.11.1-py3-none-any.whl", hash = "sha256:520d046b0d030cf970b8b2b2e00c4c2245b3807853ecd44214acd33d74581638", size = 17408 }, ] +[[package]] +name = "triton" +version = "3.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "filelock" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/98/29/69aa56dc0b2eb2602b553881e34243475ea2afd9699be042316842788ff5/triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b0dd10a925263abbe9fa37dcde67a5e9b2383fc269fdf59f5657cac38c5d1d8", size = 209460013 }, + { url = "https://files.pythonhosted.org/packages/86/17/d9a5cf4fcf46291856d1e90762e36cbabd2a56c7265da0d1d9508c8e3943/triton-3.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f34f6e7885d1bf0eaaf7ba875a5f0ce6f3c13ba98f9503651c1e6dc6757ed5c", size = 209506424 }, + { url = "https://files.pythonhosted.org/packages/78/eb/65f5ba83c2a123f6498a3097746607e5b2f16add29e36765305e4ac7fdd8/triton-3.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8182f42fd8080a7d39d666814fa36c5e30cc00ea7eeeb1a2983dbb4c99a0fdc", size = 209551444 }, +] + [[package]] name = "typer" version = "0.12.5"