Skip to content

Commit 4f1e7ee

Browse files
committed
WIP: Verified metric support to lloyd
1 parent f646aae commit 4f1e7ee

File tree

3 files changed

+22
-3
lines changed

3 files changed

+22
-3
lines changed

src/kmeans.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,6 @@ A Float type representing the computed metric is returned.
132132
"""
133133
function sum_of_squares(containers, x, labels, centre, weights, metric, r, idx)
134134
s = zero(eltype(x))
135-
136135
@inbounds for i in r
137136
s += isnothing(weights) ? distance(metric, x, centre, i, labels[i]) : weights[i] * distance(metric, x, centre, i, labels[i])
138137
end
@@ -181,7 +180,7 @@ function kmeans(alg::AbstractKMeansAlg, design_matrix, k;
181180

182181
return kmeans!(alg, containers, design_matrix, k, weights, n_threads = n_threads,
183182
k_init = k_init, max_iters = max_iters, tol = tol,
184-
verbose = verbose, init = init, rng = rng, metric = Euclidean())
183+
verbose = verbose, init = init, rng = rng, metric = metric)
185184

186185
end
187186

src/lloyd.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ function kmeans!(alg::Lloyd, containers, X, k, weights;
4848
J_previous = J
4949
niters += 1
5050
end
51-
5251
@parallelize n_threads ncol sum_of_squares(containers, X, containers.labels, centroids, weights, metric)
5352
totalcost = sum(containers.sum_of_squares)
5453

test/test03_lloyd.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ using ParallelKMeans
44
using Test
55
using StableRNGs
66
using StatsBase
7+
using Distances
8+
79

810
@testset "basic kmeans" begin
911
X = [1. 2. 4.;]
@@ -101,4 +103,23 @@ end
101103
@test res.iterations == 6
102104
end
103105

106+
@testset "Lloyd metric support" begin
107+
Random.seed!(2020)
108+
X = [1. 2. 4.;]
109+
110+
res = kmeans(Lloyd(), X, 2; tol = 1e-16, metric=Cityblock())
111+
112+
@test res.assignments == [1, 1, 2]
113+
@test res.centers == [1.5 4.0]
114+
@test res.totalcost == 1.0
115+
@test res.converged
116+
117+
Random.seed!(2020)
118+
X = rand(3, 100)
119+
120+
res = kmeans(X, 2, tol = 1e-16, metric=Cityblock())
121+
@test res.totalcost 62.040452528953736
122+
@test res.converged
123+
end
124+
104125
end # module

0 commit comments

Comments
 (0)