-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
use_model.py
executable file
·65 lines (51 loc) · 2.43 KB
/
use_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#!/usr/bin/env python3
import os
import argparse
import logging
from libbots import data, model, utils
import torch
log = logging.getLogger("use")
def words_to_words(words, emb_dict, rev_emb_dict, net, use_sampling=False):
tokens = data.encode_words(words, emb_dict)
input_seq = model.pack_input(tokens, net.emb)
enc = net.encode(input_seq)
end_token = emb_dict[data.END_TOKEN]
if use_sampling:
_, out_tokens = net.decode_chain_sampling(enc, input_seq.data[0:1], seq_len=data.MAX_TOKENS,
stop_at_token=end_token)
else:
_, out_tokens = net.decode_chain_argmax(enc, input_seq.data[0:1], seq_len=data.MAX_TOKENS,
stop_at_token=end_token)
if out_tokens[-1] == end_token:
out_tokens = out_tokens[:-1]
out_words = data.decode_words(out_tokens, rev_emb_dict)
return out_words
def process_string(s, emb_dict, rev_emb_dict, net, use_sampling=False):
out_words = words_to_words(words, emb_dict, rev_emb_dict, net, use_sampling=use_sampling)
print(" ".join(out_words))
if __name__ == "__main__":
logging.basicConfig(format="%(asctime)-15s %(levelname)s %(message)s", level=logging.INFO)
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", required=True, help="Model name to load")
parser.add_argument("-s", "--string", help="String to process, otherwise will loop")
parser.add_argument("--sample", default=False, action="store_true", help="Enable sampling generation instead of argmax")
parser.add_argument("--self", type=int, default=1, help="Enable self-loop mode with given amount of phrases.")
args = parser.parse_args()
emb_dict = data.load_emb_dict(os.path.dirname(args.model))
net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict), hid_size=model.HIDDEN_STATE_SIZE)
net.load_state_dict(torch.load(args.model))
rev_emb_dict = {idx: word for word, idx in emb_dict.items()}
while True:
if args.string:
input_string = args.string
else:
input_string = input(">>> ")
if not input_string:
break
words = utils.tokenize(input_string)
for _ in range(args.self):
words = words_to_words(words, emb_dict, rev_emb_dict, net, use_sampling=args.sample)
print(utils.untokenize(words))
if args.string:
break
pass