Skip to content

Commit 964fe70

Browse files
committed
Moved metric internally as a positional arg
1 parent c2edd87 commit 964fe70

File tree

4 files changed

+22
-3
lines changed

4 files changed

+22
-3
lines changed

src/elkan.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ kmeans(Elkan(), X, 3) # 3 clusters, Elkan algorithm
1919
struct Elkan <: AbstractKMeansAlg end
2020

2121

22-
function kmeans!(alg::Elkan, containers, X, k, weights;
22+
function kmeans!(alg::Elkan, containers, X, k, weights=nothing, metric=Euclidean();
2323
n_threads = Threads.nthreads(),
2424
k_init = "k-means++", max_iters = 300,
2525
tol = eltype(X)(1e-6), verbose = false,

src/hamerly.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ kmeans(Hamerly(), X, 3) # 3 clusters, Hamerly algorithm
1818
struct Hamerly <: AbstractKMeansAlg end
1919

2020

21-
function kmeans!(alg::Hamerly, containers, X, k, weights;
21+
function kmeans!(alg::Hamerly, containers, X, k, weights=nothing, metric=Euclidean();
2222
n_threads = Threads.nthreads(),
2323
k_init = "k-means++", max_iters = 300,
2424
tol = eltype(X)(1e-6), verbose = false,

src/kmeans.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,19 +168,36 @@ alternatively one can use `rand` to choose random points for init.
168168
169169
A `KmeansResult` structure representing labels, centroids, and sum_squares is returned.
170170
"""
171+
<<<<<<< HEAD
171172
function kmeans(alg::AbstractKMeansAlg, design_matrix, k;
172173
weights = nothing,
173174
n_threads = Threads.nthreads(),
174175
k_init = "k-means++", max_iters = 300,
175176
tol = eltype(design_matrix)(1e-6), verbose = false,
176177
init = nothing, rng = Random.GLOBAL_RNG, metric = Euclidean())
177178

179+
=======
180+
function kmeans(alg::AbstractKMeansAlg, design_matrix, k, weights = nothing;
181+
n_threads = Threads.nthreads(), k_init = "k-means++", max_iters = 300,
182+
tol = eltype(design_matrix)(1e-6), verbose = false, init = nothing, metric=Euclidean())
183+
184+
# Get dimensions of the input data
185+
>>>>>>> Moved metric internally as a positional arg
178186
nrow, ncol = size(design_matrix)
187+
188+
# Create containers based on the dimensions and specifications
179189
containers = create_containers(alg, design_matrix, k, nrow, ncol, n_threads)
180190

191+
<<<<<<< HEAD
181192
return kmeans!(alg, containers, design_matrix, k, weights, n_threads = n_threads,
182193
k_init = k_init, max_iters = max_iters, tol = tol,
183194
verbose = verbose, init = init, rng = rng, metric = metric)
195+
=======
196+
# Dispatch based on the specified algorithm
197+
return kmeans!(alg, containers, design_matrix, k, weights, metric;
198+
n_threads = n_threads, k_init = k_init, max_iters = max_iters,
199+
tol = tol, verbose = verbose, init = init)
200+
>>>>>>> Moved metric internally as a positional arg
184201

185202
end
186203

src/lloyd.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@ found in `kmeans`.
1414
Argument `containers` represent algorithm specific containers, such as labels, intermidiate
1515
centroids and so on, which are used during calculations.
1616
"""
17-
function kmeans!(alg::Lloyd, containers, X, k, weights;
17+
function kmeans!(alg::Lloyd, containers, X, k, weights=nothing, metric=Euclidean();
1818
n_threads = Threads.nthreads(),
1919
k_init = "k-means++", max_iters = 300,
2020
tol = eltype(design_matrix)(1e-6), verbose = false,
2121
init = nothing, rng = Random.GLOBAL_RNG, metric=Euclidean())
2222

23+
# Get dimensions of the input data
2324
nrow, ncol = size(X)
2425
centroids = isnothing(init) ? smart_init(X, k, n_threads, weights, rng, init=k_init).centroids : deepcopy(init)
2526

@@ -48,6 +49,7 @@ function kmeans!(alg::Lloyd, containers, X, k, weights;
4849
J_previous = J
4950
niters += 1 # TODO: Investigate the potential bug in number of iterations
5051
end
52+
5153
@parallelize n_threads ncol sum_of_squares(containers, X, containers.labels, centroids, weights, metric)
5254
totalcost = sum(containers.sum_of_squares)
5355

0 commit comments

Comments
 (0)