Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
TensorOperationsBumperExt = "Bumper"
TensorOperationsChainRulesCoreExt = "ChainRulesCore"
TensorOperationsMooncakeExt = "Mooncake"
TensorOperationsEnzymeExt = ["Enzyme", "ChainRulesCore"]
TensorOperationsEnzymeExt = "Enzyme"
TensorOperationscuTENSORExt = ["cuTENSOR", "CUDA"]

[compat]
Expand Down
127 changes: 108 additions & 19 deletions ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,44 @@ using TensorOperations
using TensorOperations: AbstractBackend, DefaultAllocator, CUDAAllocator, ManualAllocator
using VectorInterface
using TupleTools
using Enzyme, ChainRulesCore
using Enzyme
using Enzyme.EnzymeCore
using Enzyme.EnzymeCore: EnzymeRules

@inline EnzymeRules.inactive(::typeof(TensorOperations.tensorfree!), ::Any) = true
Enzyme.@import_rrule(typeof(TensorOperations.tensoralloc), Any, Any, Any, Any)

@inline EnzymeRules.inactive_type(v::Type{<:AbstractBackend}) = true
@inline EnzymeRules.inactive_type(v::Type{DefaultAllocator}) = true
@inline EnzymeRules.inactive_type(v::Type{<:CUDAAllocator}) = true
@inline EnzymeRules.inactive_type(v::Type{ManualAllocator}) = true
@inline EnzymeRules.inactive_type(v::Type{<:Index2Tuple}) = true

function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(TensorOperations.tensoralloc)},
::Type{RT},
ttype::Const,
structure::Const,
istemp::Const{Bool},
allocator::Const
) where {RT}
primal = EnzymeRules.needs_primal(config) ? TensorOperations.tensoralloc(ttype.val, structure.val, Val(false), allocator.val) : nothing
shadow = EnzymeRules.needs_shadow(config) ? TensorOperations.tensoralloc(ttype.val, structure.val, Val(false), allocator.val) : nothing
return EnzymeRules.AugmentedReturn(primal, shadow, nothing)
end

function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(TensorOperations.tensoralloc)},
::Type{RT},
cache,
ttype::Const,
structure::Const,
istemp::Const{Bool},
allocator::Const,
) where {RT}
return nothing, nothing, nothing, nothing
end

