Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: add Session Config #117

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/components/ModelTag/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import ModelIcon from './ModelIcon';
interface ModelTagProps {
model?: string;
}

const ModelTag = memo<ModelTagProps>(({ model }) => {
const { t } = useTranslation('common');
const selectedModel = OPENAI_MODEL_LIST.find(({ id }) => id === model);
Expand Down
10 changes: 10 additions & 0 deletions src/constants/session.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import { SessionConfig } from '@/types/session';

/**
* 会话默认的配置
*/
export const DEFAULT_SESSION_CONFIG: SessionConfig = {
displayMode: 'chat',
enableHistoryCount: true,
historyCount: 1,
};
6 changes: 5 additions & 1 deletion src/constants/tts.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { TTS } from '@/types/tts';
import { SessionTTS, TTS } from '@/types/tts';

export const DEFAULT_TTS_PITCH = 1;
export const DEFAULT_TTS_SPEED = 1;
Expand All @@ -18,6 +18,10 @@ export const DEFAULT_TTS_CONFIG_FEMALE: TTS = {
voice: 'zh-CN-XiaoxiaoNeural',
};

export const DEFAULT_SESSION_TTS_CONFIG: SessionTTS = {
sttLocale: 'auto',
};

export const DEFAULT_TTS_CONFIG_MALE: TTS = {
engine: 'edge',
locale: 'zh-CN',
Expand Down
2 changes: 1 addition & 1 deletion src/features/Actions/ModelSelect.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ const ModelSelect = memo(() => {

const { updateAgentConfig } = useAgentStore();

const { model, agentId } = useSessionContext()?.sessionAgent || {};
const { model = OPENAI_MODEL_LIST[0].id, agentId } = useSessionContext()?.sessionAgent || {};

const items = OPENAI_MODEL_LIST.map((item) => {
return {
Expand Down
4 changes: 4 additions & 0 deletions src/features/ChatHeader/index.tsx
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
'use client';

import { Space } from 'antd';
import classNames from 'classnames';
import React from 'react';
Expand All @@ -9,6 +11,7 @@ import ToggleChatSideBar from '@/features/Actions/ToggleChatSideBar';
import ToggleSessionList from '@/features/Actions/ToggleSessionList';
import { sessionSelectors, useSessionStore } from '@/store/session';

import ChatSetting from '../ChatSetting';
import { useStyles } from './style';

interface Props {
Expand All @@ -33,6 +36,7 @@ export default (props: Props) => {
<Space>
<ShareButton key={'share'} />
<ToggleChatSideBar key={'sidebar'} />
<ChatSetting />
</Space>
</Flexbox>
);
Expand Down
120 changes: 120 additions & 0 deletions src/features/ChatSetting/SettingModal.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import { VoiceList } from '@lobehub/tts';
import { Form, ItemGroup, Modal, SelectWithImg, SliderWithInput } from '@lobehub/ui';
import { Select, Switch } from 'antd';
import { useThemeMode } from 'antd-style';
import { LayoutList, MessagesSquare } from 'lucide-react';
import React, { useImperativeHandle, useState } from 'react';

import { FORM_STYLE } from '@/constants/token';
import { imageUrl } from '@/constants/url';
import { sessionSelectors, useSessionStore } from '@/store/session';

export interface ChatSettingModalActionType {
config: () => void;
}

export interface ChatSettingModalProps {
actionRef?: React.MutableRefObject<ChatSettingModalActionType | undefined>;
}

const ChatSettingModal = ({ actionRef }: ChatSettingModalProps) => {
// 弹窗可见状态
const [open, setOpen] = useState<boolean>(false);

const [form] = Form.useForm();

const { isDarkMode } = useThemeMode();

const sessionConfig = sessionSelectors.currentSession(useSessionStore());

const updateSessionConfig = useSessionStore((state) => state.updateSessionConfig);

// 拓展 ref
useImperativeHandle(actionRef, () => ({
config: () => {
setOpen(true);

form.setFieldsValue(sessionConfig);
},
}));

const chat: ItemGroup = {
children: [
{
children: (
<SelectWithImg
height={86}
options={[
{
icon: MessagesSquare,
img: imageUrl(`chatmode_chat_${isDarkMode ? 'dark' : 'light'}.webp`),
label: '对话模式',
value: 'chat',
},
{
icon: LayoutList,
img: imageUrl(`chatmode_docs_${isDarkMode ? 'dark' : 'light'}.webp`),
label: '文档模式',
value: 'docs',
},
]}
unoptimized={false}
width={144}
/>
),
label: '聊天窗口样式',
name: ['sessionConfig', 'displayMode'],
},
{
children: <Switch />,
label: '限制历史信息数',
minWidth: undefined,
name: ['sessionConfig', 'enableHistoryCount'],
valuePropName: 'checked',
},
{
children: <SliderWithInput max={32} min={1} />,
desc: '每次请求携带的消息数(包括最新编写的提问。每个提问和回答都计算1)',
divider: false,
hidden: !sessionConfig?.sessionConfig?.enableHistoryCount,
label: '附带消息数',
name: ['sessionConfig', 'historyCount'],
},
],
title: '聊天设置',
};

const tts: ItemGroup = {
children: [
{
children: (
<Select
placeholder="请输入"
defaultValue={'auto'}
options={[{ label: '跟随系统', value: 'auto' }, ...(VoiceList.localeOptions || [])]}
/>
),
label: '语音识别语种',
desc: '语音输入的语种,此选项可提高语音识别准确率',
name: ['tts', 'sttLocale'],
},
],
title: '语音服务',
};

return (
<Modal title={'偏好设置'} open={open} footer={false} onCancel={() => setOpen(false)}>
<Form
form={form}
items={[chat, tts]}
onValuesChange={(_, values) => updateSessionConfig(values)}
itemsType={'group'}
variant="pure"
{...FORM_STYLE}
itemMinWidth={'max(30%,304px)'}
/>
</Modal>
);
};

export default ChatSettingModal;
18 changes: 18 additions & 0 deletions src/features/ChatSetting/index.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import { ActionIcon } from '@lobehub/ui';
import { Settings } from 'lucide-react';
import React, { useRef } from 'react';

import ChatSettingModal, { ChatSettingModalActionType } from './SettingModal';

const ChatSetting = () => {
const actionRef = useRef<ChatSettingModalActionType>();

return (
<>
<ActionIcon icon={Settings} title="偏好设置" onClick={() => actionRef.current?.config()} />
<ChatSettingModal actionRef={actionRef} />
</>
);
};

export default ChatSetting;
4 changes: 1 addition & 3 deletions src/layout/StoreHydration.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ const StoreHydration = () => {

useEffect(() => {
// refs: https://github.com/pmndrs/zustand/blob/main/docs/integrations/persisting-store-data.md#hashydrated
migrate().then(() => {
// useAgentStore.persist.rehydrate();
});
migrate();
}, []);

useEffect(() => {
Expand Down
6 changes: 5 additions & 1 deletion src/panels/RolePanel/RoleEdit/LangModel/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { Form, FormProps, SliderWithInput } from '@lobehub/ui';
import React, { memo, useEffect } from 'react';
import { useTranslation } from 'react-i18next';

import { DEFAULT_LLM_CONFIG } from '@/constants/openai';
import { FORM_STYLE } from '@/constants/token';
import { agentSelectors, useAgentStore } from '@/store/agent';

Expand All @@ -19,7 +20,10 @@ const LangModel = memo(() => {
const agent = agentSelectors.currentAgentItem(useAgentStore());

useEffect(() => {
form.setFieldsValue(agent);
form.setFieldsValue({
...DEFAULT_LLM_CONFIG,
...agent,
});
}, [agent]);

const model: FormProps['items'] = [
Expand Down
1 change: 0 additions & 1 deletion src/store/agent/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,6 @@ const persistOptions: PersistOptions<AgentStore> = {
name: AGENT_STORAGE_KEY, // name of the item in the storage (must be unique)
storage: createJSONStorage(() => storage),
version: 0,
// skipHydration: true,
};

export const useAgentStore = createWithEqualityFn<AgentStore>()(
Expand Down
33 changes: 32 additions & 1 deletion src/store/session/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import { StateCreator } from 'zustand/vanilla';
import { LOBE_VIDOL_DEFAULT_AGENT_ID } from '@/constants/agent';
import { DEFAULT_USER_AVATAR_URL, LOADING_FLAG } from '@/constants/common';
import { DEFAULT_LLM_CONFIG } from '@/constants/openai';
import { DEFAULT_SESSION_CONFIG } from '@/constants/session';
import { DEFAULT_SESSION_TTS_CONFIG } from '@/constants/tts';
import { chatCompletion, handleSpeakAi } from '@/services/chat';
import { shareService } from '@/services/share';
import { Agent } from '@/types/agent';
Expand Down Expand Up @@ -139,6 +141,10 @@ export interface SessionStore {
* @returns
*/
updateMessage: (id: string, content: string) => void;
/**
* 更新会话配置
*/
updateSessionConfig: (session: Session) => void;
/**
* 更新会话消息
* @param messages
Expand Down Expand Up @@ -180,6 +186,8 @@ export const createSessionStore: StateCreator<SessionStore, [['zustand/devtools'
draft.push({
agentId: agent.agentId,
messages: [],
sessionConfig: DEFAULT_SESSION_CONFIG,
tts: DEFAULT_SESSION_TTS_CONFIG,
});
}
});
Expand Down Expand Up @@ -457,9 +465,11 @@ export const createSessionStore: StateCreator<SessionStore, [['zustand/devtools'
}
const targetSession = sessionList.find((session) => session.agentId === agentId);
if (!targetSession) {
const session = {
const session: Session = {
agentId: agentId,
messages: [],
sessionConfig: DEFAULT_SESSION_CONFIG,
tts: DEFAULT_SESSION_TTS_CONFIG,
};
set({ sessionList: [...sessionList, session] });
}
Expand Down Expand Up @@ -496,6 +506,27 @@ export const createSessionStore: StateCreator<SessionStore, [['zustand/devtools'
set({ sessionList: sessions });
}
},
updateSessionConfig: (session: Session) => {
const { sessionList, activeId, defaultSession } = get();
if (activeId === LOBE_VIDOL_DEFAULT_AGENT_ID) {
const mergeSession = produce(defaultSession, (draft) => {
Object.entries(session).forEach(([key, value]) => {
draft[key as keyof Session] = value;
});
});
set({ defaultSession: mergeSession });
} else {
const sessions = produce(sessionList, (draft) => {
const index = draft.findIndex((session) => session.agentId === activeId);
if (index === -1) return;

Object.entries(session).forEach(([key, value]) => {
draft[index][key as keyof Session] = value;
});
});
set({ sessionList: sessions });
}
},
});

const persistOptions: PersistOptions<SessionStore> = {
Expand Down
5 changes: 5 additions & 0 deletions src/store/session/initialState.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import { LOBE_VIDOL_DEFAULT_AGENT_ID } from '@/constants/agent';
import { DEFAULT_SESSION_CONFIG } from '@/constants/session';
import { Session } from '@/types/session';

const defaultSession: Session = {
agentId: LOBE_VIDOL_DEFAULT_AGENT_ID,
messages: [],
sessionConfig: DEFAULT_SESSION_CONFIG,
tts: {
sttLocale: 'auto',
},
};

const initialState = {
Expand Down
24 changes: 24 additions & 0 deletions src/types/session.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,20 @@
import { ChatMessage } from './chat';
import { SessionTTS } from './tts';

export interface SessionConfig {
/**
* 聊天窗口样式
*/
displayMode?: 'chat' | 'docs';
/**
* 开启历史记录条数
*/
enableHistoryCount?: boolean;
/**
* 历史消息条数
*/
historyCount?: number;
}

export interface Session {
/**
Expand All @@ -9,4 +25,12 @@ export interface Session {
* 会话消息列表
*/
messages: ChatMessage[];
/**
* 会话配置
*/
sessionConfig: SessionConfig;
/**
* tts配置
*/
tts: SessionTTS;
}
4 changes: 4 additions & 0 deletions src/types/tts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,7 @@ export interface Voice {
locale: string;
localeZH: string;
}

export interface SessionTTS {
sttLocale: string;
}
Loading