-
Notifications
You must be signed in to change notification settings - Fork 8
/
interactive.py
83 lines (71 loc) · 3 KB
/
interactive.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
import pickle
import numpy as np
from nlp_architect.models.intent_extraction import MultiTaskIntentModel, Seq2SeqIntentModel
from nlp_architect.utils.generic import pad_sentences
from nlp_architect.utils.io import validate_existing_filepath
import jieba
def read_input_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=validate_existing_filepath, required=True,
help='Path of model weights')
parser.add_argument('--model_info_path', type=validate_existing_filepath, required=True,
help='Path of model topology')
input_args = parser.parse_args()
return input_args
def load_saved_model():
if model_type == 'seq2seq':
model = Seq2SeqIntentModel()
else:
model = MultiTaskIntentModel()
model.load(args.model_path)
return model
def process_text(text):
return [t for (t, _, _) in jieba.tokenize(text)]
def vectorize(doc, vocab, char_vocab=None):
words = np.asarray([vocab[w.lower()] if w.lower() in vocab else 1 for w in doc])\
.reshape(1, -1)
if char_vocab is not None:
sentence_chars = []
for w in doc:
word_chars = []
for c in w:
if c in char_vocab:
_cid = char_vocab[c]
else:
_cid = 1
word_chars.append(_cid)
sentence_chars.append(word_chars)
sentence_chars = np.expand_dims(pad_sentences(sentence_chars, model.word_length), axis=0)
return [words, sentence_chars]
else:
return words
if __name__ == '__main__':
args = read_input_args()
with open(args.model_info_path, 'rb') as fp:
model_info = pickle.load(fp)
assert model_info is not None, 'No model topology information loaded'
model_type = model_info['type']
model = load_saved_model()
word_vocab = model_info['word_vocab']
tags_vocab = {v: k for k, v in model_info['tags_vocab'].items()}
if model_type == 'mtl':
char_vocab = model_info['char_vocab']
intent_vocab = {v: k for k, v in model_info['intent_vocab'].items()}
while True:
text = input('Enter sentence >> ')
text_arr = process_text(text)
if model_type == 'mtl':
doc_vec = vectorize(text_arr, word_vocab, char_vocab)
intent, tags = model.predict(doc_vec, batch_size=1)
intent = int(intent.argmax(1).flatten())
print('Detected intent type: {}'.format(intent_vocab.get(intent, None)))
else:
doc_vec = vectorize(text_arr, word_vocab, None)
tags = model.predict(doc_vec, batch_size=1)
tags = tags.argmax(2).flatten()
tag_str = [tags_vocab.get(n, None) for n in tags]
for t, n in zip(text_arr, tag_str):
print('{}\t{}\t'.format(t, n))
print()