Skip to content

matt19234/cs_886_project

Repository files navigation

You will need a recent version of Python with PyTorch and scikit-learn installed.

The following three functionalities are provided.

> python make_data.py
> n: number of training sequences
> k: number of points in DSU graph
> l: length of training sequence (number of dsu-union(u, v) calls)
> save to: file.npz where the data should be saved.
A training dataset (tr.npz) and validation dataset (val.npz) are already provided. Regrettably, I could not include the test datasets (tst1-6.npz) in the repo because they prohibitively large (~7GB). They can be recreated using the following paramater settings for (n, k, l):
- (100, 20, 30)
- (100, 50, 75)
- (100, 100, 150)
- (100, 200, 300)
- (100, 300, 450)
- (100, 400, 600)

> python train.py
> is pgn (y/n): y if training the pgn model, otherwise gnn
> train data: file.npz containing training data (e.g. tr.npz)
> model: input.pt containing starting model parameters (leave blank to train from scratch)
> save to: output.pt path to save model params to
> epochs: number of epochs to train for
On my laptop, this takes around 5s/epoch. The best models pgn5.pt and gnn3.pt were both trained for 500 epochs.

> python test.py
> is pgn (y/n): y if pgn, otherwise gnn
> model: input.pt trained model params
> test data: file.npz containing test dataset (e.g. tst1.npz)
This outputs predictions vs. ground truth on one sequence. It also outputs F1 scores over all sequences in the test dataset as well as the mean +- standard deviation.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages