Skip to content

Commit 4fd5f95

Browse files
committed
WIP: Faster metric support for Lloyd
1 parent 4f1e7ee commit 4fd5f95

File tree

4 files changed

+4
-2
lines changed

4 files changed

+4
-2
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
88
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
99
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1010
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
11+
UnsafeArrays = "c4a57d5a-5b31-53a6-b365-19f8c011fbd6"
1112

1213
[compat]
1314
Distances = "0.8.2"

src/ParallelKMeans.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module ParallelKMeans
22

33
using StatsBase
44
using Random
5+
using UnsafeArrays
56
import MLJModelInterface
67
import Base.Threads: @spawn
78
import Distances

src/kmeans.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ end
102102
103103
Allocationless calculation of distance between vectors X1[:, i1] and X2[:, i2] defined by the supplied distance metric.
104104
"""
105-
distance(metric, X1, X2, i1, i2) = evaluate(metric, X1[:, i1], X2[:, i2])
105+
@inline distance(metric, X1, X2, i1, i2) = evaluate(metric, uview(X1, :, i1), uview(X2, :, i2))
106106

107107

108108
"""

test/test03_lloyd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ end
118118
X = rand(3, 100)
119119

120120
res = kmeans(X, 2, tol = 1e-16, metric=Cityblock())
121-
@test res.totalcost 62.040452528953736
121+
@test res.totalcost 62.04045252895372
122122
@test res.converged
123123
end
124124

0 commit comments

Comments
 (0)