Skip to content

Commit 2767ec7

Browse files
committed
MLJ Interface WIP 1
1 parent 2d54b7c commit 2767ec7

File tree

1 file changed

+61
-2
lines changed

1 file changed

+61
-2
lines changed

src/mlj_interface.jl

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,71 @@
22
using MLJModelInterface
33
using ParallelKMeans
44

5+
6+
####
7+
#### MODEL DEFINITION
8+
####
59
# TODO 2: MLJ-compatible model types and constructors,
10+
@mlj_model mutable struct ParaKMeans <: MLJModelInterface.Unsupervised
11+
# Hyperparameters of the model
12+
algo::AbstractKMeansAlg = Lloyd::(_ in (Lloyd, Hamerly, Elkan))
13+
k_init::String = "k-means++"::(_ in ("k-means++", String))
14+
k::Int = 3::(_ > 0)
15+
tol::Float = 1e-6::(_ < 1)
16+
max_iters::Int = 300::(_ > 0)
17+
end
18+
19+
20+
# TODO 3: implementation of fit, predict, and fitted_params of the model
21+
####
22+
#### FIT FUNCTION
23+
####
24+
25+
function MLJModelInterface.fit(m::ParaKMeans, verbosity::Int, X, y, w=nothing)
26+
# body ...
27+
return (fitresult, cache, report)
28+
end
29+
30+
31+
function MLJModelInterface.fitted_params(model::ParaKMeans, fitresult)
32+
# extract what's relevant from `fitresult`
33+
# ...
34+
# then return as a NamedTuple
35+
return (learned_param1 = ..., learned_param2 = ...)
36+
end
37+
38+
39+
####
40+
#### PREDICT FUNCTION
41+
####
42+
function MLJModelInterface.predict(m::ParaKMeans, fitresult, Xnew)
43+
# ...
44+
end
645

7-
# TODO 3: implementation of fit, predict/transform and optionally fitted_params for your models,
846

9-
# TODO 4: metadata for your package and for each of your models
47+
####
48+
#### METADATA
49+
####
1050

51+
# TODO 4: metadata for the package and for each of your models
52+
const PARAKMEANS_MODELS = Union{ParaKMeans}
1153

54+
metadata_pkg.(PARAKMEANS_MODELS,
55+
name = "ParallelKMeans",
56+
uuid = "42b8e9d4-006b-409a-8472-7f34b3fb58af", # see your Project.toml
57+
url = "https://github.com/PyDataBlog/ParallelKMeans.jl", # URL to your package repo
58+
julia = true, # is it written entirely in Julia?
59+
license = "MIT", # your package license
60+
is_wrapper = false, # does it wrap around some other package?
61+
)
1262

1363

64+
# Metadata for ParaKMeans model
65+
metadata_model(ParaKMeans,
66+
input = MLJModelInterface.Table(MLJModelInterface.Continuous), # what input data is supported? # for a supervised model, what target?
67+
output = MLJModelInterface.Table(MLJModelInterface.Count), # for an unsupervised, what output?
68+
weights = false, # does the model support sample weights?
69+
descr = "Parallel & lightning fast implementation of all variants of the KMeans clustering algorithm in native Julia.",
70+
path = "ParallelKMeans.src.mlj_interface.ParaKMeans"
71+
#path = "YourPackage.SubModuleContainingModelStructDefinition.YourModel1"
72+
)

0 commit comments

Comments
 (0)