Skip to content

Commit

Permalink
add cache accessors
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierlabayle committed Feb 11, 2025
1 parent d63f3b5 commit a667f7f
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 21 deletions.
6 changes: 1 addition & 5 deletions docs/src/user_guide/estimation.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,7 @@ result₁
nothing # hide
```

The `cache` (see below) contains estimates for the nuisance functions that were necessary to estimate the ATE. For instance, we can see what is the value of ``\epsilon`` corresponding to the clever covariate.

```@example estimation
ϵ = last_fluctuation_epsilon(cache)
```
The `cache` (see below) contains estimates for the nuisance functions that were necessary to estimate the ATE.

The `result₁` structure corresponds to the estimation result and will display the result of a T-Test including:

Expand Down
4 changes: 3 additions & 1 deletion src/TMLE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ export default_models, TreatmentTransformer, with_encoder, encoder
export BackdoorAdjustment, identify
export Configuration
export brute_force_ordering, groups_ordering

export gradients, epsilons, estimates

# #############################################################################
# INCLUDES
# #############################################################################
Expand All @@ -63,5 +64,6 @@ include("counterfactual_mean_based/gradient.jl")

include("configuration.jl")
include("testing.jl")
include("cache.jl")

end
10 changes: 10 additions & 0 deletions src/cache.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@

MLJBase.report(factors::MLCMRelevantFactors) = MLJBase.report(factors.outcome_mean.machine)

MLJBase.report(cache) = MLJBase.report(cache[:targeted_factors])

gradients(cache) = MLJBase.report(cache[:targeted_factors]).gradients

estimates(cache) = MLJBase.report(cache[:targeted_factors]).estimates

epsilons(cache) = MLJBase.report(cache[:targeted_factors]).epsilons
18 changes: 5 additions & 13 deletions src/counterfactual_mean_based/estimators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,21 +143,11 @@ function (estimator::TargetedCMRelevantFactorsEstimator)(estimand, dataset; cach
# Build estimate
estimate = MLCMRelevantFactors(estimand, fluctuated_outcome_mean, fluctuated_propensity_score)
# Update cache
cache[:last_fluctuation] = estimate
cache[:targeted_factors] = estimate

return estimate
end

gradients(factors::MLCMRelevantFactors) = MLJBase.report(factors.outcome_mean.machine).gradients

estimates(factors::MLCMRelevantFactors) = MLJBase.report(factors.outcome_mean.machine).estimates

epsilons(factors::MLCMRelevantFactors) = MLJBase.report(factors.outcome_mean.machine).epsilons

gradient(factors::MLCMRelevantFactors) = last(gradients(factors))

Distributions.estimate(factors::MLCMRelevantFactors) = last(estimates(factors))

#####################################################################
### TMLE ###
#####################################################################
Expand Down Expand Up @@ -235,8 +225,10 @@ function (tmle::TMLEE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict()
machine_cache=tmle.machine_cache
)
# Estimation results after TMLE
IC = gradient(targeted_factors_estimate)
Ψ̂ = estimate(targeted_factors_estimate)
estimation_report = report(targeted_factors_estimate)

IC = last(estimation_report.gradients)
Ψ̂ = last(estimation_report.estimates)
σ̂ = std(IC)
n = size(IC, 1)
verbosity >= 1 && @info "Done."
Expand Down
8 changes: 7 additions & 1 deletion test/counterfactual_mean_based/3points_interactions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ end
:T₃ => LogisticClassifier(lambda=0)
)

tmle = TMLEE(models=models, machine_cache=true)
tmle = TMLEE(models=models, machine_cache=true, max_iter=3, tol=0)
result, cache = tmle(Ψ, dataset, verbosity=0);
test_coverage(result, Ψ₀)
test_fluct_decreases_risk(cache)
Expand All @@ -54,6 +54,12 @@ end
test_coverage(result, Ψ₀)
test_mean_inf_curve_almost_zero(result; atol=1e-10)

# CHecking cache accessors
@test length(gradients(cache)) == 3
@test length(estimates(cache)) == 3
@test length(epsilons(cache)) == 3
@test report(cache) isa NamedTuple

end

end
Expand Down
2 changes: 1 addition & 1 deletion test/helper_fns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ It seems that sometimes this is not entirely true in practice, so the test actua
increase risk more than tol
"""
function test_fluct_decreases_risk(cache; atol=1e-6)
fluctuated_mean_machine = cache[:last_fluctuation].outcome_mean.machine
fluctuated_mean_machine = cache[:targeted_factors].outcome_mean.machine
initial_mean_machine = fluctuated_mean_machine.model.initial_factors.outcome_mean.machine
y = fluctuated_mean_machine.data[2]
initial_risk = risk(MLJBase.predict(initial_mean_machine), y)
Expand Down

0 comments on commit a667f7f

Please sign in to comment.