diff --git a/Project.toml b/Project.toml index 71501ad..a31c31c 100644 --- a/Project.toml +++ b/Project.toml @@ -30,7 +30,7 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" TensorOperationsBumperExt = "Bumper" TensorOperationsChainRulesCoreExt = "ChainRulesCore" TensorOperationsMooncakeExt = "Mooncake" -TensorOperationsEnzymeExt = ["Enzyme", "ChainRulesCore"] +TensorOperationsEnzymeExt = "Enzyme" TensorOperationscuTENSORExt = ["cuTENSOR", "CUDA"] [compat] diff --git a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl index 8b5106f..9b40eb7 100644 --- a/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl +++ b/ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl @@ -4,19 +4,44 @@ using TensorOperations using TensorOperations: AbstractBackend, DefaultAllocator, CUDAAllocator, ManualAllocator using VectorInterface using TupleTools -using Enzyme, ChainRulesCore +using Enzyme using Enzyme.EnzymeCore using Enzyme.EnzymeCore: EnzymeRules @inline EnzymeRules.inactive(::typeof(TensorOperations.tensorfree!), ::Any) = true -Enzyme.@import_rrule(typeof(TensorOperations.tensoralloc), Any, Any, Any, Any) - @inline EnzymeRules.inactive_type(v::Type{<:AbstractBackend}) = true @inline EnzymeRules.inactive_type(v::Type{DefaultAllocator}) = true @inline EnzymeRules.inactive_type(v::Type{<:CUDAAllocator}) = true @inline EnzymeRules.inactive_type(v::Type{ManualAllocator}) = true @inline EnzymeRules.inactive_type(v::Type{<:Index2Tuple}) = true +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TensorOperations.tensoralloc)}, + ::Type{RT}, + ttype::Const, + structure::Const, + istemp::Const{Bool}, + allocator::Const + ) where {RT} + primal = EnzymeRules.needs_primal(config) ? TensorOperations.tensoralloc(ttype.val, structure.val, Val(false), allocator.val) : nothing + shadow = EnzymeRules.needs_shadow(config) ? TensorOperations.tensoralloc(ttype.val, structure.val, Val(false), allocator.val) : nothing + return EnzymeRules.AugmentedReturn(primal, shadow, nothing) +end + +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(TensorOperations.tensoralloc)}, + ::Type{RT}, + cache, + ttype::Const, + structure::Const, + istemp::Const{Bool}, + allocator::Const, + ) where {RT} + return nothing, nothing, nothing, nothing +end + function EnzymeRules.augmented_primal( config::EnzymeRules.RevConfigWidth{1}, func::Const{typeof(TensorOperations.tensorcontract!)}, @@ -36,7 +61,7 @@ function EnzymeRules.augmented_primal( # form caches if needed cache_A = EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing cache_B = EnzymeRules.overwritten(config)[6] ? copy(B_dB.val) : nothing - cache_C = !iszero(β_dβ.val) ? copy(C_dC.val) : C_dC.val + cache_C = !isa(β_dβ, Const) ? copy(C_dC.val) : C_dC.val ba = map(ba_ -> getfield(ba_, :val), ba_dba) TensorOperations.tensorcontract!(C_dC.val, A_dA.val, pA_dpA.val, conjA_dconjA.val, B_dB.val, pB_dpB.val, conjB_dconjB.val, pAB_dpAB.val, α_dα.val, β_dβ.val, ba...) primal = EnzymeRules.needs_primal(config) ? C_dC.val : nothing @@ -66,15 +91,39 @@ function EnzymeRules.reverse( Bval = something(cache_B, B_dB.val) Cval = cache_C # good way to check that we don't use it accidentally when we should not be needing it? - dC = C_dC.dval - dA = A_dA.dval - dB = B_dB.dval ba = map(ba_ -> getfield(ba_, :val), ba_dba) α = α_dα.val β = β_dβ.val pA, pB, pAB, conjA, conjB = getfield.((pA_dpA, pB_dpB, pAB_dpAB, conjA_dconjA, conjB_dconjB), :val) - dC, dA, dB, dα, dβ = TensorOperations.tensorcontract_pullback!(dC, dA, dB, Cval, Aval, pA, conjA, Bval, pB, conjB, pAB, α, β, ba...) - return nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)... + + if !isa(A_dA, Const) && !isa(C_dC, Const) + ΔC = C_dC.dval + ΔA = A_dA.dval + TensorOperations.tensorcontract_pullback_dA!(ΔA, ΔC, Cval, Aval, pA, conjA, Bval, pB, conjB, pAB, α, ba...) + end + if !isa(B_dB, Const) && !isa(C_dC, Const) + ΔC = C_dC.dval + ΔB = B_dB.dval + TensorOperations.tensorcontract_pullback_dB!(ΔB, ΔC, Cval, Aval, pA, conjA, Bval, pB, conjB, pAB, α, ba...) + end + Δα = if !isa(α_dα, Const) && !isa(C_dC, Const) + ΔC = C_dC.dval + TensorOperations.tensorcontract_pullback_dα(ΔC, Cval, Aval, pA, conjA, Bval, pB, conjB, pAB, α, ba...) + elseif !isa(α_dα, Const) + zero(α_dα.val) + else + nothing + end + Δβ = if !isa(β_dβ, Const) && !isa(C_dC, Const) + ΔC = C_dC.dval + TensorOperations.pullback_dβ(ΔC, Cval, β) + elseif !isa(β_dβ, Const) + zero(β_dβ.val) + else + nothing + end + !isa(C_dC, Const) && TensorOperations.pullback_dC!(C_dC.dval, β) + return nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, Δα, Δβ, map(ba_ -> nothing, ba)... end function EnzymeRules.augmented_primal( @@ -123,10 +172,30 @@ function EnzymeRules.reverse( α = α_dα.val β = β_dβ.val ba = map(ba_ -> getfield(ba_, :val), ba_dba) - dC = C_dC.dval - dA = A_dA.dval - dC, dA, dα, dβ = TensorOperations.tensoradd_pullback!(dC, dA, Cval, Aval, pA, conjA, α, β, ba...) - return nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)... + + if !isa(A_dA, Const) && !isa(C_dC, Const) + ΔC = C_dC.dval + ΔA = A_dA.dval + TensorOperations.tensoradd_pullback_dA!(ΔA, ΔC, Cval, Aval, pA, conjA, α, ba...) + end + Δα = if !isa(α_dα, Const) && !isa(C_dC, Const) + ΔC = C_dC.dval + TensorOperations.tensoradd_pullback_dα(ΔC, Cval, Aval, pA, conjA, α, ba...) + elseif !isa(α_dα, Const) + zero(α_dα.val) + else + nothing + end + Δβ = if !isa(β_dβ, Const) && !isa(C_dC, Const) + ΔC = C_dC.dval + TensorOperations.pullback_dβ(ΔC, Cval, β) + elseif !isa(β_dβ, Const) + zero(β_dβ.val) + else + nothing + end + !isa(C_dC, Const) && TensorOperations.pullback_dC!(C_dC.dval, β) + return nothing, nothing, nothing, nothing, Δα, Δβ, map(ba_ -> nothing, ba)... end function EnzymeRules.augmented_primal( @@ -144,7 +213,7 @@ function EnzymeRules.augmented_primal( ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number} # form caches if needed cache_A = EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing - cache_C = !iszero(β_dβ.val) ? copy(C_dC.val) : C_dC.val + cache_C = !isa(β_dβ, Const) ? copy(C_dC.val) : nothing ba = map(ba_ -> getfield(ba_, :val), ba_dba) α = α_dα.val β = β_dβ.val @@ -171,17 +240,37 @@ function EnzymeRules.reverse( ) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number} cache_A, cache_C = cache Aval = something(cache_A, A_dA.val) - Cval = cache_C + Cval = something(cache_C, C_dC.val) p = p_dp.val q = q_dq.val conjA = conjA_dconjA.val α = α_dα.val β = β_dβ.val ba = map(ba_ -> getfield(ba_, :val), ba_dba) - dC = C_dC.dval - dA = A_dA.dval - dC, dA, dα, dβ = TensorOperations.tensortrace_pullback!(dC, dA, Cval, Aval, p, q, conjA, α, β, ba...) - return nothing, nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)... + + if !isa(A_dA, Const) && !isa(C_dC, Const) + ΔC = C_dC.dval + ΔA = A_dA.dval + TensorOperations.tensortrace_pullback_dA!(ΔA, ΔC, Cval, Aval, p, q, conjA, α, ba...) + end + Δα = if !isa(α_dα, Const) && !isa(C_dC, Const) + ΔC = C_dC.dval + TensorOperations.tensortrace_pullback_dα(ΔC, Cval, Aval, p, q, conjA, α, ba...) + elseif !isa(α_dα, Const) + zero(α_dα.val) + else + nothing + end + Δβ = if !isa(β_dβ, Const) && !isa(C_dC, Const) + ΔC = C_dC.dval + TensorOperations.pullback_dβ(ΔC, Cval, β) + elseif !isa(β_dβ, Const) + zero(β_dβ.val) + else + nothing + end + !isa(C_dC, Const) && TensorOperations.pullback_dC!(C_dC.dval, β) + return nothing, nothing, nothing, nothing, nothing, Δα, Δβ, map(ba_ -> nothing, ba)... end end diff --git a/src/pullbacks/add.jl b/src/pullbacks/add.jl index a71cbfa..a3d1010 100644 --- a/src/pullbacks/add.jl +++ b/src/pullbacks/add.jl @@ -46,11 +46,12 @@ Compute the pullback for [`tensoradd!]`(ref) with respect to scaling coefficient """ function tensoradd_pullback_dα(ΔC, C, A, pA::Index2Tuple, conjA::Bool, α, ba...) _needs_tangent(α) || return nothing - return tensorscalar( + Δα = tensorscalar( tensorcontract( A, repartition(pA, 0), !conjA, ΔC, trivialpermutation(numind(pA), 0), false, ((), ()), One(), ba... ) ) + return project_scalar(α, Δα) end diff --git a/src/pullbacks/common.jl b/src/pullbacks/common.jl index c561c43..b0624a3 100644 --- a/src/pullbacks/common.jl +++ b/src/pullbacks/common.jl @@ -14,6 +14,15 @@ _needs_tangent(::Type{<:Integer}) = false _needs_tangent(::Type{<:Union{One, Zero}}) = false _needs_tangent(::Type{Complex{T}}) where {T} = _needs_tangent(T) +""" + project_scalar(x::Number, dx::Number) + +Project a computed tangent `dx` onto the correct tangent type for `x`. +For example, we might compute a complex `dx` but only require the real part. +""" +project_scalar(x::Number, dx::Number) = oftype(x, dx) +project_scalar(x::Real, dx::Complex) = project_scalar(x, real(dx)) + # (partial) pullbacks that are shared @doc """ pullback_dC(ΔC, β) @@ -31,4 +40,4 @@ pullback_dC(ΔC, β) = scale(ΔC, conj(β)) For functions of the form `f!(C, β, ...) = βC + ...`, compute the pullback with respect to `β`. """ pullback_dβ -pullback_dβ(ΔC, C, β) = _needs_tangent(β) ? inner(C, ΔC) : nothing +pullback_dβ(ΔC, C, β) = _needs_tangent(β) ? project_scalar(β, inner(C, ΔC)) : nothing diff --git a/src/pullbacks/contract.jl b/src/pullbacks/contract.jl index f8f4659..b22a232 100644 --- a/src/pullbacks/contract.jl +++ b/src/pullbacks/contract.jl @@ -122,5 +122,5 @@ function tensorcontract_pullback_dα( ) _needs_tangent(α) || return nothing C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) - return inner(C_αβ, ΔC) + return project_scalar(α, inner(C_αβ, ΔC)) end diff --git a/src/pullbacks/trace.jl b/src/pullbacks/trace.jl index f3135a2..1351d10 100644 --- a/src/pullbacks/trace.jl +++ b/src/pullbacks/trace.jl @@ -88,7 +88,7 @@ function tensortrace_pullback_dα( ) _needs_tangent(α) || return nothing C_αβ = tensortrace(A, p, q, false, One(), ba...) - return tensorscalar( + Δα = tensorscalar( tensorcontract( C_αβ, trivialpermutation(0, numind(p)), !conjA, @@ -96,4 +96,5 @@ function tensortrace_pullback_dα( ((), ()), One(), ba... ) ) + return project_scalar(α, Δα) end diff --git a/test/enzyme.jl b/test/enzyme.jl index e31926f..5a4659f 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -1,5 +1,9 @@ using TensorOperations, VectorInterface -using Enzyme, ChainRulesCore, EnzymeTestUtils +using Enzyme, EnzymeTestUtils + +# the full testsuite is really intensive +# let's not run all of it on GitHub +is_ci = get(ENV, "CI", "false") == "true" @testset "tensorcontract!" begin pAB = ((3, 2, 4, 1), ()) @@ -13,25 +17,45 @@ using Enzyme, ChainRulesCore, EnzymeTestUtils (ComplexF64, Float64), ) T = promote_type(T₁, T₂) - atol = max(precision(T₁), precision(T₂)) - rtol = max(precision(T₁), precision(T₂)) A = rand(T₁, (2, 3, 4, 2, 5)) B = rand(T₂, (4, 2, 3)) C = rand(T, (5, 2, 3, 3)) + + atol = length(C) * max(precision(T₁), precision(T₂)) + rtol = length(C) * max(precision(T₁), precision(T₂)) zero_αβs = ((Zero(), Zero()), (randn(T), Zero()), (Zero(), randn(T))) αβs = (T == T₁ == T₂ == Float64) ? vcat(zero_αβs..., (randn(T), randn(T))) : ((randn(T), randn(T)),) # test zeros only once to avoid wasteful tests @testset for (α, β) in αβs - Tα = α === Zero() ? Const : Active - Tβ = β === Zero() ? Const : Active - test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (false, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol) - test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol) - test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol) - - test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (false, Const), (pAB, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol) - test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + Tαs = if α === Zero() + (Const,) + elseif !is_ci + (Active, Const) + else + (Active,) + end + Tβs = if β === Zero() + (Const,) + elseif !is_ci + (Active, Const) + else + (Active,) + end + for (Tα, Tβ) in zip(Tαs, Tβs) + test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (false, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol) + test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol) + test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol) + test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (false, Const), (pAB, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol) + test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + if !(T <: Real) && !(α === Zero()) && !(β === Zero()) + test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (real(α), Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (false, Const), (pAB, Const), (α, Tα), (real(β), Tβ), (StridedNative(), Const); atol, rtol) + test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (real(β), Tβ), (StridedNative(), Const); atol, rtol) + test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (real(β), Tβ), (StridedNative(), Const); atol, rtol) + end + end end end end @@ -54,13 +78,31 @@ end αβs = (T == T₁ == T₂ == Float64) ? vcat(zero_αβs..., (randn(T), randn(T))) : ((randn(T), randn(T)),) # test zeros only once to avoid wasteful tests @testset for (α, β) in αβs - Tα = α === Zero() ? Const : Active - Tβ = β === Zero() ? Const : Active - test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (α, Tα), (β, Tβ); atol, rtol) - test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (α, Tα), (β, Tβ); atol, rtol) + Tαs = if α === Zero() + (Const,) + elseif !is_ci + (Active, Const) + else + (Active,) + end + Tβs = if β === Zero() + (Const,) + elseif !is_ci + (Active, Const) + else + (Active,) + end + for (Tα, Tβ) in zip(Tαs, Tβs) + test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (α, Tα), (β, Tβ); atol, rtol) + test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (α, Tα), (β, Tβ); atol, rtol) - test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol) - test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol) + test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + if !(T <: Real) && !(α === Zero()) && !(β === Zero()) + test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (real(α), Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (α, Tα), (real(β), Tβ), (StridedNative(), Const); atol, rtol) + end + end end end end @@ -85,13 +127,31 @@ end αβs = (T == T₁ == T₂ == Float64) ? vcat(zero_αβs..., (randn(T), randn(T))) : ((randn(T), randn(T)),) # test zeros only once to avoid wasteful tests @testset for (α, β) in αβs - Tα = α === Zero() ? Const : Active - Tβ = β === Zero() ? Const : Active - test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (false, Const), (α, Tα), (β, Tβ); atol, rtol) - test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (true, Const), (α, Tα), (β, Tβ); atol, rtol) + Tαs = if α === Zero() + (Const,) + elseif !is_ci + (Active, Const) + else + (Active,) + end + Tβs = if β === Zero() + (Const,) + elseif !is_ci + (Active, Const) + else + (Active,) + end + for (Tα, Tβ) in zip(Tαs, Tβs) + test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (false, Const), (α, Tα), (β, Tβ); atol, rtol) + test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (true, Const), (α, Tα), (β, Tβ); atol, rtol) - test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (true, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol) - test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (false, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (true, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol) + test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (false, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + if !(T <: Real) && !(α === Zero()) && !(β === Zero()) + test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (true, Const), (real(α), Tα), (β, Tβ), (StridedNative(), Const); atol, rtol) + test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (true, Const), (α, Tα), (real(β), Tβ), (StridedNative(), Const); atol, rtol) + end + end end end end