diff --git a/llm-service/app/routers/index/__init__.py b/llm-service/app/routers/index/__init__.py index 873b2355..de4e7948 100644 --- a/llm-service/app/routers/index/__init__.py +++ b/llm-service/app/routers/index/__init__.py @@ -57,7 +57,6 @@ router = APIRouter( prefix="/index", - tags=["index"], ) router.include_router(data_source.router) router.include_router(sessions.router) @@ -65,6 +64,14 @@ router.include_router(models.router) +class SuggestQuestionsRequest(BaseModel): + data_source_id: int + chat_history: list[RagContext] + configuration: qdrant.RagPredictConfiguration = qdrant.RagPredictConfiguration() + +class RagSuggestedQuestionsResponse(BaseModel): + suggested_questions: list[str] + class RagIndexDocumentRequest(BaseModel): data_source_id: int s3_bucket_name: str @@ -81,7 +88,7 @@ class RagIndexDocumentRequest(BaseModel): ) @exceptions.propagates def download_and_index( - request: RagIndexDocumentRequest, + request: RagIndexDocumentRequest, ) -> str: with tempfile.TemporaryDirectory() as tmpdirname: logger.debug("created temporary directory %s", tmpdirname) @@ -94,16 +101,6 @@ def download_and_index( ) return http.HTTPStatus.OK.phrase - -class SuggestQuestionsRequest(BaseModel): - data_source_id: int - chat_history: list[RagContext] - configuration: qdrant.RagPredictConfiguration = qdrant.RagPredictConfiguration() - -class RagSuggestedQuestionsResponse(BaseModel): - suggested_questions: list[str] - - @router.post("/suggest-questions", summary="Suggest questions with context") @exceptions.propagates def suggest_questions( diff --git a/llm-service/app/routers/index/amp_update/__init__.py b/llm-service/app/routers/index/amp_update/__init__.py index f5ec7852..bcec0f05 100644 --- a/llm-service/app/routers/index/amp_update/__init__.py +++ b/llm-service/app/routers/index/amp_update/__init__.py @@ -36,7 +36,6 @@ # DATA. # ############################################################################## -import json import subprocess from fastapi import APIRouter @@ -44,7 +43,7 @@ from .... import exceptions from ....services.amp_update import check_amp_update_status -router = APIRouter(prefix="/amp-update") +router = APIRouter(prefix="/amp-update" , tags=["AMP Update"]) @router.get("", summary="Returns a boolean for whether AMP needs updating.") @exceptions.propagates diff --git a/llm-service/app/routers/index/data_source/__init__.py b/llm-service/app/routers/index/data_source/__init__.py index 6ee513d2..ba8a5aaf 100644 --- a/llm-service/app/routers/index/data_source/__init__.py +++ b/llm-service/app/routers/index/data_source/__init__.py @@ -34,7 +34,7 @@ from .... import exceptions from ....services import doc_summaries, qdrant -router = APIRouter(prefix="/data_sources/{data_source_id}") +router = APIRouter(prefix="/data_sources/{data_source_id}", tags=["Data Sources"]) class SummarizeDocumentRequest(BaseModel): diff --git a/llm-service/app/routers/index/models/__init__.py b/llm-service/app/routers/index/models/__init__.py index 2682afd4..83ff3efa 100644 --- a/llm-service/app/routers/index/models/__init__.py +++ b/llm-service/app/routers/index/models/__init__.py @@ -49,7 +49,7 @@ test_llm_model, ) -router = APIRouter(prefix="/models") +router = APIRouter(prefix="/models", tags=["Models"]) @router.get("/llm", summary="Get LLM Inference models.") diff --git a/llm-service/app/routers/index/sessions/__init__.py b/llm-service/app/routers/index/sessions/__init__.py index 3532e02e..05b17071 100644 --- a/llm-service/app/routers/index/sessions/__init__.py +++ b/llm-service/app/routers/index/sessions/__init__.py @@ -41,7 +41,7 @@ from .... import exceptions from ....services.chat_store import RagStudioChatMessage, chat_store -router = APIRouter(prefix="/sessions/{session_id}") +router = APIRouter(prefix="/sessions/{session_id}", tags=["Sessions"]) @router.get("/chat-history", summary="Returns an array of chat messages for the provided session.") @exceptions.propagates diff --git a/llm-service/app/services/caii.py b/llm-service/app/services/caii.py index aa497a82..8ae247f0 100644 --- a/llm-service/app/services/caii.py +++ b/llm-service/app/services/caii.py @@ -109,7 +109,16 @@ def get_embedding_model() -> BaseEmbedding: def get_caii_llm_models(): domain = os.environ['CAII_DOMAIN'] endpoint_name = os.environ['CAII_INFERENCE_ENDPOINT_NAME'] - models = describe_endpoint(domain=domain, endpoint_name=endpoint_name) + try: + models = describe_endpoint(domain=domain, endpoint_name=endpoint_name) + except requests.exceptions.ConnectionError as e: + print(e) + raise HTTPException(status_code=421, detail = f"Unable to connect to host {domain}. Please check your CAII_DOMAIN env variable.") + except HTTPException as e: + if e.status_code == 404: + return [{"model_id": endpoint_name}] + else: + raise e return build_model_response(models) def get_caii_embedding_models(): @@ -120,6 +129,9 @@ def get_caii_embedding_models(): endpoint_name = os.environ['CAII_EMBEDDING_ENDPOINT_NAME'] try: models = describe_endpoint(domain=domain, endpoint_name=endpoint_name) + except requests.exceptions.ConnectionError as e: + print(e) + raise HTTPException(status_code=421, detail = f"Unable to connect to host {domain}. Please check your CAII_DOMAIN env variable.") except HTTPException as e: if e.status_code == 404: return [{"model_id": endpoint_name}] diff --git a/ui/src/api/modelsApi.ts b/ui/src/api/modelsApi.ts index 3827509e..ad362c3b 100644 --- a/ui/src/api/modelsApi.ts +++ b/ui/src/api/modelsApi.ts @@ -35,13 +35,15 @@ * BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF * DATA. ******************************************************************************/ -import { queryOptions, useQuery } from "@tanstack/react-query"; +import { queryOptions, useMutation, useQuery } from "@tanstack/react-query"; import { ApiError, CustomError, getRequest, llmServicePath, + MutationKeys, QueryKeys, + UseMutationType, } from "src/api/utils.ts"; export interface Model { @@ -97,14 +99,15 @@ const getModelSource = async (): Promise => { return await getRequest(`${llmServicePath}/index/models/model_source`); }; -export const useTestLlmModel = (model_id: string) => { - return useQuery({ - queryKey: [QueryKeys.testLlmModel, { model_id }], - queryFn: async () => { - return await testLlmModel(model_id); - }, - enabled: !!model_id, - retry: false, +export const useTestLlmModel = ({ + onSuccess, + onError, +}: UseMutationType) => { + return useMutation({ + mutationKey: [MutationKeys.testLlmModel], + mutationFn: testLlmModel, + onError, + onSuccess, }); }; @@ -121,14 +124,15 @@ const testLlmModel = async (model_id: string): Promise => { }); }; -export const useTestEmbeddingModel = (model_id: string) => { - return useQuery({ - queryKey: [QueryKeys.testEmbeddingModel, { model_id }], - queryFn: async () => { - return await testEmbeddingModel(model_id); - }, - retry: false, - enabled: !!model_id, +export const useTestEmbeddingModel = ({ + onSuccess, + onError, +}: UseMutationType) => { + return useMutation({ + mutationKey: [MutationKeys.testEmbeddingModel], + mutationFn: testEmbeddingModel, + onError, + onSuccess, }); }; diff --git a/ui/src/api/utils.ts b/ui/src/api/utils.ts index 37879789..40bbe4a5 100644 --- a/ui/src/api/utils.ts +++ b/ui/src/api/utils.ts @@ -65,6 +65,8 @@ export enum MutationKeys { "deleteChatHistory" = "deleteChatHistory", "deleteSession" = "deleteSession", "updateAmp" = "updateAmp", + "testLlmModel" = "testLlmModel", + "testEmbeddingModel" = "testEmbeddingModel", } export enum QueryKeys { @@ -81,8 +83,6 @@ export enum QueryKeys { "getLlmModels" = "getLlmModels", "getEmbeddingModels" = "getEmbeddingModels", "getModelSource" = "getModelSource", - "testLlmModel" = "testLlmModel", - "testEmbeddingModel" = "testEmbeddingModel", } export const commonHeaders = { diff --git a/ui/src/pages/Models/EmbeddingModelTable.tsx b/ui/src/pages/Models/EmbeddingModelTable.tsx index 8f266c48..ede90a1e 100644 --- a/ui/src/pages/Models/EmbeddingModelTable.tsx +++ b/ui/src/pages/Models/EmbeddingModelTable.tsx @@ -38,26 +38,27 @@ import { Table, TableProps } from "antd"; import { Model, useTestEmbeddingModel } from "src/api/modelsApi.ts"; -import { useState } from "react"; import { modelColumns, TestCell } from "pages/Models/ModelTable.tsx"; const EmbeddingModelTestCell = ({ model }: { model: Model }) => { - const [testModel, setTestModel] = useState(""); const { data: testResult, - isLoading, + isPending, error, - } = useTestEmbeddingModel(testModel); + mutate, + } = useTestEmbeddingModel({ + onError: () => undefined, + }); const handleTestModel = () => { - setTestModel(model.model_id); + mutate(model.model_id); }; return ( diff --git a/ui/src/pages/Models/InferenceModelTable.tsx b/ui/src/pages/Models/InferenceModelTable.tsx index 4abd0503..22a7f609 100644 --- a/ui/src/pages/Models/InferenceModelTable.tsx +++ b/ui/src/pages/Models/InferenceModelTable.tsx @@ -38,22 +38,27 @@ import { Table, TableProps } from "antd"; import { Model, useTestLlmModel } from "src/api/modelsApi.ts"; -import { useState } from "react"; import { modelColumns, TestCell } from "pages/Models/ModelTable.tsx"; const InferenceModelTestCell = ({ model }: { model: Model }) => { - const [testModel, setTestModel] = useState(""); - const { data: testResult, isLoading, error } = useTestLlmModel(testModel); + const { + data: testResult, + isPending, + error, + mutate, + } = useTestLlmModel({ + onError: () => undefined, + }); const handleTestModel = () => { - setTestModel(model.model_id); + mutate(model.model_id); }; return ( diff --git a/ui/src/pages/Models/ModelPage.tsx b/ui/src/pages/Models/ModelPage.tsx index 20ffe7e7..ca7a4d85 100644 --- a/ui/src/pages/Models/ModelPage.tsx +++ b/ui/src/pages/Models/ModelPage.tsx @@ -36,19 +36,41 @@ * DATA. ******************************************************************************/ -import { Flex, Typography } from "antd"; +import { Alert, Flex, Typography } from "antd"; import EmbeddingModelTable from "pages/Models/EmbeddingModelTable.tsx"; import { useGetEmbeddingModels, useGetLlmModels } from "src/api/modelsApi.ts"; import InferenceModelTable from "pages/Models/InferenceModelTable.tsx"; const ModelPage = () => { - const { data: embeddingModels, isLoading: areEmbeddingModelsLoading } = - useGetEmbeddingModels(); - const { data: inferenceModels, isLoading: areInferenceModelsLoading } = - useGetLlmModels(); + const { + data: embeddingModels, + isLoading: areEmbeddingModelsLoading, + error: embeddingError, + } = useGetEmbeddingModels(); + const { + data: inferenceModels, + isLoading: areInferenceModelsLoading, + error: inferenceError, + } = useGetLlmModels(); return ( +
+ {inferenceError ? ( + + ) : null} + {embeddingError ? ( + + ) : null} +
Embedding Models { - const queryClient = new QueryClient(); - console.log("RagChat.tsx: RagChat: queryClient: ", queryClient); const { dataSourceId, dataSources, activeSession } = useContext(RagChatContext);