This package implements Generative Adversarial Networks (GANs) and Denoising Diffusion Probabilistic Models (DDPMs) for generative tasks such as conditional molecular structure generation.
To use molgen
, you will need an environment with the following packages:
- Python 3.7+
- PyTorch
- PyTorch Lightning
- Einops
For running and visualizing examples:
Once you have these packages installed, you can install molgen
in the same environment using
$ pip install -e .
Once installed, you can use the package. This example trains a WGANGP to reproduce the alanine dipeptide backbone atoms conditioned on the backbone diedral angles (examples
directory.
from molgen.models import WGANGP
from pathlib import Path
import mdtaj as md
import torch
import numpy as np
# load data
pdb_fname = 'examples/data/alanine-dipeptide-nowater.pdb'
trj_fnames = [str(i) for i in Path('examples/data/').glob('alanine-dipeptide-*-250ns-nowater.xtc')]
trjs = [md.load(t, top=pdb_fname).center_coordinates() for t in trj_fnames]
# process xyz coordinates and conditioning variables
xyz = list()
phi_psi = list()
for trj in trjs:
t_backbone = trj.atom_slice(trj.top.select('backbone')).center_coordinates()
n = trj.xyz.shape[0]
_, phi = md.compute_phi(trj)
_, psi = md.compute_psi(trj)
xyz.append(torch.tensor(t_backbone.xyz.reshape(n, -1)).float())
phi_psi.append(torch.tensor(np.concatenate((phi, psi), -1)).float())
# ininstantiate the model
model = WGANGP(xyz[0].shape[1], phi_psi[0].shape[1])
# fit the model
model.fit(xyz, phi_psi, max_epochs=25)
# Generate synthetic configurations
xyz_gen = model.generate(torch.cat(phi_psi))
xyz_gen = xyz_gen.reshape(xyz_gen.size(0), -1, 3)
# Save model checkpoint
model.save('ADP.ckpt')
# Load from checkpoint
model = WGANGP.load_from_checkpoint('ADP.ckpt')
Supports both generators based on both Generative Adversarial Networks (GANs) and Denoising Diffusion Probabilistic Models (DDPMs). The example above uses GANs, DDPMs support an equivalent API -- for example,
from molgen.models import DDPM
model = DDPM(....)
Code for the DDPM models are taken from: https://github.com/lucidrains/denoising-diffusion-pytorch (version 1.0.5)
Copyright (c) 2023, Kirill Shmilovich
Project based on the Computational Molecular Science Python Cookiecutter version 1.1.