-
Notifications
You must be signed in to change notification settings - Fork 144
/
query_rewriter.py
92 lines (87 loc) · 3.92 KB
/
query_rewriter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import json
from openai.types.chat import (
ChatCompletion,
ChatCompletionToolParam,
)
def build_search_function() -> list[ChatCompletionToolParam]:
return [
{
"type": "function",
"function": {
"name": "search_database",
"description": "Search PostgreSQL database for relevant products based on user query",
"parameters": {
"type": "object",
"properties": {
"search_query": {
"type": "string",
"description": "Query string to use for full text search, e.g. 'red shoes'",
},
"price_filter": {
"type": "object",
"description": "Filter search results based on price of the product",
"properties": {
"comparison_operator": {
"type": "string",
"description": "Operator to compare the column value, either '>', '<', '>=', '<=', '='", # noqa
},
"value": {
"type": "number",
"description": "Value to compare against, e.g. 30",
},
},
},
"brand_filter": {
"type": "object",
"description": "Filter search results based on brand of the product",
"properties": {
"comparison_operator": {
"type": "string",
"description": "Operator to compare the column value, either '=' or '!='",
},
"value": {
"type": "string",
"description": "Value to compare against, e.g. AirStrider",
},
},
},
},
"required": ["search_query"],
},
},
}
]
def extract_search_arguments(original_user_query: str, chat_completion: ChatCompletion):
response_message = chat_completion.choices[0].message
search_query = None
filters = []
if response_message.tool_calls:
for tool in response_message.tool_calls:
if tool.type != "function":
continue
function = tool.function
if function.name == "search_database":
arg = json.loads(function.arguments)
# Even though its required, search_query is not always specified
search_query = arg.get("search_query", original_user_query)
if "price_filter" in arg and arg["price_filter"]:
price_filter = arg["price_filter"]
filters.append(
{
"column": "price",
"comparison_operator": price_filter["comparison_operator"],
"value": price_filter["value"],
}
)
if "brand_filter" in arg and arg["brand_filter"]:
brand_filter = arg["brand_filter"]
filters.append(
{
"column": "brand",
"comparison_operator": brand_filter["comparison_operator"],
"value": brand_filter["value"],
}
)
elif query_text := response_message.content:
search_query = query_text.strip()
return search_query, filters