You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I put this question on Julia Discourse for the following issue "Extracting each iteration result in KissABC.jl " . I tried editing the smc source code to get the desired feature, and it seems to work.
Here I have shared my edited code for smc. Please have a look and let me know what you think. This feature can really be beneficial for other people working with KissABC.jl package as intermediate results help you to judge your run beforehand.
This code gives information for each iteration as shown in the image below and also saves a CSV for all parameter values which can be used later to plot trajectories fit, contour plots or density plot to see status of your run :)
Julia version: 1.10
KissABC version: 3.0.1
function smc_edited(
prior::Tprior,
cost,
param_name::Vector{String};
rng::AbstractRNG = Random.GLOBAL_RNG,
nparticles::Int = 100,
alpha = 0.95,
mcmc_retrys::Int = 0,
mcmc_tol = 0.015,
epstol = 0.0,
r_epstol = (1 - alpha)^1.5 / 50,
min_r_ess = alpha^2,
max_stretch = 2.0,
verbose::Bool = false,
parallel::Bool = false,
) where {Tprior<:Distribution}
min_r_ess > 0 || error("min_r_ess must be > 0.")
mcmc_retrys >= 0 || error("mcmc_retrys must be >= 0.")
alpha > 0 || error("alpha must be > 0.")
r_epstol >= 0 || error("r_epstol must be >= 0")
mcmc_tol >= 0 || error("mcmc_tol must be >= 0")
max_stretch > 1 || error("max_stretch must be > 1")
Np=length(prior)
min_nparticles = ceil(
Int,
3 * Np / (min(alpha, min_r_ess)),
)
nparticles >= min_nparticles || error("nparticles must be >= $min_nparticles.")
θs = [op(float, Particle(rand(rng, prior))) for i = 1:nparticles]
Xs = parallel ?
fetch.([
Threads.@spawn cost(push_p(prior, θs[$i].x)) for i = 1:nparticles]) :
[cost(push_p(prior, θs[i].x)) for i = 1:nparticles]
lπs = [logpdf(prior, push_p(prior, θs[i].x)) for i = 1:nparticles]
α = alpha
ϵ = Inf
alive = fill(true,nparticles)
iteration = 0
# Step 1 - adaptive threshold
while true
iteration += 1
ϵv = ϵ
ϵ = quantile(Xs[alive],α)
flag=false
if ϵ > minimum(Xs[alive])
alive = Xs .< ϵ
else
alive = Xs .<= ϵ
flag=true
end
ESS = sum(alive)
verbose && @show iteration, ϵ, ESS
# Step 2 - Resampling
if α*ESS <= nparticles * min_r_ess
idxalive = (1:nparticles)[alive]
idx=repeat(idxalive,ceil(Int,nparticles/length(idxalive)))[1:nparticles]
θs = θs[idx]
Xs = Xs[idx]
lπs = lπs[idx]
ESS = nparticles
alive .= true
end
# Step 3 - MCMC
accepted = parallel ? Threads.Atomic{Int}(0) : 0
retry_N = 1 + mcmc_retrys
for r = 1:retry_N
new_p = map(1:nparticles) do i
a = b = i
alive[i] || return (nothing,nothing,nothing)
while a==i; a = rand(rng,1:nparticles); end
while b==i || b==a; b = rand(rng,1:nparticles); end
W = op(*, op(-, θs[b], θs[a]), max_stretch*randn(rng)/sqrt(Np))
(log(rand(rng)), op(+, θs[i], W), 0.0)
end
@cthreads parallel for i = 1:nparticles # non-ideal parallelism
alive[i] || continue
lprob, θp, logcorr = new_p[i]
isnothing(lprob) && continue
lπp = logpdf(prior, push_p(prior, θp.x))
lπp < 0 && (!isfinite(lπp)) && continue
lM = min(lπp - lπs[i] + logcorr, 0.0)
if lprob < lM
Xp = cost(push_p(prior, θp.x))
if flag
Xp > ϵ && continue
else
Xp >= ϵ && continue
end
θs[i] = θp
Xs[i] = Xp
lπs[i] = lπp
if parallel
Threads.atomic_add!(accepted, 1)
else
accepted += 1
end
end
end
accepted[] >= mcmc_tol * nparticles && break
end
if 2*abs(ϵv - ϵ) < r_epstol * (abs(ϵv)+abs(ϵ)) ||
ϵ <= epstol ||
accepted[] < mcmc_tol * nparticles
break
end
As = [push_p(prior, θs[i].x) for i = 1:nparticles][alive]
l = length(prior)
Q = map(x -> Particles(x), getindex.(As, i) for i = 1:l)
length(Q)==1 && (Q=first(Q))
@info "Saving Population $(iteration)"
save_param!(DataFrame(Array(Q), param_name), Xs, iteration)
@info "Current Particles info - $(Q)"
@info "Done"
end
θs = [push_p(prior, θs[i].x) for i = 1:nparticles][alive]
l = length(prior)
P = map(x -> Particles(x), getindex.(θs, i) for i = 1:l)
length(P)==1 && (P=first(P))
@info "Saving Population $(iteration) - Final"
save_param!(DataFrame(Array(P), param_name), Xs, iteration)
@info "Final Particles info after $(iteration) - $(P)"
@info "Process Finished"
(P = P, C = Xs, ϵ = ϵ)
end
function save_param!(xx::DataFrame, C, iter)
xx[!, :ϵ] = C
CSV.write("params_$(iter).csv", xx)
end
Output during run:
The text was updated successfully, but these errors were encountered:
Hello,
I put this question on Julia Discourse for the following issue "Extracting each iteration result in KissABC.jl " . I tried editing the smc source code to get the desired feature, and it seems to work.
Here I have shared my edited code for smc. Please have a look and let me know what you think. This feature can really be beneficial for other people working with KissABC.jl package as intermediate results help you to judge your run beforehand.
This code gives information for each iteration as shown in the image below and also saves a CSV for all parameter values which can be used later to plot trajectories fit, contour plots or density plot to see status of your run :)
Julia version: 1.10
KissABC version: 3.0.1
Output during run:
The text was updated successfully, but these errors were encountered: