Skip to content

Commit 26e8d98

Browse files
committed
Added random generator to Elkan, Hamerly & Lloyd
1 parent b643b4c commit 26e8d98

File tree

7 files changed

+34
-48
lines changed

7 files changed

+34
-48
lines changed

src/elkan.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ function kmeans!(alg::Elkan, containers, X, k, weights=nothing, metric=Euclidean
2323
n_threads = Threads.nthreads(),
2424
k_init = "k-means++", max_iters = 300,
2525
tol = eltype(X)(1e-6), verbose = false,
26-
init = nothing, rng = Random.GLOBAL_RNG, metric=Euclidean())
26+
init = nothing, rng = Random.GLOBAL_RNG)
2727

2828
nrow, ncol = size(X)
2929
centroids = init == nothing ? smart_init(X, k, n_threads, weights, rng, init=k_init).centroids : deepcopy(init)

src/hamerly.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function kmeans!(alg::Hamerly, containers, X, k, weights=nothing, metric=Euclide
2222
n_threads = Threads.nthreads(),
2323
k_init = "k-means++", max_iters = 300,
2424
tol = eltype(X)(1e-6), verbose = false,
25-
init = nothing, rng = Random.GLOBAL_RNG, metric=Euclidean())
25+
init = nothing, rng = Random.GLOBAL_RNG)
2626

2727
nrow, ncol = size(X)
2828
centroids = init == nothing ? smart_init(X, k, n_threads, weights, rng, init=k_init).centroids : deepcopy(init)

src/kmeans.jl

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -168,36 +168,20 @@ alternatively one can use `rand` to choose random points for init.
168168
169169
A `KmeansResult` structure representing labels, centroids, and sum_squares is returned.
170170
"""
171-
<<<<<<< HEAD
172-
function kmeans(alg::AbstractKMeansAlg, design_matrix, k;
173-
weights = nothing,
171+
function kmeans(alg::AbstractKMeansAlg, design_matrix, k; weights = nothing,
174172
n_threads = Threads.nthreads(),
175173
k_init = "k-means++", max_iters = 300,
176174
tol = eltype(design_matrix)(1e-6), verbose = false,
177175
init = nothing, rng = Random.GLOBAL_RNG, metric = Euclidean())
178176

179-
=======
180-
function kmeans(alg::AbstractKMeansAlg, design_matrix, k, weights = nothing;
181-
n_threads = Threads.nthreads(), k_init = "k-means++", max_iters = 300,
182-
tol = eltype(design_matrix)(1e-6), verbose = false, init = nothing, metric=Euclidean())
183-
184-
# Get dimensions of the input data
185-
>>>>>>> Moved metric internally as a positional arg
186177
nrow, ncol = size(design_matrix)
187178

188179
# Create containers based on the dimensions and specifications
189180
containers = create_containers(alg, design_matrix, k, nrow, ncol, n_threads)
190181

191-
<<<<<<< HEAD
192-
return kmeans!(alg, containers, design_matrix, k, weights, n_threads = n_threads,
193-
k_init = k_init, max_iters = max_iters, tol = tol,
194-
verbose = verbose, init = init, rng = rng, metric = metric)
195-
=======
196-
# Dispatch based on the specified algorithm
197182
return kmeans!(alg, containers, design_matrix, k, weights, metric;
198183
n_threads = n_threads, k_init = k_init, max_iters = max_iters,
199-
tol = tol, verbose = verbose, init = init)
200-
>>>>>>> Moved metric internally as a positional arg
184+
tol = tol, verbose = verbose, init = init, rng = rng)
201185

202186
end
203187

src/lloyd.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ function kmeans!(alg::Lloyd, containers, X, k, weights=nothing, metric=Euclidean
1818
n_threads = Threads.nthreads(),
1919
k_init = "k-means++", max_iters = 300,
2020
tol = eltype(design_matrix)(1e-6), verbose = false,
21-
init = nothing, rng = Random.GLOBAL_RNG, metric=Euclidean())
21+
init = nothing, rng = Random.GLOBAL_RNG)
2222

2323
# Get dimensions of the input data
2424
nrow, ncol = size(X)
@@ -69,8 +69,9 @@ kmeans(design_matrix, k;
6969
n_threads = Threads.nthreads(),
7070
k_init = "k-means++", max_iters = 300, tol = 1e-6,
7171
verbose = false, init = nothing, rng = Random.GLOBAL_RNG, metric = Euclidean()) =
72-
kmeans(Lloyd(), design_matrix, k; weights = weights, n_threads = n_threads, k_init = k_init, max_iters = max_iters, tol = tol,
73-
verbose = verbose, init = init, rng = rng, metric = metric)
72+
kmeans(Lloyd(), design_matrix, k; weights = weights,
73+
n_threads = n_threads, k_init = k_init, max_iters = max_iters, tol = tol,
74+
verbose = verbose, init = init, rng = rng, metric = metric)
7475

7576

7677
"""

