Skip to content

Commit 09125bb

Browse files
committed
WIP: Initial tested metric support for Elkan
1 parent 4fd5f95 commit 09125bb

File tree

2 files changed

+46
-17
lines changed

2 files changed

+46
-17
lines changed

src/elkan.jl

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,18 @@ kmeans(Elkan(), X, 3) # 3 clusters, Elkan algorithm
1818
"""
1919
struct Elkan <: AbstractKMeansAlg end
2020

21+
2122
function kmeans!(alg::Elkan, containers, X, k, weights;
2223
n_threads = Threads.nthreads(),
2324
k_init = "k-means++", max_iters = 300,
2425
tol = eltype(X)(1e-6), verbose = false,
25-
init = nothing, rng = Random.GLOBAL_RNG)
26+
init = nothing, rng = Random.GLOBAL_RNG, metric=Euclidean())
27+
2628
nrow, ncol = size(X)
2729
centroids = init == nothing ? smart_init(X, k, n_threads, weights, rng, init=k_init).centroids : deepcopy(init)
2830

29-
update_containers(alg, containers, centroids, n_threads)
30-
@parallelize n_threads ncol chunk_initialize(alg, containers, centroids, X, weights)
31+
update_containers(alg, containers, centroids, n_threads, metric)
32+
@parallelize n_threads ncol chunk_initialize(alg, containers, centroids, X, weights, metric)
3133

3234
T = eltype(X)
3335
converged = false
@@ -38,7 +40,7 @@ function kmeans!(alg::Elkan, containers, X, k, weights;
3840
while niters < max_iters
3941
niters += 1
4042
# Core iteration
41-
@parallelize n_threads ncol chunk_update_centroids(alg, containers, centroids, X, weights)
43+
@parallelize n_threads ncol chunk_update_centroids(alg, containers, centroids, X, weights, metric)
4244

4345
# Collect distributed containers (such as centroids_new, centroids_cnt)
4446
# in paper it is step 4
@@ -47,7 +49,7 @@ function kmeans!(alg::Elkan, containers, X, k, weights;
4749
J = sum(containers.ub)
4850

4951
# auxiliary calculation, in paper it's d(c, m(c))
50-
calculate_centroids_movement(alg, containers, centroids)
52+
calculate_centroids_movement(alg, containers, centroids, metric)
5153

5254
# lower and ounds update, in paper it's steps 5 and 6
5355
@parallelize n_threads ncol chunk_update_bounds(alg, containers, centroids)
@@ -67,11 +69,11 @@ function kmeans!(alg::Elkan, containers, X, k, weights;
6769
end
6870

6971
# Step 1 in original paper, calulation of distance d(c, c')
70-
update_containers(alg, containers, centroids, n_threads)
72+
update_containers(alg, containers, centroids, n_threads, metric)
7173
J_previous = J
7274
end
7375

74-
@parallelize n_threads ncol sum_of_squares(containers, X, containers.labels, centroids, weights)
76+
@parallelize n_threads ncol sum_of_squares(containers, X, containers.labels, centroids, weights, metric)
7577
totalcost = sum(containers.sum_of_squares)
7678

7779
# Terminate algorithm with the assumption that K-means has converged
@@ -85,6 +87,7 @@ function kmeans!(alg::Elkan, containers, X, k, weights;
8587
return KmeansResult(centroids, containers.labels, T[], Int[], T[], totalcost, niters, converged)
8688
end
8789

90+
8891
function create_containers(alg::Elkan, X, k, nrow, ncol, n_threads)
8992
T = eltype(X)
9093
lng = n_threads + 1
@@ -128,7 +131,8 @@ function create_containers(alg::Elkan, X, k, nrow, ncol, n_threads)
128131
)
129132
end
130133

131-
function chunk_initialize(::Elkan, containers, centroids, X, weights, r, idx)
134+
135+
function chunk_initialize(::Elkan, containers, centroids, X, weights, metric, r, idx)
132136
ub = containers.ub
133137
lb = containers.lb
134138
centroids_dist = containers.centroids_dist
@@ -138,15 +142,15 @@ function chunk_initialize(::Elkan, containers, centroids, X, weights, r, idx)
138142
T = eltype(X)
139143

140144
@inbounds for i in r
141-
min_dist = distance(X, centroids, i, 1)
145+
min_dist = distance(metric, X, centroids, i, 1)
142146
label = 1
143147
lb[label, i] = min_dist
144148
for j in 2:size(centroids, 2)
145149
# triangular inequality
146150
if centroids_dist[j, label] > min_dist
147151
lb[j, i] = min_dist
148152
else
149-
dist = distance(X, centroids, i, j)
153+
dist = distance(metric, X, centroids, i, j)
150154
label = dist < min_dist ? j : label
151155
min_dist = dist < min_dist ? dist : min_dist
152156
lb[j, i] = dist
@@ -161,7 +165,8 @@ function chunk_initialize(::Elkan, containers, centroids, X, weights, r, idx)
161165
end
162166
end
163167

164-
function update_containers(::Elkan, containers, centroids, n_threads)
168+
169+
function update_containers(::Elkan, containers, centroids, n_threads, metric)
165170
# unpack containers for easier manipulations
166171
centroids_dist = containers.centroids_dist
167172
T = eltype(centroids)
@@ -170,7 +175,7 @@ function update_containers(::Elkan, containers, centroids, n_threads)
170175
@inbounds for j in axes(centroids_dist, 2)
171176
min_dist = T(Inf)
172177
for i in j + 1:k
173-
d = distance(centroids, centroids, i, j)
178+
d = distance(metric, centroids, centroids, i, j)
174179
centroids_dist[i, j] = d
175180
centroids_dist[j, i] = d
176181
min_dist = min_dist < d ? min_dist : d
@@ -189,7 +194,8 @@ function update_containers(::Elkan, containers, centroids, n_threads)
189194
return centroids_dist
190195
end
191196

192-
function chunk_update_centroids(::Elkan, containers, centroids, X, weights, r, idx)
197+
198+
function chunk_update_centroids(::Elkan, containers, centroids, X, weights, metric, r, idx)
193199
# unpack
194200
ub = containers.ub
195201
lb = containers.lb
@@ -214,14 +220,14 @@ function chunk_update_centroids(::Elkan, containers, centroids, X, weights, r, i
214220

215221
# one calculation per iteration is enough
216222
if stale[i]
217-
min_dist = distance(X, centroids, i, label)
223+
min_dist = distance(metric, X, centroids, i, label)
218224
lb[label, i] = min_dist
219225
ub[i] = min_dist
220226
stale[i] = false
221227
end
222228

223229
if (min_dist > lb[j, i]) | (min_dist > centroids_dist[j, label])
224-
dist = distance(X, centroids, i, j)
230+
dist = distance(metric, X, centroids, i, j)
225231
lb[j, i] = dist
226232
if dist < min_dist
227233
min_dist = dist
@@ -242,15 +248,17 @@ function chunk_update_centroids(::Elkan, containers, centroids, X, weights, r, i
242248
end
243249
end
244250

245-
function calculate_centroids_movement(alg::Elkan, containers, centroids)
251+
252+
function calculate_centroids_movement(alg::Elkan, containers, centroids, metric)
246253
p = containers.p
247254
centroids_new = containers.centroids_new[end]
248255

249256
for i in axes(centroids, 2)
250-
p[i] = distance(centroids, centroids_new, i, i)
257+
p[i] = distance(metric, centroids, centroids_new, i, i)
251258
end
252259
end
253260

261+
254262
function chunk_update_bounds(alg, containers, centroids, r, idx)
255263
p = containers.p
256264
lb = containers.lb

test/test04_elkan.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module TestElkan
22

33
using ParallelKMeans
4+
using Distances
45
using Test
56
using StableRNGs
67

@@ -89,4 +90,24 @@ end
8990
@test res.iterations == baseline.iterations
9091
end
9192

93+
94+
@testset "Elkan metric support" begin
95+
Random.seed!(2020)
96+
X = [1. 2. 4.;]
97+
98+
res = kmeans(Elkan(), X, 2; tol = 1e-16, metric=Cityblock())
99+
100+
@test res.assignments == [1, 1, 2]
101+
@test res.centers == [1.5 4.0]
102+
@test res.totalcost == 1.0
103+
@test res.converged
104+
105+
Random.seed!(2020)
106+
X = rand(3, 100)
107+
108+
res = kmeans(Elkan(), X, 2, tol = 1e-16, metric=Cityblock())
109+
@test res.totalcost 62.04045252895372
110+
@test res.converged
111+
end
112+
92113
end # module

0 commit comments

Comments
 (0)