Fairness through Aleatoric Uncertainty (JAX Bayes by Backprop) #17157
Unanswered
aniquetahir
asked this question in
Show and tell
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I would like to introduce our paper which just got accepted for CIKM about Fairness through Aleatoric Uncertainty:
https://arxiv.org/abs/2304.03646
JAX played a vital role to make this possible because we used its speed benefits to prototype faster (and very intuitive handling of the BNN parameters). Our JAX/Haiku implementation of Bayes by Backprop can be found here:
https://github.com/aniquetahir/GAIA/blob/master/utils/jax/models/bnn.py
To the best of my knowledge, this is the fastest implementation of this approach. It can be helpful as a model for a more generic version (since some fairness specific things are built-in). I embedded the jit compilation in a class method so its easy to replace different components using subclasses.
Hopefully the JAX folks don't forget us non-TPU academic users (at least not until they offer me a job at Google 😅).
Beta Was this translation helpful? Give feedback.
All reactions