Skip to content

Commit

Permalink
Handling new o1 models (params, system messages, json mode)
Browse files Browse the repository at this point in the history
  • Loading branch information
peterbanda committed Jan 3, 2025
1 parent 1249852 commit 81cf970
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,25 @@ private[service] trait OpenAIChatCompletionServiceImpl

trait ChatCompletionBodyMaker {

private val o1Models = Set(
private val noSystemMessageModels = Set(
ModelId.o1_preview,
ModelId.o1_preview_2024_09_12,
ModelId.o1_mini,
ModelId.o1_mini_2024_09_12
)

private val o1PreviewModels = Set(
ModelId.o1_preview,
ModelId.o1_preview_2024_09_12,
ModelId.o1_mini,
ModelId.o1_mini_2024_09_12
)

private val o1Models = Set(
ModelId.o1,
ModelId.o1_2024_12_17
)

protected def createBodyParamsForChatCompletion(
messagesAux: Seq[BaseMessage],
settings: CreateChatCompletionSettings,
Expand All @@ -63,7 +75,7 @@ trait ChatCompletionBodyMaker {

// O1 models needs some special treatment... revisit this later
val messagesFinal =
if (o1Models.contains(settings.model))
if (noSystemMessageModels.contains(settings.model))
MessageConversions.systemToUserMessages(messagesAux)
else
messagesAux
Expand All @@ -72,8 +84,10 @@ trait ChatCompletionBodyMaker {

// O1 models needs some special treatment... revisit this later
val settingsFinal =
if (o1Models.contains(settings.model))
ChatCompletionSettingsConversions.o1Specific(settings)
if (o1PreviewModels.contains(settings.model))
ChatCompletionSettingsConversions.o1Preview(settings)
else if (o1Models.contains(settings.model))
ChatCompletionSettingsConversions.o1(settings)
else
settings

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,10 @@ object OpenAIChatCompletionExtra {
}

private val defaultJsonSchemaModels = Seq(
"openai-" + ModelId.gpt_4o_2024_08_06,
ModelId.gpt_4o_2024_08_06
)
ModelId.gpt_4o_2024_08_06,
ModelId.o1,
ModelId.o1_2024_12_17
).flatMap(id => Seq(id, "openai-" + id))

def handleOutputJsonSchema(
messages: Seq[BaseMessage],
Expand All @@ -144,7 +145,7 @@ object OpenAIChatCompletionExtra {

val (settingsFinal, addJsonToPrompt) =
if (jsonSchemaModels.contains(settings.model)) {
logger.debug(
logger.info(
s"Using OpenAI json schema mode for ${taskNameForLogging} and the model '${settings.model}' - name: ${jsonSchemaDef.name}, strict: ${jsonSchemaDef.strict}, structure:\n${jsonSchemaString}"
)

Expand All @@ -157,7 +158,7 @@ object OpenAIChatCompletionExtra {
} else {
// otherwise we failover to json object format and pass json schema to the user prompt

logger.debug(
logger.info(
s"Using JSON object mode for ${taskNameForLogging} and the model '${settings.model}'. Also passing a JSON schema as part of a user prompt."
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ object ChatCompletionSettingsConversions {
} else acc
}

private val o1Conversions = Seq(
private val o1BaseConversions = Seq(
// max tokens
FieldConversionDef(
_.max_tokens.isDefined,
Expand Down Expand Up @@ -79,18 +79,23 @@ object ChatCompletionSettingsConversions {
"O1 models don't support frequency penalty values other than the default of 0, converting to 0."
),
warning = true
),
// frequency_penalty
FieldConversionDef(
settings =>
settings.response_format_type.isDefined && settings.response_format_type.get != ChatCompletionResponseFormatType.text,
_.copy(response_format_type = None),
Some(
"O1 models don't support json object/schema response format, converting to None."
),
warning = true
)
)

val o1Specific: SettingsConversion = generic(o1Conversions)
private val o1PreviewConversions =
o1BaseConversions :+
// response format type
FieldConversionDef(
settings =>
settings.response_format_type.isDefined && settings.response_format_type.get != ChatCompletionResponseFormatType.text,
_.copy(response_format_type = None),
Some(
"O1 models don't support json object/schema response format, converting to None."
),
warning = true
)

val o1: SettingsConversion = generic(o1BaseConversions)

val o1Preview: SettingsConversion = generic(o1PreviewConversions)
}

0 comments on commit 81cf970

Please sign in to comment.