Skip to content

Commit

Permalink
refactor LSTM
Browse files Browse the repository at this point in the history
  • Loading branch information
hstern2 committed Sep 5, 2024
1 parent bac2d64 commit ad6f12a
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 33 deletions.
53 changes: 22 additions & 31 deletions amsr/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,12 @@ def _device():


class LSTMModel(nn.Module):
def __init__(
self,
d_model=128,
nhid=256,
nlayers=2,
dropout=0.3,
learning_rate=0.001,
num_epochs=50,
batch_size=64,
weight_decay=1e-5,
validation_split=0.1,
patience=5,
):
def __init__(self, d_model=128, nhid=256, nlayers=2, dropout=0.3):
super(LSTMModel, self).__init__()
self.d_model = d_model
self.nhid = nhid
self.nlayers = nlayers
self.dropout = dropout
self.learning_rate = learning_rate
self.num_epochs = num_epochs
self.batch_size = batch_size
self.weight_decay = weight_decay
self.validation_split = validation_split
self.patience = patience

def _init_model(self):
self.embedding = nn.Embedding(self.vocab_size, self.d_model)
Expand Down Expand Up @@ -68,7 +50,17 @@ def _get_vocab(self, seqs):
self.token_for_index[self.vocab_size] = t
self.vocab_size += 1

def train_and_save_model(self, seqs, model_path):
def train_and_save_model(
self,
seqs,
model_path,
num_epochs=200,
patience=10,
learning_rate=0.001,
weight_decay=1e-5,
batch_size=64,
validation_split=0.2,
):
self._get_vocab(seqs)
self._init_model()
seqs_as_tt = [
Expand All @@ -87,27 +79,25 @@ def train_and_save_model(self, seqs, model_path):
padded_targets = torch.cat((targets, pad), dim=1)

dataset = TensorDataset(padded_sequences, padded_targets)
train_size = int((1 - self.validation_split) * len(dataset))
train_size = int((1 - validation_split) * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_data_loader = DataLoader(
train_dataset, batch_size=self.batch_size, shuffle=True
)
val_data_loader = DataLoader(
val_dataset, batch_size=self.batch_size, shuffle=False
train_dataset, batch_size=batch_size, shuffle=True
)
val_data_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

self.train()
criterion = nn.CrossEntropyLoss(ignore_index=PAD_TOK)
optimizer = optim.Adam(
self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
self.parameters(), lr=learning_rate, weight_decay=weight_decay
)

best_val_loss = float("inf")
epochs_no_improve = 0

for epoch in range(self.num_epochs):
for epoch in range(num_epochs):
self.train() # training phase
total_train_loss = 0
for src, tgt in train_data_loader:
Expand All @@ -131,7 +121,8 @@ def train_and_save_model(self, seqs, model_path):
total_val_loss += loss.item()
avg_val_loss = total_val_loss / len(val_data_loader)
print(
f"Epoch {epoch}, Train Loss: {avg_train_loss}, Validation Loss: {avg_val_loss}"
f"Epoch {epoch}, Train Loss: {avg_train_loss}, Validation Loss: {avg_val_loss}",
flush=True,
)

if avg_val_loss < best_val_loss:
Expand All @@ -140,8 +131,8 @@ def train_and_save_model(self, seqs, model_path):
else:
epochs_no_improve += 1

if epochs_no_improve >= self.patience:
print(f"Early stopping at epoch {epoch}")
if epochs_no_improve >= patience:
print(f"Early stopping at epoch {epoch}", flush=True)
break

torch.save(
Expand All @@ -156,7 +147,7 @@ def train_and_save_model(self, seqs, model_path):
},
model_path,
)
print("Model and configuration saved to", model_path)
print("Model and configuration saved to", model_path, flush=True)

@classmethod
def from_saved_model(cls, model_path):
Expand Down
3 changes: 1 addition & 2 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pandas
from amsr import FromMolToTokens
from lstm import LSTMModel
from amsr import FromMolToTokens, LSTMModel
from rdkit.Chem import MolFromSmiles, RenumberAtoms
from random import shuffle

Expand Down

0 comments on commit ad6f12a

Please sign in to comment.