Skip to content

Commit 8168acb

Browse files
committed
WIP: Unverified non-euclidean metric support for Elkan
1 parent 964fe70 commit 8168acb

File tree

1 file changed

+21
-3
lines changed

1 file changed

+21
-3
lines changed

src/elkan.jl

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ function kmeans!(alg::Elkan, containers, X, k, weights=nothing, metric=Euclidean
5252
calculate_centroids_movement(alg, containers, centroids, metric)
5353

5454
# lower and ounds update, in paper it's steps 5 and 6
55-
@parallelize n_threads ncol chunk_update_bounds(alg, containers, centroids)
55+
@parallelize n_threads ncol chunk_update_bounds(alg, containers, centroids, metric)
5656

5757
# Step 7, final assignment of new centroids
5858
centroids .= containers.centroids_new[end]
@@ -259,14 +259,32 @@ function calculate_centroids_movement(alg::Elkan, containers, centroids, metric)
259259
end
260260

261261

262-
function chunk_update_bounds(alg, containers, centroids, r, idx)
262+
function chunk_update_bounds(alg, containers, centroids, metric::Euclidean, r, idx)
263263
p = containers.p
264264
lb = containers.lb
265265
ub = containers.ub
266266
stale = containers.stale
267267
labels = containers.labels
268268
T = eltype(centroids)
269-
# TODO: Add metric support with multiple dispatch
269+
270+
@inbounds for i in r
271+
for j in axes(centroids, 2)
272+
lb[j, i] = lb[j, i] > p[j] ? lb[j, i] + p[j] - T(2)*sqrt(abs(lb[j, i]*p[j])) : zero(T)
273+
end
274+
stale[i] = true
275+
ub[i] += p[labels[i]] + T(2)*sqrt(abs(ub[i]*p[labels[i]]))
276+
end
277+
end
278+
279+
280+
function chunk_update_bounds(alg, containers, centroids, metric::Metric, r, idx)
281+
p = containers.p
282+
lb = containers.lb
283+
ub = containers.ub
284+
stale = containers.stale
285+
labels = containers.labels
286+
T = eltype(centroids)
287+
# TODO: Update the metric support for non eucledian metric
270288
@inbounds for i in r
271289
for j in axes(centroids, 2)
272290
lb[j, i] = lb[j, i] > p[j] ? lb[j, i] + p[j] - T(2)*sqrt(abs(lb[j, i]*p[j])) : zero(T)

0 commit comments

Comments
 (0)