Skip to content

Commit f6fd893

Browse files
committed
added more tests for coverage
1 parent 21c0a96 commit f6fd893

File tree

3 files changed

+61
-15
lines changed

3 files changed

+61
-15
lines changed

src/mlj_interface.jl

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ const ParallelKMeans_Desc = "Parallel & lightning fast implementation of all var
2727
# availalbe variants for reference
2828
const MLJDICT = Dict(:Lloyd => Lloyd(),
2929
:Hamerly => Hamerly(),
30-
:LightElkan => LightElkan())
30+
:LightElkan => LightElkan())
3131

3232
# TODO 3: implementation of fit, predict, and fitted_params of the model
3333
####
@@ -49,7 +49,7 @@ function MLJModelInterface.fit(m::KMeans, X)
4949
# tranposes input table as a column major matrix after making a copy of the data
5050
DMatrix = MLJModelInterface.matrix(X; transpose=true)
5151
end
52-
52+
5353
# lookup available algorithms
5454
algo = MLJDICT[m.algo] # select algo
5555

@@ -60,7 +60,7 @@ function MLJModelInterface.fit(m::KMeans, X)
6060
max_iters=m.max_iters, tol=m.tol, init=m.init,
6161
verbose=verbose)
6262
cache = nothing
63-
report = (cluster_centers=fitresult.centers, iterations=fitresult.iterations,
63+
report = (cluster_centers=fitresult.centers, iterations=fitresult.iterations,
6464
converged=fitresult.converged, totalcost=fitresult.totalcost,
6565
labels=fitresult.assignments)
6666

@@ -122,10 +122,8 @@ metadata_pkg.(KMeans,
122122

123123
# Metadata for ParaKMeans model interface
124124
metadata_model(KMeans,
125-
input = MLJModelInterface.Table(MLJModelInterface.Continuous), # what input data is supported?
126-
output = MLJModelInterface.Table(MLJModelInterface.Count), # for an unsupervised, what output?
127-
weights = false,
125+
input = MLJModelInterface.Table(MLJModelInterface.Continuous),
126+
output = MLJModelInterface.Table(MLJModelInterface.Count),
127+
weights = false,
128128
descr = ParallelKMeans_Desc,
129-
path = "ParallelKMeans.src.mlj_interface.KMeans"
130-
#path = "YourPackage.SubModuleContainingModelStructDefinition.YourModel1"
131-
)
129+
path = "ParallelKMeans.src.mlj_interface.KMeans")

test/test06_verbose.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ using Suppressor
1313
# Capture output and compare
1414
r = @capture_out kmeans(Lloyd(), X, 3; n_threads=1, max_iters=1, verbose=true)
1515
@test r == "Iteration 1: Jclust = 46.534795844478815\n"
16-
1716
end
1817

1918
end # module

test/test07_mlj_interface.jl

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,18 @@ using MLJBase
2323
end
2424

2525

26-
@testset "Test model fitting" begin
26+
@testset "Test model fitting verbosity" begin
27+
Random.seed!(2020)
28+
X = table([1 2; 1 4; 1 0; 10 2; 10 4; 10 0])
29+
model = ParallelKMeans.KMeans(k=2, max_iters=1, verbosity=1)
30+
results = @capture_out fit(model, X)
31+
32+
@test results == "Iteration 1: Jclust = 28.0\n"
33+
end
34+
35+
36+
@testset "Test Lloyd model fitting" begin
37+
Random.seed!(2020)
2738
X = table([1 2; 1 4; 1 0; 10 2; 10 4; 10 0])
2839
model = ParallelKMeans.KMeans(k=2)
2940
results = fit(model, X)
@@ -34,23 +45,47 @@ end
3445
end
3546

3647

37-
@testset "Test fitted params" begin
48+
@testset "Test Hamerly model fitting" begin
49+
Random.seed!(2020)
50+
X = table([1 2; 1 4; 1 0; 10 2; 10 4; 10 0])
51+
model = ParallelKMeans.KMeans(algo=:Hamerly, k=2)
52+
results = fit(model, X)
53+
54+
@test results[2] == nothing
55+
@test results[end].converged == true
56+
@test results[end].totalcost == 16
57+
end
58+
59+
60+
@testset "Test Lloyd fitted params" begin
61+
Random.seed!(2020)
3862
X = table([1 2; 1 4; 1 0; 10 2; 10 4; 10 0])
3963
model = ParallelKMeans.KMeans(k=2)
4064
results = fit(model, X)
4165

4266
params = fitted_params(model, results)
4367
@test params.converged == true
4468
@test params.totalcost == 16
45-
4669
end
4770

4871

49-
@testset "Test transform" begin
72+
@testset "Test Hamerly fitted params" begin
73+
Random.seed!(2020)
74+
X = table([1 2; 1 4; 1 0; 10 2; 10 4; 10 0])
75+
model = ParallelKMeans.KMeans(algo=:Hamerly, k=2)
76+
results = fit(model, X)
77+
78+
params = fitted_params(model, results)
79+
@test params.converged == true
80+
@test params.totalcost == 16
81+
end
82+
83+
84+
@testset "Test Lloyd transform" begin
5085
Random.seed!(2020)
5186
X = table([1 2; 1 4; 1 0; 10 2; 10 4; 10 0])
5287
X_test = table([10 1])
53-
88+
5489
# Train model using training data X
5590
model = ParallelKMeans.KMeans(k=2)
5691
results = fit(model, X)
@@ -61,4 +96,18 @@ end
6196
end
6297

6398

99+
@testset "Test Hamerly transform" begin
100+
Random.seed!(2020)
101+
X = table([1 2; 1 4; 1 0; 10 2; 10 4; 10 0])
102+
X_test = table([10 1])
103+
104+
# Train model using training data X
105+
model = ParallelKMeans.KMeans(algo=:Hamerly, k=2)
106+
results = fit(model, X)
107+
108+
# Use trained model to cluster new data X_test
109+
preds = transform(model, results, X_test)
110+
@test preds[:x1][1] == 2
111+
end
112+
64113
end # end module

0 commit comments

Comments
 (0)