Skip to content

Commit c4087ea

Browse files
committed
save and load configurations
1 parent 37e70eb commit c4087ea

File tree

9 files changed

+141
-6
lines changed

9 files changed

+141
-6
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
/Manifest.toml
22
*.swp
3-
_*
3+
_*
4+
_test.bin
5+
_test.txt

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "0.1.0"
66
[deps]
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
88
Compose = "a81c6b42-2e10-5240-aca2-a61377ecd94b"
9+
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
910
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
1011
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
1112
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/arithematics.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,4 +197,4 @@ function onehotv(::Type{CountingTropical{TV,BS}}, x, v) where {TV,BS}
197197
CountingTropical{TV,BS}(one(TV), onehotv(BS, x, v))
198198
end
199199
onehotv(::Type{ConfigEnumerator{N,S,C}}, i::Integer, v) where {N,S,C} = ConfigEnumerator([onehotv(StaticElementVector{N,S,C}, i, v)])
200-
onehotv(::Type{ConfigSampler{N,S,C}}, i::Integer, v) where {N,S,C} = ConfigSampler(onehotv(StaticElementVector{N,S,C}, i, v))
200+
onehotv(::Type{ConfigSampler{N,S,C}}, i::Integer, v) where {N,S,C} = ConfigSampler(onehotv(StaticElementVector{N,S,C}, i, v))

src/bitvector.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,10 @@ end
6565
##### BitVectors
6666
const StaticBitVector{N,C} = StaticElementVector{N,1,C}
6767
@inline function Base.getindex(x::StaticBitVector{N,C}, i::Integer) where {N,C}
68-
@boundscheck i <= N || throw(BoundsError(x, i)) # TODO: make this @boundscheck work.
68+
@boundscheck (i <= N || throw(BoundsError(x, i))) # NOTE: still checks bounds in global scope, why?
6969
i -= 1
7070
ii = i ÷ 64
71-
@inbounds (x.data[ii+1] >> (i-ii*64)) & 1
71+
return @inbounds (x.data[ii+1] >> (i-ii*64)) & 1
7272
end
7373

7474
function StaticBitVector(x::AbstractVector)

src/configurations.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,5 +89,4 @@ end
8989
for GP in [:Independence, :Matching, :MaximalIndependence, :Coloring]
9090
@eval symbols(gp::$GP) = labels(gp.code)
9191
end
92-
symbols(gp::MaxCut) = collect(OMEinsum.getixs(OMEinsum.flatten(gp.code)))
93-
# TODO: coloring
92+
symbols(gp::MaxCut) = collect(OMEinsum.getixs(OMEinsum.flatten(gp.code)))

src/cuda.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@ for TT in [:(Tropical{<:NativeTypes}), :TropicalTypes]
2323
end
2424
end
2525
end
26+

src/interfaces.jl

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,75 @@ function solve(gp::GraphProblem, task; usecuda=false, kwargs...)
5151
error("unknown task $task.")
5252
end
5353
end
54+
55+
export save_configs, load_configs
56+
using DelimitedFiles
57+
function save_configs(filename, data::ConfigEnumerator{N,S,C}; format::Symbol=:binary) where {N,S,C}
58+
if format == :binary
59+
write(filename, raw_matrix(data))
60+
elseif format == :text
61+
writedlm(filename, plain_matrix(data))
62+
else
63+
error("format must be `:binary` or `:text`, got `:$format`")
64+
end
65+
end
66+
function load_configs(filename; len=nothing, format::Symbol=:binary, nflavors=2)
67+
if format == :binary
68+
len === nothing && error("you need to specify `len` for reading configurations from binary files.")
69+
S = ceil(Int, log2(nflavors))
70+
C = _nints(len, S)
71+
return _from_raw_matrix(StaticElementVector{len,S,C}, reshape(reinterpret(UInt64, read(filename)),C,:))
72+
elseif format == :text
73+
return from_plain_matrix(readdlm(filename); nflavors=nflavors)
74+
else
75+
error("format must be `:binary` or `:text`, got `:$format`")
76+
end
77+
end
78+
79+
function raw_matrix(x::ConfigEnumerator{N,S,C}) where {N,S,C}
80+
m = zeros(UInt64, C, length(x))
81+
@inbounds for i=1:length(x), j=1:C
82+
m[j,i] = x.data[i].data[j]
83+
end
84+
return m
85+
end
86+
function plain_matrix(x::ConfigEnumerator{N,S,C}) where {N,S,C}
87+
m = zeros(UInt8, N, length(x))
88+
@inbounds for i=1:length(x), j=1:N
89+
m[j,i] = x.data[i][j]
90+
end
91+
return m
92+
end
93+
94+
function from_raw_matrix(m; len, nflavors=2)
95+
S = ceil(Int,log2(nflavors))
96+
C = size(m, 1)
97+
T = StaticElementVector{len,S,C}
98+
@assert len*S <= C*64
99+
_from_raw_matrix(T, m)
100+
end
101+
function _from_raw_matrix(::Type{StaticElementVector{N,S,C}}, m::AbstractMatrix) where {N,S,C}
102+
data = zeros(StaticElementVector{N,S,C}, size(m, 2))
103+
@inbounds for i=1:size(m, 2)
104+
data[i] = StaticElementVector{N,S,C}(NTuple{C,UInt64}(view(m,:,i)))
105+
end
106+
return ConfigEnumerator(data)
107+
end
108+
function from_plain_matrix(m::Matrix; nflavors=2)
109+
S = ceil(Int,log2(nflavors))
110+
N = size(m, 1)
111+
C = _nints(N, S)
112+
T = StaticElementVector{N,S,C}
113+
_from_plain_matrix(T, m)
114+
end
115+
function _from_plain_matrix(::Type{StaticElementVector{N,S,C}}, m::AbstractMatrix) where {N,S,C}
116+
data = zeros(StaticElementVector{N,S,C}, size(m, 2))
117+
@inbounds for i=1:size(m, 2)
118+
data[i] = convert(StaticElementVector{N,S,C}, view(m, :, i))
119+
end
120+
return ConfigEnumerator(data)
121+
end
122+
123+
# convert to Matrix
124+
Base.Matrix(ce::ConfigEnumerator) = plain_matrix(ce)
125+
Base.Vector(ce::StaticElementVector) = collect(ce)

