-
Notifications
You must be signed in to change notification settings - Fork 0
matt19234/cs_886_project
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
You will need a recent version of Python with PyTorch and scikit-learn installed. The following three functionalities are provided. > python make_data.py > n: number of training sequences > k: number of points in DSU graph > l: length of training sequence (number of dsu-union(u, v) calls) > save to: file.npz where the data should be saved. A training dataset (tr.npz) and validation dataset (val.npz) are already provided. Regrettably, I could not include the test datasets (tst1-6.npz) in the repo because they prohibitively large (~7GB). They can be recreated using the following paramater settings for (n, k, l): - (100, 20, 30) - (100, 50, 75) - (100, 100, 150) - (100, 200, 300) - (100, 300, 450) - (100, 400, 600) > python train.py > is pgn (y/n): y if training the pgn model, otherwise gnn > train data: file.npz containing training data (e.g. tr.npz) > model: input.pt containing starting model parameters (leave blank to train from scratch) > save to: output.pt path to save model params to > epochs: number of epochs to train for On my laptop, this takes around 5s/epoch. The best models pgn5.pt and gnn3.pt were both trained for 500 epochs. > python test.py > is pgn (y/n): y if pgn, otherwise gnn > model: input.pt trained model params > test data: file.npz containing test dataset (e.g. tst1.npz) This outputs predictions vs. ground truth on one sequence. It also outputs F1 scores over all sequences in the test dataset as well as the mean +- standard deviation.
About
No description, website, or topics provided.
Resources
Stars
Watchers
Forks
Releases
No releases published
Packages 0
No packages published