-
Notifications
You must be signed in to change notification settings - Fork 38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make autodiff insertion more versatile #181
base: master
Are you sure you want to change the base?
Conversation
Codecov ReportBase: 97.85% // Head: 97.86% // Increases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## master #181 +/- ##
==========================================
+ Coverage 97.85% 97.86% +0.01%
==========================================
Files 16 16
Lines 1211 1219 +8
==========================================
+ Hits 1185 1193 +8
Misses 26 26
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
Ready for review, @fredrikekre or @KristofferC |
Found macro implement_gradient(f, f_dfdx)
return :($(esc(f))(x :: Union{AllTensors{<:Any, <:Dual}, Dual}) = _propagate_gradient($(esc(f_dfdx)), x))
end Still less specific than |
@@ -250,11 +250,26 @@ be of symmetric type | |||
|
|||
""" | |||
macro implement_gradient(f, f_dfdx) | |||
return :($(esc(f))(x :: Union{AbstractTensor{<:Any, <:Any, <:Dual}, Dual}) = _propagate_gradient($(esc(f_dfdx)), x)) | |||
return :($(esc(f))(x :: Union{AbstractTensor{<:Any, <:Any, <:Dual}, Dual}) = propagate_gradient($(esc(f_dfdx)), x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return :($(esc(f))(x :: Union{AbstractTensor{<:Any, <:Any, <:Dual}, Dual}) = propagate_gradient($(esc(f_dfdx)), x)) | |
return :($(esc(f))(x :: Union{AbstractTensor{<:Any, <:Any, <:Dual}, Dual}, args...) = propagate_gradient($(esc(f_dfdx)), x, args...)) |
Note from #197 , requires corresponding update to propagate_gradient
Two problems are clear from #179
propagate_gradient
function to give the user full control over type specification for dispatching with dual numbers. This makes it possible to solve point 1 above, and this is also documented with an example.