Skip to content

Commit 233a18f

Browse files
committed
Metric support for Coreset & Yinyang
1 parent 26e8d98 commit 233a18f

File tree

5 files changed

+85
-28
lines changed

5 files changed

+85
-28
lines changed

src/coreset.jl

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,20 @@ Coreset(; m = 100, alg = Lloyd()) = Coreset(m, alg)
3535
Coreset(m::Int) = Coreset(m, Lloyd())
3636
Coreset(alg::AbstractKMeansAlg) = Coreset(100, alg)
3737

38-
function kmeans!(alg::Coreset, containers, X, k, weights;
38+
function kmeans!(alg::Coreset, containers, X, k, weights, metric=Euclidean();
3939
n_threads = Threads.nthreads(),
4040
k_init = "k-means++", max_iters = 300,
4141
tol = eltype(design_matrix)(1e-6), verbose = false,
4242
init = nothing, rng = Random.GLOBAL_RNG)
43+
4344
nrow, ncol = size(X)
45+
4446
centroids = isnothing(init) ? smart_init(X, k, n_threads, weights, rng, init=k_init).centroids : deepcopy(init)
4547

4648
T = eltype(X)
4749
# Steps 2-4 of the paper's algorithm 3
4850
# We distribute points over the centers and calculate weights of each cluster
49-
@parallelize n_threads ncol chunk_fit(alg, containers, centroids, X, weights)
51+
@parallelize n_threads ncol chunk_fit(alg, containers, centroids, X, weights, metric)
5052

5153
# after this step, containers.centroids_new
5254
collect_containers(alg, containers, n_threads)
@@ -62,15 +64,16 @@ function kmeans!(alg::Coreset, containers, X, k, weights;
6264

6365
# run usual kmeans for new set with new weights.
6466
res = kmeans(alg.alg, coreset, k, weights = coreset_weights, tol = tol, max_iters = max_iters,
65-
verbose = verbose, init = centroids, n_threads = n_threads, rng = rng)
67+
verbose = verbose, init = centroids, n_threads = n_threads, rng = rng, metric = metric)
6668

67-
@parallelize n_threads ncol chunk_apply(alg, containers, res.centers, X, weights)
69+
@parallelize n_threads ncol chunk_apply(alg, containers, res.centers, X, weights, metric)
6870

6971
totalcost = sum(containers.totalcost)
7072

7173
return KmeansResult(res.centers, containers.labels, T[], Int[], T[], totalcost, res.iterations, res.converged)
7274
end
7375

76+
7477
function create_containers(alg::Coreset, X, k, nrow, ncol, n_threads)
7578
T = eltype(X)
7679

@@ -109,7 +112,8 @@ function create_containers(alg::Coreset, X, k, nrow, ncol, n_threads)
109112
)
110113
end
111114

112-
function chunk_fit(alg::Coreset, containers, centroids, X, weights, r, idx)
115+
116+
function chunk_fit(alg::Coreset, containers, centroids, X, weights, metric, r, idx)
113117
centroids_cnt = containers.centroids_cnt[idx]
114118
centroids_dist = containers.centroids_dist[idx]
115119
labels = containers.labels
@@ -118,10 +122,10 @@ function chunk_fit(alg::Coreset, containers, centroids, X, weights, r, idx)
118122

119123
J = zero(T)
120124
for i in r
121-
dist = distance(X, centroids, i, 1)
125+
dist = distance(metric, X, centroids, i, 1)
122126
label = 1
123127
for j in 2:size(centroids, 2)
124-
new_dist = distance(X, centroids, i, j)
128+
new_dist = distance(metric, X, centroids, i, j)
125129

126130
# calculation of the closest center (steps 2-3 of the paper's algorithm 3)
127131
label = new_dist < dist ? j : label
@@ -144,6 +148,7 @@ function chunk_fit(alg::Coreset, containers, centroids, X, weights, r, idx)
144148
containers.J[idx] = J
145149
end
146150

151+
147152
function collect_containers(::Coreset, containers, n_threads)
148153
# Here we transform formula of the step 6
149154
# By multiplying both sides of equation on $c_\phi / \alpha$ we obtain
@@ -172,6 +177,7 @@ function collect_containers(::Coreset, containers, n_threads)
172177
end
173178
end
174179

