-
Notifications
You must be signed in to change notification settings - Fork 2
/
predict.py
80 lines (62 loc) · 1.68 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
from torch import nn
from torch.utils.data import DataLoader
from pathlib import Path
import os
from tqdm import tqdm
from statistics import mean
from utils import evaluate
from datasets import GANTripletDataset
from pipelines.relation_network import RelationNetwork
from blocks import *
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
def predict_setup_relation_net():
# TODO: Parameterize the function.
test_loader = DataLoader(
GANTripletDataset(mode='test'),
batch_size=8,
pin_memory=True,
shuffle=True,
)
backbone_net_one = resnet101(
big_kernel=True,
use_fc=False,
use_pooling=False,
zero_init_residual=True
)
backbone_net_two = resnet101(
big_kernel=True,
use_fc=False,
use_pooling=False,
zero_init_residual=True
)
clf_model = RelationNetwork(
backbone_one=backbone_net_one,
backbone_two=backbone_net_two,
use_softmax=True
).to(DEVICE)
return clf_model, test_loader
def predict_loop(
model,
test_loader,
checkpoint_file):
checkpoint = torch.load(f=checkpoint_file)
model.load_state_dict(state_dict=checkpoint["model_state_dict"])
model.eval()
with torch.no_grad():
with tqdm(
enumerate(test_loader),
total=len(test_loader),
desc="Prediction"
) as tqdm_eval:
for _, (
imgs_real,zero_init_residual
imgs_generated
) in tqdm_eval:
predictions = model.to(DEVICE)(imgs_real.to(DEVICE), imgs_generated.to(DEVICE)).detach().data
print(predictions)
def run_predict():
# Relation Net
model, test_loader = predict_setup_relation_net()
predict_loop(model, test_loader, checkpoint_file=os.path.join(os.path.split(__file__)[0], '../logs/relation-net-1/weights/weights-1_0.pth'))
run_predict()