Skip to content

Commit 37e70eb

Browse files
committed
fix type ambiguity for counting tropical
1 parent e573abe commit 37e70eb

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

src/cuda.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using .CUDA
2-
using TropicalGEMM: XTranspose, NativeTypes, Tropical
2+
using TropicalGEMM: XTranspose, NativeTypes, Tropical, TropicalTypes
33
using LinearAlgebra
44

55
function onehotmask(A::CuArray{T}, X::CuArray{T}) where T
@@ -12,13 +12,14 @@ end
1212

1313
# fix the matrix multiplication ambiguity
1414
const CTranspose{T} = Transpose{T, <:StridedCuVecOrMat{T}}
15-
for RT in [:Tropical, :Real]
16-
for (TA, CTA) in [(:CuMatrix, :CuMatrix), (:CTranspose, :(Transpose{<:Any, <:StridedCuVecOrMat}))]
17-
for (TB, CTB) in [(:CuMatrix, :CuMatrix), (:CTranspose, :(Transpose{<:Any, <:StridedCuVecOrMat}))]
18-
@eval function LinearAlgebra.mul!(o::CuMatrix{T}, a::$TA{T}, b::$TB{T}, α::$RT, β::$RT) where {T<:Tropical{<:NativeTypes}}
19-
#invoke(LinearAlgebra.mul!, Tuple{CuMatrix, $CTA, $CTB, Number, Number}, o, a, b, α, β)
20-
CUDA.CUBLAS.gemm_dispatch!(o, a, b, α, β)
21-
end
22-
end
23-
end
24-
end
15+
for TT in [:(Tropical{<:NativeTypes}), :TropicalTypes]
16+
for RT in [TT, :Real]
17+
for (TA, CTA) in [(:CuMatrix, :CuMatrix), (:CTranspose, :(Transpose{<:Any, <:StridedCuVecOrMat}))]
18+
for (TB, CTB) in [(:CuMatrix, :CuMatrix), (:CTranspose, :(Transpose{<:Any, <:StridedCuVecOrMat}))]
19+
@eval function LinearAlgebra.mul!(o::CuMatrix{T}, a::$TA{T}, b::$TB{T}, α::$RT, β::$RT) where {T<:$TT}
20+
CUDA.CUBLAS.gemm_dispatch!(o, a, b, α, β)
21+
end
22+
end
23+
end
24+
end
25+
end

0 commit comments

Comments
 (0)