Skip to content

Commit 98c2689

Browse files
committed
draft of interface manual struct conversion
1 parent 2a564dc commit 98c2689

File tree

2 files changed

+47
-17
lines changed

2 files changed

+47
-17
lines changed

src/mlj_interface.jl

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,59 @@
11
# TODO 1: a using MLJModelInterface or import MLJModelInterface statement
2+
# Expose all instances of user specified structs and package artifcats.
3+
const ParallelKMeans_Desc = "Parallel & lightning fast implementation of all variants of the KMeans clustering algorithm in native Julia."
4+
5+
# availalbe variants for reference
6+
const MLJDICT = Dict(:Lloyd => Lloyd(),
7+
:Hamerly => Hamerly(),
8+
:LightElkan => LightElkan())
29

310
####
411
#### MODEL DEFINITION
512
####
613
# TODO 2: MLJ-compatible model types and constructors
7-
@mlj_model mutable struct KMeans <: MLJModelInterface.Unsupervised
8-
# Hyperparameters of the model
9-
algo::Symbol = :Lloyd::(_ in (:Lloyd, :Hamerly, :LightElkan))
10-
k_init::String = "k-means++"::(_ in ("k-means++", String)) # allow user seeding?
11-
k::Int = 3::(_ > 0)
12-
tol::Float64 = 1e-6::(_ < 1)
13-
max_iters::Int = 300::(_ > 0)
14-
copy::Bool = true
15-
threads::Int = Threads.nthreads()::(_ > 0)
16-
verbosity::Int = 0::(_ in (0, 1)) # Temp fix. Do we need to follow mlj verbosity style?
17-
init = nothing
14+
15+
mutable struct KMeans <: MLJModelInterface.Unsupervised
16+
algo::Symbol
17+
k_init::String
18+
k::Int
19+
tol::Float64
20+
max_iters::Int
21+
copy::Bool
22+
threads::Int
23+
verbosity::Int
24+
init
1825
end
1926

2027

21-
# Expose all instances of user specified structs and package artifcats.
22-
const ParallelKMeans_Desc = "Parallel & lightning fast implementation of all variants of the KMeans clustering algorithm in native Julia."
28+
function KMeans(; algo=:Lloyd, k_init="k-means++",
29+
k=3, tol=1e-6, max_iters=300, copy=true,
30+
threads=Threads.nthreads(), verbosity=0, init=nothing)
31+
32+
model = KMeans(algo, k_init, k, tol, max_iters, copy, threads, verbosity, init)
33+
message = MLJModelInterface.clean!(model)
34+
isempty(message) || @warn message
35+
return model
36+
end
37+
38+
39+
function MLJModelInterface.clean!(m::KMeans)
40+
warning = ""
41+
if !(m.algo keys(MLJDICT))
42+
warning *= "Unsuppored algorithm supplied. Please check documentation for supported Kmeans algorithms."
43+
elseif m.k < 1
44+
warning *= "Number of clusters must be greater than 0."
45+
elseif !(m.tol < 1.0)
46+
warning *= "Tolerance level must be less than 1."
47+
elseif !(m.max_iters > 0)
48+
warning *= "Number of permitted iterations must be greater than 0."
49+
elseif !(m.threads > 0)
50+
warning *= "Number of threads must be at least 1."
51+
elseif !(m.verbosity (0, 1))
52+
warning *= "Verbosity must be either 0 (no info) or 1 (info requested)"
53+
end
54+
return warning
55+
end
2356

24-
# availalbe variants for reference
25-
const MLJDICT = Dict(:Lloyd => Lloyd(),
26-
:Hamerly => Hamerly(),
27-
:LightElkan => LightElkan())
2857

2958
# TODO 3: implementation of fit, predict, and fitted_params of the model
3059
####

test/test07_mlj_interface.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module TestMLJInterface
22

33
using ParallelKMeans
4+
using ParallelKMeans: KMeans
45
using Random
56
using Test
67
using Suppressor

0 commit comments

Comments
 (0)