Skip to content

Commit

Permalink
MIsc. updates and bugfixes (#25)
Browse files Browse the repository at this point in the history
* add scores to sources

* add in toggle for querying knowledge base

* wip for excluding knowledge base

* wip on llm_talk

* make direct calls to the CAII embedding model, rather than using the OpenAI library

* init the context in the super constructor call

* Fix check for CAII domain

* fix broken test compilation

* Update release version to dev-testing

* hide the knowledge base switch for now

---------

Co-authored-by: Elijah Williams <[email protected]>
Co-authored-by: Michael Liu <[email protected]>
Co-authored-by: actions-user <[email protected]>
  • Loading branch information
4 people authored Nov 20, 2024
1 parent b92350c commit 83122e9
Show file tree
Hide file tree
Showing 15 changed files with 166 additions and 31 deletions.
5 changes: 2 additions & 3 deletions llm-service/app/rag_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,13 @@
# DATA.
# ##############################################################################

from typing import List, Optional
from typing import Optional

from pydantic import BaseModel

from .services.chat_store import RagContext, RagPredictSourceNode


class RagPredictConfiguration(BaseModel):
top_k: int = 5
chunk_size: int = 512
model_name: str = "meta.llama3-1-8b-instruct-v1:0"
exclude_knowledge_base: Optional[bool] = False
24 changes: 23 additions & 1 deletion llm-service/app/routers/index/sessions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,16 @@
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
# DATA.
# ##############################################################################
import time
import uuid

from fastapi import APIRouter

from pydantic import BaseModel

from .... import exceptions
from ....services.chat_store import RagStudioChatMessage, chat_store
from ....services import qdrant
from ....services import qdrant, llm_completion
from ....services.chat import (v2_chat, generate_suggested_questions)

router = APIRouter(prefix="/sessions/{session_id}", tags=["Sessions"])
Expand Down Expand Up @@ -77,8 +79,28 @@ def chat(
session_id: int,
request: RagStudioChatRequest,
) -> RagStudioChatMessage:
if request.configuration.exclude_knowledge_base:
return llm_talk(session_id, request)
return v2_chat(session_id, request.data_source_id, request.query, request.configuration)

def llm_talk(
session_id: int,
request: RagStudioChatRequest,
) -> RagStudioChatMessage:
chat_response = llm_completion.completion(session_id, request.query, request.configuration)
new_chat_message = RagStudioChatMessage(
id=str(uuid.uuid4()),
source_nodes=[],
evaluations=[],
rag_message={
"user": request.query,
"assistant": chat_response.message.content,
},
timestamp=time.time()
)
chat_store.append_to_history(session_id, [new_chat_message])
return new_chat_message


class SuggestQuestionsRequest(BaseModel):
data_source_id: int
Expand Down
34 changes: 25 additions & 9 deletions llm-service/app/services/CaiiEmbeddingModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
# DATA.
#

import http.client as http_client
import json
import os
from typing import List

from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding
from openai import OpenAI
Expand All @@ -60,19 +62,33 @@ def _get_query_embedding(self, query: str) -> Embedding:
return self._get_embedding(query, "query")

def _get_embedding(self, query: str, input_type: str) -> Embedding:
client, model = self._get_client()
query = query.replace("\n", " ")
return (
client.embeddings.create(input=[query], extra_body={ "input_type": input_type, "truncate": "END"}, model=model).data[0].embedding
)
model = self.endpoint["endpointmetadata"]["model_name"]
domain = os.environ['CAII_DOMAIN']

connection = http_client.HTTPSConnection(domain, 443)
headers = self.build_auth_headers()
headers["Content-Type"] = "application/json"
body = json.dumps({
"input": query,
"input_type": input_type,
"truncate": "END",
"model": model
})
connection.request("POST", self.endpoint["url"], body=body, headers=headers)
res = connection.getresponse()
data = res.read()
json_response = data.decode("utf-8")
structured_response = json.loads(json_response)
embedding = structured_response["data"][0]["embedding"]

return embedding

def _get_client(self) -> (OpenAI, any):
api_base = self.endpoint["url"].removesuffix("/embeddings")

def build_auth_headers(self) -> dict:
with open('/tmp/jwt', 'r') as file:
jwt_contents = json.load(file)
access_token = jwt_contents['access_token']
headers = {
"Authorization": f"Bearer {access_token}"
}
return OpenAI(base_url=api_base, default_headers=headers, api_key="api_key"), self.endpoint["endpointmetadata"]["model_name"]
return headers
4 changes: 2 additions & 2 deletions llm-service/app/services/CaiiModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
# DATA.
#
from llama_index.core.base.llms import generic_utils
from llama_index.core.base.llms.types import LLMMetadata
from llama_index.core.bridge.pydantic import Field
from llama_index.llms.mistralai.base import MistralAI
Expand All @@ -60,7 +59,8 @@ def __init__(
api_base=api_base,
messages_to_prompt=messages_to_prompt,
completion_to_prompt=completion_to_prompt,
default_headers=default_headers)
default_headers=default_headers,
context=context)
self.context = context

