Skip to content

Commit

Permalink
Merge pull request #9 from cloudera/max-tokens-fix
Browse files Browse the repository at this point in the history
DSE-41570 - Fix MAX_TOKENS error for AI Inference models
  • Loading branch information
cl-gavan authored Jan 8, 2025
2 parents f9926c0 + 94b9fae commit 8b8dfbd
Showing 1 changed file with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _stream(
# OpenAI Chat completions API
request_messages = self.BuildChatCompletionMessage(messages)

request = {"messages": request_messages, "model": self.model, "temperature": 1, "max_tokens": MAX_TOKENS, "stream": True}
request = {"messages": request_messages, "model": self.model, "temperature": 1, "max_tokens": self.MAX_TOKENS, "stream": True}
logging.info(f"request: {request}")
try:
r = requests.post(
Expand Down Expand Up @@ -124,7 +124,7 @@ def _stream(
prompt = self.BuildCompletionPrompt(messages)
req_data = '{"prompt": "' + prompt.encode('unicode_escape').decode("utf-8")

my_req_data = req_data + '","model":"' + self.model + '","temperature":1,"max_tokens":' + str(MAX_TOKENS) + ',"stream":true}'
my_req_data = req_data + '","model":"' + self.model + '","temperature":1,"max_tokens":' + str(self.MAX_TOKENS) + ',"stream":true}'
logging.info('req:')
logging.info(my_req_data)

Expand Down Expand Up @@ -172,7 +172,7 @@ def _call(
if inference_endpoint.find("chat/completions") != -1:
# OpenAI Chat completions API
request_messages = self.BuildChatCompletionMessage(messages)
request = {"messages": request_messages, "model": self.model, "temperature": 1, "max_tokens": 1024, "stream": False}
request = {"messages": request_messages, "model": self.model, "temperature": 1, "max_tokens": self.MAX_TOKENS, "stream": False}
logging.info(json.dumps(request))
try:
r = requests.post(inference_endpoint,
Expand All @@ -192,7 +192,7 @@ def _call(
# OpenAI Completions API
prompt = self.BuildCompletionPrompt(messages)
logging.info(f"prompt: {prompt}")
request = {"prompt": prompt, "model": self.model, "temperature": 1, "max_tokens": 1024, "stream": False}
request = {"prompt": prompt, "model": self.model, "temperature": 1, "max_tokens": self.MAX_TOKENS, "stream": False}
logging.info(json.dumps(request))

try:
Expand Down

0 comments on commit 8b8dfbd

Please sign in to comment.