Skip to content

Commit

Permalink
rebase fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mohamed82008 committed Mar 10, 2020
1 parent 784d055 commit 75d3e66
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 21 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,4 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Pkg", "TerminalLoggers", "Test", "UnicodePlots", "StatsBase", "SpecialFunctions", "FiniteDifferences", "DynamicHMC", "CmdStan", "BenchmarkTools", "Zygote"]
test = ["Pkg", "TerminalLoggers", "Test", "UnicodePlots", "StatsBase", "FiniteDifferences", "DynamicHMC", "CmdStan", "BenchmarkTools", "Zygote"]
2 changes: 1 addition & 1 deletion src/inference/AdvancedSMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ function assume(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, vn::VarName,
r = rand(dist)
push!(vi, vn, r, dist, Selector(:invalid))
end
acclogp!(vi, invlink_logpdf_trans(spl, dist, r, istrans(vi, vn)))
acclogp!(vi, logpdf_with_trans(dist, r, istrans(vi, vn)))
end
return r, 0
end
Expand Down
20 changes: 4 additions & 16 deletions src/inference/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ function AbstractMCMC.bundle_samples(
model::Model,
spl::Sampler,
N::Integer,
ts::Vector{<:AbstractTransition},
ts::Vector,
chain_type::Type{Chains};
raw_output::Bool=false,
discard_adapt::Bool=true,
Expand Down Expand Up @@ -613,7 +613,7 @@ function assume(
# r is genereated from some uniform distribution which is different from the prior
# acclogp!(vi, logpdf_with_trans(dist, r, istrans(vi, vn)))

return r, invlink_logpdf_trans(spl, dist, r, istrans(vi, vn))
return r, logpdf_with_trans(dist, r, istrans(vi, vn))
end

function observe(
Expand Down Expand Up @@ -727,7 +727,7 @@ function dot_assume(
)
@assert length(dist) == size(var, 1)
r = get_and_set_val!(vi, vns, dist, spl)
lp = sum(invlink_logpdf_trans(spl, dist, r, istrans(vi, vns[1])))
lp = sum(logpdf_with_trans(dist, r, istrans(vi, vns[1])))
var .= r
return var, lp
end
Expand All @@ -740,7 +740,7 @@ function dot_assume(
)
r = get_and_set_val!(vi, vns, dists, spl)
# Make sure `r` is not a matrix for multivariate distributions
lp = sum(invlink_logpdf_trans.(Ref(spl), dists, r, istrans(vi, vns[1])))
lp = sum(logpdf_with_trans.(dists, r, istrans(vi, vns[1])))
var .= r
return var, lp
end
Expand Down Expand Up @@ -835,18 +835,6 @@ function set_val!(
return val
end

function invlink_logpdf_trans(spl, dist, x, trans)
#if dist isa Dirichlet || dist isa VectorOfMultivariate{<:Any, <:Dirichlet}
if trans
return logpdf_with_trans(dist, invlink(dist, x), true)
else
return logpdf_with_trans(dist, x, false)
end
#else
# return logpdf_with_trans(dist, x, trans)
#end
end

# observe
function dot_tilde(ctx::DefaultContext, sampler, right, left, vi)
return _dot_tilde(sampler, right, left, vi)
Expand Down
6 changes: 3 additions & 3 deletions src/inference/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ function assume(
Turing.DEBUG && _debug("dist = $dist")
Turing.DEBUG && _debug("vn = $vn")
Turing.DEBUG && _debug("r = $r, typeof(r)=$(typeof(r))")
return r, invlink_logpdf_trans(spl, dist, r, istrans(vi, vn))
return r, logpdf_with_trans(dist, r, istrans(vi, vn))
end

function dot_assume(
Expand All @@ -461,7 +461,7 @@ function dot_assume(
updategid!.(Ref(vi), vns, Ref(spl))
r = vi[vns]
var .= r
return var, sum(invlink_logpdf_trans(spl, dist, r, istrans(vi, vns[1])))
return var, sum(logpdf_with_trans(dist, r, istrans(vi, vns[1])))
end
function dot_assume(
spl::Sampler{<:Hamiltonian},
Expand All @@ -473,7 +473,7 @@ function dot_assume(
updategid!.(Ref(vi), vns, Ref(spl))
r = reshape(vi[vec(vns)], size(var))
var .= r
return var, sum(invlink_logpdf_trans.(Ref(spl), dists, r, istrans(vi, vns[1])))
return var, sum(logpdf_with_trans.(dists, r, istrans(vi, vns[1])))
end

function observe(
Expand Down

0 comments on commit 75d3e66

Please sign in to comment.