From db128adafd9e3caf3253a63328e61c3dc026a2c9 Mon Sep 17 00:00:00 2001 From: linxiaodong Date: Mon, 23 Sep 2024 13:15:32 +0800 Subject: [PATCH] feat: support open ai --- README.md | 1 + main/helpers/storeManager.ts | 15 ++- main/helpers/translate.ts | 35 ++++--- main/service/openai.ts | 45 +++++++++ package.json | 1 + renderer/components/TaskConfigForm.tsx | 28 +++++- renderer/pages/translateControl.tsx | 127 ++++++++++++++++++++++++- yarn.lock | 107 +++++++++++++++++++++ 8 files changed, 340 insertions(+), 19 deletions(-) create mode 100644 main/service/openai.ts diff --git a/README.md b/README.md index ef786d3..7dd87e2 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ - 支持百度翻译 - 支持 deeplx 翻译 (批量翻译容易存在被限流的情况) - 支持本地模型 ollama 翻译 +- 支持 OpenAI 风格 API 翻译 如 deepspeed 等 - 自定义字幕文件名,方便兼容不同的播放器挂载字幕识别 - 自定义翻译后的字幕文件内容,纯翻译结果,原字幕+翻译结果 - 项目集成 `whisper.cpp`, 它对 apple silicon 进行了优化,有较快的生成速度 diff --git a/main/helpers/storeManager.ts b/main/helpers/storeManager.ts index 4e676c4..ff6ba1e 100644 --- a/main/helpers/storeManager.ts +++ b/main/helpers/storeManager.ts @@ -18,7 +18,8 @@ const defaultTranslationProviders = [ type: 'local', apiUrl: 'http://localhost:11434', modelName: 'llama2', - prompt: 'Please translate the following content from ${sourceLanguage} to ${targetLanguage}, only return the translation result can be. \n ${content}' }, + prompt: 'Please translate the following content from ${sourceLanguage} to ${targetLanguage}, only return the translation result can be. \n ${content}' + }, ]; export const store = new Store({ @@ -47,10 +48,18 @@ export function setupStoreHandlers() { return defaultProvider; }); + // 添加用户自定义的提供商(如OpenAI风格的API) + const customProviders = storedProviders.filter(provider => + !defaultTranslationProviders.some(defaultProvider => defaultProvider.id === provider.id) + ); + + const allProviders = [...mergedProviders, ...customProviders]; + // 更新存储 - store.set('translationProviders', mergedProviders); + store.set('translationProviders', allProviders); - return mergedProviders; + console.log(allProviders, 'translationProviders'); + return allProviders; }); ipcMain.on('setUserConfig', async (event, config) => { diff --git a/main/helpers/translate.ts b/main/helpers/translate.ts index 4d14040..68a9900 100644 --- a/main/helpers/translate.ts +++ b/main/helpers/translate.ts @@ -5,6 +5,7 @@ import volcTranslator from '../service/volc'; import baiduTranslator from '../service/baidu'; import deeplxTranslator from '../service/deeplx'; import ollamaTranslator from '../service/ollama'; +import openaiTranslator from '../service/openai'; const contentTemplate = { onlyTranslate: '${targetContent}\n\n', @@ -35,22 +36,34 @@ export default async function translate( const data = result.split('\n'); const items = []; let translator; - switch (translateProvider) { - case 'volc': - translator = volcTranslator; - break; - case 'baidu': - translator = baiduTranslator; - break; - case 'deeplx': - translator = deeplxTranslator; + + // 根据提供商类型选择翻译器 + switch (proof.type) { + case 'api': + switch (proof.id) { + case 'volc': + translator = volcTranslator; + break; + case 'baidu': + translator = baiduTranslator; + break; + case 'deeplx': + translator = deeplxTranslator; + break; + default: + throw new Error(`未知的API翻译提供商: ${proof.id}`); + } break; - case 'ollama': + case 'local': translator = (text) => ollamaTranslator(text, proof, sourceLanguage, targetLanguage); break; + case 'openai': + translator = (text) => openaiTranslator(text, proof, sourceLanguage, targetLanguage); + break; default: - translator = (val) => val; + throw new Error(`未知的翻译提供商类型: ${proof.type}`); } + for (var i = 0; i < data.length; i += 4) { const sourceContent = data[i + 2]; if (!sourceContent) continue; diff --git a/main/service/openai.ts b/main/service/openai.ts new file mode 100644 index 0000000..f4cd5b9 --- /dev/null +++ b/main/service/openai.ts @@ -0,0 +1,45 @@ +import OpenAI from "openai"; +import { renderTemplate } from '../helpers/utils'; + +type OpenAIProvider = { + apiUrl: string; + apiKey: string; + modelName?: string; + prompt?: string; +}; + +export async function translateWithOpenAI( + text: string, + provider: OpenAIProvider, + sourceLanguage: string, + targetLanguage: string +) { + const openai = new OpenAI({ + baseURL: provider.apiUrl, + apiKey: provider.apiKey, + }); + + try { + const systemPrompt = provider.prompt + ? renderTemplate(provider.prompt, { sourceLanguage, targetLanguage, content: text }) + : `You are a helpful assistant that translates text from ${sourceLanguage} to ${targetLanguage}.`; + + const userPrompt = `Translate the following text from ${sourceLanguage} to ${targetLanguage}: "${text}"`; + + const completion = await openai.chat.completions.create({ + model: provider.modelName || "gpt-3.5-turbo", + messages: [ + { role: "system", content: systemPrompt }, + { role: "user", content: userPrompt } + ], + temperature: 0.3, + }); + + return completion?.choices?.[0]?.message?.content?.trim(); + } catch (error) { + console.error('OpenAI translation error:', error); + throw new Error(`OpenAI translation failed: ${error.message}`); + } +} + +export default translateWithOpenAI; \ No newline at end of file diff --git a/package.json b/package.json index df1dddc..4159877 100644 --- a/package.json +++ b/package.json @@ -38,6 +38,7 @@ "lodash": "^4.17.21", "lucide-react": "^0.378.0", "next-themes": "^0.3.0", + "openai": "^4.0.0", "react-hook-form": "^7.51.4", "regenerator-runtime": "^0.14.1", "sonner": "^1.4.41", diff --git a/renderer/components/TaskConfigForm.tsx b/renderer/components/TaskConfigForm.tsx index 962d431..c380136 100644 --- a/renderer/components/TaskConfigForm.tsx +++ b/renderer/components/TaskConfigForm.tsx @@ -1,3 +1,4 @@ +import React, { useEffect, useState } from 'react'; import { SelectValue, SelectTrigger, @@ -21,6 +22,13 @@ import { FormLabel, } from '@/components/ui/form'; +// 定义 Provider 类型 +type Provider = { + id: string; + name: string; + type: 'api' | 'local' | 'openai'; +}; + const TaskConfigForm = ({ form, formData, @@ -28,6 +36,17 @@ const TaskConfigForm = ({ updateSystemInfo, isInstalledModel, }) => { + const [providers, setProviders] = useState([]); + + useEffect(() => { + loadProviders(); + }, []); + + const loadProviders = async () => { + const storedProviders = await window.ipc.invoke('getTranslationProviders'); + setProviders(storedProviders); + }; + if(!providers.length) return null; return (
@@ -146,10 +165,11 @@ const TaskConfigForm = ({ 不翻译 - 百度 - 火山 - deepLx - ollama + {providers.map((provider) => ( + + {provider.name} + + ))} diff --git a/renderer/pages/translateControl.tsx b/renderer/pages/translateControl.tsx index aef66b0..9689bdc 100644 --- a/renderer/pages/translateControl.tsx +++ b/renderer/pages/translateControl.tsx @@ -11,12 +11,13 @@ import { Input } from '@/components/ui/input'; import { Button } from '@/components/ui/button'; import { Eye, EyeOff } from 'lucide-react'; import { Textarea } from '@/components/ui/textarea'; +import { Plus, Trash2 } from 'lucide-react'; // 定义统一的服务提供商类型 type Provider = { id: string; name: string; - type: 'api' | 'local'; + type: 'api' | 'local' | 'openai'; apiKey?: string; apiSecret?: string; apiUrl?: string; @@ -27,6 +28,13 @@ type Provider = { const TranslateControl: React.FC = () => { const [providers, setProviders] = useState([]); const [showPassword, setShowPassword] = useState<{ [key: string]: boolean }>({}); + const [newOpenAIProvider, setNewOpenAIProvider] = useState>({ + name: '', + apiUrl: '', + apiKey: '', + modelName: '', + prompt: '', + }); useEffect(() => { loadProviders(); @@ -58,6 +66,25 @@ const TranslateControl: React.FC = () => { const apiProviders = providers.filter(p => p.type === 'api'); const localProviders = providers.filter(p => p.type === 'local'); + const openAIProviders = providers.filter(p => p.type === 'openai'); + + const addOpenAIProvider = () => { + const newProvider: Provider = { + ...newOpenAIProvider, + id: Date.now().toString(), + type: 'openai', + }; + const updatedProviders = [...providers, newProvider]; + setProviders(updatedProviders); + window?.ipc?.send('setTranslationProviders', updatedProviders); + setNewOpenAIProvider({ name: '', apiUrl: '', apiKey: '', modelName: '', prompt: '' }); + }; + + const removeOpenAIProvider = (id: string) => { + const updatedProviders = providers.filter(provider => provider.id !== id); + setProviders(updatedProviders); + window?.ipc?.send('setTranslationProviders', updatedProviders); + }; return (
@@ -152,6 +179,104 @@ const TranslateControl: React.FC = () => { ))} + +

OpenAI风格API服务配置

+ + + + 服务名称 + API地址 + API密钥 + 模型名称 + Prompt + 操作 + + + + {openAIProviders.map((provider) => ( + + {provider.name} + + handleInputChange(provider.id, 'apiUrl', e.target.value)} + /> + + +
+ handleInputChange(provider.id, 'apiKey', e.target.value)} + className="mr-2" + /> + +
+
+ + handleInputChange(provider.id, 'modelName', e.target.value)} + /> + + +