Skip to content

Commit eb6852a

Browse files
committed
mlj interface draft done. untested wip
1 parent 73fa639 commit eb6852a

File tree

2 files changed

+30
-12
lines changed

2 files changed

+30
-12
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,19 @@ authors = ["Bernard Brenyah", "Andrey Oskin"]
44
version = "0.1.0"
55

66
[deps]
7+
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
78
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
89
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
910

1011
[compat]
1112
StatsBase = "0.32, 0.33"
12-
julia = "1.3"
13+
julia = "1.3, 1.4"
1314

1415
[extras]
16+
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
1517
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1618
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
1719
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
18-
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
1920

2021
[targets]
2122
test = ["Test", "Random", "Suppressor", "MLJBase"]

src/mlj_interface.jl

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# TODO 1: a using MLJModelInterface or import MLJModelInterface statement
22
using MLJModelInterface
33
using ParallelKMeans
4+
using Distances
45

56

67
####
@@ -16,7 +17,7 @@ using ParallelKMeans
1617
max_iters::Int = 300::(_ > 0)
1718
#transpose_type::String = "permute"::(_ in ("permute", "transpose"))
1819
threads::Int = Threads.nthreads()::(_ > 0)
19-
verbosity::Int = 0::(_ in (0, 1))
20+
verbosity::Int = 0::(_ in (0, 1)) # Temp fix. Do we need to follow mlj verbosity style?
2021
init = nothing
2122
end
2223

@@ -32,11 +33,13 @@ const ParallelKMeans_Desc = "Parallel & lightning fast implementation of all var
3233
####
3334
"""
3435
TODO 3.1: Docs
36+
37+
See also the [package documentation](https://pydatablog.github.io/ParallelKMeans.jl/stable).
3538
"""
3639
function MLJModelInterface.fit(m::ParaKMeans, verbosity::Int, X)
3740
# fit the specified struct as a ParaKMeans model
3841

39-
# assumes user supplied table with columns as features
42+
# convert tabular input data into the matrix model expects. Column assumed as features so input data is permuted
4043
DMatrix = MLJModelInterface.matrix(X; transpose=true)
4144

4245
# fit model and get results
@@ -53,7 +56,9 @@ function MLJModelInterface.fit(m::ParaKMeans, verbosity::Int, X)
5356
end
5457

5558
cache = nothing
56-
report = NamedTuple{}()
59+
report = (cluster_centers=fitresult.centers, iterations=fitresult.iterations,
60+
converged=fitresult.converged, totalcost=fitresult.totalcost,
61+
labels=fitresult.assignments)
5762

5863
return (fitresult, cache, report)
5964
end
@@ -64,12 +69,15 @@ end
6469
"""
6570
function MLJModelInterface.fitted_params(model::KMeansModel, fitresult)
6671
# extract what's relevant from `fitresult`
67-
centres = fitresult.centres
68-
converged = fitresult.converged
69-
iters = fitresult.iterations
70-
totalcost = fitresult.totalcost
72+
results, _, _ = fitresult # unpack fitresult
73+
centers = results.centers
74+
converged = results.converged
75+
iters = results.iterations
76+
totalcost = results.totalcost
77+
7178
# then return as a NamedTuple
72-
return (centres = centres, totalcost = totalcost, iterations = iters, converged = converged)
79+
return (cluster_centers = centers, totalcost = totalcost,
80+
iterations = iters, converged = converged)
7381
end
7482

7583

@@ -79,8 +87,17 @@ end
7987
"""
8088
TODO 3.3: Docs
8189
"""
82-
function MLJModelInterface.predict(m::ParaKMeans, fitresult, Xnew)
83-
# ...
90+
function MLJModelInterface.transform(m::ParaKMeans, fitresult, Xnew)
91+
# make predictions/assignments using the learned centroids
92+
results = fitresult[1]
93+
DMatrix = MLJModelInterface.matrix(Xnew, transpose=true)
94+
95+
# TODO 3.3.1: Warn users if fitresult is from a `non-converged` fit.
96+
# use centroid matrix to assign clusters for new data
97+
centroids = results.centers
98+
distances = pairwise(SqEuclidean(), DMatrix, centroids; dims=2)
99+
preds = argmin.(eachrow(distances))
100+
return MLJModelInterface.table(reshape(preds, :, 1), prototype=Xnew)
84101
end
85102

86103

0 commit comments

Comments
 (0)