Simple-minded and (hopefully) easy to understand implementation of AD. The implementation is done in SML for two reasons:
-
To make sure that no "advanced" language features was used by accident (looking at you Haskell). Thus, everything should be easily ported to other languages.
-
To have relatively easy access to imperative features. Meaning that we are explicit about when we are using imperative features, and don't use them by accident, but only use them when it is essential for the algorithm. Still, the use of imperative features shouldn't be a big deal.
Both reasons are for focusing on what is state of the art.
The main use for code was for illustration during a number of talks introducing AD.
-
Simple.sml
simple forward-mode explained as fusion ofeval
anddiff
. Doesn't depend on any other files. -
Expr.sml
utility modules declaring the AST data type for expressions parameterized types for the labels at variables and (sub-)expressions. Helper functions for generating test expressions of varying size, and for printing expressions. Used byForward
andReverse
. -
Forward.sml
declares functions for symbolic differentiation and for forward-mode AD for multivariate functions. -
Reverse.sml
declares three different implementations of reverse-mode AD for multivariate functions:-
reverse
a fully functional implementation, show the three essential phases of reverse-mode: eval and decorate of sub-expressions, push derivatives to leafs, and sum all leaves per variable. -
reverse_fused
fuse the last two phases, and keep derivatives in an array, rather than at the labels at the leaves. -
reverse_imp
don't rebuild the AST during the first phase, but update the imperative references in the labels at the sub-expressions. Avoiding the rebuilding of AST is essential, because it destroys sharing and can thereby change the asymptotically complexity.
-
$ mosmlc -c -toplevel Expr.sml Forward.sml Reverse.sml
$ mosml
Moscow ML version 2.10
Enter `quit();' to quit.
- load"Reverse";
> val it = () : unit
- Expr.fibTime Forward.forward 28 100;
User: 6.582 System: 0.042 GC: 0.536 Real: 6.629
> val it = 196418.0 : real
- Expr.fibTime Reverse.reverse 28 100;
User: 3.495 System: 0.016 GC: 1.588 Real: 3.516
> val it = 196418.0 : real
- Expr.fibTime Reverse.reverse_fused 28 100;
User: 1.242 System: 0.004 GC: 0.763 Real: 1.247
> val it = 196418.0 : real
- Expr.fibTime Reverse.reverse_imp 28 100;
User: 0.411 System: 0.001 GC: 0.006 Real: 0.412
> val it = 196418.0 : real