11# TODO 1: a using MLJModelInterface or import MLJModelInterface statement
22using MLJModelInterface
33using ParallelKMeans
4+ using Distances
45
56
67# ###
@@ -16,7 +17,7 @@ using ParallelKMeans
1617 max_iters:: Int = 300 :: (_ > 0)
1718 # transpose_type::String = "permute"::(_ in ("permute", "transpose"))
1819 threads:: Int = Threads. nthreads ():: (_ > 0)
19- verbosity:: Int = 0 :: (_ in (0, 1) )
20+ verbosity:: Int = 0 :: (_ in (0, 1) ) # Temp fix. Do we need to follow mlj verbosity style?
2021 init = nothing
2122end
2223
@@ -32,11 +33,13 @@ const ParallelKMeans_Desc = "Parallel & lightning fast implementation of all var
3233# ###
3334"""
3435 TODO 3.1: Docs
36+
37+ See also the [package documentation](https://pydatablog.github.io/ParallelKMeans.jl/stable).
3538"""
3639function MLJModelInterface. fit (m:: ParaKMeans , verbosity:: Int , X)
3740 # fit the specified struct as a ParaKMeans model
3841
39- # assumes user supplied table with columns as features
42+ # convert tabular input data into the matrix model expects. Column assumed as features so input data is permuted
4043 DMatrix = MLJModelInterface. matrix (X; transpose= true )
4144
4245 # fit model and get results
@@ -53,7 +56,9 @@ function MLJModelInterface.fit(m::ParaKMeans, verbosity::Int, X)
5356 end
5457
5558 cache = nothing
56- report = NamedTuple {} ()
59+ report = (cluster_centers= fitresult. centers, iterations= fitresult. iterations,
60+ converged= fitresult. converged, totalcost= fitresult. totalcost,
61+ labels= fitresult. assignments)
5762
5863 return (fitresult, cache, report)
5964end
6469"""
6570function MLJModelInterface. fitted_params (model:: KMeansModel , fitresult)
6671 # extract what's relevant from `fitresult`
67- centres = fitresult. centres
68- converged = fitresult. converged
69- iters = fitresult. iterations
70- totalcost = fitresult. totalcost
72+ results, _, _ = fitresult # unpack fitresult
73+ centers = results. centers
74+ converged = results. converged
75+ iters = results. iterations
76+ totalcost = results. totalcost
77+
7178 # then return as a NamedTuple
72- return (centres = centres, totalcost = totalcost, iterations = iters, converged = converged)
79+ return (cluster_centers = centers, totalcost = totalcost,
80+ iterations = iters, converged = converged)
7381end
7482
7583
7987"""
8088 TODO 3.3: Docs
8189"""
82- function MLJModelInterface. predict (m:: ParaKMeans , fitresult, Xnew)
83- # ...
90+ function MLJModelInterface. transform (m:: ParaKMeans , fitresult, Xnew)
91+ # make predictions/assignments using the learned centroids
92+ results = fitresult[1 ]
93+ DMatrix = MLJModelInterface. matrix (Xnew, transpose= true )
94+
95+ # TODO 3.3.1: Warn users if fitresult is from a `non-converged` fit.
96+ # use centroid matrix to assign clusters for new data
97+ centroids = results. centers
98+ distances = pairwise (SqEuclidean (), DMatrix, centroids; dims= 2 )
99+ preds = argmin .(eachrow (distances))
100+ return MLJModelInterface. table (reshape (preds, :, 1 ), prototype= Xnew)
84101end
85102
86103
0 commit comments