A few others have released amazing related work which helped inspire and improve my own implementation. It goes without saying that this release would not be nearly as good if it were not for all of the following:
- @avivt (Paper Author, MATLAB implementation)
- @zuoxingdong (Tensorflow implementation, Pytorch implementation)
- @TheAbhiKumar (Tensorflow implementation)
- @onlytailei (Pytorch implementation)
- The Pytorch VIN model in this repository is, in my opinion, more readable and closer to the original Theano implementation than others I have found (both Tensorflow and Pytorch).
- This is not simply an implementation of the VIN model in Pytorch, it is also a full Python implementation of the gridworld environments as used in the original MATLAB implementation.
- Provide a more extensible research base for others to build off of without needing to jump through the possible MATLAB paywall.
This repository requires following packages:
- SciPy >= 0.19.0
- Python >= 2.7 (if using Python 3.x: python3-tk should be installed)
- Numpy >= 1.12.1
- Matplotlib >= 2.0.0
- PyTorch >= 0.1.11
Use pip
to install the necessary dependencies:
pip install -U -r requirements.txt
Note that PyTorch cannot be installed directly from PyPI; refer to http://pytorch.org/ for custom installation instructions specific to your needs.
python train.py --datafile dataset/gridworld_8x8.npz --imsize 8 --lr 0.005 --epochs 30 --k 10 --batch_size 128
python train.py --datafile dataset/gridworld_16x16.npz --imsize 16 --lr 0.002 --epochs 30 --k 20 --batch_size 128
python train.py --datafile dataset/gridworld_28x28.npz --imsize 28 --lr 0.002 --epochs 30 --k 36 --batch_size 128
Flags:
datafile
: The path to the data files.imsize
: The size of input images. One of: [8, 16, 28]lr
: Learning rate with RMSProp optimizer. Recommended: [0.01, 0.005, 0.002, 0.001]epochs
: Number of epochs to train. Default: 30k
: Number of Value Iterations. Recommended: [10 for 8x8, 20 for 16x16, 36 for 28x28]l_i
: Number of channels in input layer. Default: 2, i.e. obstacles image and goal image.l_h
: Number of channels in first convolutional layer. Default: 150, described in paper.l_q
: Number of channels in q layer (~actions) in VI-module. Default: 10, described in paper.batch_size
: Batch size. Default: 128
python test.py --weights trained/vin_8x8.pth --imsize 8 --k 10
python test.py --weights trained/vin_16x16.pth --imsize 16 --k 20
python test.py --weights trained/vin_28x28.pth --imsize 28 --k 36
To visualize the optimal and predicted paths simply pass:
--plot
Flags:
weights
: Path to trained weights.imsize
: The size of input images. One of: [8, 16, 28]plot
: If supplied, the optimal and predicted paths will be plottedk
: Number of Value Iterations. Recommended: [10 for 8x8, 20 for 16x16, 36 for 28x28]l_i
: Number of channels in input layer. Default: 2, i.e. obstacles image and goal image.l_h
: Number of channels in first convolutional layer. Default: 150, described in paper.l_q
: Number of channels in q layer (~actions) in VI-module. Default: 10, described in paper.
Gridworld | Sample One | Sample Two |
---|---|---|
8x8 | ||
16x16 | ||
28x28 |
Each data sample consists of an obstacle image and a goal image followed by the (x, y) coordinates of current state in the gridworld.
Dataset size | 8x8 | 16x16 | 28x28 |
---|---|---|---|
Train set | 81337 | 456309 | 1529584 |
Test set | 13846 | 77203 | 251755 |
The datasets (8x8, 16x16, and 28x28) included in this repository can be reproduced using the dataset/make_training_data.py
script. Note that this script is not optimized and runs rather slowly (also uses a lot of memory :D)
This is the success rate from rollouts of the learned policy in the environment (taken over 5000 randomly generated domains).
Success Rate | 8x8 | 16x16 | 28x28 |
---|---|---|---|
PyTorch | 99.69% | 96.99% | 91.07% |
NOTE: This is the accuracy on test set. It is different from the table in the paper, which indicates the success rate from rollouts of the learned policy in the environment.
Test Accuracy | 8x8 | 16x16 | 28x28 |
---|---|---|---|
PyTorch | 99.83% | 94.84% | 88.54% |