-
Notifications
You must be signed in to change notification settings - Fork 57
More tweaks #375
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
More tweaks #375
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,9 @@ const AdjointCuTensorMap{T, S, N₁, N₂} = AdjointTensorMap{T, S, N₁, N₂, | |
| function CuTensorMap(t::TensorMap{T, S, N₁, N₂, A}) where {T, S, N₁, N₂, A} | ||
| return CuTensorMap{T, S, N₁, N₂}(CuArray{T}(t.data), space(t)) | ||
| end | ||
| function TensorMap{T, S, N₁, N₂, DA}(t::TensorMap{T, S, N₁, N₂, HA}) where {T, S, N₁, N₂, DA <: CuArray{T}, HA <: Array{T}} | ||
| return CuTensorMap{T, S, N₁, N₂}(CuArray{T}(t.data), space(t)) | ||
| end | ||
|
|
||
| # project_symmetric! doesn't yet work for GPU types, so do this on the host, then copy | ||
| function TensorKit.project_symmetric_and_check(::Type{T}, ::Type{A}, data::AbstractArray, V::TensorMapSpace; tol = sqrt(eps(real(float(eltype(data)))))) where {T, A <: CuVector{T}} | ||
|
|
@@ -17,6 +20,10 @@ function TensorKit.project_symmetric_and_check(::Type{T}, ::Type{A}, data::Abstr | |
| return TensorKit.TensorMapWithStorage{T, A}(A(h_t.data), V) | ||
| end | ||
|
|
||
| function TensorKit.blocktype(::Type{<:CuTensorMap{T, S}}) where {T, S} | ||
| return CuMatrix{T, CUDA.DeviceMemory} | ||
| end | ||
|
|
||
| for (fname, felt) in ((:zeros, :zero), (:ones, :one)) | ||
| @eval begin | ||
| function CUDA.$fname( | ||
|
|
@@ -101,18 +108,6 @@ function TensorKit.scalar(t::CuTensorMap{T, S, 0, 0}) where {T, S} | |
| return isempty(inds) ? zero(scalartype(t)) : @allowscalar @inbounds t.data[only(inds)] | ||
| end | ||
|
|
||
| function Base.convert( | ||
| TT::Type{CuTensorMap{T, S, N₁, N₂}}, | ||
| t::AbstractTensorMap{<:Any, S, N₁, N₂} | ||
| ) where {T, S, N₁, N₂} | ||
| if typeof(t) === TT | ||
| return t | ||
| else | ||
| tnew = TT(undef, space(t)) | ||
| return copy!(tnew, t) | ||
| end | ||
| end | ||
|
|
||
| function LinearAlgebra.isposdef(t::CuTensorMap) | ||
| domain(t) == codomain(t) || | ||
| throw(SpaceMismatch("`isposdef` requires domain and codomain to be the same")) | ||
|
|
@@ -138,10 +133,9 @@ function Base.promote_rule( | |
| return CuTensorMap{T, S, N₁, N₂} | ||
| end | ||
|
|
||
| TensorKit.promote_storage_rule(::Type{CuArray{T, N}}, ::Type{<:CuArray{T, N}}) where {T, N} = | ||
| TensorKit.promote_storage_rule(::Type{<:CuArray{T, N}}, ::Type{<:CuArray{T, N}}) where {T, N} = | ||
| CuArray{T, N, CUDA.default_memory} | ||
|
|
||
|
|
||
| # CuTensorMap exponentation: | ||
| function TensorKit.exp!(t::CuTensorMap) | ||
| domain(t) == codomain(t) || | ||
|
|
@@ -168,3 +162,45 @@ for f in (:sqrt, :log, :asin, :acos, :acosh, :atanh, :acoth) | |
| return tf | ||
| end | ||
| end | ||
|
|
||
| function TensorKit._add_general_kernel_nonthreaded!( | ||
| tdst::CuTensorMap, tsrc::CuTensorMap, p, transformer::TensorKit.GenericTreeTransformer, α, β, backend... | ||
| ) | ||
| # preallocate buffers | ||
| buffers = TensorKit.allocate_buffers(tdst, tsrc, transformer) | ||
|
|
||
| for subtransformer in transformer.data | ||
| # Special case without intermediate buffers whenever there is only a single block | ||
| if length(subtransformer[1]) == 1 | ||
| TensorKit._add_transform_single!(tdst, tsrc, p, subtransformer, α, β, backend...) | ||
| else | ||
| cu_subtransformer = tuple(CUDA.adapt(CuArray, subtransformer[1]), subtransformer[2:end]...) | ||
| TensorKit._add_transform_multi!(tdst, tsrc, p, cu_subtransformer, buffers, α, β, backend...) | ||
| end | ||
| end | ||
| return nothing | ||
| end | ||
|
Comment on lines
+166
to
+182
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess the only change here is to promote the unitary basis transformation into a CuArray, which probably makes more sense to just support at the
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's remove it and see! 😈 |
||
|
|
||
| function TensorKit.allocate_buffers( | ||
| tdst::CuTensorMap, tsrc::CuTensorMap, transformer::TensorKit.GenericTreeTransformer | ||
| ) | ||
| sz = TensorKit.buffersize(transformer) | ||
| # force zeros to ensure the buffers are empty | ||
| # otherwise memory re-use can fill them with garbage data | ||
| return CUDA.zeros(eltype(tdst.data), sz), CUDA.zeros(eltype(tsrc.data), sz) | ||
| end | ||
|
Comment on lines
+184
to
+191
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is slightly confusing to me, the |
||
|
|
||
| function LinearAlgebra.mul!( | ||
| tC::CuTensorMap, tA::TensorKit.BraidingTensor, tB::CuTensorMap, α = true, β = false | ||
| ) | ||
| return mul!(tC, CUDA.adapt(CuArray, TensorMap(tA)), tB, α, β) | ||
| end | ||
|
|
||
| function LinearAlgebra.mul!( | ||
| tC::CuTensorMap, tA::CuTensorMap, tB::TensorKit.BraidingTensor, α = true, β = false | ||
| ) | ||
| return mul!(tC, tA, CUDA.adapt(CuArray, TensorMap(tB)), α, β) | ||
| end | ||
|
|
||
| @inline TensorKit.promote_storagetype(::Type{T}, A::CuTensorMap, B::TensorKit.BraidingTensor) where {T <: Number} = similarstoragetype(A, T) | ||
| @inline TensorKit.promote_storagetype(::Type{T}, A::TensorKit.BraidingTensor, B::CuTensorMap) where {T <: Number} = similarstoragetype(B, T) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -53,9 +53,11 @@ storagetype(t) = storagetype(typeof(t)) | |
| function storagetype(::Type{T}) where {T <: AbstractTensorMap} | ||
| if T isa Union | ||
| # attempt to be slightly more specific by promoting unions | ||
| Ma = storagetype(T.a) | ||
| Mb = storagetype(T.b) | ||
| return promote_storagetype(Ma, Mb) | ||
| return promote_storagetype(T.a, T.b) | ||
| elseif eltype(T) isa Union | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this to better support
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it's for the block case. I don't think we can have scalar unions?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a bit weird to support that here, since for generic |
||
| # attempt to be slightly more specific by promoting unions | ||
| TU = eltype(T) | ||
| return promote_storagetype(TU.a, TU.b) | ||
| else | ||
| # fallback definition by using scalartype | ||
| return similarstoragetype(scalartype(T)) | ||
|
|
@@ -103,11 +105,19 @@ similarstoragetype(X::Type, ::Type{T}) where {T <: Number} = | |
|
|
||
| # implement on tensors | ||
| similarstoragetype(::Type{TT}) where {TT <: AbstractTensorMap} = similarstoragetype(storagetype(TT)) | ||
| similarstoragetype(::Type{TT}, ::Type{T}) where {TT <: AbstractTensorMap, T <: Number} = | ||
| similarstoragetype(storagetype(TT), T) | ||
| function similarstoragetype(::Type{TT}, ::Type{T}) where {TT <: AbstractTensorMap, T <: Number} | ||
| return similarstoragetype(storagetype(TT), T) | ||
| end | ||
|
Comment on lines
+108
to
+110
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is just a formatting change right? |
||
| function similarstoragetype(::Type{<:AbstractTensorMap{T, S, N₁, N₂}}, ::Type{TA}) where {T <: Number, TA <: DenseVector, S, N₁, N₂} | ||
| return similarstoragetype(TA, T) | ||
| end | ||
| function similarstoragetype(t::AbstractTensorMap{T, S, N₁, N₂}, ::Type{TA}) where {T <: Number, TA <: DenseVector, S, N₁, N₂} | ||
| return similarstoragetype(typeof(t), TA) | ||
| end | ||
|
|
||
| # implement on arrays | ||
| similarstoragetype(::Type{A}) where {A <: DenseVector{<:Number}} = A | ||
| similarstoragetype(::Type{A}, ::Type{A}) where {A <: DenseVector{<:Number}} = A | ||
| Base.@assume_effects :foldable similarstoragetype(::Type{A}) where {A <: AbstractArray{<:Number}} = | ||
| Core.Compiler.return_type(similar, Tuple{A, Int}) | ||
| Base.@assume_effects :foldable similarstoragetype(::Type{A}, ::Type{T}) where {A <: AbstractArray, T <: Number} = | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -376,9 +376,10 @@ function blas_contract!( | |
| twistB = false | ||
| end | ||
|
|
||
| TTC = storagetype(C) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess this effectively means that we are deciding to promote inputs to the storagetype of the output. I'm not sure if I am fully convinced that we should solve this automatically at all, since I think that is also inconsistent with how regular matrices work (same for adding): julia> CUDA.rand(2, 2) * rand(Float32, 2, 2)
ERROR: Scalar indexing is disallowed.I do think that this might be the right approach, and requiring explicit conversions in the cases of mixed inputs seems like the right call to me. (Even though I can see how that is annoying for MPSKit 😉 ) |
||
| # Bring A in the correct form for BLAS contraction | ||
| if copyA | ||
| Anew = TO.tensoralloc_add(TC, A, pA, false, Val(true), allocator) | ||
| Anew = TO.tensoralloc_add(TTC, A, pA, false, Val(true), allocator) | ||
| Anew = TO.tensoradd!(Anew, A, pA, false, One(), Zero(), backend, allocator) | ||
| twistA && twist!(Anew, filter(!isdual ∘ Base.Fix1(space, Anew), domainind(Anew))) | ||
| else | ||
|
|
@@ -388,7 +389,7 @@ function blas_contract!( | |
|
|
||
| # Bring B in the correct form for BLAS contraction | ||
| if copyB | ||
| Bnew = TO.tensoralloc_add(TC, B, pB, false, Val(true), allocator) | ||
| Bnew = TO.tensoralloc_add(TTC, B, pB, false, Val(true), allocator) | ||
| Bnew = TO.tensoradd!(Bnew, B, pB, false, One(), Zero(), backend, allocator) | ||
| twistB && twist!(Bnew, filter(isdual ∘ Base.Fix1(space, Bnew), codomainind(Bnew))) | ||
| else | ||
|
|
@@ -401,7 +402,7 @@ function blas_contract!( | |
| copyC = !TO.isblasdestination(C, ipAB) | ||
|
|
||
| if copyC | ||
| Cnew = TO.tensoralloc_add(TC, C, ipAB, false, Val(true), allocator) | ||
| Cnew = TO.tensoralloc_add(TTC, C, ipAB, false, Val(true), allocator) | ||
| mul!(Cnew, Anew, Bnew) | ||
| TO.tensoradd!(C, Cnew, pAB, false, α, β, backend, allocator) | ||
| TO.tensorfree!(Cnew, allocator) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is now more properly addressed through type inference.