The following repository contains a PyTorch-based recurrent neural network trained on the TIMIT dataset. It gives frame-level phones as output and is trained using CTC loss. The following architectures are implemented: LSTM, GRU, TCNN (and their bidirectional versions). A custom implementation of the recurrent neural networks is also provided which makes it easy to modify the core equations.
Folders:
- architectures/: various RNN architectures.
- models/: the dumped models which are used during training and inference.
- pickle/: the insertion, deletion and substitution probabilities.
- plots/: the train/test loss vs epoch figures and the PER on TIMIT Test as a function of the epochs.
Files:
- config.yaml: contains all hyperparameters used for training/inference.
- beam_search.py: implementation of beam search.
- metadata.py: reads the TIMIT dataset and returns a list of feature vectors and ground truth phones for each recording.
- dataloader.py: pads the sequences and returns batches which are used for training/testing.
- dl_model.py: contains the actual train/infer functions. Starting point of the code as far as training the model is concerned.
- utils.py: some common functions which are used throughout the project.
- hypo_search.py: contains functions which traverse the lattice generated by the RNN and picks out the best subsequence given a target sequence. Currently checking only the top-5 lattices.
- extract_q_values.py: computes Q-values as specified in the paper for each phone which is used as a threshold during inference.
- infer.py: specify keywords here and carry out grid search over the hyperparameters.
For training:
- Place the parent directory of TIMIT dataset in the config.yaml file (config['dir']['dataset']).
- Customise the hyperparameters in config.yaml.
- Run the function train in dl_model.py.
- Models are periodically dumped in the following folder: models/<name_of_model>_<number_of_layers>_<number_of_hidden_units>_<number_of_audio_features> e.g. GRU_5_284_79 for a 5-layer GRU with 384 hidden units and 79 audio features.
For inference:
- Ensure that the model/<model_name> (as specified above) folder contains the pre-trained model.
- Run the function infer in dl_model.py with argument as a list of file paths of .wav files which are to be passed through the model. The function returns a list of tensors, each of shape time_step x number_of_phones
Check the commented code at the very end in dl_model.py for an illustration of how the results are generated.