@@ -47,22 +47,24 @@ Yinyang(; group_size = 7, auto = true) = Yinyang(auto, group_size)
4747阴阳 (group_size:: Int ) = Yinyang (true , group_size)
4848阴阳 (; group_size = 7 , auto = true ) = Yinyang (auto, group_size)
4949
50- function kmeans! (alg:: Yinyang , containers, X, k, weights;
50+ function kmeans! (alg:: Yinyang , containers, X, k, weights, metric = Euclidean () ;
5151 n_threads = Threads. nthreads (),
5252 k_init = " k-means++" , max_iters = 300 ,
5353 tol = 1e-6 , verbose = false ,
5454 init = nothing , rng = Random. GLOBAL_RNG)
55+
5556 nrow, ncol = size (X)
57+
5658 centroids = init == nothing ? smart_init (X, k, n_threads, weights, rng, init= k_init). centroids : deepcopy (init)
5759
5860 # create initial groups of centers, step 1 in original paper
5961 initialize (alg, containers, centroids, rng, n_threads)
6062 # construct initial bounds, step 2
61- @parallelize n_threads ncol chunk_initialize (alg, containers, centroids, X, weights)
63+ @parallelize n_threads ncol chunk_initialize (alg, containers, centroids, X, weights, metric )
6264 collect_containers (alg, containers, n_threads)
6365
6466 # update centers and calculate drifts. Step 3.1 of the original paper.
65- calculate_centroids_movement (alg, containers, centroids)
67+ calculate_centroids_movement (alg, containers, centroids, metric )
6668
6769 T = eltype (X)
6870 converged = false
@@ -87,14 +89,14 @@ function kmeans!(alg::Yinyang, containers, X, k, weights;
8789
8890 # push!(containers.debug, [0, 0, 0])
8991 # Core calculation of the Yinyang, 3.2-3.3 steps of the original paper
90- @parallelize n_threads ncol chunk_update_centroids (alg, containers, centroids, X, weights)
92+ @parallelize n_threads ncol chunk_update_centroids (alg, containers, centroids, X, weights, metric )
9193 collect_containers (alg, containers, n_threads)
9294
9395 # update centers and calculate drifts. Step 3.1 of the original paper.
94- calculate_centroids_movement (alg, containers, centroids)
96+ calculate_centroids_movement (alg, containers, centroids, metric )
9597 end
9698
97- @parallelize n_threads ncol sum_of_squares (containers, X, containers. labels, centroids, weights)
99+ @parallelize n_threads ncol sum_of_squares (containers, X, containers. labels, centroids, weights, metric )
98100 totalcost = sum (containers. sum_of_squares)
99101
100102 # Terminate algorithm with the assumption that K-means has converged
@@ -170,6 +172,7 @@ function create_containers(alg::Yinyang, X, k, nrow, ncol, n_threads)
170172 )
171173end
172174
175+
173176function initialize (alg:: Yinyang , containers, centroids, rng, n_threads)
174177 groups = containers. groups
175178 indices = containers. indices
@@ -186,21 +189,23 @@ function initialize(alg::Yinyang, containers, centroids, rng, n_threads)
186189 end
187190end
188191
189- function chunk_initialize (alg:: Yinyang , containers, centroids, X, weights, r, idx)
192+
193+ function chunk_initialize (alg:: Yinyang , containers, centroids, X, weights, metric, r, idx)
190194 T = eltype (X)
191195 centroids_cnt = containers. centroids_cnt[idx]
192196 centroids_new = containers. centroids_new[idx]
193197
194198 @inbounds for i in r
195- label = point_all_centers! (alg, containers, centroids, X, i)
199+ label = point_all_centers! (alg, containers, centroids, X, i, metric )
196200 centroids_cnt[label] += isnothing (weights) ? one (T) : weights[i]
197201 for j in axes (X, 1 )
198202 centroids_new[j, label] += isnothing (weights) ? X[j, i] : weights[i] * X[j, i]
199203 end
200204 end
201205end
202206
203- function calculate_centroids_movement (alg:: Yinyang , containers, centroids)
207+
208+ function calculate_centroids_movement (alg:: Yinyang , containers, centroids, metric)
204209 p = containers. p
205210 groups = containers. groups
206211 gd = containers. gd
@@ -210,7 +215,7 @@ function calculate_centroids_movement(alg::Yinyang, containers, centroids)
210215 @inbounds for (gi, ri) in enumerate (groups)
211216 max_drift = T (- Inf )
212217 for i in ri
213- p[i] = sqrt (distance (centroids, centroids_new, i, i))
218+ p[i] = sqrt (distance (metric, centroids, centroids_new, i, i))
214219 max_drift = p[i] > max_drift ? p[i] : max_drift
215220
216221 # Should do it more elegantly
@@ -222,7 +227,8 @@ function calculate_centroids_movement(alg::Yinyang, containers, centroids)
222227 end
223228end
224229
225- function chunk_update_centroids (alg:: Yinyang , containers, centroids, X, weights, r, idx)
230+
231+ function chunk_update_centroids (alg:: Yinyang , containers, centroids, X, weights, metric, r, idx)
226232 # unpack containers for easier manipulations
227233 centroids_new = containers. centroids_new[idx]
228234 centroids_cnt = containers. centroids_cnt[idx]
@@ -256,7 +262,7 @@ function chunk_update_centroids(alg::Yinyang, containers, centroids, X, weights,
256262
257263 # tighten upper bound
258264 label = labels[i]
259- ubx = sqrt (distance (X, centroids, i, label))
265+ ubx = sqrt (distance (metric, X, centroids, i, label))
260266 ub[i] = ubx
261267 ubx <= lbx && continue
262268
@@ -275,7 +281,7 @@ function chunk_update_centroids(alg::Yinyang, containers, centroids, X, weights,
275281 ((c == old_label) | (ubx < old_lb - p[c])) && continue
276282 mask[c] = true
277283 # containers.debug[end][2] += 1 # local filter update
278- dist = distance (X, centroids, i, c)
284+ dist = distance (metric, X, centroids, i, c)
279285 if dist < ubx2
280286 new_lb2 = ubx2
281287 ubx2 = dist
@@ -290,7 +296,7 @@ function chunk_update_centroids(alg::Yinyang, containers, centroids, X, weights,
290296 mask[c] && continue
291297 new_lb < old_lb - p[c] && continue
292298 # containers.debug[end][3] += 1 # lower bound update
293- dist = distance (X, centroids, i, c)
299+ dist = distance (metric, X, centroids, i, c)
294300 if dist < new_lb2
295301 new_lb2 = dist
296302 new_lb = sqrt (new_lb2)
@@ -314,7 +320,7 @@ function chunk_update_centroids(alg::Yinyang, containers, centroids, X, weights,
314320 ubx < old_lb - p[c] && continue
315321 # containers.debug[end][2] += 1 # local filter update
316322 mask[c] = true
317- dist = distance (X, centroids, i, c)
323+ dist = distance (metric, X, centroids, i, c)
318324 if dist < ubx2
319325 # closest center was in previous cluster
320326 if indices[label] != gi
@@ -336,7 +342,7 @@ function chunk_update_centroids(alg::Yinyang, containers, centroids, X, weights,
336342 mask[c] && continue
337343 new_lb < old_lb - p[c] && continue
338344 # containers.debug[end][3] += 1 # lower bound update
339- dist = distance (X, centroids, i, c)
345+ dist = distance (metric, X, centroids, i, c)
340346 if dist < new_lb2
341347 new_lb2 = dist
342348 new_lb = sqrt (new_lb2)
@@ -360,12 +366,13 @@ function chunk_update_centroids(alg::Yinyang, containers, centroids, X, weights,
360366 end
361367end
362368
369+
363370"""
364371 point_all_centers!(containers, centroids, X, i)
365372
366373Calculates new labels and upper and lower bounds for all points.
367374"""
368- function point_all_centers! (alg:: Yinyang , containers, centroids, X, i)
375+ function point_all_centers! (alg:: Yinyang , containers, centroids, X, i, metric )
369376 ub = containers. ub
370377 lb = containers. lb
371378 labels = containers. labels
@@ -381,7 +388,7 @@ function point_all_centers!(alg::Yinyang, containers, centroids, X, i)
381388 group_min_distance2 = T (Inf )
382389 group_label = ri[1 ]
383390 for k in ri
384- dist = distance (X, centroids, i, k)
391+ dist = distance (metric, X, centroids, i, k)
385392 if group_min_distance > dist
386393 group_label = k
387394 group_min_distance2 = group_min_distance
@@ -407,6 +414,7 @@ function point_all_centers!(alg::Yinyang, containers, centroids, X, i)
407414 return label
408415end
409416
417+
410418# I believe there should be oneliner for it
411419function rangify (x)
412420 res = UnitRange{Int}[]
0 commit comments