Mercurial > repos > public > sbplib_julia
diff src/LazyTensors/lazy_tensor_operations.jl @ 489:1a7d6da3cc45
Merge feature/compose_identity_mappings
author | Vidar Stiernström <vidar.stiernstrom@it.uu.se> |
---|---|
date | Thu, 05 Nov 2020 11:32:49 +0100 |
parents | 6a6b7eaf9edf |
children | df566372bb4f |
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl Mon Nov 02 21:33:35 2020 +0100 +++ b/src/LazyTensors/lazy_tensor_operations.jl Thu Nov 05 11:32:49 2020 +0100 @@ -86,12 +86,9 @@ t2::TM2 @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D} - @boundscheck if domain_size(t1) != range_size(t2) - throw(DimensionMismatch("the first argument has domain size $(domain_size(t1)) while the second has range size $(range_size(t2)) ")) - end + @boundscheck check_domain_size(t1, range_size(t2)) return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2) end - # Add check for matching sizes as a boundscheck end export TensorMappingComposition @@ -170,6 +167,27 @@ apply(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] +""" +Base.:∘(tm, tmi) +Base.:∘(tmi, tm) + +Composes a `Tensormapping` `tm` with an `IdentityMapping` `tmi`, by returning `tm` +""" +@inline function Base.:∘(tm::TensorMapping{T,R,D}, tmi::IdentityMapping{T,D}) where {T,R,D} + @boundscheck check_domain_size(tm, range_size(tmi)) + return tm +end + +@inline function Base.:∘(tmi::IdentityMapping{T,R}, tm::TensorMapping{T,R,D}) where {T,R,D} + @boundscheck check_domain_size(tmi, range_size(tm)) + return tm +end +# Specialization for the case where tm is an IdentityMapping. Required to resolve ambiguity. +@inline function Base.:∘(tm::IdentityMapping{T,D}, tmi::IdentityMapping{T,D}) where {T,D} + @boundscheck check_domain_size(tm, range_size(tmi)) + return tmi +end + """ InflatedTensorMapping{T,R,D} <: TensorMapping{T,R,D} @@ -279,3 +297,20 @@ flatten_tuple(t::NTuple{N, Number} where N) = t flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify? flatten_tuple(ts::Vararg) = flatten_tuple(ts) + +function check_domain_size(tm::TensorMapping, sz) + if domain_size(tm) != sz + throw(SizeMismatch(tm,sz)) + end +end + +struct SizeMismatch <: Exception + tm::TensorMapping + sz +end +export SizeMismatch + +function Base.showerror(io::IO, err::SizeMismatch) + print(io, "SizeMismatch: ") + print(io, "domain size $(domain_size(err.tm)) of TensorMapping not matching size $(err.sz)") +end