Skip to content

Commit 7cd46e6

Browse files
committed
tested draft of mlj interface ready for feedback
1 parent 98c2689 commit 7cd46e6

File tree

2 files changed

+35
-9
lines changed

2 files changed

+35
-9
lines changed

src/mlj_interface.jl

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,34 @@ end
3838

3939
function MLJModelInterface.clean!(m::KMeans)
4040
warning = ""
41+
4142
if !(m.algo keys(MLJDICT))
42-
warning *= "Unsuppored algorithm supplied. Please check documentation for supported Kmeans algorithms."
43+
warning *= "Unsuppored algorithm supplied. Defauting to KMeans++ seeding algorithm."
44+
m.algo = :Lloyd
45+
46+
elseif m.k_init != "k-means++"
47+
warning *= "Only `k-means++` or random seeding algorithms are supported. Defaulting to random seeding."
48+
m.k_init = "random"
49+
4350
elseif m.k < 1
44-
warning *= "Number of clusters must be greater than 0."
51+
warning *= "Number of clusters must be greater than 0. Defaulting to 3 clusters."
52+
m.k = 3
53+
4554
elseif !(m.tol < 1.0)
46-
warning *= "Tolerance level must be less than 1."
55+
warning *= "Tolerance level must be less than 1. Defaulting to tol of 1e-6."
56+
m.tol = 1e-6
57+
4758
elseif !(m.max_iters > 0)
48-
warning *= "Number of permitted iterations must be greater than 0."
59+
warning *= "Number of permitted iterations must be greater than 0. Defaulting to 300 iterations."
60+
m.max_iters = 300
61+
4962
elseif !(m.threads > 0)
50-
warning *= "Number of threads must be at least 1."
63+
warning *= "Number of threads must be at least 1. Defaulting to all threads available."
64+
m.threads = Threads.nthreads()
65+
5166
elseif !(m.verbosity (0, 1))
52-
warning *= "Verbosity must be either 0 (no info) or 1 (info requested)"
67+
warning *= "Verbosity must be either 0 (no info) or 1 (info requested). Defaulting to 0."
68+
m.verbosity = 0
5369
end
5470
return warning
5571
end
@@ -61,12 +77,11 @@ end
6177
####
6278
"""
6379
TODO 3.1: Docs
80+
# fit the specified struct as a ParaKMeans model
6481
6582
See also the [package documentation](https://pydatablog.github.io/ParallelKMeans.jl/stable).
6683
"""
6784
function MLJModelInterface.fit(m::KMeans, X)
68-
# fit the specified struct as a ParaKMeans model
69-
7085
# convert tabular input data into the matrix model expects. Column assumed as features so input data is permuted
7186
if !m.copy
7287
# transpose input table without copying and pass to model
@@ -152,4 +167,4 @@ metadata_model(KMeans,
152167
output = MLJModelInterface.Table(MLJModelInterface.Count),
153168
weights = false,
154169
descr = ParallelKMeans_Desc,
155-
path = "ParallelKMeans.src.mlj_interface.KMeans")
170+
path = "ParallelKMeans.KMeans")

test/test07_mlj_interface.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,17 @@ using MLJBase
2323
end
2424

2525

26+
@testset "Test bad struct warings" begin
27+
@test_logs (:warn, "Unsuppored algorithm supplied. Defauting to KMeans++ seeding algorithm.") ParallelKMeans.KMeans(algo=:Fake)
28+
@test_logs (:warn, "Only `k-means++` or random seeding algorithms are supported. Defaulting to random seeding.") ParallelKMeans.KMeans(k_init="abc")
29+
@test_logs (:warn, "Number of clusters must be greater than 0. Defaulting to 3 clusters.") ParallelKMeans.KMeans(k=0)
30+
@test_logs (:warn, "Tolerance level must be less than 1. Defaulting to tol of 1e-6.") ParallelKMeans.KMeans(tol=2)
31+
@test_logs (:warn, "Number of permitted iterations must be greater than 0. Defaulting to 300 iterations.") ParallelKMeans.KMeans(max_iters=0)
32+
@test_logs (:warn, "Number of threads must be at least 1. Defaulting to all threads available.") ParallelKMeans.KMeans(threads=0)
33+
@test_logs (:warn, "Verbosity must be either 0 (no info) or 1 (info requested). Defaulting to 0.") ParallelKMeans.KMeans(verbosity=100)
34+
end
35+
36+
2637
@testset "Test model fitting verbosity" begin
2738
Random.seed!(2020)
2839
X = table([1 2; 1 4; 1 0; 10 2; 10 4; 10 0])

0 commit comments

Comments
 (0)