Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
Manifest.toml
*.swp
docs/build/
17 changes: 2 additions & 15 deletions ext/DistributionsADLazyArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ module DistributionsADLazyArraysExt
if isdefined(Base, :get_extension)
using DistributionsAD
using LazyArrays
using DistributionsAD: Distributions, ValueSupport, UnivariateDistribution, VectorOfUnivariate, MatrixOfUnivariate
using DistributionsAD: Distributions, ValueSupport, UnivariateDistribution, VectorOfUnivariate
using LazyArrays: BroadcastArray, BroadcastVector, LazyArray
else
using ..DistributionsAD
using ..LazyArrays
using ..DistributionsAD: Distributions, ValueSupport, UnivariateDistribution, VectorOfUnivariate, MatrixOfUnivariate
using ..DistributionsAD: Distributions, ValueSupport, UnivariateDistribution, VectorOfUnivariate
using ..LazyArrays: BroadcastArray, BroadcastVector, LazyArray
end

Expand All @@ -34,19 +34,6 @@ function Distributions.logpdf(
return vec(sum(copy(Distributions.logpdf.(dists, x)), dims = 1))
end

const LazyMatrixOfUnivariate{
S<:ValueSupport,
T<:UnivariateDistribution{S},
Tdists<:BroadcastArray{T,2},
} = MatrixOfUnivariate{S,T,Tdists}

function Distributions._logpdf(
dist::LazyMatrixOfUnivariate,
x::AbstractMatrix{<:Real},
)
return sum(copy(Distributions.logpdf.(dist.dists, x)))
end

DistributionsAD.lazyarray(f, x...) = LazyArray(Base.broadcasted(f, x...))

end # module
7 changes: 2 additions & 5 deletions src/DistributionsAD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,10 @@ export TuringScalMvNormal,
TuringMvLogNormal,
TuringPoissonBinomial,
TuringWishart,
TuringInverseWishart,
arraydist,
filldist
TuringInverseWishart

include("common.jl")
include("arraydist.jl")
include("filldist.jl")
include("product_distribution_compat.jl")
include("univariate.jl")
include("multivariate.jl")
include("matrixvariate.jl")
Expand Down
108 changes: 0 additions & 108 deletions src/arraydist.jl

This file was deleted.

123 changes: 0 additions & 123 deletions src/filldist.jl

This file was deleted.

4 changes: 4 additions & 0 deletions src/product_distribution_compat.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Compatibility aliases for product distributions
# These are maintained for compatibility with extensions

const VectorOfUnivariate = Distributions.Product
8 changes: 0 additions & 8 deletions src/zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,6 @@ ZygoteRules.@adjoint function Distributions._logpdf(d::Product, x::AbstractVecto
sum(map(logpdf, d.v, x))
end
end
ZygoteRules.@adjoint function Distributions._logpdf(
d::FillVectorOfUnivariate,
x::AbstractVector{<:Real},
)
return ZygoteRules.pullback(d, x) do d, x
_flat_logpdf(d.v.value, x)
end
end

# Loglikelihood of multi- and matrixvariate distributions: multiple samples
# workaround for Zygote issues discussed in
Expand Down
Loading
Loading