Skip to content

Commit 0344d23

Browse files
committed
mlj interface unittests ready for refactoring
1 parent 1359798 commit 0344d23

File tree

2 files changed

+28
-5
lines changed

2 files changed

+28
-5
lines changed

src/mlj_interface.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@ import Distances
1515
k::Int = 3::(_ > 0)
1616
tol::Float64 = 1e-6::(_ < 1)
1717
max_iters::Int = 300::(_ > 0)
18-
copy::Bool = true::(_ in (true, false))
18+
copy::Bool = true
1919
threads::Int = Threads.nthreads()::(_ > 0)
2020
verbosity::Int = 0::(_ in (0, 1)) # Temp fix. Do we need to follow mlj verbosity style?
2121
init = nothing
2222
end
2323

2424

2525
# Expose all instances of user specified structs and package artifcats.
26-
const KMeansModel = Union{KMeans}
26+
#const KMeansModel = Union{KMeans}
2727
const ParallelKMeans_Desc = "Parallel & lightning fast implementation of all variants of the KMeans clustering algorithm in native Julia."
2828

2929

@@ -72,7 +72,7 @@ end
7272
"""
7373
TODO 3.2: Docs
7474
"""
75-
function MLJModelInterface.fitted_params(model::KMeansModel, fitresult)
75+
function MLJModelInterface.fitted_params(model::KMeans, fitresult)
7676
# extract what's relevant from `fitresult`
7777
results, _, _ = fitresult # unpack fitresult
7878
centers = results.centers
@@ -111,7 +111,7 @@ end
111111
####
112112

113113
# TODO 4: metadata for the package and for each of the model interfaces
114-
metadata_pkg.(KMeansModel,
114+
metadata_pkg.(KMeans,
115115
name = "ParallelKMeans",
116116
uuid = "42b8e9d4-006b-409a-8472-7f34b3fb58af", # see your Project.toml
117117
url = "https://github.com/PyDataBlog/ParallelKMeans.jl", # URL to your package repo

test/test07_mlj_interface.jl

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using MLJBase
1111
@testset "Test struct construction" begin
1212
model = ParallelKMeans.KMeans()
1313

14-
@test model.algo == Lloyd()
14+
@test model.algo == :Lloyd
1515
@test model.init == nothing
1616
@test model.k == 3
1717
@test model.k_init == "k-means++"
@@ -24,17 +24,40 @@ end
2424

2525

2626
@testset "Test model fitting" begin
27+
X = table([1 2; 1 4; 1 0; 10 2; 10 4; 10 0])
28+
model = ParallelKMeans.KMeans(k=2)
29+
results = fit(model, 0, X)
2730

31+
@test results[2] == nothing
32+
@test results[end].converged == true
33+
@test results[end].totalcost == 16
2834
end
2935

3036

3137
@testset "Test fitted params" begin
38+
X = table([1 2; 1 4; 1 0; 10 2; 10 4; 10 0])
39+
model = ParallelKMeans.KMeans(k=2)
40+
results = fit(model, 0, X)
41+
42+
params = fitted_params(model, results)
43+
@test params.converged == true
44+
@test params.totalcost == 16
3245

3346
end
3447

3548

3649
@testset "Test transform" begin
50+
Random.seed!(2020)
51+
X = table([1 2; 1 4; 1 0; 10 2; 10 4; 10 0])
52+
X_test = table([10 1])
3753

54+
# Train model using training data X
55+
model = ParallelKMeans.KMeans(k=2)
56+
results = fit(model, 0, X)
57+
58+
# Use trained model to cluster new data X_test
59+
preds = transform(model, results, X_test)
60+
@test preds[:x1][1] == 2
3861
end
3962

4063

0 commit comments

Comments
 (0)