forked from alexcg1/easy_text_generator
-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
69 lines (53 loc) · 1.67 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import streamlit as st
from utils import *
import json
line_wrap = False
st.title("Easy Text Generator")
st.write("Use language models with just a few clicks")
model_names = []
for model_dict in models:
for key, value in model_dict.items():
model_names.append(key)
if "model_select" in locals():
st.header(model_select)
st.sidebar.title("Options")
# Setup sidebar
max_length = st.sidebar.slider(
""" Max Text Length
(Longer length, slower generation)""",
50,
1000,
value=100
)
model_selectbox = st.sidebar.selectbox("Model", model_names)
for item in models:
for key, value in item.items():
if key == model_selectbox:
model_data = value
model_select = model_data["path"]
context = st.sidebar.text_area("Starting text")
advanced = st.sidebar.checkbox("Advanced options", False, "advanced")
if advanced:
top_k = st.sidebar.slider("Words to consider (top_k)", 1, 100, value=50)
top_p = st.sidebar.slider("Creativity (top_p)", 0.0, 1.0, value=0.95)
custom_model = st.sidebar.text_input(label="Model from transformers")
if custom_model != '':
model_select = custom_model
else:
top_k = 50
top_p = 0.95
if st.sidebar.button("Generate"):
model, tokenizer = load_model(model_dir=model_select)
if context:
sample = generate(model,tokenizer,input_text=context,max_length=max_length, top_k=top_k, top_p=top_p)
else:
sample = generate(model,tokenizer,max_length=max_length, top_k=top_k, top_p=top_p)
st.balloons()
else:
sample = ['']
# Fix up line wrapping
if line_wrap == True:
sample[0] = wrap_text(sample[0], length=80)
else:
sample[0] = sample[0]
st.text(sample[0])