@property
Expand Down
2 changes: 1 addition & 1 deletion llm-service/app/services/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def v2_chat(
return new_chat_message


def retrieve_chat_history(session_id):
def retrieve_chat_history(session_id) -> list[RagContext]:
chat_history = chat_store.retrieve_chat_history(session_id)[:10]
history: [RagContext] = list()
for message in chat_history:
Expand Down
59 changes: 59 additions & 0 deletions llm-service/app/services/llm_completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#
# CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP)
# (C) Cloudera, Inc. 2024
# All rights reserved.
#
# Applicable Open Source License: Apache 2.0
#
# NOTE: Cloudera open source products are modular software products
# made up of hundreds of individual components, each of which was
# individually copyrighted. Each Cloudera open source product is a
# collective work under U.S. Copyright Law. Your license to use the
# collective work is as provided in your written agreement with
# Cloudera. Used apart from the collective work, this file is
# licensed for your use pursuant to the open source license
# identified above.
#
# This code is provided to you pursuant a written agreement with
# (i) Cloudera, Inc. or (ii) a third-party authorized to distribute
# this code. If you do not have a written agreement with Cloudera nor
# with an authorized and properly licensed third party, you do not
# have any rights to access nor to use this code.
#
# Absent a written agreement with Cloudera, Inc. ("Cloudera") to the
# contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY
# KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED
# WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO
# IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND
# FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU,
# AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS
# ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE
# OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR
# CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES
# RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
# DATA.
#
import itertools

from llama_index.core.base.llms.types import ChatMessage, ChatResponse

from .chat_store import chat_store, RagStudioChatMessage
from .qdrant import RagPredictConfiguration
from .models import get_llm


def make_chat_messages(x: RagStudioChatMessage) -> list[ChatMessage]:
user = ChatMessage.from_str(x.rag_message['user'], role="user")
assistant = ChatMessage.from_str(x.rag_message['assistant'], role="assistant")
return [user, assistant]


def completion(session_id: int, question: str, configuration: RagPredictConfiguration) -> ChatResponse:
model = get_llm(configuration.model_name)
chat_history = chat_store.retrieve_chat_history(session_id)[:10]
messages = list(itertools.chain.from_iterable(map(lambda x: make_chat_messages(x), chat_history)))
messages.append(ChatMessage.from_str(question, role="user"))
return model.chat(messages)

14 changes: 8 additions & 6 deletions llm-service/app/services/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,9 @@ def get_available_llm_models():
return _get_bedrock_llm_models()


def is_caii_enabled():
return "CAII_DOMAIN" in os.environ
def is_caii_enabled() -> bool:
domain: str = os.environ.get("CAII_DOMAIN", "")
return len(domain) > 0


def _get_bedrock_llm_models():
Expand Down Expand Up @@ -130,22 +131,23 @@ def test_llm_model(model_name: str) -> Literal["ok"]:
models = get_available_llm_models()
for model in models:
if model["model_id"] == model_name:
if not is_caii_enabled() or model['available']:
if not is_caii_enabled() or model["available"]:
get_llm(model_name).complete("Are you available to answer questions?")
return "ok"
else:
raise HTTPException(status_code=503, detail="Model not ready")

raise HTTPException(status_code=404, detail="Model not found")


def test_embedding_model(model_name: str) -> str:
models = get_available_embedding_models()
for model in models:
if model["model_id"] == model_name:
if not is_caii_enabled() or model['available']:
if not is_caii_enabled() or model["available"]:
# TODO: Update to pass embedding model in the future when multiple are supported
get_embedding_model().get_text_embedding('test')
return 'ok'
get_embedding_model().get_text_embedding("test")
return "ok"
else:
raise HTTPException(status_code=503, detail="Model not ready")

Expand Down
3 changes: 1 addition & 2 deletions llm-service/app/services/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,7 @@ def query(

response_synthesizer = get_response_synthesizer(llm=llm)
query_engine = RetrieverQueryEngine(
retriever=retriever, response_synthesizer=response_synthesizer
)
retriever=retriever, response_synthesizer=response_synthesizer)
chat_engine = CondenseQuestionChatEngine.from_defaults(
query_engine=query_engine,
llm=llm,
Expand Down
1 change: 1 addition & 0 deletions ui/src/api/chatApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ export interface RagMessageV2 {
export interface QueryConfiguration {
top_k: number;
model_name: string;
exclude_knowledge_base: boolean;
}

export interface ChatMutationRequest {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,11 @@ describe("ChatBodyController", () => {
chatHistory: [],
dataSourceId: undefined,
dataSourcesStatus: undefined,
queryConfiguration: { top_k: 5, model_name: "" },
queryConfiguration: {
top_k: 5,
model_name: "",
exclude_knowledge_base: false,
},
setQueryConfiguration: () => null,
setCurrentQuestion: () => null,
dataSourceSize: null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ const ChatMessage = ({
</div>
<Flex vertical gap={12}>
<SourceNodes data={data} />
<Typography.Text style={{ fontSize: 16 }}>
<Typography.Text style={{ fontSize: 16, whiteSpace: "pre-wrap" }}>
{data.rag_message.assistant}
</Typography.Text>
<Evaluations evaluations={data.evaluations} />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ import Images from "src/components/images/Images.ts";

const UserQuestion = (props: { question: string }) => (
<Flex justify="end" gap={8}>
<Typography.Text style={{ fontSize: 16 }}>{props.question}</Typography.Text>
<Typography.Text style={{ fontSize: 16, whiteSpace: "pre-wrap" }}>
{props.question}
</Typography.Text>
<Images.User
style={{
padding: 4,
Expand Down
14 changes: 12 additions & 2 deletions ui/src/pages/RagChatTab/ChatOutput/Sources/SourceCard.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import { SourceNode } from "src/api/chatApi.ts";
import { useGetDocumentSummary } from "src/api/summaryApi.ts";
import DocumentationIcon from "src/cuix/icons/DocumentationIcon";
import Icon from "@ant-design/icons";
import { cdlGray600 } from "src/cuix/variables.ts";

export const SourceCard = ({ source }: { source: SourceNode }) => {
const { dataSourceId } = useContext(RagChatContext);
Expand Down Expand Up @@ -71,7 +72,14 @@ export const SourceCard = ({ source }: { source: SourceNode }) => {
onOpenChange={handleGetChunkContents}
content={
<Card
title={source.source_file_name}
title={
<Flex justify="space-between">
{source.source_file_name}
<Typography.Text style={{ color: cdlGray600 }}>
Score: {source.score}
</Typography.Text>
</Flex>
}
bordered={false}
style={{ width: 600, height: 300, overflowY: "auto" }}
>
Expand Down Expand Up @@ -105,7 +113,9 @@ export const SourceCard = ({ source }: { source: SourceNode }) => {
<Typography.Title level={5} style={{ marginTop: 10 }}>
Extracted reference content
</Typography.Title>
<Typography.Paragraph style={{ textAlign: "left" }}>
<Typography.Paragraph
style={{ textAlign: "left", whiteSpace: "pre-wrap" }}
>
{chunkContents.data}
</Typography.Paragraph>
</Flex>
Expand Down
24 changes: 22 additions & 2 deletions ui/src/pages/RagChatTab/FooterComponents/RagChatQueryInput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@
* DATA.
******************************************************************************/

import { Button, Flex, Input } from "antd";
import { Button, Flex, Input, Switch, Tooltip } from "antd";
import SuggestedQuestionsFooter from "pages/RagChatTab/FooterComponents/SuggestedQuestionsFooter.tsx";
import { SendOutlined } from "@ant-design/icons";
import { DatabaseFilled, SendOutlined } from "@ant-design/icons";
import { useContext, useState } from "react";
import { RagChatContext } from "pages/RagChatTab/State/RagChatContext.tsx";
import messageQueue from "src/utils/messageQueue.ts";
Expand All @@ -47,6 +47,8 @@ import { useSuggestQuestions } from "src/api/ragQueryApi.ts";
import { useParams } from "@tanstack/react-router";
import { cdlBlue600 } from "src/cuix/variables.ts";

import type { SwitchChangeEventHandler } from "antd/lib/switch";

const RagChatQueryInput = () => {
const {
dataSourceId,
Expand All @@ -55,6 +57,7 @@ const RagChatQueryInput = () => {
chatHistory,
dataSourceSize,
dataSourcesStatus,
setQueryConfiguration,
} = useContext(RagChatContext);

const [userInput, setUserInput] = useState("");
Expand Down Expand Up @@ -92,6 +95,13 @@ const RagChatQueryInput = () => {
}
};

const handleExcludeKnowledgeBase: SwitchChangeEventHandler = (checked) => {
setQueryConfiguration((prev) => ({
...prev,
exclude_knowledge_base: !checked,
}));
};

return (
<div>
<Flex vertical align="center" gap={10}>
Expand Down Expand Up @@ -119,6 +129,16 @@ const RagChatQueryInput = () => {
handleChat(userInput);
}
}}
suffix={
<Tooltip title="Whether to query against the knowledge base. Disabling will query only against the model's training data.">
<Switch
checkedChildren={<DatabaseFilled />}
value={!queryConfiguration.exclude_knowledge_base}
onChange={handleExcludeKnowledgeBase}
style={{ display: "none" }} // note: disabled for now, until UX is ready
/>
</Tooltip>
}
disabled={!dataSourceSize || chatMutation.isPending}
/>
<Button
Expand Down
1 change: 1 addition & 0 deletions ui/src/pages/RagChatTab/State/RagChatContext.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ export interface RagChatContextType {
export const defaultQueryConfig = {
top_k: 5,
model_name: "",
exclude_knowledge_base: false,
};

export const RagChatContext = createContext<RagChatContextType>({
Expand Down

0 comments on commit 83122e9

Please sign in to comment.