11# TODO 1: a using MLJModelInterface or import MLJModelInterface statement
22using MLJModelInterface
33using ParallelKMeans
4- using Distances
4+ import Distances
55
66
77# ###
88# ### MODEL DEFINITION
99# ###
1010# TODO 2: MLJ-compatible model types and constructors
11- @mlj_model mutable struct ParaKMeans <: MLJModelInterface.Unsupervised
11+ @mlj_model mutable struct KMeans <: MLJModelInterface.Unsupervised
1212 # Hyperparameters of the model
13- algo:: ParallelKMeans.AbstractKMeansAlg = Lloyd () :: (_ in (Lloyd() , Hamerly (), LightElkan () ))
13+ algo:: Symbol = : Lloyd:: (_ in (: Lloyd, : Hamerly, : LightElkan) )
1414 k_init:: String = " k-means++" :: (_ in ("k-means++", String) ) # allow user seeding?
1515 k:: Int = 3 :: (_ > 0)
1616 tol:: Float64 = 1e-6 :: (_ < 1)
1717 max_iters:: Int = 300 :: (_ > 0)
18- # transpose_type::String = "permute" ::(_ in ("permute", "transpose" ))
18+ copy :: Bool = true :: (_ in (true, false ) )
1919 threads:: Int = Threads. nthreads ():: (_ > 0)
2020 verbosity:: Int = 0 :: (_ in (0, 1) ) # Temp fix. Do we need to follow mlj verbosity style?
2121 init = nothing
2222end
2323
2424
2525# Expose all instances of user specified structs and package artifcats.
26- const KMeansModel = Union{ParaKMeans }
26+ const KMeansModel = Union{KMeans }
2727const ParallelKMeans_Desc = " Parallel & lightning fast implementation of all variants of the KMeans clustering algorithm in native Julia."
2828
2929
@@ -36,25 +36,30 @@ const ParallelKMeans_Desc = "Parallel & lightning fast implementation of all var
3636
3737 See also the [package documentation](https://pydatablog.github.io/ParallelKMeans.jl/stable).
3838"""
39- function MLJModelInterface. fit (m:: ParaKMeans , verbosity:: Int , X)
39+ function MLJModelInterface. fit (m:: KMeans , verbosity:: Int , X)
4040 # fit the specified struct as a ParaKMeans model
4141
4242 # convert tabular input data into the matrix model expects. Column assumed as features so input data is permuted
43- DMatrix = MLJModelInterface. matrix (X; transpose= true )
44-
45- # fit model and get results
46- if m. verbosity > 0
47- fitresult = ParallelKMeans. kmeans (m. algo, DMatrix, m. k;
48- n_threads = m. threads, k_init= m. k_init,
49- max_iters= m. max_iters, tol= m. tol, init= m. init,
50- verbose= true )
43+ if ! m. copy
44+ # transpose input table without copying and pass to model
45+ DMatrix = convert (Array{Float64, 2 }, X)'
5146 else
52- fitresult = ParallelKMeans. kmeans (m. algo, DMatrix, m. k;
53- n_threads = m. threads, k_init= m. k_init,
54- max_iters= m. max_iters, tol= m. tol, init= m. init,
55- verbose= false )
47+ # tranposes input table as a column major matrix after making a copy of the data
48+ DMatrix = MLJModelInterface. matrix (X; transpose= true )
5649 end
50+
51+ # lookup available algorithms
52+ algos = Dict (:Lloyd => Lloyd (),
53+ :Hamerly => Hamerly (),
54+ :LightElkan => LightElkan ())
55+ algo = algos[m. algo] # select algo
5756
57+ # fit model and get results
58+ verbose = m. verbosity != 0
59+ fitresult = ParallelKMeans. kmeans (algo, DMatrix, m. k;
60+ n_threads = m. threads, k_init= m. k_init,
61+ max_iters= m. max_iters, tol= m. tol, init= m. init,
62+ verbose= verbose)
5863 cache = nothing
5964 report = (cluster_centers= fitresult. centers, iterations= fitresult. iterations,
6065 converged= fitresult. converged, totalcost= fitresult. totalcost,
8792"""
8893 TODO 3.3: Docs
8994"""
90- function MLJModelInterface. transform (m:: ParaKMeans , fitresult, Xnew)
95+ function MLJModelInterface. transform (m:: KMeans , fitresult, Xnew)
9196 # make predictions/assignments using the learned centroids
9297 results = fitresult[1 ]
9398 DMatrix = MLJModelInterface. matrix (Xnew, transpose= true )
9499
95100 # TODO 3.3.1: Warn users if fitresult is from a `non-converged` fit.
96101 # use centroid matrix to assign clusters for new data
97102 centroids = results. centers
98- distances = pairwise (SqEuclidean (), DMatrix, centroids; dims= 2 )
103+ distances = Distances . pairwise (Distances . SqEuclidean (), DMatrix, centroids; dims= 2 )
99104 preds = argmin .(eachrow (distances))
100105 return MLJModelInterface. table (reshape (preds, :, 1 ), prototype= Xnew)
101106end
@@ -117,11 +122,11 @@ metadata_pkg.(KMeansModel,
117122
118123
119124# Metadata for ParaKMeans model interface
120- metadata_model (ParaKMeans ,
121- input = MLJModelInterface. Table (MLJModelInterface. Continuous), # what input data is supported? # for a supervised model, what target?
125+ metadata_model (KMeans ,
126+ input = MLJModelInterface. Table (MLJModelInterface. Continuous), # what input data is supported?
122127 output = MLJModelInterface. Table (MLJModelInterface. Count), # for an unsupervised, what output?
123- weights = false , # does the model support sample weights?
128+ weights = false ,
124129 descr = ParallelKMeans_Desc,
125- path = " ParallelKMeans.src.mlj_interface.ParaKMeans "
130+ path = " ParallelKMeans.src.mlj_interface.KMeans "
126131 # path = "YourPackage.SubModuleContainingModelStructDefinition.YourModel1"
127132 )
0 commit comments