We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dice.jl/examples/qc/benchmarks/benchmarks.jl
Line 89 in 1c0c443
println_flush(rs.io) end struct SimpleLossMgr <: LossMgr loss::ADNode function SimpleLossMgr(loss::ADNode) # TODO: share an expander? l = Dice.LogPrExpander(WMC(BDDCompiler(Dice.bool_roots([loss])))) loss = Dice.expand_logprs(l, loss) new(loss) end end produce_loss(rs::RunState, m::SimpleLossMgr, epoch::Integer) = m.loss struct SamplingEntropy{T} <: LossConfig{T} resampling_frequency::Integer samples_per_batch::Integer end mutable struct SamplingEntropyLossMgr <: LossMgr p::SamplingEntropy val::Dist consider ignore current_loss::Union{Nothing,ADNode} SamplingEntropyLossMgr(p, val, consider, ignore) = new(p, val, consider, ignore, nothing) end function produce_loss(rs::RunState, m::SamplingEntropyLossMgr, epoch::Integer) if (epoch - 1) % m.p.resampling_frequency == 0 println_flush(rs.io, "Sampling...") time_sample = @elapsed samples = with_concrete_ad_flips(rs.var_vals, m.val) do [sample_as_dist(rs.rng, Valuation(), m.val) for _ in 1:m.p.samples_per_batch] end println(rs.io, " $(time_sample) seconds") loss = sum( LogPr(prob_equals(m.val, sample)) for sample in samples if m.consider(sample) ) for sample in samples @assert m.consider(sample) ^ m.ignore(sample) end l = Dice.LogPrExpander(WMC(BDDCompiler(Dice.bool_roots([loss])))) loss = Dice.expand_logprs(l, loss) / m.p.samples_per_batch m.current_loss = loss end @assert !isnothing(m.current_loss) m.current_loss end function save_learning_curve(out_dir, learning_curve, name) open(joinpath(out_dir, "$(name).csv"), "w") do file xs = 0:length(learning_curve)-1 for (epoch, logpr) in zip(xs, learning_curve) println(file, "$(epoch)\t$(logpr)") end plot(xs, learning_curve) savefig(joinpath(out_dir, "$(name).svg")) end end ##################################
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Dice.jl/examples/qc/benchmarks/benchmarks.jl
Line 89 in 1c0c443
The text was updated successfully, but these errors were encountered: