A novel Transformer variant that is both an associative memory model and a continuous dynamical system with a tractable energy that is guaranteed to converge to a fixed point. See our paper for full details. Also, check out other official implementations of our work, see ET for Graph Anamoly Detection (PyTorch), ET for Image (PyTorch), and ET for Image (Jax).
pip install -r requirements.txt
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Note, it is important to read the official Jax installation guide to properly enable GPU and for further details. Additionally, see Jax Versions for installing specific Jax-cuda version. Credits to Ben Hoover for the diagrams.
Test the install by starting python and running the following code to check whether if GPU is enabled for Jax:
import jax
print(jax.local_devices())
Fortunately, PyTorch Geometric has provided awesome datasets and dataloaders which will automatically download datasets when code is ran. Simply change the provided dataset name for TUDataset or GNNBenchmark.
model_name = data_name = 'CIFAR10'
train_data = GNNBenchmarkDataset(root = '../data/', name = data_name, split = 'train')
Simply, navigate to the nbs folder for the provided Jupyter notebooks to run the experiments.
./run_nb_inplace nbs/eval_cifar10.ipynb
Since there are a number of provided pretrained models, please ensure that such files are removed or stored in a different folder such that they won't be reloaded.
./run_nb_inplace nbs/cifar10.ipynb
Some pretrained models are provided in the saved_models folder. To download the rest of the pretrained models, see Google Drive Link.
if you find the code or the work useful, please cite our work!
@article{hoover2023energy,
title={Energy Transformer},
author={Hoover, Benjamin and Liang, Yuchen and Pham, Bao and Panda, Rameswar and Strobelt, Hendrik and Chau, Duen Horng and Zaki, Mohammed J and Krotov, Dmitry},
journal={arXiv preprint arXiv:2302.07253},
year={2023}
}