@@ -22,11 +22,12 @@ function kmeans!(alg::Hamerly, containers, X, k, weights;
2222 n_threads = Threads. nthreads (),
2323 k_init = " k-means++" , max_iters = 300 ,
2424 tol = eltype (X)(1e-6 ), verbose = false ,
25- init = nothing , rng = Random. GLOBAL_RNG)
25+ init = nothing , rng = Random. GLOBAL_RNG, metric= Euclidean ())
26+
2627 nrow, ncol = size (X)
2728 centroids = init == nothing ? smart_init (X, k, n_threads, weights, rng, init= k_init). centroids : deepcopy (init)
2829
29- @parallelize n_threads ncol chunk_initialize (alg, containers, centroids, X, weights)
30+ @parallelize n_threads ncol chunk_initialize (alg, containers, centroids, X, weights, metric )
3031
3132 T = eltype (X)
3233 converged = false
@@ -37,8 +38,8 @@ function kmeans!(alg::Hamerly, containers, X, k, weights;
3738 # Update centroids & labels with closest members until convergence
3839 while niters < max_iters
3940 niters += 1
40- update_containers (alg, containers, centroids, n_threads)
41- @parallelize n_threads ncol chunk_update_centroids (alg, containers, centroids, X, weights)
41+ update_containers (alg, containers, centroids, n_threads, metric )
42+ @parallelize n_threads ncol chunk_update_centroids (alg, containers, centroids, X, weights, metric )
4243 collect_containers (alg, containers, n_threads)
4344
4445 J = sum (containers. ub)
@@ -61,7 +62,7 @@ function kmeans!(alg::Hamerly, containers, X, k, weights;
6162 J_previous = J
6263 end
6364
64- @parallelize n_threads ncol sum_of_squares (containers, X, containers. labels, centroids, weights)
65+ @parallelize n_threads ncol sum_of_squares (containers, X, containers. labels, centroids, weights, metric )
6566 totalcost = sum (containers. sum_of_squares)
6667
6768 # Terminate algorithm with the assumption that K-means has converged
@@ -75,6 +76,7 @@ function kmeans!(alg::Hamerly, containers, X, k, weights;
7576 return KmeansResult (centroids, containers. labels, T[], Int[], T[], totalcost, niters, converged)
7677end
7778
79+
7880function create_containers (alg:: Hamerly , X, k, nrow, ncol, n_threads)
7981 T = eltype (X)
8082 lng = n_threads + 1
@@ -115,52 +117,55 @@ function create_containers(alg::Hamerly, X, k, nrow, ncol, n_threads)
115117 )
116118end
117119
120+
118121"""
119122 chunk_initialize(alg::Hamerly, containers, centroids, design_matrix, r, idx)
120123
121124Initial calulation of all bounds and points labeling.
122125"""
123- function chunk_initialize (alg:: Hamerly , containers, centroids, X, weights, r, idx)
126+ function chunk_initialize (alg:: Hamerly , containers, centroids, X, weights, metric, r, idx)
124127 T = eltype (X)
125128 centroids_cnt = containers. centroids_cnt[idx]
126129 centroids_new = containers. centroids_new[idx]
127130
128131 @inbounds for i in r
129- label = point_all_centers! (containers, centroids, X, i)
132+ label = point_all_centers! (containers, centroids, X, i, metric )
130133 centroids_cnt[label] += isnothing (weights) ? one (T) : weights[i]
131134 for j in axes (X, 1 )
132135 centroids_new[j, label] += isnothing (weights) ? X[j, i] : weights[i] * X[j, i]
133136 end
134137 end
135138end
136139
140+
137141"""
138142 update_containers(::Hamerly, containers, centroids, n_threads)
139143
140144Calculates minimum distances from centers to each other.
141145"""
142- function update_containers (:: Hamerly , containers, centroids, n_threads)
146+ function update_containers (:: Hamerly , containers, centroids, n_threads, metric )
143147 T = eltype (centroids)
144148 s = containers. s
145149 s .= T (Inf )
146150 @inbounds for i in axes (centroids, 2 )
147151 for j in i+ 1 : size (centroids, 2 )
148- d = distance (centroids, centroids, i, j)
152+ d = distance (metric, centroids, centroids, i, j)
149153 d = T (0.25 )* d
150154 s[i] = s[i] > d ? d : s[i]
151155 s[j] = s[j] > d ? d : s[j]
152156 end
153157 end
154158end
155159
160+
156161"""
157162 chunk_update_centroids(::Hamerly, containers, centroids, X, r, idx)
158163
159164Detailed description of this function can be found in the original paper. It iterates through
160165all points and tries to skip some calculation using known upper and lower bounds of distances
161166from point to centers. If it fails to skip than it fall back to generic `point_all_centers!` function.
162167"""
163- function chunk_update_centroids (alg:: Hamerly , containers, centroids, X, weights, r, idx)
168+ function chunk_update_centroids (alg:: Hamerly , containers, centroids, X, weights, metric, r, idx)
164169
165170 # unpack containers for easier manipulations
166171 centroids_new = containers. centroids_new[idx]
@@ -178,10 +183,10 @@ function chunk_update_centroids(alg::Hamerly, containers, centroids, X, weights,
178183 if ub[i] > m
179184 # tighten upper bound
180185 label = labels[i]
181- ub[i] = distance (X, centroids, i, label)
186+ ub[i] = distance (metric, X, centroids, i, label)
182187 # second bound test
183188 if ub[i] > m
184- label_new = point_all_centers! (containers, centroids, X, i)
189+ label_new = point_all_centers! (containers, centroids, X, i, metric )
185190 if label != label_new
186191 labels[i] = label_new
187192 centroids_cnt[label_new] += isnothing (weights) ? one (T) : weights[i]
@@ -196,12 +201,13 @@ function chunk_update_centroids(alg::Hamerly, containers, centroids, X, weights,
196201 end
197202end
198203
204+
199205"""
200206 point_all_centers!(containers, centroids, X, i)
201207
202208Calculates new labels and upper and lower bounds for all points.
203209"""
204- function point_all_centers! (containers, centroids, X, i)
210+ function point_all_centers! (containers, centroids, X, i, metric )
205211 ub = containers. ub
206212 lb = containers. lb
207213 labels = containers. labels
@@ -211,7 +217,7 @@ function point_all_centers!(containers, centroids, X, i)
211217 min_distance2 = T (Inf )
212218 label = 1
213219 @inbounds for k in axes (centroids, 2 )
214- dist = distance (X, centroids, i, k)
220+ dist = distance (metric, X, centroids, i, k)
215221 if min_distance > dist
216222 label = k
217223 min_distance2 = min_distance
@@ -228,6 +234,7 @@ function point_all_centers!(containers, centroids, X, i)
228234 return label
229235end
230236
237+
231238"""
232239 move_centers(::Hamerly, containers, centroids)
233240
@@ -249,6 +256,7 @@ function move_centers(::Hamerly, containers, centroids)
249256 end
250257end
251258
259+
252260"""
253261 chunk_update_bounds(alg::Hamerly, containers, r1, r2, pr1, pr2, r, idx)
254262
@@ -261,7 +269,7 @@ function chunk_update_bounds(alg::Hamerly, containers, r1, r2, pr1, pr2, r, idx)
261269 labels = containers. labels
262270 T = eltype (containers. ub)
263271
264- # Since bounds are squred distance, `sqrt` is used to make corresponding estimation, unlike
272+ # Since bounds are squared distance, `sqrt` is used to make corresponding estimation, unlike
265273 # the original paper, where usual metric is used.
266274 #
267275 # Using notation from original paper, `u` is upper bound and `a` is `labels`, so
@@ -288,6 +296,7 @@ function chunk_update_bounds(alg::Hamerly, containers, r1, r2, pr1, pr2, r, idx)
288296 end
289297end
290298
299+
291300"""
292301 double_argmax(p)
293302
0 commit comments