diff --git a/docs/public/schemas/llms.json b/docs/public/schemas/llms.json index bc004791db..5a3147842c 100644 --- a/docs/public/schemas/llms.json +++ b/docs/public/schemas/llms.json @@ -47,6 +47,15 @@ "type": "boolean", "description": "Indicates if tools are supported" }, + "listModels": { + "type": "boolean", + "default": true, + "description": "Indicates if listing models is supported" + }, + "pullModel": { + "type": "boolean", + "description": "Indicates if pulling models is supported" + }, "openaiCompatibility": { "type": "string", "description": "Uses OpenAI API compatibility layer documentation URL" diff --git a/packages/cli/src/info.ts b/packages/cli/src/info.ts index 7054076593..a197cd8806 100644 --- a/packages/cli/src/info.ts +++ b/packages/cli/src/info.ts @@ -18,6 +18,7 @@ import { ModelConnectionInfo, resolveModelConnectionInfo, } from "../../core/src/models" +import { deleteEmptyValues } from "../../core/src/util" import { CORE_VERSION } from "../../core/src/version" import { YAMLStringify } from "../../core/src/yaml" import { buildProject } from "./build" @@ -64,11 +65,19 @@ export async function envInfo( if (models) { const lm = await resolveLanguageModel(modelProvider.id) if (lm.listModels) { - const ms = await lm.listModels(conn) + const ms = await lm.listModels(conn, {}) if (ms?.length) conn.models = ms } } - res.providers.push(conn) + res.providers.push( + deleteEmptyValues({ + provider: conn.provider, + source: conn.source, + base: conn.base, + type: conn.type, + models: conn.models, + }) + ) } } catch (e) { if (error) diff --git a/packages/core/src/aici.ts b/packages/core/src/aici.ts index cad0e9c1f4..bbe4878535 100644 --- a/packages/core/src/aici.ts +++ b/packages/core/src/aici.ts @@ -15,6 +15,8 @@ import { ChatCompletionContentPartText, ChatCompletionResponse, } from "./chattypes" +import { TraceOptions } from "./trace" +import { CancellationOptions } from "./cancellation" /** * Renders an AICI node into a string representation. @@ -404,7 +406,10 @@ const AICIChatCompletion: ChatCompletionHandler = async ( * @param cfg - The configuration for the language model. * @returns A list of language model information. */ -async function listModels(cfg: LanguageModelConfiguration) { +async function listModels( + cfg: LanguageModelConfiguration, + options?: TraceOptions & CancellationOptions +) { const { token, base, version } = cfg const url = `${base}/proxy/info` const fetch = await createFetch() diff --git a/packages/core/src/chat.ts b/packages/core/src/chat.ts index 203aa65c15..0f69d938ba 100644 --- a/packages/core/src/chat.ts +++ b/packages/core/src/chat.ts @@ -127,7 +127,8 @@ export interface LanguageModelInfo { } export type ListModelsFunction = ( - cfg: LanguageModelConfiguration + cfg: LanguageModelConfiguration, + options: TraceOptions & CancellationOptions ) => Promise export type PullModelFunction = ( diff --git a/packages/core/src/constants.ts b/packages/core/src/constants.ts index feef81cf3e..0fd0e26e6c 100644 --- a/packages/core/src/constants.ts +++ b/packages/core/src/constants.ts @@ -108,10 +108,13 @@ export const MARKDOWN_PROMPT_FENCE = "`````" export const OPENAI_API_BASE = "https://api.openai.com/v1" export const OLLAMA_DEFAUT_PORT = 11434 -export const OLLAMA_API_BASE = "http://localhost:11434/v1" -export const LLAMAFILE_API_BASE = "http://localhost:8080/v1" -export const LOCALAI_API_BASE = "http://localhost:8080/v1" -export const LITELLM_API_BASE = "http://localhost:4000" +export const OLLAMA_API_BASE = "http://127.0.0.1:11434/v1" +export const LLAMAFILE_API_BASE = "http://127.0.0.1:8080/v1" +export const LOCALAI_API_BASE = "http://127.0.0.1:8080/v1" +export const LITELLM_API_BASE = "http://127.0.0.1:4000" +export const LMSTUDIO_API_BASE = "http://127.0.0.1:1234/v1" +export const JAN_API_BASE = "http://127.0.0.1:1337/v1" + export const ANTHROPIC_API_BASE = "https://api.anthropic.com" export const HUGGINGFACE_API_BASE = "https://api-inference.huggingface.co/v1" export const GOOGLE_API_BASE = @@ -119,8 +122,6 @@ export const GOOGLE_API_BASE = export const ALIBABA_BASE = "https://dashscope-intl.aliyuncs.com/compatible-mode/v1" export const MISTRAL_API_BASE = "https://api.mistral.ai/v1" -export const LMSTUDIO_API_BASE = "http://localhost:1234/v1" -export const JAN_API_BASE = "http://localhost:1337/v1" export const PROMPTFOO_CACHE_PATH = ".genaiscript/cache/tests" export const PROMPTFOO_CONFIG_DIR = ".genaiscript/config/tests" @@ -190,6 +191,8 @@ export const MODEL_PROVIDERS = Object.freeze< topP?: boolean prediction?: boolean bearerToken?: boolean + listModels?: boolean + pullModel?: boolean aliases?: Record }[] >(CONFIGURATION_DATA.providers) diff --git a/packages/core/src/llms.json b/packages/core/src/llms.json index 46b17b071d..14629c89a5 100644 --- a/packages/core/src/llms.json +++ b/packages/core/src/llms.json @@ -17,11 +17,13 @@ { "id": "azure", "detail": "Azure OpenAI deployment", + "listModels": false, "bearerToken": false }, { "id": "azure_serverless", "detail": "Azure AI OpenAI (serverless deployments)", + "listModels": false, "bearerToken": false, "aliases": { "large": "gpt-4o", @@ -34,6 +36,7 @@ { "id": "azure_serverless_models", "detail": "Azure AI Models (serverless deployments, not OpenAI)", + "listModels": false, "prediction": false, "bearerToken": true }, @@ -43,6 +46,7 @@ "logprobs": false, "topLogprobs": false, "prediction": false, + "listModels": false, "aliases": { "large": "claude-3-5-sonnet-latest", "small": "claude-3-5-haiku-latest", @@ -70,6 +74,7 @@ "openaiCompatibility": "https://ai.google.dev/gemini-api/docs/openai", "prediction": false, "bearerToken": true, + "listModels": false, "aliases": { "large": "gemini-1.5-flash-latest", "small": "gemini-1.5-flash-latest", @@ -84,6 +89,7 @@ "id": "huggingface", "detail": "Hugging Face models", "prediction": false, + "listModels": false, "aliases": { "large": "Qwen/Qwen2.5-72B-Instruct", "small": "Qwen/Qwen2.5-Coder-32B-Instruct", @@ -110,6 +116,7 @@ "openaiCompatibility": "https://www.alibabacloud.com/help/en/model-studio/developer-reference/compatibility-of-openai-with-dashscope", "tools": false, "prediction": false, + "listModels": false, "bearerToken": true, "aliases": { "large": "qwen-max", @@ -125,6 +132,7 @@ "topLogprobs": false, "limitations": "Smaller context windows, and rate limiting", "prediction": false, + "listModels": false, "bearerToken": true, "aliases": { "large": "gpt-4o", @@ -145,6 +153,7 @@ "detail": "Ollama local model", "logitBias": false, "openaiCompatibility": "https://github.com/ollama/ollama/blob/main/docs/openai.md", + "pullModel": true, "prediction": false }, { @@ -155,7 +164,9 @@ { "id": "jan", "detail": "Jan local server", - "prediction": false + "prediction": false, + "listModels": true, + "top_p": false }, { "id": "llamafile", diff --git a/packages/core/src/lm.ts b/packages/core/src/lm.ts index 78546a62a6..d7d1b0e57d 100644 --- a/packages/core/src/lm.ts +++ b/packages/core/src/lm.ts @@ -7,12 +7,14 @@ import { MODEL_PROVIDER_ANTHROPIC, MODEL_PROVIDER_ANTHROPIC_BEDROCK, MODEL_PROVIDER_CLIENT, + MODEL_PROVIDER_JAN, MODEL_PROVIDER_OLLAMA, MODEL_PROVIDER_TRANSFORMERS, + MODEL_PROVIDERS, } from "./constants" import { host } from "./host" import { OllamaModel } from "./ollama" -import { OpenAIModel } from "./openai" +import { LocalOpenAICompatibleModel } from "./openai" import { TransformersModel } from "./transformers" export function resolveLanguageModel(provider: string): LanguageModel { @@ -27,5 +29,10 @@ export function resolveLanguageModel(provider: string): LanguageModel { if (provider === MODEL_PROVIDER_ANTHROPIC_BEDROCK) return AnthropicBedrockModel if (provider === MODEL_PROVIDER_TRANSFORMERS) return TransformersModel - return OpenAIModel + + const features = MODEL_PROVIDERS.find((p) => p.id === provider) + return LocalOpenAICompatibleModel(provider, { + listModels: features?.listModels !== false, + pullModel: features?.pullModel, + }) } diff --git a/packages/core/src/ollama.ts b/packages/core/src/ollama.ts index e0da923751..939c8d0714 100644 --- a/packages/core/src/ollama.ts +++ b/packages/core/src/ollama.ts @@ -8,6 +8,8 @@ import { OpenAIChatCompletion } from "./openai" import { LanguageModelConfiguration } from "./host" import { host } from "./host" import { logError, logVerbose } from "./util" +import { TraceOptions } from "./trace" +import { CancellationOptions } from "./cancellation" /** * Lists available models for the Ollama language model configuration. @@ -17,10 +19,11 @@ import { logError, logVerbose } from "./util" * @returns A promise that resolves to an array of LanguageModelInfo objects. */ async function listModels( - cfg: LanguageModelConfiguration + cfg: LanguageModelConfiguration, + options: TraceOptions & CancellationOptions ): Promise { // Create a fetch instance to make HTTP requests - const fetch = await createFetch({ retries: 0 }) + const fetch = await createFetch({ retries: 0, ...options }) // Fetch the list of models from the remote API const res = await fetch(cfg.base.replace("/v1", "/api/tags"), { method: "GET", diff --git a/packages/core/src/openai.ts b/packages/core/src/openai.ts index 73b324cdc6..5008e735a8 100644 --- a/packages/core/src/openai.ts +++ b/packages/core/src/openai.ts @@ -1,5 +1,6 @@ import { deleteUndefinedValues, + logError, logVerbose, normalizeInt, trimTrailingSlash, @@ -18,9 +19,14 @@ import { TOOL_URL, } from "./constants" import { estimateTokens } from "./tokens" -import { ChatCompletionHandler, LanguageModel, LanguageModelInfo } from "./chat" +import { + ChatCompletionHandler, + LanguageModel, + LanguageModelInfo, + PullModelFunction, +} from "./chat" import { RequestError, errorMessage, serializeError } from "./error" -import { createFetch, traceFetchPost } from "./fetch" +import { createFetch, iterateBody, traceFetchPost } from "./fetch" import { parseModelIdentifier } from "./models" import { JSON5TryParse } from "./json5" import { @@ -39,9 +45,10 @@ import { ChatCompletionTokenLogprob, } from "./chattypes" import { resolveTokenEncoder } from "./encoders" -import { toSignal } from "./cancellation" +import { CancellationOptions, toSignal } from "./cancellation" import { INITryParse } from "./ini" import { serializeChunkChoiceToLogProbs } from "./logprob" +import { TraceOptions } from "./trace" export function getConfigHeaders(cfg: LanguageModelConfiguration) { let { token, type, base, provider } = cfg @@ -420,9 +427,10 @@ export const OpenAIChatCompletion: ChatCompletionHandler = async ( } async function listModels( - cfg: LanguageModelConfiguration + cfg: LanguageModelConfiguration, + options: TraceOptions & CancellationOptions ): Promise { - const fetch = await createFetch({ retries: 0 }) + const fetch = await createFetch({ retries: 0, ...(options || {}) }) const res = await fetch(cfg.base + "/models", { method: "GET", headers: { @@ -449,8 +457,68 @@ async function listModels( ) } -export const OpenAIModel = Object.freeze({ - completer: OpenAIChatCompletion, - id: MODEL_PROVIDER_OPENAI, - listModels, -}) +const pullModel: PullModelFunction = async (modelId, options) => { + const { trace, cancellationToken } = options || {} + const { provider, model } = parseModelIdentifier(modelId) + const fetch = await createFetch({ retries: 0, ...options }) + const conn = await host.getLanguageModelConfiguration(modelId, { + token: true, + cancellationToken, + trace, + }) + try { + // test if model is present + const resTags = await fetch(`${conn.base}/models`, { + retries: 0, + method: "GET", + headers: { + "User-Agent": TOOL_ID, + "Content-Type": "application/json", + }, + }) + if (resTags.ok) { + const { data: models }: { data: { id: string }[] } = + await resTags.json() + if (models.find((m) => m.id === model)) return { ok: true } + } + + // pull + logVerbose(`${provider}: pull ${model}`) + const resPull = await fetch(`${conn.base}/models/pull`, { + method: "POST", + headers: { + "User-Agent": TOOL_ID, + "Content-Type": "application/json", + }, + body: JSON.stringify({ model }), + }) + if (!resPull.ok) { + logError(`${provider}: failed to pull model ${model}`) + logVerbose(resPull.statusText) + return { ok: false, status: resPull.status } + } + 0 + for await (const chunk of iterateBody(resPull, { cancellationToken })) + process.stderr.write(".") + process.stderr.write("\n") + return { ok: true } + } catch (e) { + logError(e) + trace?.error(e) + return { ok: false, error: serializeError(e) } + } +} + +export function LocalOpenAICompatibleModel( + providerId: string, + options: { listModels?: boolean; pullModel?: boolean } +) { + return Object.freeze( + deleteUndefinedValues({ + completer: OpenAIChatCompletion, + id: providerId, + listModels: options?.listModels ? listModels : undefined, + pullModel: options?.pullModel ? pullModel : undefined, + }) + ) +} diff --git a/packages/vscode/src/servermanager.ts b/packages/vscode/src/servermanager.ts index 8e80fffa41..5b232f5f1d 100644 --- a/packages/vscode/src/servermanager.ts +++ b/packages/vscode/src/servermanager.ts @@ -83,7 +83,7 @@ export class TerminalServerManager implements ServerManager { private async startClient(): Promise { assert(!this._client) this._port = await findRandomOpenPort() - const url = `http://localhost:${this._port}?api-key=${encodeURIComponent(this.state.sessionApiKey)}` + const url = `http://127.0.0.1:${this._port}?api-key=${encodeURIComponent(this.state.sessionApiKey)}` logInfo(`client url: ${url}`) const client = (this._client = new WebSocketClient(url)) client.chatRequest = createChatModelRunner(this.state)