Skip to content

Commit 1359798

Browse files
committed
initiated unittesting of mlj interface
1 parent eb6852a commit 1359798

File tree

4 files changed

+71
-24
lines changed

4 files changed

+71
-24
lines changed

src/ParallelKMeans.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ include("kmeans.jl")
88
include("lloyd.jl")
99
include("light_elkan.jl")
1010
include("hamerly.jl")
11+
include("mlj_interface.jl")
1112

1213
export kmeans
1314
export Lloyd, LightElkan, Hamerly

src/mlj_interface.jl

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,29 @@
11
# TODO 1: a using MLJModelInterface or import MLJModelInterface statement
22
using MLJModelInterface
33
using ParallelKMeans
4-
using Distances
4+
import Distances
55

66

77
####
88
#### MODEL DEFINITION
99
####
1010
# TODO 2: MLJ-compatible model types and constructors
11-
@mlj_model mutable struct ParaKMeans <: MLJModelInterface.Unsupervised
11+
@mlj_model mutable struct KMeans <: MLJModelInterface.Unsupervised
1212
# Hyperparameters of the model
13-
algo::ParallelKMeans.AbstractKMeansAlg = Lloyd()::(_ in (Lloyd(), Hamerly(), LightElkan()))
13+
algo::Symbol = :Lloyd::(_ in (:Lloyd, :Hamerly, :LightElkan))
1414
k_init::String = "k-means++"::(_ in ("k-means++", String)) # allow user seeding?
1515
k::Int = 3::(_ > 0)
1616
tol::Float64 = 1e-6::(_ < 1)
1717
max_iters::Int = 300::(_ > 0)
18-
#transpose_type::String = "permute"::(_ in ("permute", "transpose"))
18+
copy::Bool = true::(_ in (true, false))
1919
threads::Int = Threads.nthreads()::(_ > 0)
2020
verbosity::Int = 0::(_ in (0, 1)) # Temp fix. Do we need to follow mlj verbosity style?
2121
init = nothing
2222
end
2323

2424

2525
# Expose all instances of user specified structs and package artifcats.
26-
const KMeansModel = Union{ParaKMeans}
26+
const KMeansModel = Union{KMeans}
2727
const ParallelKMeans_Desc = "Parallel & lightning fast implementation of all variants of the KMeans clustering algorithm in native Julia."
2828

2929

@@ -36,25 +36,30 @@ const ParallelKMeans_Desc = "Parallel & lightning fast implementation of all var
3636
3737
See also the [package documentation](https://pydatablog.github.io/ParallelKMeans.jl/stable).
3838
"""
39-
function MLJModelInterface.fit(m::ParaKMeans, verbosity::Int, X)
39+
function MLJModelInterface.fit(m::KMeans, verbosity::Int, X)
4040
# fit the specified struct as a ParaKMeans model
4141

4242
# convert tabular input data into the matrix model expects. Column assumed as features so input data is permuted
43-
DMatrix = MLJModelInterface.matrix(X; transpose=true)
44-
45-
# fit model and get results
46-
if m.verbosity > 0
47-
fitresult = ParallelKMeans.kmeans(m.algo, DMatrix, m.k;
48-
n_threads = m.threads, k_init=m.k_init,
49-
max_iters=m.max_iters, tol=m.tol, init=m.init,
50-
verbose=true)
43+
if !m.copy
44+
# transpose input table without copying and pass to model
45+
DMatrix = convert(Array{Float64, 2}, X)'
5146
else
52-
fitresult = ParallelKMeans.kmeans(m.algo, DMatrix, m.k;
53-
n_threads = m.threads, k_init=m.k_init,
54-
max_iters=m.max_iters, tol=m.tol, init=m.init,
55-
verbose=false)
47+
# tranposes input table as a column major matrix after making a copy of the data
48+
DMatrix = MLJModelInterface.matrix(X; transpose=true)
5649
end
50+
51+
# lookup available algorithms
52+
algos = Dict(:Lloyd => Lloyd(),
53+
:Hamerly => Hamerly(),
54+
:LightElkan => LightElkan())
55+
algo = algos[m.algo] # select algo
5756

57+
# fit model and get results
58+
verbose = m.verbosity != 0
59+
fitresult = ParallelKMeans.kmeans(algo, DMatrix, m.k;
60+
n_threads = m.threads, k_init=m.k_init,
61+
max_iters=m.max_iters, tol=m.tol, init=m.init,
62+
verbose=verbose)
5863
cache = nothing
5964
report = (cluster_centers=fitresult.centers, iterations=fitresult.iterations,
6065
converged=fitresult.converged, totalcost=fitresult.totalcost,
@@ -87,15 +92,15 @@ end
8792
"""
8893
TODO 3.3: Docs
8994
"""
90-
function MLJModelInterface.transform(m::ParaKMeans, fitresult, Xnew)
95+
function MLJModelInterface.transform(m::KMeans, fitresult, Xnew)
9196
# make predictions/assignments using the learned centroids
9297
results = fitresult[1]
9398
DMatrix = MLJModelInterface.matrix(Xnew, transpose=true)
9499

95100
# TODO 3.3.1: Warn users if fitresult is from a `non-converged` fit.
96101
# use centroid matrix to assign clusters for new data
97102
centroids = results.centers
98-
distances = pairwise(SqEuclidean(), DMatrix, centroids; dims=2)
103+
distances = Distances.pairwise(Distances.SqEuclidean(), DMatrix, centroids; dims=2)
99104
preds = argmin.(eachrow(distances))
100105
return MLJModelInterface.table(reshape(preds, :, 1), prototype=Xnew)
101106
end
@@ -117,11 +122,11 @@ metadata_pkg.(KMeansModel,
117122

118123

119124
# Metadata for ParaKMeans model interface
120-
metadata_model(ParaKMeans,
121-
input = MLJModelInterface.Table(MLJModelInterface.Continuous), # what input data is supported? # for a supervised model, what target?
125+
metadata_model(KMeans,
126+
input = MLJModelInterface.Table(MLJModelInterface.Continuous), # what input data is supported?
122127
output = MLJModelInterface.Table(MLJModelInterface.Count), # for an unsupervised, what output?
123-
weights = false, # does the model support sample weights?
128+
weights = false,
124129
descr = ParallelKMeans_Desc,
125-
path = "ParallelKMeans.src.mlj_interface.ParaKMeans"
130+
path = "ParallelKMeans.src.mlj_interface.KMeans"
126131
#path = "YourPackage.SubModuleContainingModelStructDefinition.YourModel1"
127132
)

test/test07_mlj_interface.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
module TestMLJInterface
2+
3+
using MLJModelInterface
4+
using ParallelKMeans
5+
using Random
6+
using Test
7+
using Suppressor
8+
using MLJBase
9+
10+
11+
@testset "Test struct construction" begin
12+
model = ParallelKMeans.KMeans()
13+
14+
@test model.algo == Lloyd()
15+
@test model.init == nothing
16+
@test model.k == 3
17+
@test model.k_init == "k-means++"
18+
@test model.max_iters == 300
19+
@test model.copy == true
20+
@test model.threads == Threads.nthreads()
21+
@test model.tol == 1.0e-6
22+
@test model.verbosity == 0
23+
end
24+
25+
26+
@testset "Test model fitting" begin
27+
28+
end
29+
30+
31+
@testset "Test fitted params" begin
32+
33+
end
34+
35+
36+
@testset "Test transform" begin
37+
38+
end
39+
40+
41+
end # end module

0 commit comments

Comments
 (0)