Skip to content

Commit

Permalink
Add new features to support Claude 2.1 document Q&A.
Browse files Browse the repository at this point in the history
  • Loading branch information
yym68686 committed Nov 30, 2023
1 parent 906bba2 commit 413926e
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 27 deletions.
72 changes: 48 additions & 24 deletions bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ async def command_bot(update, context, language=None, prompt=translator_prompt,
prompt = prompt.format(language)
message = prompt + message
if message:
if "claude" in config.GPT_ENGINE and config.ClaudeAPI:
robot = config.claudeBot
await context.bot.send_chat_action(chat_id=update.message.chat_id, action=ChatAction.TYPING)
await getChatGPT(update, context, title, robot, message, config.SEARCH_USE_GPT, has_command)
else:
Expand Down Expand Up @@ -337,8 +339,8 @@ async def button_press(update, context):
if config.API and "gpt-" in data:
config.ChatGPTbot = GPT(api_key=f"{config.API}", engine=config.GPT_ENGINE, system_prompt=config.systemprompt, temperature=config.temperature)
config.ChatGPTbot.reset(convo_id=str(update.effective_chat.id), system_prompt=config.systemprompt)
if config.ClaudeAPI and "claude" in data and not config.API:
config.ChatGPTbot = claudebot(api_key=f"{config.ClaudeAPI}", engine=config.GPT_ENGINE, system_prompt=config.systemprompt, temperature=config.temperature)
if config.ClaudeAPI and "claude" in data:
config.claudeBot = claudebot(api_key=f"{config.ClaudeAPI}", engine=config.GPT_ENGINE, system_prompt=config.systemprompt, temperature=config.temperature)
try:
info_message = (
f"`Hi, {update.effective_user.username}!`\n\n"
Expand Down Expand Up @@ -461,8 +463,8 @@ async def button_press(update, context):
if config.API:
config.ChatGPTbot = GPT(api_key=f"{config.API}", engine=config.GPT_ENGINE, system_prompt=config.systemprompt, temperature=config.temperature)
config.ChatGPTbot.reset(convo_id=str(update.effective_chat.id), system_prompt=config.systemprompt)
if config.ClaudeAPI and not config.API:
config.ChatGPTbot = claudebot(api_key=f"{config.ClaudeAPI}", engine=config.GPT_ENGINE, system_prompt=config.systemprompt, temperature=config.temperature)
if config.ClaudeAPI:
config.claudeBot = claudebot(api_key=f"{config.ClaudeAPI}", engine=config.GPT_ENGINE, system_prompt=config.systemprompt, temperature=config.temperature)

info_message = (
f"`Hi, {update.effective_user.username}!`\n\n"
Expand Down Expand Up @@ -513,36 +515,58 @@ async def info(update, context):
messageid = message.message_id
await context.bot.delete_message(chat_id=update.effective_chat.id, message_id=update.message.message_id)

from utils.agent import pdfQA, getmd5, persist_emdedding_pdf
from utils.agent import pdfQA, getmd5, persist_emdedding_pdf, get_doc_from_url
from pdfminer.high_level import extract_text
@decorators.Authorization
async def handle_pdf(update, context):
# 获取接收到的文件
pdf_file = update.message.document
# 得到文件的url
file_name = pdf_file.file_name
docpath = os.getcwd() + "/" + file_name
persist_db_path = getmd5(docpath)
match_embedding = os.path.exists(persist_db_path)
# file_name = pdf_file.file_name
# docpath = os.getcwd() + "/" + file_name
file_id = pdf_file.file_id
new_file = await context.bot.get_file(file_id)
file_url = new_file.file_path

question = update.message.caption
if question is None:
if not match_embedding:
persist_emdedding_pdf(file_url, persist_db_path)
filename = get_doc_from_url(file_url)
docpath = os.getcwd() + "/" + filename
if config.ClaudeAPI:
text = extract_text(docpath)
prompt = (
"Here is the document, inside <document></document> XML tags:"
"<document>"
"{}"
"</document>"
)
# print(prompt.format(text))
config.claudeBot.add_to_conversation(prompt.format(text), "Human", str(update.effective_chat.id))
message = (
f"已成功解析文档!\n\n"
f"请输入 `要问的问题`\n\n"
f"例如已经上传某文档 ,问题是 蘑菇怎么分类?\n\n"
f"先左滑文档进入回复模式,并在聊天框里面输入 `蘑菇怎么分类?`\n\n"
f"文档上传成功!\n\n"
)
await context.bot.send_message(chat_id=update.effective_chat.id, text=escape(message), parse_mode='MarkdownV2', disable_web_page_preview=True)
return

result = await pdfQA(file_url, docpath, question)
print(result)
await context.bot.send_message(chat_id=update.message.chat_id, text=escape(result), parse_mode='MarkdownV2', disable_web_page_preview=True)
os.remove(docpath)
await context.bot.send_message(chat_id=update.message.chat_id, text=escape(message), parse_mode='MarkdownV2', disable_web_page_preview=True)

# persist_db_path = getmd5(docpath)
# match_embedding = os.path.exists(persist_db_path)
# file_id = pdf_file.file_id
# new_file = await context.bot.get_file(file_id)
# file_url = new_file.file_path

# question = update.message.caption
# if question is None:
# if not match_embedding:
# persist_emdedding_pdf(file_url, persist_db_path)
# message = (
# f"已成功解析文档!\n\n"
# f"请输入 `要问的问题`\n\n"
# f"例如已经上传某文档 ,问题是 蘑菇怎么分类?\n\n"
# f"先左滑文档进入回复模式,并在聊天框里面输入 `蘑菇怎么分类?`\n\n"
# )
# await context.bot.send_message(chat_id=update.effective_chat.id, text=escape(message), parse_mode='MarkdownV2', disable_web_page_preview=True)
# return

# result = await pdfQA(file_url, docpath, question)
# print(result)
# await context.bot.send_message(chat_id=update.message.chat_id, text=escape(result), parse_mode='MarkdownV2', disable_web_page_preview=True)

@decorators.Authorization
async def qa(update, context):
Expand Down
12 changes: 9 additions & 3 deletions utils/chatgpt2api.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def get_filtered_keys_from_object(obj: object, *keys: str) -> Set[str]:
"gpt-4-32k-0613",
"gpt-4-1106-preview",
"claude-2-web",
"claude-2",
"claude-2.1",
]

class claudeConversation(dict):
Expand All @@ -68,6 +68,7 @@ def __init__(
top_p: float = 0.7,
chat_url: str = "https://api.anthropic.com/v1/complete",
timeout: float = None,
system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally",
**kwargs,
):
self.api_key: str = api_key
Expand All @@ -78,17 +79,21 @@ def __init__(
self.timeout = timeout
self.session = requests.Session()
self.conversation = claudeConversation()
self.system_prompt = system_prompt

def add_to_conversation(
self,
message: str,
role: str,
convo_id: str = "default",

pass_history: bool = True,
) -> None:
"""
Add a message to the conversation
"""

if convo_id not in self.conversation or pass_history == False:
self.reset(convo_id=convo_id)
self.conversation[convo_id].append({"role": role, "content": message})

def reset(self, convo_id: str = "default") -> None:
Expand Down Expand Up @@ -147,6 +152,7 @@ def ask_stream(
model_max_tokens: int = 4096,
**kwargs,
):
pass_history = True
if convo_id not in self.conversation or pass_history == False:
self.reset(convo_id=convo_id)
self.add_to_conversation(prompt, role, convo_id=convo_id)
Expand Down Expand Up @@ -191,7 +197,7 @@ def ask_stream(
full_response += content
yield content
self.add_to_conversation(full_response, response_role, convo_id=convo_id)
print(repr(self.conversation.Conversation(convo_id)))
# print(repr(self.conversation.Conversation(convo_id)))
# print("total tokens:", self.get_token_count(convo_id))


Expand Down

0 comments on commit 413926e

Please sign in to comment.