test/cuda.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using CUDA, Random
2+
using LinearAlgebra: mul!
3+
using GraphTensorNetworks, Test
4+
5+
@testset "cuda patch" begin
6+
for T in [Tropical{Float64}, CountingTropical{Float64,Float64}]
7+
a = T.(CUDA.randn(4, 4))
8+
b = T.(CUDA.randn(4))
9+
for A in [transpose(a), a, transpose(b)]
10+
for B in [transpose(a), a, b]
11+
if !(size(A) == (1,4) && size(B) == (4,))
12+
res0 = Array(A) * Array(B)
13+
res1 = A * B
14+
res2 = mul!(CUDA.zeros(T, size(res0)...), A, B, true, false)
15+
@test Array(res1) res0
16+
@test Array(res2) res0
17+
end
18+
end
19+
end
20+
end
21+
end

test/interfaces.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,43 @@ using LightGraphs, Test
3030
@test res11 == res5
3131
@test res12.c.data res13.c.data
3232
@test res13.c.data == res7.c.data
33+
end
34+
35+
@testset "save load" begin
36+
M = 10
37+
m = ConfigEnumerator([StaticBitVector(rand(Bool, 300)) for i=1:M])
38+
bm = GraphTensorNetworks.plain_matrix(m)
39+
rm = GraphTensorNetworks.raw_matrix(m)
40+
m1 = GraphTensorNetworks.from_raw_matrix(rm; len=300, nflavors=2)
41+
m2 = GraphTensorNetworks.from_plain_matrix(bm; nflavors=2)
42+
@test m1 == m
43+
@test m2 == m
44+
save_configs("_test.bin", m; format=:binary)
45+
@test_throws ErrorException load_configs("_test.bin"; format=:binary)
46+
ma = load_configs("_test.bin"; format=:binary, len=300, nflavors=2)
47+
@test ma == m
48+
49+
save_configs("_test.txt", m; format=:text)
50+
mb = load_configs("_test.txt"; format=:text, nflavors=2)
51+
@test mb == m
52+
53+
M = 10
54+
m = ConfigEnumerator([StaticElementVector(3, rand(1:3, 300)) for i=1:M])
55+
bm = GraphTensorNetworks.plain_matrix(m)
56+
rm = GraphTensorNetworks.raw_matrix(m)
57+
m1 = GraphTensorNetworks.from_raw_matrix(rm; len=300, nflavors=3)
58+
m2 = GraphTensorNetworks.from_plain_matrix(bm; nflavors=3)
59+
@test m1 == m
60+
@test m2 == m
61+
@test Matrix(m) == bm
62+
@test Vector(m.data[1]) == bm[:,1]
63+
64+
save_configs("_test.bin", m; format=:binary)
65+
@test_throws ErrorException load_configs("_test.bin"; format=:binary)
66+
ma = load_configs("_test.bin"; format=:binary, len=300, nflavors=3)
67+
@test ma == m
68+
69+
save_configs("_test.txt", m; format=:text)
70+
mb = load_configs("_test.txt"; format=:text, nflavors=3)
71+
@test mb == m
3372
end

0 commit comments

Comments
 (0)