-
Notifications
You must be signed in to change notification settings - Fork 9
/
model.py
executable file
·106 lines (88 loc) · 4.73 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from generic_model import generic_model
#generic model contains generic methods for loading and storing a model
class RNN(generic_model):
def __init__(self, config):
super(RNN, self).__init__(config)
# Store important parameters
self.rnn_name = config['rnn']
self.input_dim = config['vocab_size'] + 1
self.hidden_dim = config['hidden_dim']
self.num_layers = config['num_layers']
self.embed_dim = config['embedding_dim']
self.output_dim = config['vocab_size']
#whether to use character embeddings
if config['use_embedding']:
self.use_embedding = True
self.embedding = nn.Embedding(self.input_dim, self.embed_dim)
else:
self.use_embedding = False
#linear layer after RNN output
in_features = config['miss_linear_dim'] + self.hidden_dim*2
mid_features = config['output_mid_features']
self.linear1_out = nn.Linear(in_features, mid_features)
self.relu = nn.ReLU()
self.linear2_out = nn.Linear(mid_features, self.output_dim)
#linear layer after missed characters
self.miss_linear = nn.Linear(config['vocab_size'], config['miss_linear_dim'])
#declare RNN
if self.rnn_name == 'LSTM':
self.rnn = nn.LSTM(input_size=self.embed_dim if self.use_embedding else self.input_dim, hidden_size=self.hidden_dim, num_layers=self.num_layers,
dropout=config['dropout'],
bidirectional=True, batch_first=True)
else:
self.rnn = nn.GRU(input_size=self.embed_dim if self.use_embedding else self.input_dim, hidden_size=self.hidden_dim, num_layers=self.num_layers,
dropout=config['dropout'],
bidirectional=True, batch_first=True)
#optimizer
self.optimizer = optim.Adam(self.parameters(), lr=config['lr'])
def forward(self, x, x_lens, miss_chars):
"""
Forward pass through RNN
:param x: input tensor of shape (batch size, max sequence length, input_dim)
:param x_lens: actual lengths of each sequence < max sequence length (since padded with zeros)
:param miss_chars: tensor of length batch_size x vocab size. 1 at index i indicates that ith character is NOT present
:return: tensor of shape (batch size, max sequence length, output dim)
"""
if self.use_embedding:
x = self.embedding(x)
batch_size, seq_len, _ = x.size()
x = torch.nn.utils.rnn.pack_padded_sequence(x, x_lens, batch_first=True, enforce_sorted=False)
# now run through RNN
output, hidden = self.rnn(x)
hidden = hidden.view(self.num_layers, 2, -1, self.hidden_dim)
hidden = hidden[-1]
hidden = hidden.permute(1, 0, 2)
hidden = hidden.contiguous().view(hidden.shape[0], -1)
#project miss_chars onto a higher dimension
miss_chars = self.miss_linear(miss_chars)
#concatenate RNN output and miss chars
concatenated = torch.cat((hidden, miss_chars), dim=1)
#predict
return self.linear2_out(self.relu(self.linear1_out(concatenated)))
def calculate_loss(self, model_out, labels, input_lens, miss_chars, use_cuda):
"""
:param model_out: tensor of shape (batch size, max sequence length, output dim) from forward pass
:param labels: tensor of shape (batch size, vocab_size). 1 at index i indicates that ith character should be predicted
:param: miss_chars: tensor of length batch_size x vocab size. 1 at index i indicates that ith character is NOT present
passed here to check if model's output probability of missed_chars is decreasing
"""
outputs = nn.functional.log_softmax(model_out, dim=1)
#calculate model output loss for miss characters
miss_penalty = torch.sum(outputs*miss_chars, dim=(0,1))/outputs.shape[0]
input_lens = input_lens.float()
#weights per example is inversely proportional to length of word
#this is because shorter words are harder to predict due to higher chances of missing a character
weights_orig = (1/input_lens)/torch.sum(1/input_lens).unsqueeze(-1)
weights = torch.zeros((weights_orig.shape[0], 1))
#resize so that torch can process it correctly
weights[:, 0] = weights_orig
if use_cuda:
weights = weights.cuda()
#actual loss
loss_func = nn.BCEWithLogitsLoss(weight=weights, reduction='sum')
actual_penalty = loss_func(model_out, labels)
return actual_penalty, miss_penalty