diff --git a/src/cloudflare/internal/ai-api.ts b/src/cloudflare/internal/ai-api.ts index 7077e6765e8..67a458cf9a6 100644 --- a/src/cloudflare/internal/ai-api.ts +++ b/src/cloudflare/internal/ai-api.ts @@ -96,10 +96,12 @@ export class Ai { }, }; - const res = await this.fetcher.fetch( - 'https://workers-binding.ai/run?version=3', - fetchOptions - ); + let endpointUrl = 'https://workers-binding.ai/run?version=3'; + if (options?.gateway?.id) { + endpointUrl = 'https://workers-binding.ai/ai-gateway/run?version=3'; + } + + const res = await this.fetcher.fetch(endpointUrl, fetchOptions); this.lastRequestId = res.headers.get('cf-ai-req-id'); this.aiGatewayLogId = res.headers.get('cf-aig-log-id'); diff --git a/src/cloudflare/internal/test/ai/ai-api-test.js b/src/cloudflare/internal/test/ai/ai-api-test.js index 49956a4521e..80f791814cd 100644 --- a/src/cloudflare/internal/test/ai/ai-api-test.js +++ b/src/cloudflare/internal/test/ai/ai-api-test.js @@ -91,7 +91,11 @@ export const tests = { // Test raw input const resp = await env.ai.run('rawInputs', { prompt: 'test' }); - assert.deepStrictEqual(resp, { inputs: { prompt: 'test' }, options: {} }); + assert.deepStrictEqual(resp, { + inputs: { prompt: 'test' }, + options: {}, + requestUrl: 'https://workers-binding.ai/run?version=3', + }); } { @@ -105,6 +109,7 @@ export const tests = { assert.deepStrictEqual(resp, { inputs: { prompt: 'test' }, options: { gateway: { id: 'my-gateway', skipCache: true } }, + requestUrl: 'https://workers-binding.ai/ai-gateway/run?version=3', }); } @@ -126,6 +131,7 @@ export const tests = { example: 123, gateway: { id: 'my-gateway', metadata: { employee: 1233 } }, }, + requestUrl: 'https://workers-binding.ai/ai-gateway/run?version=3', }); } }, diff --git a/src/cloudflare/internal/test/ai/ai-mock.js b/src/cloudflare/internal/test/ai/ai-mock.js index da07f77e752..c85c6f23017 100644 --- a/src/cloudflare/internal/test/ai/ai-mock.js +++ b/src/cloudflare/internal/test/ai/ai-mock.js @@ -22,9 +22,15 @@ export default { } if (modelName === 'rawInputs') { - return Response.json(data, { - headers: respHeaders, - }); + return Response.json( + { + ...data, + requestUrl: request.url, + }, + { + headers: respHeaders, + } + ); } if (modelName === 'inputErrorModel') {