Skip to content

Commit 57621dc

Browse files
committed
update cuda patch
1 parent 2774cce commit 57621dc

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

src/cuda.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,21 @@
11
using .CUDA
2+
using TropicalGEMM: XTranspose, NativeTypes, Tropical
3+
using LinearAlgebra
24

35
function onehotmask(A::CuArray{T}, X::CuArray{T}) where T
46
mask = X .== inv.(A)
57
ci = argmax(mask)
68
mask .= false
79
mask[CuArray([ci])] = true
810
return mask
11+
end
12+
13+
# fix the matrix multiplication ambiguity
14+
const CTranspose{T} = Transpose{T, <:StridedCuVecOrMat}
15+
for (TA, CTA) in [(:AbstractMatrix, :CuMatrix), (:XTranspose, :CTranspose)]
16+
for (TB, CTB) in [(:AbstractMatrix, :CuMatrix), (:XTranspose, :CTranspose)]
17+
@eval function LinearAlgebra.mul!(o::CuMatrix{T}, a::$TA{T}, b::$TB{T}, α::Number, β::Number) where {T<:Tropical{<:NativeTypes}}
18+
invoke(LinearAlgebra.mul!, Tuple{CuMatrix, $CTA, $CTB, Number, Number}, o, a, b, α, β)
19+
end
20+
end
921
end

src/networks.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ Optimize the contraction order.
109109
Check `optimize_kahypar` method in package `OMEinsumContractionOrders`.
110110
* `:auto`, also the kahypar + greedy approach, but determines `sc_target` automatically. It is slower!
111111
* `:greedy`, the greedy approach. Check `optimize_greedy` in package `OMEinsum`.
112+
* `:tree`, the approach of running simulated annealing on expression tree. Check `optimize_tree` in package `OMEinsumContractionOrders`.
112113
* `:sa`, the simulated annealing approach. Check `optimize_sa` in package `OMEinsumContractionOrders`.
113114
* `:raw`, do nothing and return the raw EinCode.
114115
"""
@@ -162,4 +163,4 @@ function set_packing(sets; kwargs...)
162163
n = length(sets)
163164
code = EinCode(([(i,) for i=1:n]..., [(i,j) for i=1:n,j=1:n if j>i && !isempty(sets[i] sets[j])]...), ())
164165
Independence(optimize_code(code; kwargs...))
165-
end
166+
end

0 commit comments

Comments
 (0)