-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
49 lines (37 loc) · 1.37 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# -*- coding: utf-8 -*-
# author: Victor H. Wirz
# discipline: UNIRIO-tin0145
# prof.: Pedro Moura
import rntn
from sentiment_tensor import SentimentTree
import pickle
import torch
from dataset import SSTDataset
from torch.utils.data import DataLoader
stoi = pickle.load(open('./assets/stoi.pkl', 'rb'))
lexis_size = len(stoi)
BATCH_SIZE = 128
PARAMETERS = "./assets/batch_parameters/net_parameters_6.pth"
test = SSTDataset("./sst/test.txt", stoi)
testloader = DataLoader(test, batch_size=BATCH_SIZE)
N = test.__len__()
# Since Sentiment Tree have no support for GPU allocation
# they can't be fed to the model using cuda device. Training is done
# on the CPU with a subset of the training set.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = rntn.RNTensorN(lexis_size)
net.load_state_dict(torch.load(PARAMETERS))
test_loss = acc = 0
with torch.no_grad():
for j, trees in enumerate(testloader, 0):
for tree in trees:
sentiment_tree = SentimentTree(tree, stoi, device)
logits = net(sentiment_tree.root)
lb = sentiment_tree.get_labels()
ground_truth = torch.tensor([lb])[0]
loss = net.get_loss(logits, ground_truth)
acc += net.tree_accuracy.item()
test_loss += loss.item()
acc /= N
test_loss /= N
print("Test loss: %.6f, Test accuracy: %.6f" % (test_loss, acc))