function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(TensorOperations.tensorcontract!)},
Expand All @@ -36,7 +61,7 @@ function EnzymeRules.augmented_primal(
# form caches if needed
cache_A = EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing
cache_B = EnzymeRules.overwritten(config)[6] ? copy(B_dB.val) : nothing
cache_C = !iszero(β_dβ.val) ? copy(C_dC.val) : C_dC.val
cache_C = !isa(β_dβ, Const) ? copy(C_dC.val) : C_dC.val
ba = map(ba_ -> getfield(ba_, :val), ba_dba)
TensorOperations.tensorcontract!(C_dC.val, A_dA.val, pA_dpA.val, conjA_dconjA.val, B_dB.val, pB_dpB.val, conjB_dconjB.val, pAB_dpAB.val, α_dα.val, β_dβ.val, ba...)
primal = EnzymeRules.needs_primal(config) ? C_dC.val : nothing
Expand Down Expand Up @@ -66,15 +91,39 @@ function EnzymeRules.reverse(
Bval = something(cache_B, B_dB.val)
Cval = cache_C
# good way to check that we don't use it accidentally when we should not be needing it?
dC = C_dC.dval
dA = A_dA.dval
dB = B_dB.dval
ba = map(ba_ -> getfield(ba_, :val), ba_dba)
α = α_dα.val
β = β_dβ.val
pA, pB, pAB, conjA, conjB = getfield.((pA_dpA, pB_dpB, pAB_dpAB, conjA_dconjA, conjB_dconjB), :val)
dC, dA, dB, dα, dβ = TensorOperations.tensorcontract_pullback!(dC, dA, dB, Cval, Aval, pA, conjA, Bval, pB, conjB, pAB, α, β, ba...)
return nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)...

if !isa(A_dA, Const) && !isa(C_dC, Const)
ΔC = C_dC.dval
ΔA = A_dA.dval
TensorOperations.tensorcontract_pullback_dA!(ΔA, ΔC, Cval, Aval, pA, conjA, Bval, pB, conjB, pAB, α, ba...)
end
if !isa(B_dB, Const) && !isa(C_dC, Const)
ΔC = C_dC.dval
ΔB = B_dB.dval
TensorOperations.tensorcontract_pullback_dB!(ΔB, ΔC, Cval, Aval, pA, conjA, Bval, pB, conjB, pAB, α, ba...)
end
Δα = if !isa(α_dα, Const) && !isa(C_dC, Const)
ΔC = C_dC.dval
TensorOperations.tensorcontract_pullback_dα(ΔC, Cval, Aval, pA, conjA, Bval, pB, conjB, pAB, α, ba...)
elseif !isa(α_dα, Const)
zero(α_dα.val)
else
nothing
end
Δβ = if !isa(β_dβ, Const) && !isa(C_dC, Const)
ΔC = C_dC.dval
TensorOperations.pullback_dβ(ΔC, Cval, β)
elseif !isa(β_dβ, Const)
zero(β_dβ.val)
else
nothing
end
!isa(C_dC, Const) && TensorOperations.pullback_dC!(C_dC.dval, β)
return nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, Δα, Δβ, map(ba_ -> nothing, ba)...
end

function EnzymeRules.augmented_primal(
Expand Down Expand Up @@ -123,10 +172,30 @@ function EnzymeRules.reverse(
α = α_dα.val
β = β_dβ.val
ba = map(ba_ -> getfield(ba_, :val), ba_dba)
dC = C_dC.dval
dA = A_dA.dval
dC, dA, dα, dβ = TensorOperations.tensoradd_pullback!(dC, dA, Cval, Aval, pA, conjA, α, β, ba...)
return nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)...

if !isa(A_dA, Const) && !isa(C_dC, Const)
ΔC = C_dC.dval
ΔA = A_dA.dval
TensorOperations.tensoradd_pullback_dA!(ΔA, ΔC, Cval, Aval, pA, conjA, α, ba...)
end
Δα = if !isa(α_dα, Const) && !isa(C_dC, Const)
ΔC = C_dC.dval
TensorOperations.tensoradd_pullback_dα(ΔC, Cval, Aval, pA, conjA, α, ba...)
elseif !isa(α_dα, Const)
zero(α_dα.val)
else
nothing
end
Δβ = if !isa(β_dβ, Const) && !isa(C_dC, Const)
ΔC = C_dC.dval
TensorOperations.pullback_dβ(ΔC, Cval, β)
elseif !isa(β_dβ, Const)
zero(β_dβ.val)
else
nothing
end
!isa(C_dC, Const) && TensorOperations.pullback_dC!(C_dC.dval, β)
return nothing, nothing, nothing, nothing, Δα, Δβ, map(ba_ -> nothing, ba)...
end

function EnzymeRules.augmented_primal(
Expand All @@ -144,7 +213,7 @@ function EnzymeRules.augmented_primal(
) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number}
# form caches if needed
cache_A = EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing
cache_C = !iszero(β_dβ.val) ? copy(C_dC.val) : C_dC.val
cache_C = !isa(β_dβ, Const) ? copy(C_dC.val) : nothing
ba = map(ba_ -> getfield(ba_, :val), ba_dba)
α = α_dα.val
β = β_dβ.val
Expand All @@ -171,17 +240,37 @@ function EnzymeRules.reverse(
) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number}
cache_A, cache_C = cache
Aval = something(cache_A, A_dA.val)
Cval = cache_C
Cval = something(cache_C, C_dC.val)
p = p_dp.val
q = q_dq.val
conjA = conjA_dconjA.val
α = α_dα.val
β = β_dβ.val
ba = map(ba_ -> getfield(ba_, :val), ba_dba)
dC = C_dC.dval
dA = A_dA.dval
dC, dA, dα, dβ = TensorOperations.tensortrace_pullback!(dC, dA, Cval, Aval, p, q, conjA, α, β, ba...)
return nothing, nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)...

