-
Notifications
You must be signed in to change notification settings - Fork 6
Add Mooncake and Enzyme tests for Diagonal #179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -15,19 +15,6 @@ using LinearAlgebra | |||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| Mooncake.tangent_type(::Type{<:MatrixAlgebraKit.AbstractAlgorithm}) = Mooncake.NoTangent | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(copy_input), Any, Any} | ||||||||||||||||||||||||||||||||||||||||||||||||
| function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDual) | ||||||||||||||||||||||||||||||||||||||||||||||||
| Ac = copy_input(Mooncake.primal(f_df), Mooncake.primal(A_dA)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| Ac_dAc = Mooncake.zero_fcodual(Ac) | ||||||||||||||||||||||||||||||||||||||||||||||||
| dAc = Mooncake.tangent(Ac_dAc) | ||||||||||||||||||||||||||||||||||||||||||||||||
| function copy_input_pb(::NoRData) | ||||||||||||||||||||||||||||||||||||||||||||||||
| Mooncake.increment!!(Mooncake.tangent(A_dA), dAc) | ||||||||||||||||||||||||||||||||||||||||||||||||
| return NoRData(), NoRData(), NoRData() | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
| return Ac_dAc, copy_input_pb | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(initialize_output), Any, Any, Any} | ||||||||||||||||||||||||||||||||||||||||||||||||
| # two-argument in-place factorizations like LQ, QR, EIG | ||||||||||||||||||||||||||||||||||||||||||||||||
| for (f!, f, pb, adj) in ( | ||||||||||||||||||||||||||||||||||||||||||||||||
| (:qr_full!, :qr_full, :qr_pullback!, :qr_adjoint), | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -53,13 +40,29 @@ for (f!, f, pb, adj) in ( | |||||||||||||||||||||||||||||||||||||||||||||||
| arg2c = copy(arg2) | ||||||||||||||||||||||||||||||||||||||||||||||||
| $f!(A, args, Mooncake.primal(alg_dalg)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| function $adj(::NoRData) | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(A, Ac) | ||||||||||||||||||||||||||||||||||||||||||||||||
| $pb(dA, A, (arg1, arg2), (darg1, darg2)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(arg1, arg1c) | ||||||||||||||||||||||||||||||||||||||||||||||||
| # DON'T copy Ac to A if A === one | ||||||||||||||||||||||||||||||||||||||||||||||||
| # of the output args -- this can | ||||||||||||||||||||||||||||||||||||||||||||||||
| # mess up the pullback because | ||||||||||||||||||||||||||||||||||||||||||||||||
| # generally the args are used there | ||||||||||||||||||||||||||||||||||||||||||||||||
| if !(A === arg1 || A === arg2) | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(A, Ac) | ||||||||||||||||||||||||||||||||||||||||||||||||
| $pb(dA, A, (arg1, arg2), (darg1, darg2)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||||||||
| ΔA = zero(A) | ||||||||||||||||||||||||||||||||||||||||||||||||
| $pb(ΔA, A, (arg1, arg2), (darg1, darg2)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| dA .= ΔA | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+47
to
+54
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Again mostly a readability suggestion. (Note that I might again be missing something about the use of
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See comment above about
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion for the bottom part still stands though :)
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NB that if
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pushed a small cleanup here at least... |
||||||||||||||||||||||||||||||||||||||||||||||||
| if A === arg1 | ||||||||||||||||||||||||||||||||||||||||||||||||
| zero!(darg2) | ||||||||||||||||||||||||||||||||||||||||||||||||
| elseif A === arg2 | ||||||||||||||||||||||||||||||||||||||||||||||||
| zero!(darg1) | ||||||||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||||||||
| zero!(darg1) | ||||||||||||||||||||||||||||||||||||||||||||||||
| zero!(darg2) | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(arg2, arg2c) | ||||||||||||||||||||||||||||||||||||||||||||||||
| zero!(darg1) | ||||||||||||||||||||||||||||||||||||||||||||||||
| zero!(darg2) | ||||||||||||||||||||||||||||||||||||||||||||||||
| return NoRData(), NoRData(), NoRData(), NoRData() | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(arg1, arg1c) | ||||||||||||||||||||||||||||||||||||||||||||||||
| return ntuple(Returns(NoRData()), 4) | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
| return args_dargs, $adj | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -140,9 +143,19 @@ for (f!, f, f_full, pb, adj) in ( | |||||||||||||||||||||||||||||||||||||||||||||||
| copy!(D, diagview(DV[1])) | ||||||||||||||||||||||||||||||||||||||||||||||||
| V = DV[2] | ||||||||||||||||||||||||||||||||||||||||||||||||
| function $adj(::NoRData) | ||||||||||||||||||||||||||||||||||||||||||||||||
| $pb(dA, A, DV, dD) | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(D, Dc) | ||||||||||||||||||||||||||||||||||||||||||||||||
| zero!(dD) | ||||||||||||||||||||||||||||||||||||||||||||||||
| if A !== D | ||||||||||||||||||||||||||||||||||||||||||||||||
| $pb(dA, A, DV, dD) | ||||||||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||||||||
| ΔA = zero(A) | ||||||||||||||||||||||||||||||||||||||||||||||||
| $pb(ΔA, A, DV, dD) | ||||||||||||||||||||||||||||||||||||||||||||||||
| dA .= A | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
| if A !== D | ||||||||||||||||||||||||||||||||||||||||||||||||
| zero!(dD) | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(D, Dc) | ||||||||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(A, Ac) | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
| return NoRData(), NoRData(), NoRData(), NoRData() | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
| return D_dD, $adj | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -199,15 +212,27 @@ for f in (:eig, :eigh) | |||||||||||||||||||||||||||||||||||||||||||||||
| # not for nested structs with various fields (like Diagonal{Complex}) | ||||||||||||||||||||||||||||||||||||||||||||||||
| output_codual = Mooncake.zero_fcodual(output) | ||||||||||||||||||||||||||||||||||||||||||||||||
| function $f_adjoint!(dy::Tuple{NoRData, NoRData, <:Real}) | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(A, Ac) | ||||||||||||||||||||||||||||||||||||||||||||||||
| Dtrunc, Vtrunc, ϵ = Mooncake.primal(output_codual) | ||||||||||||||||||||||||||||||||||||||||||||||||
| dDtrunc_, dVtrunc_, dϵ = Mooncake.tangent(output_codual) | ||||||||||||||||||||||||||||||||||||||||||||||||
| _warn_pullback_truncerror(dy[3]) | ||||||||||||||||||||||||||||||||||||||||||||||||
| D′, dD′ = arrayify(Dtrunc, dDtrunc_) | ||||||||||||||||||||||||||||||||||||||||||||||||
| V′, dV′ = arrayify(Vtrunc, dVtrunc_) | ||||||||||||||||||||||||||||||||||||||||||||||||
| $f_trunc_pullback!(dA, A, (D′, V′), (dD′, dV′)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(DV[1], DVc[1]) | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(DV[2], DVc[2]) | ||||||||||||||||||||||||||||||||||||||||||||||||
| D, dD = arrayify(DV[1], dDV[1]) | ||||||||||||||||||||||||||||||||||||||||||||||||
| V, dV = arrayify(DV[2], dDV[2]) | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(A, Ac) | ||||||||||||||||||||||||||||||||||||||||||||||||
| if !(A === D || A === V) | ||||||||||||||||||||||||||||||||||||||||||||||||
| $f_trunc_pullback!(dA, A, (D′, V′), (dD′, dV′)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||||||||
| ΔA = zero(A) | ||||||||||||||||||||||||||||||||||||||||||||||||
| $f_trunc_pullback!(ΔA, A, (D′, V′), (dD′, dV′)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| dA .= ΔA | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
| if A === D | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(DV[2], DVc[2]) | ||||||||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(DV[1], DVc[1]) | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(DV[2], DVc[2]) | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
| zero!(dD′) | ||||||||||||||||||||||||||||||||||||||||||||||||
| zero!(dV′) | ||||||||||||||||||||||||||||||||||||||||||||||||
| return NoRData(), NoRData(), NoRData(), NoRData() | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -239,12 +264,22 @@ for f in (:eig, :eigh) | |||||||||||||||||||||||||||||||||||||||||||||||
| _warn_pullback_truncerror(dϵ) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| # compute pullbacks | ||||||||||||||||||||||||||||||||||||||||||||||||
| $f_pullback!(dA, Ac, DV, dDVtrunc, ind) | ||||||||||||||||||||||||||||||||||||||||||||||||
| zero!.(dDVtrunc) # since this is allocated in this function this is probably not required | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| if !(A === DV[1] || A === DV[2]) | ||||||||||||||||||||||||||||||||||||||||||||||||
| $f_pullback!(dA, Ac, DV, dDVtrunc, ind) | ||||||||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||||||||
| ΔA = zero(A) | ||||||||||||||||||||||||||||||||||||||||||||||||
| $f_pullback!(ΔA, Ac, DV, dDVtrunc, ind) | ||||||||||||||||||||||||||||||||||||||||||||||||
| dA .= ΔA | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
| # restore state | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(A, Ac) | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!.(DV, DVc) | ||||||||||||||||||||||||||||||||||||||||||||||||
| if A === DV[1] | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(DV[2], DVc[2]) | ||||||||||||||||||||||||||||||||||||||||||||||||
| zero!(dDV[2]) | ||||||||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!.(DV, DVc) | ||||||||||||||||||||||||||||||||||||||||||||||||
| zero!.(dDV) | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| return ntuple(Returns(NoRData()), 4) | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -351,12 +386,23 @@ for f in (:eig, :eigh) | |||||||||||||||||||||||||||||||||||||||||||||||
| dDVtrunc = last.(arrayify.(DVtrunc, Mooncake.tangent(DVtrunc_dDVtrunc))) | ||||||||||||||||||||||||||||||||||||||||||||||||
| function $f_adjoint!(::NoRData) | ||||||||||||||||||||||||||||||||||||||||||||||||
| # compute pullbacks | ||||||||||||||||||||||||||||||||||||||||||||||||
| $f_pullback!(dA, Ac, DV, dDVtrunc, ind) | ||||||||||||||||||||||||||||||||||||||||||||||||
| zero!.(dDV) | ||||||||||||||||||||||||||||||||||||||||||||||||
| if !(A === DV[1] || A === DV[2]) | ||||||||||||||||||||||||||||||||||||||||||||||||
| $f_pullback!(dA, Ac, DV, dDVtrunc, ind) | ||||||||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||||||||
| ΔA = zero(A) | ||||||||||||||||||||||||||||||||||||||||||||||||
| $f_pullback!(ΔA, Ac, DV, dDVtrunc, ind) | ||||||||||||||||||||||||||||||||||||||||||||||||
| dA .= ΔA | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| # restore state | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(A, Ac) | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!.(DV, DVc) | ||||||||||||||||||||||||||||||||||||||||||||||||
| if A === DV[1] | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!(DV[2], DVc[2]) | ||||||||||||||||||||||||||||||||||||||||||||||||
| zero!(dDV[2]) | ||||||||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||||||||
| copy!.(DV, DVc) | ||||||||||||||||||||||||||||||||||||||||||||||||
| zero!.(dDV) | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| return ntuple(Returns(NoRData()), 4) | ||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code-organization-wise, I am not a huge fan of this hardcoded special case for different types.
Why do you need the specific specialization here? Could it not just be
A.dval === D.dval?That being said, we should probably just add an implementation of
isalias(A, arg).Maybe something like
Base.mightaliascould be a more generic solution, or even a fallback definition, but that function is technically internal, and similar problems hold for checkingBase.dataids.I might actually be okay with depending on that though, that has been quite stable and does seem to me like a good way of going at this. ( and we already secretly use this in TensorOperations! )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the
_valsmethods we can't doA.dval === D.dvalbecauseDis aVector, andAis aDiagonal, whosediagfield points toDThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah you are right, my bad! I think then I would even more think an
isalias(A, arg) = Base.mightalias(A, arg)is the right abstraction (I would leave in the hook sinceBase.mightaliasis internal, and this allows us to overload without piracy)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that sounds good to me. We could do it as part of this PR or separately?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a lot of goofy special cases running around here haha
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be reasonable to include it in this PR, as this does seem to be the primary change that is needed to solve the issue? Happy to defer it if you prefer though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TBH this PR is complex enough, I might do the
isaliasas a followup?