Skip to content

Commit e3f8a67

Browse files
torfjeldedevmotion
andauthored
Added default impls for filldist and arraydist (#264)
* added default impls for `filldist` and `arraydist` * Apply suggestions from code review Co-authored-by: David Widmann <devmotion@users.noreply.github.com> * Update filldist.jl (#265) --------- Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
1 parent b8c3d82 commit e3f8a67

File tree

3 files changed

+18
-2
lines changed

3 files changed

+18
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DistributionsAD"
22
uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
3-
version = "0.6.54"
3+
version = "0.6.55"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/arraydist.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
"""
2+
arraydist(dists)
3+
4+
Create a distribution from an array of distributions.
5+
"""
6+
arraydist(dists::AbstractArray{<:Distribution}) = product_distribution(dists)
7+
18
# Univariate
29

310
const VectorOfUnivariate = Distributions.Product

src/filldist.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
1+
# Default implementation just defers to Distributions.jl.
2+
"""
3+
filldist(d::Distribution, ns...)
4+
5+
Create a product distribution using `FillArrays.Fill` as the array type.
6+
"""
7+
filldist(d::Distribution, n1::Int, ns::Int...) = product_distribution(Fill(d, n1, ns...))
8+
19
# Univariate
210

11+
# TODO: Do we even need these? Probably should benchmark to be sure.
312
const FillVectorOfUnivariate{
413
S <: ValueSupport,
514
T <: UnivariateDistribution{S},
@@ -59,7 +68,7 @@ const FillMatrixOfUnivariate{
5968
Tdists <: Fill{T, 2},
6069
} = MatrixOfUnivariate{S, T, Tdists}
6170

62-
function filldist(dist::UnivariateDistribution, N1::Integer, N2::Integer)
71+
function filldist(dist::UnivariateDistribution, N1::Int, N2::Int)
6372
return MatrixOfUnivariate(Fill(dist, N1, N2))
6473
end
6574
function Distributions._logpdf(dist::FillMatrixOfUnivariate, x::AbstractMatrix{<:Real})

0 commit comments

Comments
 (0)