Skip to content

Commit aef8f69

Browse files
committed
WIP: Initial unverified metric support for Hamerly
1 parent 09125bb commit aef8f69

File tree

2 files changed

+47
-16
lines changed

2 files changed

+47
-16
lines changed

src/hamerly.jl

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@ function kmeans!(alg::Hamerly, containers, X, k, weights;
2222
n_threads = Threads.nthreads(),
2323
k_init = "k-means++", max_iters = 300,
2424
tol = eltype(X)(1e-6), verbose = false,
25-
init = nothing, rng = Random.GLOBAL_RNG)
25+
init = nothing, rng = Random.GLOBAL_RNG, metric=Euclidean())
26+
2627
nrow, ncol = size(X)
2728
centroids = init == nothing ? smart_init(X, k, n_threads, weights, rng, init=k_init).centroids : deepcopy(init)
2829

29-
@parallelize n_threads ncol chunk_initialize(alg, containers, centroids, X, weights)
30+
@parallelize n_threads ncol chunk_initialize(alg, containers, centroids, X, weights, metric)
3031

3132
T = eltype(X)
3233
converged = false
@@ -37,8 +38,8 @@ function kmeans!(alg::Hamerly, containers, X, k, weights;
3738
# Update centroids & labels with closest members until convergence
3839
while niters < max_iters
3940
niters += 1
40-
update_containers(alg, containers, centroids, n_threads)
41-
@parallelize n_threads ncol chunk_update_centroids(alg, containers, centroids, X, weights)
41+
update_containers(alg, containers, centroids, n_threads, metric)
42+
@parallelize n_threads ncol chunk_update_centroids(alg, containers, centroids, X, weights, metric)
4243
collect_containers(alg, containers, n_threads)
4344

4445
J = sum(containers.ub)
@@ -61,7 +62,7 @@ function kmeans!(alg::Hamerly, containers, X, k, weights;
6162
J_previous = J
6263
end
6364

64-
@parallelize n_threads ncol sum_of_squares(containers, X, containers.labels, centroids, weights)
65+
@parallelize n_threads ncol sum_of_squares(containers, X, containers.labels, centroids, weights, metric)
6566
totalcost = sum(containers.sum_of_squares)
6667

6768
# Terminate algorithm with the assumption that K-means has converged
@@ -75,6 +76,7 @@ function kmeans!(alg::Hamerly, containers, X, k, weights;
7576
return KmeansResult(centroids, containers.labels, T[], Int[], T[], totalcost, niters, converged)
7677
end
7778

79+
7880
function create_containers(alg::Hamerly, X, k, nrow, ncol, n_threads)
7981
T = eltype(X)
8082
lng = n_threads + 1
@@ -115,52 +117,55 @@ function create_containers(alg::Hamerly, X, k, nrow, ncol, n_threads)
115117
)
116118
end
117119

120+
118121
"""
119122
chunk_initialize(alg::Hamerly, containers, centroids, design_matrix, r, idx)
120123
121124
Initial calulation of all bounds and points labeling.
122125
"""
123-
function chunk_initialize(alg::Hamerly, containers, centroids, X, weights, r, idx)
126+
function chunk_initialize(alg::Hamerly, containers, centroids, X, weights, metric, r, idx)
124127
T = eltype(X)
125128
centroids_cnt = containers.centroids_cnt[idx]
126129
centroids_new = containers.centroids_new[idx]
127130

128131
@inbounds for i in r
129-
label = point_all_centers!(containers, centroids, X, i)
132+
label = point_all_centers!(containers, centroids, X, i, metric)
130133
centroids_cnt[label] += isnothing(weights) ? one(T) : weights[i]
131134
for j in axes(X, 1)
132135
centroids_new[j, label] += isnothing(weights) ? X[j, i] : weights[i] * X[j, i]
133136
end
134137
end
135138
end
136139

140+
137141
"""
138142
update_containers(::Hamerly, containers, centroids, n_threads)
139143
140144
Calculates minimum distances from centers to each other.
141145
"""
142-
function update_containers(::Hamerly, containers, centroids, n_threads)
146+
function update_containers(::Hamerly, containers, centroids, n_threads, metric)
143147
T = eltype(centroids)
144148
s = containers.s
145149
s .= T(Inf)
146150
@inbounds for i in axes(centroids, 2)
147151
for j in i+1:size(centroids, 2)
148-
d = distance(centroids, centroids, i, j)
152+
d = distance(metric, centroids, centroids, i, j)
149153
d = T(0.25)*d
150154
s[i] = s[i] > d ? d : s[i]
151155
s[j] = s[j] > d ? d : s[j]
152156
end
153157
end
154158
end
155159

160+
156161
"""
157162
chunk_update_centroids(::Hamerly, containers, centroids, X, r, idx)
158163
159164
Detailed description of this function can be found in the original paper. It iterates through
160165
all points and tries to skip some calculation using known upper and lower bounds of distances
161166
from point to centers. If it fails to skip than it fall back to generic `point_all_centers!` function.
162167
"""
163-
function chunk_update_centroids(alg::Hamerly, containers, centroids, X, weights, r, idx)
168+
function chunk_update_centroids(alg::Hamerly, containers, centroids, X, weights, metric, r, idx)
164169

165170
# unpack containers for easier manipulations
166171
centroids_new = containers.centroids_new[idx]
@@ -178,10 +183,10 @@ function chunk_update_centroids(alg::Hamerly, containers, centroids, X, weights,
178183
if ub[i] > m
179184
# tighten upper bound
180185
label = labels[i]
181-
ub[i] = distance(X, centroids, i, label)
186+
ub[i] = distance(metric, X, centroids, i, label)
182187
# second bound test
183188
if ub[i] > m
184-
label_new = point_all_centers!(containers, centroids, X, i)
189+
label_new = point_all_centers!(containers, centroids, X, i, metric)
185190
if label != label_new
186191
labels[i] = label_new
187192
centroids_cnt[label_new] += isnothing(weights) ? one(T) : weights[i]
@@ -196,12 +201,13 @@ function chunk_update_centroids(alg::Hamerly, containers, centroids, X, weights,
196201
end
197202
end
198203

204+
199205
"""
200206
point_all_centers!(containers, centroids, X, i)
201207
202208
Calculates new labels and upper and lower bounds for all points.
203209
"""
204-
function point_all_centers!(containers, centroids, X, i)
210+
function point_all_centers!(containers, centroids, X, i, metric)
205211
ub = containers.ub
206212
lb = containers.lb
207213
labels = containers.labels
@@ -211,7 +217,7 @@ function point_all_centers!(containers, centroids, X, i)
211217
min_distance2 = T(Inf)
212218
label = 1
213219
@inbounds for k in axes(centroids, 2)
214-
dist = distance(X, centroids, i, k)
220+
dist = distance(metric, X, centroids, i, k)
215221
if min_distance > dist
216222
label = k
217223
min_distance2 = min_distance
@@ -228,6 +234,7 @@ function point_all_centers!(containers, centroids, X, i)
228234
return label
229235
end
230236

237+
231238
"""
232239
move_centers(::Hamerly, containers, centroids)
233240
@@ -249,6 +256,7 @@ function move_centers(::Hamerly, containers, centroids)
249256
end
250257
end
251258

259+
252260
"""
253261
chunk_update_bounds(alg::Hamerly, containers, r1, r2, pr1, pr2, r, idx)
254262
@@ -261,7 +269,7 @@ function chunk_update_bounds(alg::Hamerly, containers, r1, r2, pr1, pr2, r, idx)
261269
labels = containers.labels
262270
T = eltype(containers.ub)
263271

264-
# Since bounds are squred distance, `sqrt` is used to make corresponding estimation, unlike
272+
# Since bounds are squared distance, `sqrt` is used to make corresponding estimation, unlike
265273
# the original paper, where usual metric is used.
266274
#
267275
# Using notation from original paper, `u` is upper bound and `a` is `labels`, so
@@ -288,6 +296,7 @@ function chunk_update_bounds(alg::Hamerly, containers, r1, r2, pr1, pr2, r, idx)
288296
end
289297
end
290298

299+
291300
"""
292301
double_argmax(p)
293302

test/test05_hamerly.jl

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@ using ParallelKMeans
44
using ParallelKMeans: chunk_initialize, double_argmax
55
using Test
66
using StableRNGs
7+
using Random
8+
using Distances
79

810
@testset "initialize" begin
911
X = permutedims([1.0 2; 2 1; 4 5; 6 6])
1012
centroids = permutedims([1.0 2; 4 5; 6 6])
1113
nrow, ncol = size(X)
1214
containers = ParallelKMeans.create_containers(Hamerly(), X, 3, nrow, ncol, 1)
1315

14-
ParallelKMeans.chunk_initialize(Hamerly(), containers, centroids, X, nothing, 1:ncol, 1)
16+
ParallelKMeans.chunk_initialize(Hamerly(), containers, centroids, X, nothing, Euclidean(), 1:ncol, 1)
1517
@test containers.lb == [18.0, 20.0, 5.0, 5.0]
1618
@test containers.ub == [0.0, 2.0, 0.0, 0.0]
1719
end
@@ -101,4 +103,24 @@ end
101103
@test res.iterations == baseline.iterations
102104
end
103105

106+
107+
@testset "Hamerly metric support" begin
108+
Random.seed!(2020)
109+
X = [1. 2. 4.;]
110+
111+
res = kmeans(Hamerly(), X, 2; tol = 1e-16, metric=Cityblock())
112+
113+
@test res.assignments == [1, 1, 2]
114+
@test res.centers == [1.5 4.0]
115+
@test res.totalcost == 1.0
116+
@test res.converged
117+
118+
Random.seed!(2020)
119+
X = rand(3, 100)
120+
121+
res = kmeans(Hamerly(), X, 2, tol = 1e-16, metric=Cityblock())
122+
@test res.totalcost 62.04045252895372
123+
@test res.converged
124+
end
125+
104126
end # module

0 commit comments

Comments
 (0)