Skip to content

Commit

Permalink
LSTM
Browse files Browse the repository at this point in the history
  • Loading branch information
hstern2 committed Jun 23, 2024
1 parent 5097b59 commit e37e7f2
Show file tree
Hide file tree
Showing 7 changed files with 305 additions and 15 deletions.
2 changes: 2 additions & 0 deletions amsr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"Markov",
"Modifier",
"GetConformerAndEnergy",
"LSTMModel",
]

from .version import __version__
Expand All @@ -27,3 +28,4 @@
from .markov import Markov
from .modifier import Modifier
from .conf import GetConformerAndEnergy
from .lstm import LSTMModel
216 changes: 216 additions & 0 deletions amsr/lstm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
from torch.nn.utils.rnn import pad_sequence

# Define special tokens
PAD_TOK = 0
EOS_TOK = 1


def _device():
return torch.device("cuda" if torch.cuda.is_available() else "cpu")


class LSTMModel(nn.Module):
def __init__(
self,
d_model=128,
nhid=256,
nlayers=2,
dropout=0.3,
learning_rate=0.001,
num_epochs=50,
batch_size=64,
weight_decay=1e-5,
validation_split=0.1,
patience=5,
):
super(LSTMModel, self).__init__()
self.d_model = d_model
self.nhid = nhid
self.nlayers = nlayers
self.dropout = dropout
self.learning_rate = learning_rate
self.num_epochs = num_epochs
self.batch_size = batch_size
self.weight_decay = weight_decay
self.validation_split = validation_split
self.patience = patience

def _init_model(self):
self.embedding = nn.Embedding(self.vocab_size, self.d_model)
self.lstm = nn.LSTM(
input_size=self.d_model,
hidden_size=self.nhid,
num_layers=self.nlayers,
dropout=self.dropout,
batch_first=True,
)
self.fc_out = nn.Linear(self.nhid, self.vocab_size)
self.to(_device())

def forward(self, src):
embedded = self.embedding(src)
lstm_out, _ = self.lstm(embedded)
output = self.fc_out(lstm_out)
return output

def _get_vocab(self, seqs):
self.vocab_size = 2
self.index_for_token = {}
self.token_for_index = {}
for s in seqs:
for t in s:
if t not in self.index_for_token:
self.index_for_token[t] = self.vocab_size
self.token_for_index[self.vocab_size] = t
self.vocab_size += 1

def train_and_save_model(self, seqs, model_path):
self._get_vocab(seqs)
self._init_model()
seqs_as_tt = [
torch.tensor(
[self.index_for_token[t] for t in s] + [EOS_TOK], device=_device()
)
for s in seqs
]
padded_sequences = pad_sequence(
seqs_as_tt, batch_first=True, padding_value=PAD_TOK
)
targets = padded_sequences[:, 1:]
pad = torch.full(
(targets.shape[0], 1), PAD_TOK, dtype=targets.dtype, device=_device()
)
padded_targets = torch.cat((targets, pad), dim=1)

