Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make this faster) #219

Open
github-actions bot opened this issue Nov 1, 2024 · 0 comments
Open

make this faster) #219

github-actions bot opened this issue Nov 1, 2024 · 0 comments
Labels

Comments

@github-actions
Copy link

github-actions bot commented Nov 1, 2024

with a memory estimate of 388.02 MiB, over 8512429 allocations.

retries[] # 30

Range (min … max): 2.029 ms … 137.030 ms ┊ GC (min … max): 0.00% … 98.14%

Time (median): 2.879 ms ┊ GC (median): 0.00%

Time (mean ± σ): 4.377 ms ± 6.119 ms ┊ GC (mean ± σ): 4.36% ± 4.07%

██████████████████▅▇▆▄▆▄▆▇▄▆▄▅▁▆▁▅▇▁▄▄▁▁▁▄▁▁▅▄▆▁▄▄▁▁▁▄▁▁▁▅▅ █

2.03 ms Histogram: log(frequency) by time 22.6 ms <

# Sample some tree until it's valid (TODO: make this faster)

using Revise
using Dice
using BenchmarkTools
using ProfileView

include("benchmarks.jl")

generation_params = LangSiblingDerivedGenerator{STLC}(
    root_ty=Expr.t,
    ty_sizes=[Expr.t=>5, Typ.t=>2],
    stack_size=2,
    intwidth=3,
)

SEED = 0
out_dir="/tmp"
log_path="/dev/null"
rs = RunState(Valuation(), Dict{String,ADNode}(), open(log_path, "w"), out_dir, MersenneTwister(SEED), nothing,generation_params)

generation::Generation = generate(rs, generation_params)

g::Dist = generation.value

# Sample some tree until it's valid (TODO: make this faster)
a = ADComputer(rs.var_vals)

NUM_SAMPLES = 10

function wellTyped(e)
    @assert isdeterministic(e)
    @match typecheck(e) [
        Some(_) -> true,
        None() -> false,
    ]
end

retries = Ref(0)
#== @benchmark ==# @time begin
    samples = []
    while length(samples) < NUM_SAMPLES
        retries[] += 1
        s = sample_as_dist(rs.rng, a, g)
        if wellTyped(s)
            push!(samples, s)
        end
    end
end
#  Single result which took 26.426 s (3.00% GC) to evaluate, (7s, 26s, 30s, 40s)
#  with a memory estimate of 388.02 MiB, over 8512429 allocations.
retries[] # 30 

l = Dice.LogPrExpander(WMC(BDDCompiler([
    prob_equals(g, sample)
    for sample in samples
])))
@time begin
    loss, actual_loss = sum(
        begin
            lpr_eq = Dice.expand_logprs(l, LogPr(prob_equals(g, sample)))
            [lpr_eq * compute(a, lpr_eq), lpr_eq]
        end
        for sample in samples
    )
end
# 5.3s first run, 1.4s rest

length(l.cache) # 331

@benchmark vals, derivs = differentiate(
    rs.var_vals,
    Derivs([loss => 1.])
)
# BenchmarkTools.Trial: 1060 samples with 1 evaluation.
#  Range (min … max):  2.029 ms … 137.030 ms  ┊ GC (min … max): 0.00% … 98.14%
#  Time  (median):     2.879 ms               ┊ GC (median):    0.00%
#  Time  (mean ± σ):   4.377 ms ±   6.119 ms  ┊ GC (mean ± σ):  4.36% ±  4.07%

#   ██▇▆▅▃▃▂▃▁▁▂▂  ▁                                             
#   ██████████████████▅▇▆▄▆▄▆▇▄▆▄▅▁▆▁▅▇▁▄▄▁▁▁▄▁▁▅▄▆▁▄▄▁▁▁▄▁▁▁▅▅ █
#   2.03 ms      Histogram: log(frequency) by time      22.6 ms <

#  Memory estimate: 292.17 KiB, allocs estimate: 8034.

ct = Ref(0)
Dice.foreach_down(loss) do _ ct[] += 1 end
ct[] # 350

@github-actions github-actions bot added the todo label Nov 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

0 participants