Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does CrabNet use the validation data to improve the model? #15

Open
sgbaird opened this issue Sep 23, 2021 · 4 comments
Open

Does CrabNet use the validation data to improve the model? #15

sgbaird opened this issue Sep 23, 2021 · 4 comments

Comments

@sgbaird
Copy link
Collaborator

sgbaird commented Sep 23, 2021

For example, as it is progressing through epochs, would the results change at all if dummy validation were supplied instead of "true" validation data?

@anthony-wang
Copy link
Owner

anthony-wang commented Sep 23, 2021 via email

@sgbaird
Copy link
Collaborator Author

sgbaird commented Sep 23, 2021

Good point. I think I've been a bit confused as to whether sometimes people use "validation data" to adjust hyperparameters of the model (in some sense I think we all do this, but maybe a bit more manually as we run models while debugging/developing), hence the idea of a third "test" set. Just wanted to make sure. I agree with your definition. Thanks Anthony!

@sgbaird
Copy link
Collaborator Author

sgbaird commented Feb 7, 2022

Ok, I think my suspicion is confirmed (validation data is used to improve CrabNet's training results). I had been looking and wondering where the validation set may have been used during the training process, and I found somewhere in train that appears to use the validation data to improve generalization.

CrabNet/crabnet/model.py

Lines 115 to 119 in e884482

if learning_time:
act_v, pred_v, _, _ = self.predict(self.data_loader)
mae_v = mean_absolute_error(act_v, pred_v)
self.optimizer.update_swa(mae_v)
minima.append(self.optimizer.minimum_found)

Near the end of training (not sure why this is on the 2nd to last epoch instead of the last epoch), it goes into L273:

CrabNet/crabnet/model.py

Lines 272 to 273 in e884482

if not (self.optimizer.discard_count >= self.discard_n):
self.optimizer.swap_swa_sgd()

See PyTorch 1.6 now includes Stochastic Weight Averaging (blog)

Recently, I came to an error where pred_v was all np.nan in one of the crabnet-hyperparameter combinations, and decided to check the size of pred_v using the original repo and example_materials_property. pred_v turned out to be the same size as the validation data (727), compared to training batch_size==256 and was consistent across consecutive jumps to the same breakpoint. The total # of training data points is 3433 in this case.

@anthony-wang @Kaaiian

@sgbaird
Copy link
Collaborator Author

sgbaird commented Feb 7, 2022

Empirical tests

Empirical tests also seem to confirm this (similar to my question about using "dummy" validation data). 42, 43, and 44 refer to the torch seeds used. The 3 numbers that follow the colon (:) are the train/val/test MAEs.

1. Using validation data to calculate mae for update_swa

seed: train MAE, val MAE, test MAE
42: 6.53, 10.7, 9.82
43: 6.38, 10.3, 9.63
44: 6.49, 10.5, 9.95

2. Shuffling validation predictions just prior to swa_update

Epoch 39 failed to improve.
Discarded: 1/3 weight updates ♻🗑️

42: 9.77, 11.9, 11.4
43: 9.19, 11.5, 10.7
44: 9.77, 11.9, 11.4

3. Using training data to calculate mae for update_swa

44: 6.35, 10.6, 9.85

4. Add val.csv data to train.csv and leave val.csv intact, use train.csv for update_swa

This is relevant to #19

42: 6.49, 6.96, 9.39
43: 6.58, 6.98, 9.38
44: 6.68, 6.89, 9.64

5. Add val.csv data to train.csv and drop all but 1 datapoint from val.csv, use train.csv for update_swa

Just to make sure that val.csv isn't affecting the test score in other ways than what I've described here. Since the results match up with those immediately above, I think SWA is the only place where val.csv affects the training process.

42: 6.49, 27.3, 9.39
43: 6.58, 26, 9.38
44: 6.68, 22.1, 9.64

Takeaways

  • The validation data affects the training process through SWA
  • The average test MAE is somewhat lower for (4) than for (1), (9.74 vs. 9.80), but the difference is marginal enough that it might be the opposite for other datasets. In general, I would think the current setup (1) is probably more robust to extrapolation. Adding the 25% validation data back into training data in the matbench benchmark may not necessarily improve model performance CrabNet matbench results - possibly neglecting 25% of the training data it could have used #19.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants