11using . CUDA
2- using TropicalGEMM: XTranspose, NativeTypes, Tropical
2+ using TropicalGEMM: XTranspose, NativeTypes, Tropical, TropicalTypes
33using LinearAlgebra
44
55function onehotmask (A:: CuArray{T} , X:: CuArray{T} ) where T
1212
1313# fix the matrix multiplication ambiguity
1414const CTranspose{T} = Transpose{T, <: StridedCuVecOrMat{T} }
15- for RT in [:Tropical , :Real ]
16- for (TA, CTA) in [(:CuMatrix , :CuMatrix ), (:CTranspose , :(Transpose{<: Any , <: StridedCuVecOrMat }))]
17- for (TB, CTB) in [(:CuMatrix , :CuMatrix ), (:CTranspose , :(Transpose{<: Any , <: StridedCuVecOrMat }))]
18- @eval function LinearAlgebra. mul! (o:: CuMatrix{T} , a:: $TA{T} , b:: $TB{T} , α:: $RT , β:: $RT ) where {T<: Tropical{<:NativeTypes} }
19- # invoke(LinearAlgebra.mul!, Tuple{CuMatrix, $CTA, $CTB, Number, Number}, o, a, b, α, β)
20- CUDA. CUBLAS. gemm_dispatch! (o, a, b, α, β)
21- end
22- end
23- end
24- end
15+ for TT in [:(Tropical{<: NativeTypes }), :TropicalTypes ]
16+ for RT in [TT, :Real ]
17+ for (TA, CTA) in [(:CuMatrix , :CuMatrix ), (:CTranspose , :(Transpose{<: Any , <: StridedCuVecOrMat }))]
18+ for (TB, CTB) in [(:CuMatrix , :CuMatrix ), (:CTranspose , :(Transpose{<: Any , <: StridedCuVecOrMat }))]
19+ @eval function LinearAlgebra. mul! (o:: CuMatrix{T} , a:: $TA{T} , b:: $TB{T} , α:: $RT , β:: $RT ) where {T<: $TT }
20+ CUDA. CUBLAS. gemm_dispatch! (o, a, b, α, β)
21+ end
22+ end
23+ end
24+ end
25+ end
0 commit comments