@@ -6,39 +6,79 @@ using ParallelKMeans
66# ###
77# ### MODEL DEFINITION
88# ###
9- # TODO 2: MLJ-compatible model types and constructors,
9+ # TODO 2: MLJ-compatible model types and constructors
1010@mlj_model mutable struct ParaKMeans <: MLJModelInterface.Unsupervised
1111 # Hyperparameters of the model
12- algo:: AbstractKMeansAlg = Lloyd:: (_ in (Lloyd, Hamerly, Elkan) )
13- k_init:: String = " k-means++" :: (_ in ("k-means++", String) )
14- k:: Int = 3 :: (_ > 0)
15- tol:: Float = 1e-6 :: (_ < 1)
16- max_iters:: Int = 300 :: (_ > 0)
12+ algo:: ParallelKMeans.AbstractKMeansAlg = Lloyd ():: (_ in (Lloyd() , Hamerly (), LightElkan ()))
13+ k_init:: String = " k-means++" :: (_ in ("k-means++", String) ) # allow user seeding?
14+ k:: Int = 3 :: (_ > 0)
15+ tol:: Float64 = 1e-6 :: (_ < 1)
16+ max_iters:: Int = 300 :: (_ > 0)
17+ # transpose_type::String = "permute"::(_ in ("permute", "transpose"))
18+ threads:: Int = Threads. nthreads ():: (_ > 0)
19+ verbosity:: Int = 0 :: (_ in (0, 1) )
20+ init = nothing
1721end
1822
1923
24+ # Expose all instances of user specified structs and package artifcats.
25+ const KMeansModel = Union{ParaKMeans}
26+ const ParallelKMeans_Desc = " Parallel & lightning fast implementation of all variants of the KMeans clustering algorithm in native Julia."
27+
28+
2029# TODO 3: implementation of fit, predict, and fitted_params of the model
2130# ###
2231# ### FIT FUNCTION
2332# ###
33+ """
34+ TODO 3.1: Docs
35+ """
36+ function MLJModelInterface. fit (m:: ParaKMeans , verbosity:: Int , X)
37+ # fit the specified struct as a ParaKMeans model
38+
39+ # assumes user supplied table with columns as features
40+ DMatrix = MLJModelInterface. matrix (X; transpose= true )
41+
42+ # fit model and get results
43+ if m. verbosity > 0
44+ fitresult = ParallelKMeans. kmeans (m. algo, DMatrix, m. k;
45+ n_threads = m. threads, k_init= m. k_init,
46+ max_iters= m. max_iters, tol= m. tol, init= m. init,
47+ verbose= true )
48+ else
49+ fitresult = ParallelKMeans. kmeans (m. algo, DMatrix, m. k;
50+ n_threads = m. threads, k_init= m. k_init,
51+ max_iters= m. max_iters, tol= m. tol, init= m. init,
52+ verbose= false )
53+ end
54+
55+ cache = nothing
56+ report = NamedTuple {} ()
2457
25- function MLJModelInterface. fit (m:: ParaKMeans , verbosity:: Int , X, y, w= nothing )
26- # body ...
2758 return (fitresult, cache, report)
2859end
2960
3061
31- function MLJModelInterface. fitted_params (model:: ParaKMeans , fitresult)
62+ """
63+ TODO 3.2: Docs
64+ """
65+ function MLJModelInterface. fitted_params (model:: KMeansModel , fitresult)
3266 # extract what's relevant from `fitresult`
33- # ...
67+ centres = fitresult. centres
68+ converged = fitresult. converged
69+ iters = fitresult. iterations
70+ totalcost = fitresult. totalcost
3471 # then return as a NamedTuple
35- return (learned_param1 = ... , learned_param2 = ... )
72+ return (centres = centres, totalcost = totalcost, iterations = iters, converged = converged )
3673end
3774
3875
3976# ###
4077# ### PREDICT FUNCTION
4178# ###
79+ """
80+ TODO 3.3: Docs
81+ """
4282function MLJModelInterface. predict (m:: ParaKMeans , fitresult, Xnew)
4383 # ...
4484end
4888# ### METADATA
4989# ###
5090
51- # TODO 4: metadata for the package and for each of your models
52- const PARAKMEANS_MODELS = Union{ParaKMeans}
53-
54- metadata_pkg .(PARAKMEANS_MODELS,
91+ # TODO 4: metadata for the package and for each of the model interfaces
92+ metadata_pkg .(KMeansModel,
5593 name = " ParallelKMeans" ,
5694 uuid = " 42b8e9d4-006b-409a-8472-7f34b3fb58af" , # see your Project.toml
5795 url = " https://github.com/PyDataBlog/ParallelKMeans.jl" , # URL to your package repo
@@ -61,12 +99,12 @@ metadata_pkg.(PARAKMEANS_MODELS,
6199)
62100
63101
64- # Metadata for ParaKMeans model
102+ # Metadata for ParaKMeans model interface
65103metadata_model (ParaKMeans,
66104 input = MLJModelInterface. Table (MLJModelInterface. Continuous), # what input data is supported? # for a supervised model, what target?
67105 output = MLJModelInterface. Table (MLJModelInterface. Count), # for an unsupervised, what output?
68106 weights = false , # does the model support sample weights?
69- descr = " Parallel & lightning fast implementation of all variants of the KMeans clustering algorithm in native Julia. " ,
107+ descr = ParallelKMeans_Desc ,
70108 path = " ParallelKMeans.src.mlj_interface.ParaKMeans"
71109 # path = "YourPackage.SubModuleContainingModelStructDefinition.YourModel1"
72110 )
0 commit comments