test/test03_lloyd.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,21 +104,22 @@ end
104104
end
105105

106106
@testset "Lloyd metric support" begin
107-
Random.seed!(2020)
107+
rng = StableRNG(2020)
108108
X = [1. 2. 4.;]
109109

110-
res = kmeans(Lloyd(), X, 2; tol = 1e-16, metric=Cityblock())
110+
res = kmeans(Lloyd(), X, 2; tol = 1e-16, metric=Cityblock(), rng = rng)
111111

112-
@test res.assignments == [1, 1, 2]
113-
@test res.centers == [1.5 4.0]
112+
@test res.assignments == [2, 2, 1]
113+
@test res.centers == [4.0 1.5]
114114
@test res.totalcost == 1.0
115115
@test res.converged
116116

117-
Random.seed!(2020)
118-
X = rand(3, 100)
117+
rng = StableRNG(2020)
118+
X = rand(rng, 3, 100)
119119

120-
res = kmeans(X, 2, tol = 1e-16, metric=Cityblock())
121-
@test res.totalcost 62.04045252895372
120+
res = kmeans(X, 2; tol = 1e-16, metric = Cityblock(), rng = rng)
121+
@test res.totalcost 60.893492629945044
122+
@test res.iterations == 6
122123
@test res.converged
123124
end
124125

test/test04_elkan.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,21 +92,21 @@ end
9292

9393

9494
@testset "Elkan metric support" begin
95-
Random.seed!(2020)
95+
rng = StableRNG(2020)
9696
X = [1. 2. 4.;]
9797

98-
res = kmeans(Elkan(), X, 2; tol = 1e-16, metric=Cityblock())
98+
res = kmeans(Lloyd(), X, 2; tol = 1e-16, metric=Cityblock(), rng = rng)
9999

100-
@test res.assignments == [1, 1, 2]
101-
@test res.centers == [1.5 4.0]
100+
@test res.assignments == [2, 2, 1]
101+
@test res.centers == [4.0 1.5]
102102
@test res.totalcost == 1.0
103103
@test res.converged
104104

105-
Random.seed!(2020)
106-
X = rand(3, 100)
105+
rng = StableRNG(2020)
106+
X = rand(rng, 3, 100)
107107

108-
res = kmeans(Elkan(), X, 2, tol = 1e-16, metric=Cityblock())
109-
@test res.totalcost 62.04045252895372
108+
res = kmeans(Elkan(), X, 2; tol = 1e-16, metric = Cityblock(), rng = rng)
109+
@test res.totalcost 60.893492629945044
110110
@test res.converged
111111
end
112112

test/test05_hamerly.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -105,25 +105,25 @@ end
105105

106106

107107
@testset "Hamerly metric support" begin
108-
Random.seed!(2020)
108+
rng = StableRNG(2020)
109109
X = [1. 2. 4.;]
110110

111-
res = kmeans(Hamerly(), X, 2; tol = 1e-16, metric=Cityblock())
111+
res = kmeans(Hamerly(), X, 2; tol = 1e-16, metric=Cityblock(), rng = rng)
112112

113-
@test res.assignments == [1, 1, 2]
114-
@test res.centers == [1.5 4.0]
113+
@test res.assignments == [2, 2, 1]
114+
@test res.centers == [4.0 1.5]
115115
@test res.totalcost == 1.0
116116
@test res.converged
117117

118-
Random.seed!(2020)
118+
rng = StableRNG(2020)
119119
X = rand(3, 100)
120+
rng_orig = deepcopy(rng)
120121

121-
baseline = kmeans(Lloyd(), X, 2, tol = 1e-16, metric=Cityblock())
122+
baseline = kmeans(Lloyd(), X, 2, tol = 1e-16, metric=Cityblock(), rng = rng)
122123

123-
Random.seed!(2020)
124-
X = rand(3, 100)
124+
rng = deepcopy(rng_orig)
125+
res = kmeans(Hamerly(), X, 2; tol = 1e-16, metric=Cityblock(), rng = rng)
125126

126-
res = kmeans(Hamerly(), X, 2; tol = 1e-16, metric=Cityblock())
127127
@test res.totalcost baseline.totalcost
128128
@test res.converged == baseline.converged
129129
@test res.iterations == baseline.iterations

0 commit comments

Comments
 (0)