forked from hellonlp/classifier-multi-label
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
75 lines (55 loc) · 2.17 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# -*- coding: utf-8 -*-
"""
Created on Thu May 30 17:12:37 2019
@author: cm
"""
import os
import sys
pwd = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
import numpy as np
import tensorflow as tf
from classifier_multi_label_textcnn.networks import NetworkAlbertTextCNN
from classifier_multi_label_textcnn.classifier_utils import get_feature_test,id2label
from classifier_multi_label_textcnn.hyperparameters import Hyperparamters as hp
class ModelAlbertTextCNN(object,):
"""
Load NetworkAlbert TextCNN model
"""
def __init__(self):
self.albert, self.sess = self.load_model()
@staticmethod
def load_model():
with tf.Graph().as_default():
sess = tf.Session()
with sess.as_default():
albert = NetworkAlbertTextCNN(is_training=False)
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
checkpoint_dir = os.path.abspath(os.path.join(pwd,hp.file_load_model))
print (checkpoint_dir)
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
saver.restore(sess, ckpt.model_checkpoint_path)
return albert,sess
MODEL = ModelAlbertTextCNN()
print('Load model finished!')
def get_label(sentence):
"""
Prediction of the sentence's label.
"""
feature = get_feature_test(sentence)
fd = {MODEL.albert.input_ids: [feature[0]],
MODEL.albert.input_masks: [feature[1]],
MODEL.albert.segment_ids:[feature[2]],
}
prediction = MODEL.sess.run(MODEL.albert.predictions, feed_dict=fd)[0]
return [id2label(l) for l in np.where(prediction==1)[0] if l!=0]
if __name__ == '__main__':
# Test
sentences = ['耗电情况:整体来说耗电不是特别严重',
'取暖效果:取暖效果好',
'取暖效果:开到二挡很暖和',
'一个小时房间仍然没暖和',
'开着坐旁边才能暖和']
for sentence in sentences:
print(sentence,get_label(sentence))