180+
175181
function chunk_update_sensitivity(alg::Coreset, containers, r, idx)
176182
labels = containers.labels
177183
centroids_const = containers.centroids_const
@@ -182,18 +188,19 @@ function chunk_update_sensitivity(alg::Coreset, containers, r, idx)
182188
end
183189
end
184190

185-
function chunk_apply(alg::Coreset, containers, centroids, X, weights, r, idx)
191+
192+
function chunk_apply(alg::Coreset, containers, centroids, X, weights, metric, r, idx)
186193
centroids_cnt = containers.centroids_cnt[idx]
187194
centroids_dist = containers.centroids_dist[idx]
188195
labels = containers.labels
189196
T = eltype(X)
190197

191198
J = zero(T)
192199
for i in r
193-
dist = distance(X, centroids, i, 1)
200+
dist = distance(metric, X, centroids, i, 1)
194201
label = 1
195202
for j in 2:size(centroids, 2)
196-
new_dist = distance(X, centroids, i, j)
203+
new_dist = distance(metric, X, centroids, i, j)
197204

198205
# calculation of the closest center (steps 2-3 of the paper's algorithm 3)
199206
label = new_dist < dist ? j : label

src/yinyang.jl

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -47,22 +47,24 @@ Yinyang(; group_size = 7, auto = true) = Yinyang(auto, group_size)
4747
阴阳(group_size::Int) = Yinyang(true, group_size)
4848
阴阳(; group_size = 7, auto = true) = Yinyang(auto, group_size)
4949

