Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ext/TensorKitCUDAExt/TensorKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using TensorKit.Factorizations
using TensorKit.Strided
using TensorKit.Factorizations: AbstractAlgorithm
using TensorKit: SectorDict, tensormaptype, scalar, similarstoragetype, AdjointTensorMap, scalartype, project_symmetric_and_check
import TensorKit: randisometry, rand, randn
import TensorKit: randisometry, rand, randn, similarmatrixtype

using TensorKit: MatrixAlgebraKit

Expand All @@ -19,4 +19,6 @@ using Random
include("cutensormap.jl")
include("truncation.jl")

TensorKit.similarmatrixtype(::Type{A}) where {T <: Number, M, A <: CuVector{T, M}} = CuMatrix{T, M}

end
13 changes: 13 additions & 0 deletions src/tensors/abstracttensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,19 @@ similarstoragetype(::Type{D}, ::Type{T}) where {D <: AbstractDict{<:Sector, <:Ab
# default storage type for numbers
similarstoragetype(::Type{T}) where {T <: Number} = Vector{T}

@doc """
similarmatrixtype(T::Type{<:Number}) -> Matrix{T}
similarmatrixtype(A::Type{T, <:DenseVector{T}}) -> Matrix{T}

For a given dense vector type `A` or number type `T`, compute an appropriate
**matrix** storage type for tensors. This function is used internally for
[`BraidingTensor`](@ref) to determine the output storage format for indexing
and other operations with other tensor types.
""" similarmatrixtype

similarmatrixtype(::Type{T}) where {T <: Number} = Matrix{T}
similarmatrixtype(::Type{A}) where {T <: Number, A <: DenseVector{T}} = Matrix{T}

@doc """
promote_storagetype([T], A, B, C...)
promote_storagetype([T], TA, TB, TC...)
Expand Down
68 changes: 46 additions & 22 deletions src/tensors/braidingtensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,59 +2,80 @@
# special (2,2) tensor that implements a standard braiding operation
#====================================================================#
"""
struct BraidingTensor{T,S<:IndexSpace} <: AbstractTensorMap{T, S, 2, 2}
BraidingTensor(V1::S, V2::S, adjoint::Bool=false) where {S<:IndexSpace}
struct BraidingTensor{T,S<:IndexSpace,A<:DenseVector{T}} <: AbstractTensorMap{T, S, 2, 2}
BraidingTensor(V1::S, V2::S, ::Type{A}, adjoint::Bool=false) where {S<:IndexSpace, A <: DenseVector{<:Number}}

Specific subtype of [`AbstractTensorMap`](@ref) for representing the braiding tensor that
braids the first input over the second input; its inverse can be obtained as the adjoint.

It holds that `domain(BraidingTensor(V1, V2)) == V1 ⊗ V2` and
`codomain(BraidingTensor(V1, V2)) == V2 ⊗ V1`.
`codomain(BraidingTensor(V1, V2)) == V2 ⊗ V1`. The storage type `TA`
controls the array type of the braiding tensor used when indexing
and multiplying with other tensors.
"""
struct BraidingTensor{T, S} <: AbstractTensorMap{T, S, 2, 2}
struct BraidingTensor{T, S, A} <: AbstractTensorMap{T, S, 2, 2}
V1::S
V2::S
adjoint::Bool
function BraidingTensor{T, S}(V1::S, V2::S, adjoint::Bool = false) where {T, S <: IndexSpace}
for a in sectors(V1)
for b in sectors(V2)
for c in (a ⊗ b)
Nsymbol(a, b, c) == Nsymbol(b, a, c) ||
throw(ArgumentError("Cannot define a braiding between $a and $b"))
end
end
function BraidingTensor{T, S}(V1::S, V2::S, ::Type{A}, adjoint::Bool = false) where {T, S <: IndexSpace, A <: DenseVector{T}}
for a in sectors(V1), b in sectors(V2), c in (a ⊗ b)
Nsymbol(a, b, c) == Nsymbol(b, a, c) ||
throw(ArgumentError("Cannot define a braiding between $a and $b"))
end
return new{T, S}(V1, V2, adjoint)
return new{T, S, A}(V1, V2, adjoint)
# partial construction: only construct rowr and colr when needed
end
end
function BraidingTensor{T}(V1::S, V2::S, A, adjoint::Bool = false) where {T, S <: IndexSpace}
return BraidingTensor{T, S}(V1, V2, A, adjoint)
end
function BraidingTensor{T}(V1::S, V2::S, adjoint::Bool = false) where {T, S <: IndexSpace}
return BraidingTensor{T, S}(V1, V2, adjoint)
return BraidingTensor{T, S}(V1, V2, Vector{T}, adjoint)
end
function BraidingTensor{T}(V1::IndexSpace, V2::IndexSpace, A, adjoint::Bool = false) where {T}
return BraidingTensor{T}(promote(V1, V2)..., A, adjoint)
end
function BraidingTensor{T}(V1::IndexSpace, V2::IndexSpace, adjoint::Bool = false) where {T}
return BraidingTensor{T}(promote(V1, V2)..., adjoint)
return BraidingTensor{T}(V1, V2, Vector{T}, adjoint)
end
function BraidingTensor(V1::IndexSpace, V2::IndexSpace, ::Type{A}, adjoint::Bool = false) where {T, A <: DenseVector{T}}
return BraidingTensor{T}(promote(V1, V2)..., A, adjoint)
end
function BraidingTensor(V1::IndexSpace, V2::IndexSpace, ::Type{T}, adjoint::Bool = false) where {T}
return BraidingTensor{T}(promote(V1, V2)..., Vector{T}, adjoint)
end
function BraidingTensor(V1::IndexSpace, V2::IndexSpace, adjoint::Bool = false)
return BraidingTensor(promote(V1, V2)..., adjoint)
end
function BraidingTensor(V1::S, V2::S, adjoint::Bool = false) where {S <: IndexSpace}
T = BraidingStyle(sectortype(S)) isa SymmetricBraiding ? Float64 : ComplexF64
return BraidingTensor{T, S}(V1, V2, adjoint)
return BraidingTensor{T, S}(V1, V2, Vector{T}, adjoint)
end
function BraidingTensor(V1::S, V2::S, ::Type{A}, adjoint::Bool = false) where {S <: IndexSpace, A <: AbstractArray}
T = BraidingStyle(sectortype(S)) isa SymmetricBraiding ? Float64 : ComplexF64
A′ = similarstoragetype(A, T)
return BraidingTensor{T, S}(V1, V2, A′, adjoint)
end
function BraidingTensor(V::HomSpace, adjoint::Bool = false)
domain(V) == reverse(codomain(V)) ||
throw(SpaceMismatch("Cannot define a braiding on $V"))
return BraidingTensor(V[2], V[1], adjoint)
end
function BraidingTensor(V::HomSpace, ::Type{A}, adjoint::Bool = false) where {A}
domain(V) == reverse(codomain(V)) ||
throw(SpaceMismatch("Cannot define a braiding on $V"))
return BraidingTensor(V[2], V[1], A, adjoint)
end
function BraidingTensor{T}(V::HomSpace, adjoint::Bool = false) where {T}
domain(V) == reverse(codomain(V)) ||
throw(SpaceMismatch("Cannot define a braiding on $V"))
return BraidingTensor{T}(V[2], V[1], adjoint)
end
function Base.adjoint(b::BraidingTensor{T, S}) where {T, S}
return BraidingTensor{T, S}(b.V1, b.V2, !b.adjoint)
function Base.adjoint(b::BraidingTensor{T, S, A}) where {T, S, A}
return BraidingTensor{T, S, A}(b.V1, b.V2, !b.adjoint)
end

storagetype(b::BraidingTensor{T, S, A}) where {T, S, A} = A
space(b::BraidingTensor) = b.adjoint ? b.V1 ⊗ b.V2 ← b.V2 ⊗ b.V1 : b.V2 ⊗ b.V1 ← b.V1 ⊗ b.V2

# specializations to ignore the storagetype of BraidingTensor
Expand Down Expand Up @@ -115,7 +136,8 @@ end
d = (dims(codomain(b), f₁.uncoupled)..., dims(domain(b), f₂.uncoupled)...)
n1 = d[1] * d[2]
n2 = d[3] * d[4]
data = sreshape(StridedView(Matrix{eltype(b)}(undef, n1, n2)), d)
data_t = similarmatrixtype(storagetype(b))(undef, (n1, n2))
data = sreshape(StridedView(data_t), d)
fill!(data, zero(eltype(b)))

r = _braiding_factor(f₁, f₂, b.adjoint)
Expand All @@ -134,8 +156,10 @@ TensorMap(b::BraidingTensor) = copy!(similar(b), b)
Base.convert(::Type{TensorMap}, b::BraidingTensor) = TensorMap(b)

Base.complex(b::BraidingTensor{<:Complex}) = b
function Base.complex(b::BraidingTensor)
return BraidingTensor{complex(scalartype(b))}(space(b), b.adjoint)
function Base.complex(b::BraidingTensor{T, S, A}) where {T, S, A}
Tc = complex(T)
Ac = similarstoragetype(Tc, A)
return BraidingTensor{Tc, S, Ac}(space(b), b.adjoint)
end

function block(b::BraidingTensor, s::Sector)
Expand All @@ -145,7 +169,7 @@ function block(b::BraidingTensor, s::Sector)
# TODO: probably always square?
m = blockdim(codomain(b), s)
n = blockdim(domain(b), s)
data = Matrix{eltype(b)}(undef, (m, n))
data = similarmatrixtype(storagetype(b))(undef, (m, n))

length(data) == 0 && return data # s ∉ blocksectors(b)

Expand Down
Loading
Loading