diff --git a/.gitignore b/.gitignore index df02284..9c9e5e5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ Manifest.toml *.swp +docs/build/ diff --git a/ext/DistributionsADLazyArraysExt.jl b/ext/DistributionsADLazyArraysExt.jl index a030ca6..e63c3b5 100644 --- a/ext/DistributionsADLazyArraysExt.jl +++ b/ext/DistributionsADLazyArraysExt.jl @@ -3,12 +3,12 @@ module DistributionsADLazyArraysExt if isdefined(Base, :get_extension) using DistributionsAD using LazyArrays - using DistributionsAD: Distributions, ValueSupport, UnivariateDistribution, VectorOfUnivariate, MatrixOfUnivariate + using DistributionsAD: Distributions, ValueSupport, UnivariateDistribution, VectorOfUnivariate using LazyArrays: BroadcastArray, BroadcastVector, LazyArray else using ..DistributionsAD using ..LazyArrays - using ..DistributionsAD: Distributions, ValueSupport, UnivariateDistribution, VectorOfUnivariate, MatrixOfUnivariate + using ..DistributionsAD: Distributions, ValueSupport, UnivariateDistribution, VectorOfUnivariate using ..LazyArrays: BroadcastArray, BroadcastVector, LazyArray end @@ -34,19 +34,6 @@ function Distributions.logpdf( return vec(sum(copy(Distributions.logpdf.(dists, x)), dims = 1)) end -const LazyMatrixOfUnivariate{ - S<:ValueSupport, - T<:UnivariateDistribution{S}, - Tdists<:BroadcastArray{T,2}, -} = MatrixOfUnivariate{S,T,Tdists} - -function Distributions._logpdf( - dist::LazyMatrixOfUnivariate, - x::AbstractMatrix{<:Real}, -) - return sum(copy(Distributions.logpdf.(dist.dists, x))) -end - DistributionsAD.lazyarray(f, x...) = LazyArray(Base.broadcasted(f, x...)) end # module diff --git a/src/DistributionsAD.jl b/src/DistributionsAD.jl index c6203f6..0b4caaf 100644 --- a/src/DistributionsAD.jl +++ b/src/DistributionsAD.jl @@ -19,13 +19,10 @@ export TuringScalMvNormal, TuringMvLogNormal, TuringPoissonBinomial, TuringWishart, - TuringInverseWishart, - arraydist, - filldist + TuringInverseWishart include("common.jl") -include("arraydist.jl") -include("filldist.jl") +include("product_distribution_compat.jl") include("univariate.jl") include("multivariate.jl") include("matrixvariate.jl") diff --git a/src/arraydist.jl b/src/arraydist.jl deleted file mode 100644 index 062bab0..0000000 --- a/src/arraydist.jl +++ /dev/null @@ -1,108 +0,0 @@ -""" - arraydist(dists::AbstractArray{<:Distribution}) - -Create a product distribution from an array of sub-distributions. Each element -of `dists` should have the same size. If the size of each element is `(d1, d2, -...)`, and `size(dists)` is `(n1, n2, ...)`, then the resulting distribution -will have size `(d1, d2, ..., n1, n2, ...)`. - -The default behaviour is to directly use -[`Distributions.product_distribution`](https://juliastats.org/Distributions.jl/stable/multivariate/#Distributions.product_distribution), -although this can sometimes be specialised. - -# Examples - -```jldoctest; setup=:(using Distributions, Random) -julia> d1 = arraydist([Normal(0, 1), Normal(10, 1)]) -Product{Continuous, Normal{Float64}, Vector{Normal{Float64}}}(v=Normal{Float64}[Normal{Float64}(μ=0.0, σ=1.0), Normal{Float64}(μ=10.0, σ=1.0)]) - -julia> size(d1) -(2,) - -julia> Random.seed!(42); rand(d1) -2-element Vector{Float64}: - 0.7883556016042917 - 9.1201414040456 - -julia> d2 = arraydist([Normal(0, 1) Normal(5, 1); Normal(10, 1) Normal(15, 1)]) -DistributionsAD.MatrixOfUnivariate{Continuous, Normal{Float64}, Matrix{Normal{Float64}}}( -dists: Normal{Float64}[Normal{Float64}(μ=0.0, σ=1.0) Normal{Float64}(μ=5.0, σ=1.0); Normal{Float64}(μ=10.0, σ=1.0) Normal{Float64}(μ=15.0, σ=1.0)] -) - -julia> size(d2) -(2, 2) - -julia> Random.seed!(42); rand(d2) -2×2 Matrix{Float64}: - 0.788356 4.12621 - 9.12014 14.2667 -``` -""" -arraydist(dists::AbstractArray{<:Distribution}) = product_distribution(dists) - -# Univariate - -const VectorOfUnivariate = Distributions.Product - -function arraydist(dists::AbstractVector{<:UnivariateDistribution}) - V = typeof(dists) - T = eltype(dists) - S = Distributions.value_support(T) - return Product{S,T,V}(dists) -end - -struct MatrixOfUnivariate{ - S <: ValueSupport, - Tdist <: UnivariateDistribution{S}, - Tdists <: AbstractMatrix{Tdist}, -} <: MatrixDistribution{S} - dists::Tdists -end -Base.size(dist::MatrixOfUnivariate) = size(dist.dists) -function arraydist(dists::AbstractMatrix{<:UnivariateDistribution}) - return MatrixOfUnivariate(dists) -end -function Distributions._logpdf(dist::MatrixOfUnivariate, x::AbstractMatrix{<:Real}) - # Lazy broadcast to avoid allocations and use pairwise summation - return sum(Broadcast.instantiate(Broadcast.broadcasted(logpdf, dist.dists, x))) -end -function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:AbstractMatrix{<:Real}}) - return map(Base.Fix1(logpdf, dist), x) -end -function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:Matrix{<:Real}}) - return map(Base.Fix1(logpdf, dist), x) -end - -function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixOfUnivariate) - return rand.(Ref(rng), dist.dists) -end - -# Multivariate - -struct VectorOfMultivariate{ - S <: ValueSupport, - Tdist <: MultivariateDistribution{S}, - Tdists <: AbstractVector{Tdist}, -} <: MatrixDistribution{S} - dists::Tdists -end -Base.size(dist::VectorOfMultivariate) = (length(dist.dists[1]), length(dist)) -Base.length(dist::VectorOfMultivariate) = length(dist.dists) -function arraydist(dists::AbstractVector{<:MultivariateDistribution}) - return VectorOfMultivariate(dists) -end - -function Distributions._logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real}) - return sum(Broadcast.instantiate(Broadcast.broadcasted(logpdf, dist.dists, eachcol(x)))) -end -function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:AbstractMatrix{<:Real}}) - return map(Base.Fix1(logpdf, dist), x) -end -function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:Matrix{<:Real}}) - return map(Base.Fix1(logpdf, dist), x) -end - -function Distributions.rand(rng::Random.AbstractRNG, dist::VectorOfMultivariate) - init = reshape(rand(rng, dist.dists[1]), :, 1) - return mapreduce(Base.Fix1(rand, rng), hcat, view(dist.dists, 2:length(dist)); init = init) -end diff --git a/src/filldist.jl b/src/filldist.jl deleted file mode 100644 index d958361..0000000 --- a/src/filldist.jl +++ /dev/null @@ -1,123 +0,0 @@ -""" - filldist(d::Distribution, ns...) - -Create a product distribution from a single distribution and a list of -dimension sizes. If `size(d)` is `(d1, d2, ...)` and `ns` is `(n1, n2, ...)`, -then the resulting distribution will have size `(d1, d2, ..., n1, n2, ...)`. - -The default behaviour is to use -[`Distributions.product_distribution`](https://juliastats.org/Distributions.jl/stable/multivariate/#Distributions.product_distribution), -with `FillArrays.Fill` supplied as the array argument. However, this behaviour -is specialised in some instances, such as the one shown below. - -When sampling from the resulting distribution, the output will be an array where -each element is sampled from the original distribution `d`. - -# Examples - -```jldoctest; setup=:(using Distributions, Random) -julia> d = filldist(Normal(0, 1), 4, 5); - -julia> size(d) -(4, 5) - -julia> rand(d) isa Matrix{Float64} -true -``` -""" -filldist(d::Distribution, n1::Int, ns::Int...) = product_distribution(Fill(d, n1, ns...)) - -# Univariate - -# TODO: Do we even need these? Probably should benchmark to be sure. -const FillVectorOfUnivariate{ - S <: ValueSupport, - T <: UnivariateDistribution{S}, - Tdists <: Fill{T, 1}, -} = VectorOfUnivariate{S, T, Tdists} - -function filldist(dist::UnivariateDistribution, N::Int) - return product_distribution(Fill(dist, N)) -end -filldist(d::Normal, N::Int) = TuringMvNormal(fill(d.μ, N), d.σ) - -function Distributions._logpdf( - dist::FillVectorOfUnivariate, - x::AbstractVector{<:Real}, -) - return _flat_logpdf(dist.v.value, x) -end - -function Distributions.logpdf( - dist::FillVectorOfUnivariate, - x::AbstractMatrix{<:Real}, -) - size(x, 1) == length(dist) || - throw(DimensionMismatch("Inconsistent array dimensions.")) - return _flat_logpdf_mat(dist.v.value, x) -end - -function _flat_logpdf(dist, x) - if toflatten(dist) - f, args = flatten(dist) - # Lazy broadcast to avoid allocations and use pairwise summation - return sum(Broadcast.instantiate(Broadcast.broadcasted(xi -> f(args..., xi), x))) - else - return sum(Broadcast.instantiate(Broadcast.broadcasted(Base.Fix1(logpdf, dist), x))) - end -end - -function _flat_logpdf_mat(dist, x) - if toflatten(dist) - f, args = flatten(dist) - return vec(mapreduce(xi -> f(args..., xi), +, x, dims = 1)) - else - return vec(mapreduce(Base.Fix1(logpdf, dist), +, x; dims = 1)) - end -end - -function Distributions.rand(rng::Random.AbstractRNG, d::FillVectorOfUnivariate) - return rand(rng, d.v.value, length(d)) -end -function Distributions.rand(rng::Random.AbstractRNG, d::FillVectorOfUnivariate, n::Int) - return rand(rng, d.v.value, length(d), n) -end - -const FillMatrixOfUnivariate{ - S <: ValueSupport, - T <: UnivariateDistribution{S}, - Tdists <: Fill{T, 2}, -} = MatrixOfUnivariate{S, T, Tdists} - -function filldist(dist::UnivariateDistribution, N1::Int, N2::Int) - return MatrixOfUnivariate(Fill(dist, N1, N2)) -end -function Distributions._logpdf(dist::FillMatrixOfUnivariate, x::AbstractMatrix{<:Real}) - # return loglikelihood(dist.dists.value, x) - return _flat_logpdf(dist.dists.value, x) -end -function Distributions.rand(rng::Random.AbstractRNG, dist::FillMatrixOfUnivariate) - return rand(rng, dist.dists.value, length.(dist.dists.axes)...,) -end - -# Multivariate - -const FillVectorOfMultivariate{ - S <: ValueSupport, - T <: MultivariateDistribution{S}, - Tdists <: Fill{T, 1}, -} = VectorOfMultivariate{S, T, Tdists} - -function filldist(dist::MultivariateDistribution, N::Int) - return VectorOfMultivariate(Fill(dist, N)) -end -function Distributions._logpdf( - dist::FillVectorOfMultivariate, - x::AbstractMatrix{<:Real}, -) - return loglikelihood(dist.dists.value, x) -end - -function Distributions.rand(rng::Random.AbstractRNG, dist::FillVectorOfMultivariate) - return rand(rng, dist.dists.value, length.(dist.dists.axes)...,) -end diff --git a/src/product_distribution_compat.jl b/src/product_distribution_compat.jl new file mode 100644 index 0000000..2297589 --- /dev/null +++ b/src/product_distribution_compat.jl @@ -0,0 +1,4 @@ +# Compatibility aliases for product distributions +# These are maintained for compatibility with extensions + +const VectorOfUnivariate = Distributions.Product \ No newline at end of file diff --git a/src/zygote.jl b/src/zygote.jl index 86598b6..ff6e3a6 100644 --- a/src/zygote.jl +++ b/src/zygote.jl @@ -6,14 +6,6 @@ ZygoteRules.@adjoint function Distributions._logpdf(d::Product, x::AbstractVecto sum(map(logpdf, d.v, x)) end end -ZygoteRules.@adjoint function Distributions._logpdf( - d::FillVectorOfUnivariate, - x::AbstractVector{<:Real}, -) - return ZygoteRules.pullback(d, x) do d, x - _flat_logpdf(d.v.value, x) - end -end # Loglikelihood of multi- and matrixvariate distributions: multiple samples # workaround for Zygote issues discussed in diff --git a/test/ad/distributions.jl b/test/ad/distributions.jl index 8baa50c..d601e30 100644 --- a/test/ad/distributions.jl +++ b/test/ad/distributions.jl @@ -395,7 +395,7 @@ test_ad(d) end - # Test `filldist` and `arraydist` distributions of univariate distributions + # Test `product_distribution` distributions of univariate distributions n = 2 # always use two distributions for d in univariate_distributions d.x isa Number || continue @@ -409,28 +409,28 @@ # PoissonBinomial fails with Zygote # Matrix case does not work with Skellam: # https://github.com/TuringLang/DistributionsAD.jl/pull/172#issuecomment-853721493 - filldist_broken = if D <: PoissonBinomial + fill_broken = if D <: PoissonBinomial ((d.broken..., :Zygote), (d.broken..., :Zygote)) elseif D <: Chernoff - # Zygote is not broken with `filldist` + # Zygote is not broken with Fill ((), ()) else (d.broken, d.broken) end - arraydist_broken = if D <: PoissonBinomial + array_broken = if D <: PoissonBinomial ((d.broken..., :Zygote), (d.broken..., :Zygote)) else (d.broken, d.broken) end - # Create `filldist` distribution + # Create `product_distribution` with Fill f = d.f - f_filldist = (θ...,) -> filldist(f(θ...), n) - d_filldist = f_filldist(d.θ...) + f_fill = (θ...,) -> product_distribution(Fill(f(θ...), n)) + d_fill = f_fill(d.θ...) - # Create `arraydist` distribution - f_arraydist = (θ...,) -> arraydist([f(θ...) for _ in 1:n]) - d_arraydist = f_arraydist(d.θ...) + # Create `product_distribution` with vector + f_array = (θ...,) -> product_distribution([f(θ...) for _ in 1:n]) + d_array = f_array(d.θ...) for (i, sz) in enumerate(((n,), (n, 2))) # Matrix case doesn't work for continuous distributions for some reason @@ -443,25 +443,25 @@ x = fill(d.x, sz) # Test AD - @info "Testing: filldist($(nameof(D)), $sz)" + @info "Testing: product_distribution(Fill($(nameof(D)), $sz))" test_ad( DistSpec( - f_filldist, + f_fill, d.θ, x, d.xtrans; - broken=filldist_broken[i], + broken=fill_broken[i], ) ) - @info "Testing: arraydist($(nameof(D)), $sz)" + @info "Testing: product_distribution([$(nameof(D)), ...], $sz)" test_ad( DistSpec( - f_arraydist, + f_array, d.θ, x, d.xtrans; - broken=arraydist_broken[i], + broken=array_broken[i], ) ) end @@ -476,7 +476,7 @@ test_ad(d) end - # Test `filldist` and `arraydist` distributions of univariate distributions + # Test `product_distribution` distributions of univariate distributions (2D) n = (2, 2) # always use 2 x 2 distributions for d in univariate_distributions d.x isa Number || continue @@ -486,35 +486,35 @@ # Broken distributions D <: Union{VonMises,TriangularDist} && continue - # Create `filldist` distribution + # Create `product_distribution` with Fill f = d.f - f_filldist = (θ...,) -> filldist(f(θ...), n...) + f_fill = (θ...,) -> product_distribution(Fill(f(θ...), n...)) - # Create `arraydist` distribution + # Create `product_distribution` with matrix # Zygote's fill definition does not like non-numbers, so we use a workaround - f_arraydist = (θ...,) -> arraydist(reshape([f(θ...) for _ in 1:prod(n)], n)) + f_array = (θ...,) -> product_distribution(reshape([f(θ...) for _ in 1:prod(n)], n)) # Matrix `x` x_mat = fill(d.x, n) - # Zygote is not broken with `filldist` + Chernoff - filldist_broken = D <: Chernoff ? () : d.broken + # Zygote is not broken with Fill + Chernoff + fill_broken = D <: Chernoff ? () : d.broken # Test AD - @info "Testing: filldist($(nameof(D)), $n)" + @info "Testing: product_distribution(Fill($(nameof(D)), $n))" test_ad( DistSpec( - f_filldist, + f_fill, d.θ, x_mat, d.xtrans; - broken=filldist_broken, + broken=fill_broken, ) ) - @info "Testing: arraydist($(nameof(D)), $n)" + @info "Testing: product_distribution(matrix of $(nameof(D)), $n)" test_ad( DistSpec( - f_arraydist, + f_array, d.θ, x_mat, d.xtrans; @@ -526,20 +526,20 @@ x_vec_of_mat = [fill(d.x, n) for _ in 1:2] # Test AD - @info "Testing: filldist($(nameof(D)), $n, 2)" + @info "Testing: product_distribution(Fill($(nameof(D)), $n)) with vector of matrices" test_ad( DistSpec( - f_filldist, + f_fill, d.θ, x_vec_of_mat, d.xtrans; - broken=filldist_broken, + broken=fill_broken, ) ) - @info "Testing: arraydist($(nameof(D)), $n, 2)" + @info "Testing: product_distribution(matrix of $(nameof(D)), $n) with vector of matrices" test_ad( DistSpec( - f_arraydist, + f_array, d.θ, x_vec_of_mat, d.xtrans; @@ -548,7 +548,7 @@ ) end - # test `filldist` and `arraydist` distributions of multivariate distributions + # test `product_distribution` distributions of multivariate distributions n = 2 # always use two distributions for d in multivariate_distributions d.x isa AbstractVector || continue @@ -566,31 +566,31 @@ any(x isa Matrix for x in d.θ) && continue end - # Create `filldist` distribution + # Create `product_distribution` with Fill f = d.f - f_filldist = (θ...,) -> filldist(f(θ...), n) + f_fill = (θ...,) -> product_distribution(Fill(f(θ...), n)) - # Create `arraydist` distribution - f_arraydist = (θ...,) -> arraydist([f(θ...) for _ in 1:n]) + # Create `product_distribution` with vector + f_array = (θ...,) -> product_distribution([f(θ...) for _ in 1:n]) # Matrix `x` x_mat = repeat(d.x, 1, n) # Test AD - @info "Testing: filldist($(nameof(D)), $n)" + @info "Testing: product_distribution(Fill($(nameof(D)), $n))" test_ad( DistSpec( - f_filldist, + f_fill, d.θ, x_mat, d.xtrans; broken=d.broken, ) ) - @info "Testing: arraydist($(nameof(D)), $n)" + @info "Testing: product_distribution([$(nameof(D)), ...])" test_ad( DistSpec( - f_arraydist, + f_array, d.θ, x_mat, d.xtrans; @@ -602,20 +602,20 @@ x_vec_of_mat = [repeat(d.x, 1, n) for _ in 1:2] # Test AD - @info "Testing: filldist($(nameof(D)), $n, 2)" + @info "Testing: product_distribution(Fill($(nameof(D)), $n)) with vector of matrices" test_ad( DistSpec( - f_filldist, + f_fill, d.θ, x_vec_of_mat, d.xtrans; broken=d.broken, ) ) - @info "Testing: arraydist($(nameof(D)), $n, 2)" + @info "Testing: product_distribution([$(nameof(D)), ...]) with vector of matrices" test_ad( DistSpec( - f_arraydist, + f_array, d.θ, x_vec_of_mat, d.xtrans; diff --git a/test/runtests.jl b/test/runtests.jl index c25d19e..9773ee3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using DistributionsAD using Combinatorics using Distributions using Documenter +using FillArrays using PDMats import LazyArrays