Skip to content

Commit 2774cce

Browse files
committed
2 parents 6237178 + dd36155 commit 2774cce

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

src/arithematics.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ function Base.show(io::IO, ::MIME"text/plain", x::Max2Poly)
9999
end
100100
end
101101

102+
# patch for CUDA matmul
103+
Base.:*(a::Bool, y::Max2Poly{T,TO}) where {T,TO} = a ? y : zero(y)
104+
Base.:*(y::Max2Poly{T,TO}, a::Bool) where {T,TO} = a ? y : zero(y)
105+
102106
struct ConfigEnumerator{N,S,C}
103107
data::Vector{StaticElementVector{N,S,C}}
104108
end

src/graph_polynomials.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,17 @@ function graph_polynomial end
3030

3131
function graph_polynomial(gp::GraphProblem, ::Val{:fft}; usecuda=false,
3232
maxorder=max_size(gp; usecuda=usecuda), r=1.0)
33-
ω = exp(-2im*π/(maxorder+1))
34-
xs = r .* collect.^ (0:maxorder))
35-
ys = [contractx(gp, x; usecuda=usecuda) for x in xs]
36-
map(ci->Polynomial(ifft(getindex.(ys, Ref(ci))) ./ (r .^ (0:maxorder))), CartesianIndices(ys[1]))
33+
ω = exp(-2im*π/(maxorder+1))
34+
xs = r .* collect.^ (0:maxorder))
35+
ys = [Array(contractx(gp, x; usecuda=usecuda)) for x in xs]
36+
map(ci->Polynomial(ifft(getindex.(ys, Ref(ci))) ./ (r .^ (0:maxorder))), CartesianIndices(ys[1]))
3737
end
3838

3939
function graph_polynomial(gp::GraphProblem, ::Val{:fitting}; usecuda=false,
4040
maxorder = max_size(gp; usecuda=usecuda))
41-
xs = (0:maxorder)
42-
ys = [contractx(gp, x; usecuda=usecuda) for x in xs]
43-
map(ci->fit(xs, getindex.(ys, Ref(ci))), CartesianIndices(ys[1]))
41+
xs = (0:maxorder)
42+
ys = [Array(contractx(gp, x; usecuda=usecuda)) for x in xs]
43+
map(ci->fit(xs, getindex.(ys, Ref(ci))), CartesianIndices(ys[1]))
4444
end
4545

4646
function graph_polynomial(gp::GraphProblem, ::Val{:polynomial}; usecuda=false)

0 commit comments

Comments
 (0)