@@ -11,11 +11,14 @@ function onehotmask(A::CuArray{T}, X::CuArray{T}) where T
1111end
1212
1313# fix the matrix multiplication ambiguity
14- const CTranspose{T} = Transpose{T, <: StridedCuVecOrMat }
15- for (TA, CTA) in [(:AbstractMatrix , :CuMatrix ), (:XTranspose , :CTranspose )]
16- for (TB, CTB) in [(:AbstractMatrix , :CuMatrix ), (:XTranspose , :CTranspose )]
17- @eval function LinearAlgebra. mul! (o:: CuMatrix{T} , a:: $TA{T} , b:: $TB{T} , α:: Number , β:: Number ) where {T<: Tropical{<:NativeTypes} }
18- invoke (LinearAlgebra. mul!, Tuple{CuMatrix, $ CTA, $ CTB, Number, Number}, o, a, b, α, β)
14+ const CTranspose{T} = Transpose{T, <: StridedCuVecOrMat{T} }
15+ for RT in [:Tropical , :Real ]
16+ for (TA, CTA) in [(:CuMatrix , :CuMatrix ), (:CTranspose , :(Transpose{<: Any , <: StridedCuVecOrMat }))]
17+ for (TB, CTB) in [(:CuMatrix , :CuMatrix ), (:CTranspose , :(Transpose{<: Any , <: StridedCuVecOrMat }))]
18+ @eval function LinearAlgebra. mul! (o:: CuMatrix{T} , a:: $TA{T} , b:: $TB{T} , α:: $RT , β:: $RT ) where {T<: Tropical{<:NativeTypes} }
19+ # invoke(LinearAlgebra.mul!, Tuple{CuMatrix, $CTA, $CTB, Number, Number}, o, a, b, α, β)
20+ CUDA. CUBLAS. gemm_dispatch! (o, a, b, α, β)
21+ end
1922 end
2023 end
2124end
0 commit comments