dataset = TensorDataset(padded_sequences, padded_targets)
train_size = int((1 - self.validation_split) * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_data_loader = DataLoader(
train_dataset, batch_size=self.batch_size, shuffle=True
)
val_data_loader = DataLoader(
val_dataset, batch_size=self.batch_size, shuffle=False
)

self.train()
criterion = nn.CrossEntropyLoss(ignore_index=PAD_TOK)
optimizer = optim.Adam(
self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
)

best_val_loss = float("inf")
epochs_no_improve = 0

for epoch in range(self.num_epochs):
self.train() # training phase
total_train_loss = 0
for src, tgt in train_data_loader:
src, tgt = src.to(_device()), tgt.to(_device())
optimizer.zero_grad()
output = self(src)
loss = criterion(output.view(-1, self.vocab_size), tgt.view(-1))
loss.backward()
torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
optimizer.step()
total_train_loss += loss.item()
avg_train_loss = total_train_loss / len(train_data_loader)

self.eval() # validation phase
total_val_loss = 0
with torch.no_grad():
for src, tgt in val_data_loader:
src, tgt = src.to(_device()), tgt.to(_device())
output = self(src)
loss = criterion(output.view(-1, self.vocab_size), tgt.view(-1))
total_val_loss += loss.item()
avg_val_loss = total_val_loss / len(val_data_loader)
print(
f"Epoch {epoch}, Train Loss: {avg_train_loss}, Validation Loss: {avg_val_loss}"
)

if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
epochs_no_improve = 0
else:
epochs_no_improve += 1

if epochs_no_improve >= self.patience:
print(f"Early stopping at epoch {epoch}")
break

torch.save(
{
"token_for_index": self.token_for_index,
"index_for_token": self.index_for_token,
"model_state_dict": self.state_dict(),
"d_model": self.d_model,
"nhid": self.nhid,
"nlayers": self.nlayers,
"vocab_size": self.vocab_size,
},
model_path,
)
print("Model and configuration saved to", model_path)

@classmethod
def from_saved_model(cls, model_path):
checkpoint = torch.load(model_path, map_location=_device())
model = cls(
d_model=checkpoint["d_model"],
nhid=checkpoint["nhid"],
nlayers=checkpoint["nlayers"],
)
model.vocab_size = checkpoint["vocab_size"]
model.token_for_index = checkpoint["token_for_index"]
model.index_for_token = checkpoint["index_for_token"]
model._init_model()
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
model.to(_device())
return model

def generate_tokens(self, start_input, max_length=20, temperature=1.0):

assert temperature > 0

self.eval()
generated_sequence = torch.tensor(
[self.index_for_token[t] for t in start_input], device=_device()
).unsqueeze(0)

for _ in range(max_length - 1):
with torch.no_grad():
output = self(generated_sequence)
logits = (
output[:, -1, :] / temperature
) # Scale the logits by the temperature
probabilities = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(
probabilities, num_samples=1
) # Sample from the probability distribution
t = next_token.item()
if t == PAD_TOK:
continue
elif t == EOS_TOK:
break
generated_sequence = torch.cat((generated_sequence, next_token), dim=1)

return [self.token_for_index[i] for i in generated_sequence.squeeze(0).tolist()]

def generate(self, start_input, max_length=20, temperature=1.0):
return "".join(self.generate_tokens(start_input, max_length, temperature))


if __name__ == "__main__":
seqs = [["A", "B", "C"], ["B", "C", "D", "E"], ["C", "A", "D"]]
pth = "tmp_model.pth"
model = LSTMModel()
model.train_and_save_model(seqs, pth)
loaded_model = LSTMModel.from_saved_model(pth)
print(loaded_model.generate(["A", "B"]))
10 changes: 3 additions & 7 deletions app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,8 @@ def mol_isOK(mol):

app = Flask(__name__)
methods = ["GET", "POST"]
base_url = "https://raw.githubusercontent.com/hstern2/amsr/main/data/"
ertl_nps = pandas.read_csv(base_url + "some_ertl_npsubs.csv")
fda = pandas.read_csv(base_url + "some_FDA_approved_structures.csv")
corpus = pandas.concat([ertl_nps, fda])
markov = amsr.Markov((Chem.MolFromSmiles(s) for s in corpus["SMILES"]))
model_path = os.path.join(os.path.dirname("__file__"), "..", "models", "model.pth")
lstm = amsr.LSTMModel.from_saved_model(model_path)


@app.route("/", methods=methods)
Expand All @@ -66,8 +63,7 @@ def index():

@app.route("/random_mol", methods=methods)
def random_mol():
k = max(round(expovariate(1 / 20)), 1)
return json.dumps({"amsr": markov.generate(nmax=25)})
return json.dumps({"amsr": lstm.generate(["C"])})


@app.route("/mol_changed", methods=methods)
Expand Down
52 changes: 52 additions & 0 deletions gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from amsr import ToTokens, ToSmiles, LSTMModel
from argparse import ArgumentParser
import os

DEFAULT_T = 1.0
DEFAULT_MAX_LENGTH = 20
DEFAULT_N = 10
DEFAULT_MODEL = os.path.join(os.path.dirname(__file__), "models", "model.pth")

ap = ArgumentParser(description="generate AMSR sequences")
ap.add_argument("s", help="initial AMSR sequence")
ap.add_argument(
"-t",
"--temperature",
type=float,
help=f"temperature. Default: {DEFAULT_T}",
default=DEFAULT_T,
)
ap.add_argument(
"-l",
"--max_length",
type=int,
help=f"maximum length. Default: {DEFAULT_MAX_LENGTH}",
default=DEFAULT_MAX_LENGTH,
)
ap.add_argument(
"-n",
"--num_seqs",
type=int,
help=f"number of sequences. Default: {DEFAULT_N}",
default=DEFAULT_N,
)
ap.add_argument(
"-m",
"--model",
help=f"path to model file. Default: {DEFAULT_MODEL}",
default=DEFAULT_MODEL,
)
a = ap.parse_args()

loaded_model = LSTMModel.from_saved_model(a.model)
toks = ToTokens(a.s)
print(f"Initial tokens: {toks}")
print("Generated sequences:")
for _ in range(a.num_seqs):
print(
ToSmiles(
loaded_model.generate(
toks, temperature=a.temperature, max_length=a.max_length
)
)
)
Binary file added models/model.pth
Binary file not shown.
20 changes: 12 additions & 8 deletions tests/test_amsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,14 @@ def test_version():
assert amsr.__version__


def test_methane():
assert amsr.CheckSmiles("C")
def test_caffeine():
assert amsr.CheckSmiles("Cn1cnc2c1c(=O)n(C)c(=O)n2C")


def test_cage():
assert amsr.CheckAMSR("CCccCccc6oC..CCCC6C6.6")


def test_caffeine():
assert amsr.CheckSmiles("Cn1cnc2c1c(=O)n(C)c(=O)n2C")


def _read_csv(csv_file):
return pandas.read_csv(
os.path.join(os.path.dirname(__file__), "..", "data", csv_file)
Expand Down Expand Up @@ -54,11 +50,19 @@ def test_DEL():
_test_csv("DEL_compounds.csv")


def test_lstm():
lstm = amsr.LSTMModel.from_saved_model(
os.path.join(os.path.dirname(__file__), "..", "models", "model.pth")
)
for _ in range(20):
assert amsr.CheckAMSR(lstm.generate(["C"]))


def test_markov():
seed(0)
fda = _read_csv("some_FDA_approved_structures.csv")
markov = amsr.Markov([Chem.MolFromSmiles(s) for s in fda["SMILES"]])
for _ in range(100):
for _ in range(20):
assert amsr.CheckAMSR(markov.generate())


Expand All @@ -75,6 +79,6 @@ def test_modify():
def test_morph():
seed(0)
np = _read_csv("natural_products.csv")
for s, t in combinations(np["SMILES"], 2):
for s, t in combinations(np["SMILES"][:10], 2):
for a in amsr.morph.Morph.fromSmiles(s, t).amsr:
assert amsr.CheckAMSR(a)
20 changes: 20 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pandas
from amsr import FromMolToTokens
from lstm import LSTMModel
from rdkit.Chem import MolFromSmiles, RenumberAtoms
from random import shuffle

model = LSTMModel()
df = pandas.read_csv("chembl_33_filtered.csv")
smi = df.SMILES # .sample(1000)
a = []
for s in smi:
m = MolFromSmiles(s)
k = list(range(m.GetNumAtoms()))
for _ in range(10):
shuffle(k)
a.append(FromMolToTokens(RenumberAtoms(m, k)))
print(f"Training on {len(a)} AMSR strings")
pth = "model.pth"
model.train_and_save_model(a, pth)
print("Done.")

0 comments on commit e37e7f2

Please sign in to comment.