-
Notifications
You must be signed in to change notification settings - Fork 35
/
generation.py
30 lines (23 loc) · 1.01 KB
/
generation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import paddle
import paddlenlp
from conf import MODELNAME
paddle.set_device('gpu')
gptModel = paddlenlp.transformers.GPTModel.from_pretrained('models')
gptModel = paddlenlp.transformers.GPTForPretraining(gptModel)
gptModel.eval()
tokenizer = paddlenlp.transformers.GPTChineseTokenizer.from_pretrained(MODELNAME)
def getPredictText(text: str, length: int = 200) -> str:
"""
生成半佛风格文本
:param text: 前面部分的文本
:param length: 生成文本长度
:return: 生成的文本
"""
encodedText = tokenizer(text=text, return_token_type_ids=False)
inputIds = paddle.to_tensor(encodedText['input_ids'], dtype='int64').unsqueeze(0)
ids, _ = gptModel.generate(input_ids=inputIds, max_length=length, min_length=32, decode_strategy='sampling')
ids = ids[0].numpy().tolist()
# 使用tokenizer将生成的id转为文本
generatedText = tokenizer.convert_ids_to_string(ids)
return generatedText
getPredictText('开始预测模型会先初始化一下, 抵消掉这个时间')