Skip to content

Commit 73fa639

Browse files
committed
MLJ port WIP v2
1 parent 2767ec7 commit 73fa639

File tree

2 files changed

+56
-18
lines changed

2 files changed

+56
-18
lines changed

src/kmeans.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ end
173173

174174

175175
"""
176-
Kmeans!(alg::AbstractKMeansAlg, containers, design_matrix, k; n_threads = nthreads(), k_init="k-means++", max_iters=300, tol=1e-6, verbose=true)
176+
Kmeans!(alg::AbstractKMeansAlg, containers, design_matrix, k; n_threads = nthreads(), k_init="k-means++", max_iters=300, tol=1e-6, verbose=false)
177177
178178
Mutable version of `kmeans` function. Definition of arguments and results can be
179179
found in `kmeans`.

src/mlj_interface.jl

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,39 +6,79 @@ using ParallelKMeans
66
####
77
#### MODEL DEFINITION
88
####
9-
# TODO 2: MLJ-compatible model types and constructors,
9+
# TODO 2: MLJ-compatible model types and constructors
1010
@mlj_model mutable struct ParaKMeans <: MLJModelInterface.Unsupervised
1111
# Hyperparameters of the model
12-
algo::AbstractKMeansAlg = Lloyd::(_ in (Lloyd, Hamerly, Elkan))
13-
k_init::String = "k-means++"::(_ in ("k-means++", String))
14-
k::Int = 3::(_ > 0)
15-
tol::Float = 1e-6::(_ < 1)
16-
max_iters::Int = 300::(_ > 0)
12+
algo::ParallelKMeans.AbstractKMeansAlg = Lloyd()::(_ in (Lloyd(), Hamerly(), LightElkan()))
13+
k_init::String = "k-means++"::(_ in ("k-means++", String)) # allow user seeding?
14+
k::Int = 3::(_ > 0)
15+
tol::Float64 = 1e-6::(_ < 1)
16+
max_iters::Int = 300::(_ > 0)
17+
#transpose_type::String = "permute"::(_ in ("permute", "transpose"))
18+
threads::Int = Threads.nthreads()::(_ > 0)
19+
verbosity::Int = 0::(_ in (0, 1))
20+
init = nothing
1721
end
1822

1923

24+
# Expose all instances of user specified structs and package artifcats.
25+
const KMeansModel = Union{ParaKMeans}
26+
const ParallelKMeans_Desc = "Parallel & lightning fast implementation of all variants of the KMeans clustering algorithm in native Julia."
27+
28+
2029
# TODO 3: implementation of fit, predict, and fitted_params of the model
2130
####
2231
#### FIT FUNCTION
2332
####
33+
"""
34+
TODO 3.1: Docs
35+
"""
36+
function MLJModelInterface.fit(m::ParaKMeans, verbosity::Int, X)
37+
# fit the specified struct as a ParaKMeans model
38+
39+
# assumes user supplied table with columns as features
40+
DMatrix = MLJModelInterface.matrix(X; transpose=true)
41+
42+
# fit model and get results
43+
if m.verbosity > 0
44+
fitresult = ParallelKMeans.kmeans(m.algo, DMatrix, m.k;
45+
n_threads = m.threads, k_init=m.k_init,
46+
max_iters=m.max_iters, tol=m.tol, init=m.init,
47+
verbose=true)
48+
else
49+
fitresult = ParallelKMeans.kmeans(m.algo, DMatrix, m.k;
50+
n_threads = m.threads, k_init=m.k_init,
51+
max_iters=m.max_iters, tol=m.tol, init=m.init,
52+
verbose=false)
53+
end
54+
55+
cache = nothing
56+
report = NamedTuple{}()
2457

25-
function MLJModelInterface.fit(m::ParaKMeans, verbosity::Int, X, y, w=nothing)
26-
# body ...
2758
return (fitresult, cache, report)
2859
end
2960

3061

31-
function MLJModelInterface.fitted_params(model::ParaKMeans, fitresult)
62+
"""
63+
TODO 3.2: Docs
64+
"""
65+
function MLJModelInterface.fitted_params(model::KMeansModel, fitresult)
3266
# extract what's relevant from `fitresult`
33-
# ...
67+
centres = fitresult.centres
68+
converged = fitresult.converged
69+
iters = fitresult.iterations
70+
totalcost = fitresult.totalcost
3471
# then return as a NamedTuple
35-
return (learned_param1 = ..., learned_param2 = ...)
72+
return (centres = centres, totalcost = totalcost, iterations = iters, converged = converged)
3673
end
3774

3875

3976
####
4077
#### PREDICT FUNCTION
4178
####
79+
"""
80+
TODO 3.3: Docs
81+
"""
4282
function MLJModelInterface.predict(m::ParaKMeans, fitresult, Xnew)
4383
# ...
4484
end
@@ -48,10 +88,8 @@ end
4888
#### METADATA
4989
####
5090

51-
# TODO 4: metadata for the package and for each of your models
52-
const PARAKMEANS_MODELS = Union{ParaKMeans}
53-
54-
metadata_pkg.(PARAKMEANS_MODELS,
91+
# TODO 4: metadata for the package and for each of the model interfaces
92+
metadata_pkg.(KMeansModel,
5593
name = "ParallelKMeans",
5694
uuid = "42b8e9d4-006b-409a-8472-7f34b3fb58af", # see your Project.toml
5795
url = "https://github.com/PyDataBlog/ParallelKMeans.jl", # URL to your package repo
@@ -61,12 +99,12 @@ metadata_pkg.(PARAKMEANS_MODELS,
6199
)
62100

63101

64-
# Metadata for ParaKMeans model
102+
# Metadata for ParaKMeans model interface
65103
metadata_model(ParaKMeans,
66104
input = MLJModelInterface.Table(MLJModelInterface.Continuous), # what input data is supported? # for a supervised model, what target?
67105
output = MLJModelInterface.Table(MLJModelInterface.Count), # for an unsupervised, what output?
68106
weights = false, # does the model support sample weights?
69-
descr = "Parallel & lightning fast implementation of all variants of the KMeans clustering algorithm in native Julia.",
107+
descr = ParallelKMeans_Desc,
70108
path = "ParallelKMeans.src.mlj_interface.ParaKMeans"
71109
#path = "YourPackage.SubModuleContainingModelStructDefinition.YourModel1"
72110
)

0 commit comments

Comments
 (0)