Skip to content

Commit a65192c

Browse files
committed
Added metric supported distance functions
1 parent 7cc6969 commit a65192c

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

src/ParallelKMeans.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using Random
55
import MLJModelInterface
66
import Base.Threads: @spawn
77
import Distances
8+
import Distances: Euclidean, evaluate
89

910
include("kmeans.jl")
1011
include("seeding.jl")

src/kmeans.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ macro parallelize(n_threads, ncol, f)
9696
end
9797
end
9898

99+
99100
"""
100101
distance(X1, X2, i1, i2)
101102
@@ -107,10 +108,35 @@ function distance(X1, X2, i1, i2)
107108
@inbounds @simd for i in axes(X1, 1)
108109
d += (X1[i, i1] - X2[i, i2])^2
109110
end
111+
return d
112+
end
113+
114+
115+
"""
116+
distance(metric, X1, X2, i1, i2)
117+
118+
Allocationless calculation of distance between vectors X1[:, i1] and X2[:, i2] defined by the supplied distance metric.
119+
"""
120+
distance(metric, X1, X2, i1, i2) = evaluate(metric, X1[:, i1], X2[:, i2])
121+
122+
123+
"""
124+
distance(X1, X2, i1, i2)
125+
126+
Allocationless calculation of square eucledean distance between vectors X1[:, i1] and X2[:, i2]
127+
"""
128+
function distance(metric::Euclidean, X1, X2, i1, i2)
129+
# here goes my definition
130+
d = zero(eltype(X1))
131+
# TODO: break of the loop if d is larger than threshold (known minimum disatnce)
132+
@inbounds @simd for i in axes(X1, 1)
133+
d += (X1[i, i1] - X2[i, i2])^2
134+
end
110135

111136
return d
112137
end
113138

139+
114140
"""
115141
sum_of_squares(x, labels, centre, k)
116142
@@ -129,6 +155,7 @@ function sum_of_squares(containers, x, labels, centre, weights, r, idx)
129155
containers.sum_of_squares[idx] = s
130156
end
131157

158+
132159
"""
133160
kmeans([alg::AbstractKMeansAlg,] design_matrix, k; n_threads = nthreads(),
134161
k_init="k-means++", max_iters=300, tol=1e-6, verbose=true, rng = Random.GLOBAL_RNG)

0 commit comments

Comments
 (0)