Skip to content

Commit 1367a04

Browse files
PyDataBlogPyDataBlog
authored andcommitted
WIP: Refactor metric & broken lloyd
1 parent d08932b commit 1367a04

File tree

4 files changed

+11
-10
lines changed

4 files changed

+11
-10
lines changed

src/kmeans.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ design matrix(x), centroids (centre), and the number of desired groups (k).
130130
131131
A Float type representing the computed metric is returned.
132132
"""
133-
function sum_of_squares(metric, containers, x, labels, centre, weights, r, idx)
133+
function sum_of_squares(containers, x, labels, centre, weights, metric, r, idx)
134134
s = zero(eltype(x))
135135

136136
@inbounds for i in r

src/lloyd.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ function kmeans!(alg::Lloyd, containers, X, k, weights;
3030

3131
# Update centroids & labels with closest members until convergence
3232
while niters <= max_iters
33-
@parallelize n_threads ncol chunk_update_centroids(alg, metric, containers, centroids, X, weights)
33+
@parallelize n_threads ncol chunk_update_centroids(alg, containers, centroids, X, weights, metric)
3434
collect_containers(alg, containers, centroids, n_threads)
3535
J = sum(containers.J)
3636

@@ -49,7 +49,7 @@ function kmeans!(alg::Lloyd, containers, X, k, weights;
4949
niters += 1
5050
end
5151

52-
@parallelize n_threads ncol sum_of_squares(metric, containers, X, containers.labels, centroids, weights)
52+
@parallelize n_threads ncol sum_of_squares(containers, X, containers.labels, centroids, weights, metric)
5353
totalcost = sum(containers.sum_of_squares)
5454

5555
# Terminate algorithm with the assumption that K-means has converged
@@ -104,7 +104,7 @@ function create_containers(::Lloyd, X, k, nrow, ncol, n_threads)
104104
end
105105

106106

107-
function chunk_update_centroids(::Lloyd, metric, containers, centroids, X, weights, r, idx)
107+
function chunk_update_centroids(::Lloyd, containers, centroids, X, weights, metric, r, idx)
108108
# unpack containers for easier manipulations
109109
centroids_new = containers.centroids_new[idx]
110110
centroids_cnt = containers.centroids_cnt[idx]
@@ -118,7 +118,7 @@ function chunk_update_centroids(::Lloyd, metric, containers, centroids, X, weigh
118118
min_dist = distance(metric, X, centroids, i, 1)
119119
label = 1
120120
for j in 2:size(centroids, 2)
121-
dist = distance(metric, X, centroids, i, j)
121+
dist = distance(metric, X, centroids, i, 1)
122122
label = dist < min_dist ? j : label
123123
min_dist = dist < min_dist ? dist : min_dist
124124
end

src/seeding.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ centroid vector `y[:, i]`.
99
UnitRange argument `r` select subarray of original design matrix `x` that is going
1010
to be processed.
1111
"""
12-
function chunk_colwise(metric, target, x, y, i, weights, r, idx)
12+
function chunk_colwise(target, x, y, i, weights, metric, r, idx)
1313
T = eltype(x)
1414
@inbounds for j in r
1515
dist = distance(metric, x, y, j, i)
@@ -18,6 +18,7 @@ function chunk_colwise(metric, target, x, y, i, weights, r, idx)
1818
end
1919
end
2020

21+
2122
"""
2223
smart_init(X, k; init="k-means++")
2324
@@ -56,7 +57,7 @@ function smart_init(X, k, n_threads = Threads.nthreads(), weights = nothing,
5657
distances = fill(T(Inf), ncol)
5758

5859
# compute distances from the first centroid chosen to all the other data points
59-
@parallelize n_threads ncol chunk_colwise(metric, distances, X, centroids, 1, weights)
60+
@parallelize n_threads ncol chunk_colwise(distances, X, centroids, 1, weights, metric)
6061
distances[rand_idx] = zero(T)
6162

6263
for i = 2:k
@@ -72,7 +73,7 @@ function smart_init(X, k, n_threads = Threads.nthreads(), weights = nothing,
7273
i == k && break
7374

7475
# compute distances from the centroids to all data points
75-
@parallelize n_threads ncol chunk_colwise(metric, distances, X, centroids, i, weights)
76+
@parallelize n_threads ncol chunk_colwise(distances, X, centroids, i, weights, metric)
7677

7778
distances[r_idx] = zero(T)
7879
end

test/test01_distance.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using Test
1010
r = fill(Inf, ncol)
1111
n_threads = 1
1212

13-
@parallelize n_threads ncol chunk_colwise(Euclidean(), r, X, y, 1, nothing)
13+
@parallelize n_threads ncol chunk_colwise(r, X, y, 1, nothing, Euclidean())
1414
@test all(r .≈ [0.0, 13.0, 25.0])
1515
end
1616

@@ -21,7 +21,7 @@ end
2121
r = fill(Inf, ncol)
2222
n_threads = 2
2323

24-
@parallelize n_threads ncol chunk_colwise(Euclidean(), r, X, y, 1, nothing)
24+
@parallelize n_threads ncol chunk_colwise(r, X, y, 1, nothing, Euclidean())
2525

2626
@test all(r .≈ [0.0, 13.0, 25.0])
2727
end

0 commit comments

Comments
 (0)