Skip to content

Commit

Permalink
DSE-40313 - Increase max tokens for CAII models.
Browse files Browse the repository at this point in the history
  • Loading branch information
cl-gavan committed Dec 18, 2024
1 parent 30f8da7 commit b05e0b9
Showing 1 changed file with 4 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class ClouderaAIInferenceLanguageModelProvider(BaseProvider, SimpleChatModel, LL
ai_inference_models, models = getCopilotModels(copilot_config_dir, model_type="inference")
jwt_path = '/tmp/jwt'

MAX_TOKENS = 2048

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.model = kwargs.get("model_id")
Expand Down Expand Up @@ -93,7 +95,7 @@ def _stream(
# OpenAI Chat completions API
request_messages = self.BuildChatCompletionMessage(messages)

request = {"messages": request_messages, "model": self.model, "temperature": 1, "max_tokens": 256, "stream": True}
request = {"messages": request_messages, "model": self.model, "temperature": 1, "max_tokens": MAX_TOKENS, "stream": True}
logging.info(f"request: {request}")
try:
r = requests.post(
Expand Down Expand Up @@ -122,7 +124,7 @@ def _stream(
prompt = self.BuildCompletionPrompt(messages)
req_data = '{"prompt": "' + prompt.encode('unicode_escape').decode("utf-8")

my_req_data = req_data + '","model":"' + self.model + '","temperature":1,"max_tokens":256,"stream":true}'
my_req_data = req_data + '","model":"' + self.model + '","temperature":1,"max_tokens":' + str(MAX_TOKENS) + ',"stream":true}'
logging.info('req:')
logging.info(my_req_data)

Expand Down

0 comments on commit b05e0b9

Please sign in to comment.