Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

share an expander? #182

Open
github-actions bot opened this issue Mar 13, 2024 · 0 comments
Open

share an expander? #182

github-actions bot opened this issue Mar 13, 2024 · 0 comments
Labels

Comments

@github-actions
Copy link

# TODO: share an expander?

    println_flush(rs.io)
end

struct SimpleLossMgr <: LossMgr
    loss::ADNode
    function SimpleLossMgr(loss::ADNode)
        # TODO: share an expander?
        l = Dice.LogPrExpander(WMC(BDDCompiler(Dice.bool_roots([loss]))))
        loss = Dice.expand_logprs(l, loss)
        new(loss)
    end
end
produce_loss(rs::RunState, m::SimpleLossMgr, epoch::Integer) = m.loss

struct SamplingEntropy{T} <: LossConfig{T}
    resampling_frequency::Integer
    samples_per_batch::Integer
end

mutable struct SamplingEntropyLossMgr <: LossMgr
    p::SamplingEntropy
    val::Dist
    consider
    ignore
    current_loss::Union{Nothing,ADNode}
    SamplingEntropyLossMgr(p, val, consider, ignore) = new(p, val, consider, ignore, nothing)
end
function produce_loss(rs::RunState, m::SamplingEntropyLossMgr, epoch::Integer)
    if (epoch - 1) % m.p.resampling_frequency == 0
        println_flush(rs.io, "Sampling...")
        time_sample = @elapsed samples = with_concrete_ad_flips(rs.var_vals, m.val) do
            [sample_as_dist(rs.rng, Valuation(), m.val) for _ in 1:m.p.samples_per_batch]
        end
        println(rs.io, "  $(time_sample) seconds")

        loss = sum(
            LogPr(prob_equals(m.val, sample))
            for sample in samples
            if m.consider(sample)
        )
        for sample in samples
            @assert m.consider(sample) ^ m.ignore(sample)
        end
        l = Dice.LogPrExpander(WMC(BDDCompiler(Dice.bool_roots([loss]))))
        loss = Dice.expand_logprs(l, loss) / m.p.samples_per_batch
        m.current_loss = loss
    end

    @assert !isnothing(m.current_loss)
    m.current_loss
end

function save_learning_curve(out_dir, learning_curve, name)
    open(joinpath(out_dir, "$(name).csv"), "w") do file
        xs = 0:length(learning_curve)-1
        for (epoch, logpr) in zip(xs, learning_curve)
            println(file, "$(epoch)\t$(logpr)")
        end
        plot(xs, learning_curve)
        savefig(joinpath(out_dir, "$(name).svg"))
    end
end

##################################
@github-actions github-actions bot added the todo label Mar 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

0 participants