Skip to content

Commit e573abe

Browse files
committed
fix cuda dispatch
1 parent 57621dc commit e573abe

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

src/cuda.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,14 @@ function onehotmask(A::CuArray{T}, X::CuArray{T}) where T
1111
end
1212

1313
# fix the matrix multiplication ambiguity
14-
const CTranspose{T} = Transpose{T, <:StridedCuVecOrMat}
15-
for (TA, CTA) in [(:AbstractMatrix, :CuMatrix), (:XTranspose, :CTranspose)]
16-
for (TB, CTB) in [(:AbstractMatrix, :CuMatrix), (:XTranspose, :CTranspose)]
17-
@eval function LinearAlgebra.mul!(o::CuMatrix{T}, a::$TA{T}, b::$TB{T}, α::Number, β::Number) where {T<:Tropical{<:NativeTypes}}
18-
invoke(LinearAlgebra.mul!, Tuple{CuMatrix, $CTA, $CTB, Number, Number}, o, a, b, α, β)
14+
const CTranspose{T} = Transpose{T, <:StridedCuVecOrMat{T}}
15+
for RT in [:Tropical, :Real]
16+
for (TA, CTA) in [(:CuMatrix, :CuMatrix), (:CTranspose, :(Transpose{<:Any, <:StridedCuVecOrMat}))]
17+
for (TB, CTB) in [(:CuMatrix, :CuMatrix), (:CTranspose, :(Transpose{<:Any, <:StridedCuVecOrMat}))]
18+
@eval function LinearAlgebra.mul!(o::CuMatrix{T}, a::$TA{T}, b::$TB{T}, α::$RT, β::$RT) where {T<:Tropical{<:NativeTypes}}
19+
#invoke(LinearAlgebra.mul!, Tuple{CuMatrix, $CTA, $CTB, Number, Number}, o, a, b, α, β)
20+
CUDA.CUBLAS.gemm_dispatch!(o, a, b, α, β)
21+
end
1922
end
2023
end
2124
end

0 commit comments

Comments
 (0)