Skip to content

Commit

Permalink
feat: support ollama
Browse files Browse the repository at this point in the history
  • Loading branch information
linxiaodong committed Sep 20, 2024
1 parent 17b3752 commit 2bd033a
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 70 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
- 支持火山引擎翻译
- 支持百度翻译
- 支持 deeplx 翻译 (批量翻译容易存在被限流的情况)
- 支持本地模型 ollama 翻译
- 自定义字幕文件名,方便兼容不同的播放器挂载字幕识别
- 自定义翻译后的字幕文件内容,纯翻译结果,原字幕+翻译结果
- 项目集成 `whisper.cpp`, 它对 apple silicon 进行了优化,有较快的生成速度
Expand Down
3 changes: 1 addition & 2 deletions main/helpers/fileProcessor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { extractAudio } from './ffmpeg';
import translate from './translate';
import { renderTemplate, isWin32, getExtraResourcesPath } from './utils';

export async function processFile(event, file, formData, hasOpenAiWhisper, translationProviders) {
export async function processFile(event, file, formData, hasOpenAiWhisper, provider) {
const {
model,
sourceLanguage,
Expand Down Expand Up @@ -81,7 +81,6 @@ export async function processFile(event, file, formData, hasOpenAiWhisper, trans
"translateSubtitle",
"loading",
);
const provider = translationProviders.find(p => p.id === translateProvider);
await translate(
event,
directory,
Expand Down
37 changes: 27 additions & 10 deletions main/helpers/storeManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,16 @@ type StoreType = {
}

const defaultTranslationProviders = [
{ id: 'baidu', name: '百度', apiKey: '', apiSecret: '' },
{ id: 'volc', name: '火山', apiKey: '', apiSecret: '' },
{ id: 'deeplx', name: 'DeepLX', apiKey: '', apiSecret: '' },
{ id: 'baidu', name: '百度', type: 'api', apiKey: '', apiSecret: '' },
{ id: 'volc', name: '火山', type: 'api', apiKey: '', apiSecret: '' },
{ id: 'deeplx', name: 'DeepLX', type: 'api', apiKey: '', apiSecret: '' },
{
id: 'ollama',
name: 'Ollama',
type: 'local',
apiUrl: 'http://localhost:11434',
modelName: 'llama2',
prompt: 'Please translate the following content from ${sourceLanguage} to ${targetLanguage}, only return the translation result can be. \n ${content}' },
];

export const store = new Store<StoreType>({
Expand All @@ -27,13 +34,23 @@ export function setupStoreHandlers() {
});

ipcMain.handle('getTranslationProviders', async () => {
let providers = store.get('translationProviders');
if (!providers || providers.length === 0) {
providers = defaultTranslationProviders;
store.set('translationProviders', providers);
}
console.log(providers, 'translationProviders');
return providers;
let storedProviders = store.get('translationProviders');

// 合并存储的提供商和默认提供商
const mergedProviders = defaultTranslationProviders.map(defaultProvider => {
const storedProvider = storedProviders.find(p => p.id === defaultProvider.id);
if (storedProvider) {
// 如果存储的提供商存在,合并默认值和存储的值
return { ...defaultProvider, ...storedProvider };
}
// 如果存储中不存在该提供商,使用默认值
return defaultProvider;
});

// 更新存储
store.set('translationProviders', mergedProviders);

return mergedProviders;
});

ipcMain.on('setUserConfig', async (event, config) => {
Expand Down
5 changes: 4 additions & 1 deletion main/helpers/taskProcessor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ async function processNextTasks(event) {
const translationProviders = store.get('translationProviders');

try {
await Promise.all(tasks.map(task => processFile(event, task.file, task.formData, hasOpenAiWhisper, translationProviders)));
await Promise.all(tasks.map(task => {
const provider = translationProviders.find(p => p.id === task.formData.translateProvider);
return processFile(event, task.file, task.formData, hasOpenAiWhisper, provider);
}));
} catch (error) {
event.sender.send("message", error);
}
Expand Down
13 changes: 10 additions & 3 deletions main/helpers/translate.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import path from 'path';
import fs from 'fs';
import { renderTemplate } from './utils';
import volcTranslator from '../service/volc';
import baiduTranslator from '../service/baidu';
import deeplxTranslator from '../service/deeplx';
import ollamaTranslator from '../service/ollama';

const contentTemplate = {
onlyTranslate: '${targetContent}\n\n',
Expand Down Expand Up @@ -33,13 +37,16 @@ export default async function translate(
let translator;
switch (translateProvider) {
case 'volc':
translator = (await import('../service/volc')).default;
translator = volcTranslator;
break;
case 'baidu':
translator = (await import('../service/baidu')).default;
translator = baiduTranslator;
break;
case 'deeplx':
translator = (await import('../service/deeplx')).default;
translator = deeplxTranslator;
break;
case 'ollama':
translator = (text) => ollamaTranslator(text, proof, sourceLanguage, targetLanguage);
break;
default:
translator = (val) => val;
Expand Down
39 changes: 39 additions & 0 deletions main/service/ollama.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import axios from 'axios';
import { renderTemplate } from '../helpers/utils';

interface OllamaConfig {
apiUrl: string;
modelName: string;
prompt: string;
}

export default async function translateWithOllama(
text: string,
config: OllamaConfig,
sourceLanguage: string,
targetLanguage: string
) {
const { apiUrl, modelName, prompt } = config;

const renderedPrompt = renderTemplate(prompt, {
sourceLanguage,
targetLanguage,
content: text
});

try {
const response = await axios.post(`${apiUrl}/api/generate`, {
model: modelName,
prompt: renderedPrompt,
stream: false
});

if (response.data && response.data.response) {
return response.data.response.trim();
} else {
throw new Error(response?.data?.error || 'Unexpected response from Ollama');
}
} catch (error) {
throw error;
}
}
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"private": true,
"name": "video-subtitle-master",
"description": "视频转字幕,字幕翻译软件",
"version": "1.0.17",
"version": "1.0.18",
"author": "buxuku <[email protected]>",
"main": "app/background.js",
"scripts": {
Expand Down
1 change: 1 addition & 0 deletions renderer/components/TaskConfigForm.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ const TaskConfigForm = ({
<SelectItem value="baidu">百度</SelectItem>
<SelectItem value="volc">火山</SelectItem>
<SelectItem value="deeplx">deepLx</SelectItem>
<SelectItem value="ollama">ollama</SelectItem>
</SelectContent>
</Select>
</FormControl>
Expand Down
119 changes: 66 additions & 53 deletions renderer/pages/translateControl.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,25 @@ import {
import { Input } from '@/components/ui/input';
import { Button } from '@/components/ui/button';
import { Eye, EyeOff } from 'lucide-react';
import { Textarea } from '@/components/ui/textarea';

// 定义翻译服务提供商类型
type TranslationProvider = {
// 定义统一的服务提供商类型
type Provider = {
id: string;
name: string;
apiKey: string;
apiSecret: string;
type: 'api' | 'local';
apiKey?: string;
apiSecret?: string;
apiUrl?: string;
modelName?: string;
prompt?: string;
};

const TranslateControl: React.FC = () => {
const [providers, setProviders] = useState<TranslationProvider[]>([]);
const [showPassword, setShowPassword] = useState<{ [key: string]: boolean }>(
{}
);
const [providers, setProviders] = useState<Provider[]>([]);
const [showPassword, setShowPassword] = useState<{ [key: string]: boolean }>({});

useEffect(() => {
// 组件加载时获取存储的配置
loadProviders();
}, []);

Expand All @@ -37,31 +39,32 @@ const TranslateControl: React.FC = () => {

const handleInputChange = async (
id: string,
field: 'apiKey' | 'apiSecret',
field: keyof Provider,
value: string
) => {
const updatedProviders = providers.map((provider) =>
provider.id === id ? { ...provider, [field]: value } : provider
);
setProviders(updatedProviders);
// 保存更新后的配置
window?.ipc?.send('setTranslationProviders', updatedProviders);
};

const togglePasswordVisibility = (
id: string,
field: 'apiKey' | 'apiSecret'
) => {
const togglePasswordVisibility = (id: string, field: 'apiKey' | 'apiSecret') => {
setShowPassword((prev) => ({
...prev,
[`${id}_${field}`]: !prev[`${id}_${field}`],
}));
};

const apiProviders = providers.filter(p => p.type === 'api');
const localProviders = providers.filter(p => p.type === 'local');

return (
<div className="container mx-auto p-4">
<h1 className="text-2xl font-bold mb-4">翻译服务管理</h1>
<Table>

<h2 className="text-xl font-bold mb-2">API 服务提供商</h2>
<Table className="mb-8">
<TableHeader>
<TableRow>
<TableHead>翻译服务提供商</TableHead>
Expand All @@ -70,75 +73,85 @@ const TranslateControl: React.FC = () => {
</TableRow>
</TableHeader>
<TableBody>
{providers.map((provider) => (
{apiProviders.map((provider) => (
<TableRow key={provider.id}>
<TableCell>{provider.name}</TableCell>
<TableCell>
<div className="flex items-center">
<Input
type={
showPassword[`${provider.id}_apiKey`]
? 'text'
: 'password'
}
type={showPassword[`${provider.id}_apiKey`] ? 'text' : 'password'}
value={provider.apiKey}
onChange={(e) =>
handleInputChange(provider.id, 'apiKey', e.target.value)
}
onChange={(e) => handleInputChange(provider.id, 'apiKey', e.target.value)}
className="mr-2"
/>
<Button
variant="ghost"
size="icon"
onClick={() =>
togglePasswordVisibility(provider.id, 'apiKey')
}
onClick={() => togglePasswordVisibility(provider.id, 'apiKey')}
>
{showPassword[`${provider.id}_apiKey`] ? (
<EyeOff size={16} />
) : (
<Eye size={16} />
)}
{showPassword[`${provider.id}_apiKey`] ? <EyeOff size={16} /> : <Eye size={16} />}
</Button>
</div>
</TableCell>
<TableCell>
<div className="flex items-center">
<Input
type={
showPassword[`${provider.id}_apiSecret`]
? 'text'
: 'password'
}
type={showPassword[`${provider.id}_apiSecret`] ? 'text' : 'password'}
value={provider.apiSecret}
onChange={(e) =>
handleInputChange(
provider.id,
'apiSecret',
e.target.value
)
}
onChange={(e) => handleInputChange(provider.id, 'apiSecret', e.target.value)}
className="mr-2"
/>
<Button
variant="ghost"
size="icon"
onClick={() =>
togglePasswordVisibility(provider.id, 'apiSecret')
}
onClick={() => togglePasswordVisibility(provider.id, 'apiSecret')}
>
{showPassword[`${provider.id}_apiSecret`] ? (
<EyeOff size={16} />
) : (
<Eye size={16} />
)}
{showPassword[`${provider.id}_apiSecret`] ? <EyeOff size={16} /> : <Eye size={16} />}
</Button>
</div>
</TableCell>
</TableRow>
))}
</TableBody>
</Table>

<h2 className="text-xl font-bold mb-2">本地模型配置</h2>
<Table>
<TableHeader>
<TableRow>
<TableHead>模型名称</TableHead>
<TableHead>API 地址</TableHead>
<TableHead>模型名</TableHead>
<TableHead>Prompt</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{localProviders.map((provider) => (
<TableRow key={provider.id}>
<TableCell>{provider.name}</TableCell>
<TableCell>
<Input
value={provider.apiUrl}
onChange={(e) => handleInputChange(provider.id, 'apiUrl', e.target.value)}
/>
</TableCell>
<TableCell>
<Input
value={provider.modelName}
onChange={(e) => handleInputChange(provider.id, 'modelName', e.target.value)}
/>
</TableCell>
<TableCell>
<Textarea
value={provider.prompt}
onChange={(e) => handleInputChange(provider.id, 'prompt', e.target.value)}
rows={3}
/>
</TableCell>
</TableRow>
))}
</TableBody>
</Table>
</div>
);
};
Expand Down

0 comments on commit 2bd033a

Please sign in to comment.