From 589cc3e66473d709096ff6705cfc7542419f5b9b Mon Sep 17 00:00:00 2001 From: szymon Date: Thu, 1 Feb 2024 17:41:56 +0100 Subject: [PATCH] Rewriting local example in llm-app to use xpack (#5541) GitOrigin-RevId: 9b44e727a2d28a56ee49be2a0cda9862af7cb66e --- examples/pipelines/local/app.py | 43 ++++++++++++++++++++++----------- poetry.lock | 29 ++++++++++++++++------ pyproject.toml | 2 +- 3 files changed, 52 insertions(+), 22 deletions(-) diff --git a/examples/pipelines/local/app.py b/examples/pipelines/local/app.py index 6a62bae..8176b7c 100644 --- a/examples/pipelines/local/app.py +++ b/examples/pipelines/local/app.py @@ -13,6 +13,13 @@ for documents in the corpus. A prompt is build from the relevant documentations pages and run through a local LLM downloaded form the HuggingFace repository. +Because of restrictions of model you need to be careful about the length of prompt with +the embedded documents. In this example this is solved with cropping the prompt to a set +length - the query is in the beginning of the prompt, so it won't be removed, but some +parts of documents to be omitted from the query. +Depending on the length of documents and the model you use this may not be necessary or +you can use some more refined method of shortening your prompts. + Usage: In the root of this repository run: `poetry run ./run_examples.py local` @@ -28,8 +35,8 @@ import pathway as pw from pathway.stdlib.ml.index import KNNIndex - -from llm_app.model_wrappers import HFTextGenerationTask, SentenceTransformerTask +from pathway.xpacks.llm.embedders import SentenceTransformerEmbedder +from pathway.xpacks.llm.llms import HFPipelineChat, prompt_chat_single_qa class DocumentInputSchema(pw.Schema): @@ -50,13 +57,12 @@ def run( port: int = 8080, model_locator: str = os.environ.get("MODEL", "gpt2"), embedder_locator: str = os.environ.get("EMBEDDER", "intfloat/e5-large-v2"), - embedding_dimension: int = 1024, max_tokens: int = 60, device: str = "cpu", **kwargs, ): - embedder = SentenceTransformerTask(model=embedder_locator, device=device) - embedding_dimension = len(embedder("")) + embedder = SentenceTransformerEmbedder(model=embedder_locator, device=device) + embedding_dimension = len(embedder.__wrapped__("")) documents = pw.io.jsonlines.read( data_dir, @@ -65,9 +71,7 @@ def run( autocommit_duration_ms=50, ) - enriched_documents = documents + documents.select( - vector=embedder.apply(text=pw.this.doc) - ) + enriched_documents = documents + documents.select(vector=embedder(text=pw.this.doc)) index = KNNIndex( enriched_documents.vector, enriched_documents, n_dimensions=embedding_dimension @@ -82,7 +86,7 @@ def run( ) query += query.select( - vector=embedder.apply(text=pw.this.query), + vector=embedder(text=pw.this.query), ) query_context = query + index.get_nearest_items( @@ -92,20 +96,31 @@ def run( @pw.udf def build_prompt(documents, query): docs_str = "\n".join(documents) - prompt = f"Given the following documents : \n {docs_str} \nanswer this query: {query}" + prompt = f"You are given a query: {query}\n Answer this query based on the following documents: \n {docs_str}" return prompt prompt = query_context.select( prompt=build_prompt(pw.this.documents_list, pw.this.query) ) - model = HFTextGenerationTask(model=model_locator, device=device) + model = HFPipelineChat( + model=model_locator, + device=device, + return_full_text=False, + max_new_tokens=max_tokens, + ) + + # Cropping the prompt so that it is short enough for the model. Depending on input documents + # and chosen model this may not be necessary. + prompt = prompt.select( + prompt=model.crop_to_max_length( + input_string=pw.this.prompt, max_prompt_length=500 + ) + ) responses = prompt.select( query_id=pw.this.id, - result=model.apply( - pw.this.prompt, return_full_text=False, max_new_tokens=max_tokens - ), + result=model(prompt_chat_single_qa(pw.this.prompt)), ) response_writer(responses) diff --git a/poetry.lock b/poetry.lock index 787b3e6..13a198c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -96,6 +96,20 @@ yarl = ">=1.0,<2.0" [package.extras] speedups = ["Brotli", "aiodns", "brotlicffi"] +[[package]] +name = "aiohttp-cors" +version = "0.7.0" +description = "CORS support for aiohttp" +optional = false +python-versions = "*" +files = [ + {file = "aiohttp-cors-0.7.0.tar.gz", hash = "sha256:4d39c6d7100fd9764ed1caf8cebf0eb01bf5e3f24e2e073fda6234bc48b19f5d"}, + {file = "aiohttp_cors-0.7.0-py3-none-any.whl", hash = "sha256:0451ba59fdf6909d0e2cd21e4c0a43752bc0703d33fc78ae94d9d9321710193e"}, +] + +[package.dependencies] +aiohttp = ">=1.1" + [[package]] name = "aiosignal" version = "1.3.1" @@ -2783,9 +2797,9 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.23.5", markers = "python_version >= \"3.11\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, + {version = ">=1.23.5", markers = "python_version >= \"3.11\""}, ] [[package]] @@ -2849,9 +2863,9 @@ files = [ [package.dependencies] numpy = [ + {version = ">=1.22.4,<2", markers = "python_version < \"3.11\""}, {version = ">=1.23.2,<2", markers = "python_version == \"3.11\""}, {version = ">=1.26.0,<2", markers = "python_version >= \"3.12\""}, - {version = ">=1.22.4,<2", markers = "python_version < \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -2956,18 +2970,19 @@ testing = ["docopt", "pytest (<6.0.0)"] [[package]] name = "pathway" -version = "0.7.10" +version = "0.8.0" description = "Pathway is a data processing framework which takes care of streaming data updates for you." optional = false python-versions = ">=3.10" files = [ - {file = "pathway-0.7.10-cp310-abi3-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:3c4d63202f805cda154a936766c755385e57e83c0b53e09224a6bc2c03b11cff"}, - {file = "pathway-0.7.10-cp310-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4d28d1a8ef4bf75b2dd762e146313a11b190d425dd6da376ebf06b4241fcf0f5"}, - {file = "pathway-0.7.10-cp310-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd1c96cd3ecb5857aceb0d902b488b3b12e2785f8d3491b46d4e3ec1a809ad6e"}, + {file = "pathway-0.8.0-cp310-abi3-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:8e4274470279a80d5cc0646284e2475bbafb7e66bad922c3340cfb4a06aeaffd"}, + {file = "pathway-0.8.0-cp310-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:35d9370b2b4ffe2410362e514ee4d5e5309088814d7802e13cc43e22cb421d23"}, + {file = "pathway-0.8.0-cp310-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d6991a9b6eb75b4738ea4b08b1dd9b837075aa11d9b89506514c97ade7b77f1"}, ] [package.dependencies] aiohttp = ">=3.8.4" +aiohttp_cors = ">=0.7.0" beartype = ">=0.14.0,<0.16.0" boto3 = ">=1.26.76" click = ">=8.1" @@ -5764,4 +5779,4 @@ unstructured-to-sql = ["psycopg", "tiktoken", "unstructured"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "78417b5694e6167e1c2493ed5f75935e6345e021b7af84ef67193d2d34ab0ca2" +content-hash = "87fd30947e499f0e0267cc5e85d0d1acddf3f368d8a82c7c079e0c5cd10e9ecd" diff --git a/pyproject.toml b/pyproject.toml index 6544b0f..c524386 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ classifiers = [ [tool.poetry.dependencies] python = ">=3.10,<3.13" -pathway = "=0.7.10" +pathway = "=0.8.0" openai = ">=1.2.4" requests = "^2.31.0" diskcache = "^5.6.1"