if !isa(A_dA, Const) && !isa(C_dC, Const)
ΔC = C_dC.dval
ΔA = A_dA.dval
TensorOperations.tensortrace_pullback_dA!(ΔA, ΔC, Cval, Aval, p, q, conjA, α, ba...)
end
Δα = if !isa(α_dα, Const) && !isa(C_dC, Const)
ΔC = C_dC.dval
TensorOperations.tensortrace_pullback_dα(ΔC, Cval, Aval, p, q, conjA, α, ba...)
elseif !isa(α_dα, Const)
zero(α_dα.val)
else
nothing
end
Δβ = if !isa(β_dβ, Const) && !isa(C_dC, Const)
ΔC = C_dC.dval
TensorOperations.pullback_dβ(ΔC, Cval, β)
elseif !isa(β_dβ, Const)
zero(β_dβ.val)
else
nothing
end
!isa(C_dC, Const) && TensorOperations.pullback_dC!(C_dC.dval, β)
return nothing, nothing, nothing, nothing, nothing, Δα, Δβ, map(ba_ -> nothing, ba)...
end

end
3 changes: 2 additions & 1 deletion src/pullbacks/add.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,12 @@ Compute the pullback for [`tensoradd!]`(ref) with respect to scaling coefficient
"""
function tensoradd_pullback_dα(ΔC, C, A, pA::Index2Tuple, conjA::Bool, α, ba...)
_needs_tangent(α) || return nothing
return tensorscalar(
Δα = tensorscalar(
tensorcontract(
A, repartition(pA, 0), !conjA,
ΔC, trivialpermutation(numind(pA), 0), false,
((), ()), One(), ba...
)
)
return project_scalar(α, Δα)
end
11 changes: 10 additions & 1 deletion src/pullbacks/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@ _needs_tangent(::Type{<:Integer}) = false
_needs_tangent(::Type{<:Union{One, Zero}}) = false
_needs_tangent(::Type{Complex{T}}) where {T} = _needs_tangent(T)

"""
project_scalar(x::Number, dx::Number)

Project a computed tangent `dx` onto the correct tangent type for `x`.
For example, we might compute a complex `dx` but only require the real part.
"""
project_scalar(x::Number, dx::Number) = oftype(x, dx)
project_scalar(x::Real, dx::Complex) = project_scalar(x, real(dx))

# (partial) pullbacks that are shared
@doc """
pullback_dC(ΔC, β)
Expand All @@ -31,4 +40,4 @@ pullback_dC(ΔC, β) = scale(ΔC, conj(β))
For functions of the form `f!(C, β, ...) = βC + ...`, compute the pullback with respect to `β`.
""" pullback_dβ

pullback_dβ(ΔC, C, β) = _needs_tangent(β) ? inner(C, ΔC) : nothing
pullback_dβ(ΔC, C, β) = _needs_tangent(β) ? project_scalar(β, inner(C, ΔC)) : nothing
2 changes: 1 addition & 1 deletion src/pullbacks/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,5 +122,5 @@ function tensorcontract_pullback_dα(
)
_needs_tangent(α) || return nothing
C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...)
return inner(C_αβ, ΔC)
return project_scalar(α, inner(C_αβ, ΔC))
end
3 changes: 2 additions & 1 deletion src/pullbacks/trace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,13 @@ function tensortrace_pullback_dα(
)
_needs_tangent(α) || return nothing
C_αβ = tensortrace(A, p, q, false, One(), ba...)
return tensorscalar(
Δα = tensorscalar(
tensorcontract(
C_αβ, trivialpermutation(0, numind(p)),
!conjA,
ΔC, trivialpermutation(numind(p), 0), false,
((), ()), One(), ba...
)
)
return project_scalar(α, Δα)
end
Loading
Loading