50-
function kmeans!(alg::Yinyang, containers, X, k, weights;
50+
function kmeans!(alg::Yinyang, containers, X, k, weights, metric=Euclidean();
5151
n_threads = Threads.nthreads(),
5252
k_init = "k-means++", max_iters = 300,
5353
tol = 1e-6, verbose = false,
5454
init = nothing, rng = Random.GLOBAL_RNG)
55+
5556
nrow, ncol = size(X)
57+
5658
centroids = init == nothing ? smart_init(X, k, n_threads, weights, rng, init=k_init).centroids : deepcopy(init)
5759

5860
# create initial groups of centers, step 1 in original paper
5961
initialize(alg, containers, centroids, rng, n_threads)
6062
# construct initial bounds, step 2
61-
@parallelize n_threads ncol chunk_initialize(alg, containers, centroids, X, weights)
63+
@parallelize n_threads ncol chunk_initialize(alg, containers, centroids, X, weights, metric)
6264
collect_containers(alg, containers, n_threads)
6365

6466
# update centers and calculate drifts. Step 3.1 of the original paper.
65-
calculate_centroids_movement(alg, containers, centroids)
67+
calculate_centroids_movement(alg, containers, centroids, metric)
6668

6769
T = eltype(X)
6870
converged = false
@@ -87,14 +89,14 @@ function kmeans!(alg::Yinyang, containers, X, k, weights;
8789

8890
# push!(containers.debug, [0, 0, 0])
8991
# Core calculation of the Yinyang, 3.2-3.3 steps of the original paper
90-
@parallelize n_threads ncol chunk_update_centroids(alg, containers, centroids, X, weights)
92+
@parallelize n_threads ncol chunk_update_centroids(alg, containers, centroids, X, weights, metric)
9193
collect_containers(alg, containers, n_threads)
9294

9395
# update centers and calculate drifts. Step 3.1 of the original paper.
94-
calculate_centroids_movement(alg, containers, centroids)
96+
calculate_centroids_movement(alg, containers, centroids, metric)
9597
end
9698

97-
@parallelize n_threads ncol sum_of_squares(containers, X, containers.labels, centroids, weights)
99+
@parallelize n_threads ncol sum_of_squares(containers, X, containers.labels, centroids, weights, metric)
98100
totalcost = sum(containers.sum_of_squares)
99101

100102
# Terminate algorithm with the assumption that K-means has converged
@@ -170,6 +172,7 @@ function create_containers(alg::Yinyang, X, k, nrow, ncol, n_threads)
170172
)
171173
end
172174

175+
173176
function initialize(alg::Yinyang, containers, centroids, rng, n_threads)
174177
groups = containers.groups
175178
indices = containers.indices
@@ -186,21 +189,23 @@ function initialize(alg::Yinyang, containers, centroids, rng, n_threads)
186189
end
187190
end
188191

189-
function chunk_initialize(alg::Yinyang, containers, centroids, X, weights, r, idx)
192+
193+
function chunk_initialize(alg::Yinyang, containers, centroids, X, weights, metric, r, idx)
190194
T = eltype(X)
191195
centroids_cnt = containers.centroids_cnt[idx]
192196
centroids_new = containers.centroids_new[idx]
193197

194198
@inbounds for i in r
195-
label = point_all_centers!(alg, containers, centroids, X, i)
199+
label = point_all_centers!(alg, containers, centroids, X, i, metric)
196200
centroids_cnt[label] += isnothing(weights) ? one(T) : weights[i]
197201
for j in axes(X, 1)
198202
centroids_new[j, label] += isnothing(weights) ? X[j, i] : weights[i] * X[j, i]
199203
end
200204
end
201205
end
202206

203-
function calculate_centroids_movement(alg::Yinyang, containers, centroids)
207+
208+
function calculate_centroids_movement(alg::Yinyang, containers, centroids, metric)
204209
p = containers.p
205210
groups = containers.groups
206211
gd = containers.gd
@@ -210,7 +215,7 @@ function calculate_centroids_movement(alg::Yinyang, containers, centroids)
210215
@inbounds for (gi, ri) in enumerate(groups)
211216
max_drift = T(-Inf)
212217
for i in ri
213-
p[i] = sqrt(distance(centroids, centroids_new, i, i))
218+
p[i] = sqrt(distance(metric, centroids, centroids_new, i, i))
214219
max_drift = p[i] > max_drift ? p[i] : max_drift
215220

216221
# Should do it more elegantly
@@ -222,7 +227,8 @@ function calculate_centroids_movement(alg::Yinyang, containers, centroids)
222227
end
223228
end
224229

225-
function chunk_update_centroids(alg::Yinyang, containers, centroids, X, weights, r, idx)
230+
231+
function chunk_update_centroids(alg::Yinyang, containers, centroids, X, weights, metric, r, idx)
226232
# unpack containers for easier manipulations
227233
centroids_new = containers.centroids_new[idx]
228234
centroids_cnt = containers.centroids_cnt[idx]
@@ -256,7 +262,7 @@ function chunk_update_centroids(alg::Yinyang, containers, centroids, X, weights,
256262

257263
# tighten upper bound
258264
label = labels[i]
259-
ubx = sqrt(distance(X, centroids, i, label))
265+
ubx = sqrt(distance(metric, X, centroids, i, label))
260266
ub[i] = ubx
261267
ubx <= lbx && continue
262268

@@ -275,7 +281,7 @@ function chunk_update_centroids(alg::Yinyang, containers, centroids, X, weights,
275281
((c == old_label) | (ubx < old_lb - p[c])) && continue
276282
mask[c] = true
277283
# containers.debug[end][2] += 1 # local filter update
278-
dist = distance(X, centroids, i, c)
284+
dist = distance(metric, X, centroids, i, c)
279285
if dist < ubx2
280286
new_lb2 = ubx2
281287
ubx2 = dist
@@ -290,7 +296,7 @@ function chunk_update_centroids(alg::Yinyang, containers, centroids, X, weights,
290296
mask[c] && continue
291297
new_lb < old_lb - p[c] && continue
292298
# containers.debug[end][3] += 1 # lower bound update
293-
dist = distance(X, centroids, i, c)
299+
dist = distance(metric, X, centroids, i, c)
294300
if dist < new_lb2
295301
new_lb2 = dist
296302
new_lb = sqrt(new_lb2)
@@ -314,7 +320,7 @@ function chunk_update_centroids(alg::Yinyang, containers, centroids, X, weights,
314320
ubx < old_lb - p[c] && continue
315321
# containers.debug[end][2] += 1 # local filter update
316322
mask[c] = true
317-
dist = distance(X, centroids, i, c)
323+
dist = distance(metric, X, centroids, i, c)
318324
if dist < ubx2
319325
# closest center was in previous cluster
320326
if indices[label] != gi
@@ -336,7 +342,7 @@ function chunk_update_centroids(alg::Yinyang, containers, centroids, X, weights,
336342
mask[c] && continue
337343
new_lb < old_lb - p[c] && continue
338344
# containers.debug[end][3] += 1 # lower bound update
339-
dist = distance(X, centroids, i, c)
345+
dist = distance(metric, X, centroids, i, c)
340346
if dist < new_lb2
341347
new_lb2 = dist
342348
new_lb = sqrt(new_lb2)
@@ -360,12 +366,13 @@ function chunk_update_centroids(alg::Yinyang, containers, centroids, X, weights,
360366
end
361367
end
362368

369+
363370
"""
364371
point_all_centers!(containers, centroids, X, i)
365372
366373
Calculates new labels and upper and lower bounds for all points.
367374
"""
368-
function point_all_centers!(alg::Yinyang, containers, centroids, X, i)
375+
function point_all_centers!(alg::Yinyang, containers, centroids, X, i, metric)
369376
ub = containers.ub
370377
lb = containers.lb
371378
labels = containers.labels
@@ -381,7 +388,7 @@ function point_all_centers!(alg::Yinyang, containers, centroids, X, i)
381388
group_min_distance2 = T(Inf)
382389
group_label = ri[1]
383390
for k in ri
384-
dist = distance(X, centroids, i, k)
391+
dist = distance(metric, X, centroids, i, k)
385392
if group_min_distance > dist
386393
group_label = k
387394
group_min_distance2 = group_min_distance
@@ -407,6 +414,7 @@ function point_all_centers!(alg::Yinyang, containers, centroids, X, i)
407414
return label
408415
end
409416

417+
410418
# I believe there should be oneliner for it
411419
function rangify(x)
412420
res = UnitRange{Int}[]

test/test05_hamerly.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using StableRNGs
77
using Random
88
using Distances
99

10+
1011
@testset "initialize" begin
1112
X = permutedims([1.0 2; 2 1; 4 5; 6 6])
1213
centroids = permutedims([1.0 2; 4 5; 6 6])

test/test06_yinyang.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ module TestYinyang
33
using ParallelKMeans
44
using Test
55
using StableRNGs
6+
using Distances
7+
68

79
@testset "basic kmeans Yinyang" begin
810
X = [1. 2. 4.;]
@@ -196,4 +198,29 @@ end
196198
@test !alg.auto
197199
end
198200

201+
@testset "Yinyang metric support" begin
202+
rng = StableRNG(2020)
203+
X = [1. 2. 4.;]
204+
205+
res = kmeans(Yinyang(), X, 2; tol = 1e-16, metric=Cityblock(), rng = rng)
206+
207+
@test res.assignments == [2, 2, 1]
208+
@test res.centers == [4.0 1.5]
209+
@test res.totalcost == 1.0
210+
@test res.converged
211+
212+
rng = StableRNG(2020)
213+
X = rand(3, 100)
214+
rng_orig = deepcopy(rng)
215+
216+
baseline = kmeans(Lloyd(), X, 2, tol = 1e-16, metric=Cityblock(), rng = rng)
217+
218+
rng = deepcopy(rng_orig)
219+
res = kmeans(Yinyang(), X, 2; tol = 1e-16, metric=Cityblock(), rng = rng)
220+
221+
@test res.totalcost baseline.totalcost
222+
@test res.converged == baseline.converged
223+
@test res.iterations == baseline.iterations
224+
end
225+
199226
end # module

test/test07_coreset.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ module TestCoreset
33
using ParallelKMeans
44
using Test
55
using StableRNGs
6+
using Distances
7+
68

79
@testset "basic coresets" begin
810
rng = StableRNG(2020)
@@ -45,4 +47,16 @@ end
4547
@test alg.alg == Hamerly()
4648
end
4749

50+
@testset "Coreset metric support" begin
51+
rng = StableRNG(2020)
52+
X = [1. 2. 4.;]
53+
54+
res = kmeans(Coreset(), X, 2; tol = 1e-16, metric=Cityblock(), rng = rng)
55+
56+
@test res.assignments == [2, 2, 1]
57+
@test res.centers == [4.0 1.4865168535972686]
58+
@test res.totalcost == 1.0
59+
@test res.converged
60+
end
61+
4862
end # module

0 commit comments

Comments
 (0)