From 56ca45b5cd0ead062192d4cd69d919679d83e6a9 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 24 May 2024 12:38:20 +0100 Subject: [PATCH] update factorialEstimand generation --- src/counterfactual_mean_based/estimands.jl | 211 +++++++++++--------- src/utils.jl | 11 +- test/counterfactual_mean_based/estimands.jl | 36 ++-- test/utils.jl | 4 +- 4 files changed, 147 insertions(+), 115 deletions(-) diff --git a/src/counterfactual_mean_based/estimands.jl b/src/counterfactual_mean_based/estimands.jl index 884cd2d..86c9804 100644 --- a/src/counterfactual_mean_based/estimands.jl +++ b/src/counterfactual_mean_based/estimands.jl @@ -245,22 +245,88 @@ unique_treatment_values(dataset, colnames) = """ Generated from transitive treatment switches to create independent estimands. """ -get_treatment_settings(::Union{typeof(ATE), typeof(IATE)}, treatments_unique_values) = - [collect(zip(vals[1:end-1], vals[2:end])) for vals in values(treatments_unique_values)] +get_treatment_settings(::Union{typeof(ATE), typeof(IATE)}, treatments_unique_values::NamedTuple{names}) where names = + NamedTuple{names}([collect(zip(vals[1:end-1], vals[2:end])) for vals in values(treatments_unique_values)]) -get_treatment_settings(::typeof(CM), treatments_unique_values) = - values(treatments_unique_values) +get_treatment_settings(::typeof(CM), treatments_unique_values) = treatments_unique_values get_treatment_setting(combo::Tuple{Vararg{Tuple}}) = [NamedTuple{(:control, :case)}(treatment_control_case) for treatment_control_case ∈ combo] get_treatment_setting(combo) = collect(combo) -FACTORIAL_DOCSTRING = """ -The components of this estimand are generated from the treatment variables contrasts. +""" +If there is no dataset and the treatments_levels are a NamedTuple, then they are assumed correct. +""" +make_or_check_treatment_levels(treatments_levels::NamedTuple, dataset::Nothing) = treatments_levels + +""" +If no dataset is provided, then a NamedTuple precising treatment levels is expected +""" +make_or_check_treatment_levels(treatments, dataset::Nothing) = + throw(ArgumenError("No dataset from which to infer treatment levels was provided. Either provide a `dataset` or a NamedTuple `treatments` e.g. (T=[0, 1, 2],)")) + +""" +If a list of treatments is provided as well as a dataset then the treatment_levels are infered from it. +""" +make_or_check_treatment_levels(treatments, dataset) = unique_treatment_values(dataset, treatments) + +""" +If a NamedTuple of treatments_levels is provided as well as a dataset then the treatment_levels are checked from the dataset. +""" +function make_or_check_treatment_levels(treatments_levels::NamedTuple, dataset) + for (treatment, treatment_levels) in zip(keys(treatments_levels), values(treatments_levels)) + dataset_treatment_levels = Set(skipmissing(Tables.getcolumn(dataset, treatment))) + missing_levels = setdiff(treatment_levels, dataset_treatment_levels) + length(missing_levels) == 0 || + throw(ArgumentError(string("Not all levels provided for treatment ", treatment, " were found in the dataset: ", missing_levels))) + end + return treatments_levels +end + +function _factorialEstimand( + constructor, + treatments_settings::NamedTuple{names}, outcome; + confounders=nothing, + outcome_extra_covariates=nothing, + freq_table=nothing, + positivity_constraint=nothing, + verbosity=1 + ) where names + components = [] + for combo ∈ Iterators.product(values(treatments_settings)...) + Ψ = constructor( + outcome=outcome, + treatment_values=NamedTuple{names}(get_treatment_setting(combo)), + treatment_confounders = confounders, + outcome_extra_covariates=outcome_extra_covariates + ) + if satisfies_positivity(Ψ, freq_table; positivity_constraint=positivity_constraint) + push!(components, Ψ) + else + verbosity > 0 && @warn("Sub estimand", Ψ, " did not pass the positivity constraint, skipped.") + end + end + return ComposedEstimand(joint_estimand, Tuple(components)) +end + +""" + factorialEstimand( + constructor::Union{typeof(CM), typeof(ATE), typeof(IATE)}, + treatments, outcome; + confounders=nothing, + dataset=nothing, + outcome_extra_covariates=(), + positivity_constraint=nothing, + freq_table=nothing, + verbosity=1 + ) + +Generates a factorial `ComposedEstimand` with components of type `constructor` (CM, ATE, IATE). + +For the ATE and the IATE, the generated components are restricted to the Cartesian Product of single treatment levels transitions. For example, consider two treatment variables T₁ and T₂ each taking three possible values (0, 1, 2). -For each treatment variable, the marginal transitive contrasts are defined by (0 → 1, 1 → 2). Note that (0 → 2) or (1 → 0) need not -be considered because they are linearly dependent on the other contrasts. Then, the cartesian product of treatment contrasts is taken, -resulting in a 2 x 2 = 4 dimensional joint estimand: +For each treatment variable, the single treatment levels transitions are defined by (0 → 1, 1 → 2). +Then, the Cartesian Product of these transitions is taken, resulting in a 2 x 2 = 4 dimensional joint estimand: - (T₁: 0 → 1, T₂: 0 → 1) - (T₁: 0 → 1, T₂: 1 → 2) @@ -273,111 +339,63 @@ A `ComposedEstimand` with causal or statistical components. # Args -- `treatments_levels`: A NamedTuple providing the unique levels each treatment variable can take. +- `constructor`: CM, ATE or IATE. +- `treatments`: A NamedTuple of treatment levels (e.g. `(T=(0, 1, 2),)`) or a treatment iterator, then a dataset must be provided to infer the levels from it. - `outcome`: The outcome variable. -- `confounders=nothing`: The generated components will inherit these confounding variables. -If `nothing`, causal estimands are generated. +- `confounders=nothing`: The generated components will inherit these confounding variables. If `nothing`, causal estimands are generated. - `outcome_extra_covariates=()`: The generated components will inherit these `outcome_extra_covariates`. -- `positivity_constraint=nothing`: Only components that pass the positivity constraint are added to the `ComposedEstimand` +- `dataset`: An optional dataset to enforce a positivity constraint and infer treatment levels. +- `positivity_constraint=nothing`: Only components that pass the positivity constraint are added to the `ComposedEstimand`. A `dataset` must then be provided. +- `freq_table`: This is only to be used by `factorialEstimands` to avoid unecessary computations. - `verbosity=1`: Verbosity level. -""" - -""" - factorialEstimand( - constructor::Union{typeof(ATE), typeof(IATE)}, - treatments_levels::NamedTuple{names}, - outcome; - confounders=nothing, - outcome_extra_covariates=(), - freq_table=nothing, - positivity_constraint=nothing, - verbosity=1 - ) where names - -Generate a `ComposedEstimand` from `treatments_levels`. $FACTORIAL_DOCSTRING # Examples: -Average Treatment Effects: +- An Average Treatment Effect with causal components: ```@example factorialEstimand(ATE, (T₁ = (0, 1), T₂=(0, 1, 2)), :Y₁) ``` +- An Average Interaction Effect with statistical components: + ```@example -factorial(ATE, (T₁ = (0, 1, 2), T₂=(0, 1, 2)), :Y₁, confounders=[:W₁, :W₂]) +factorial(IATE, (T₁ = (0, 1, 2), T₂=(0, 1, 2)), :Y₁, confounders=[:W₁, :W₂]) ``` +- With a dataset, the treatment levels can be infered and a positivity constraint enforced: Interactions: ```@example -factorialEstimand(IATE, (T₁ = (0, 1), T₂=(0, 1, 2)), :Y₁) +factorialEstimand(ATE, [:T₁, :T₂], :Y₁, + confounders=[:W₁, :W₂], + dataset=dataset, + positivity_constraint=0.1 +) ``` -```@example -factorialEstimand(IATE, (T₁ = (0, 1, 2), T₂=(0, 1, 2)), :Y₁, confounders=[:W₁, :W₂]) -""" -function factorialEstimand( - constructor::Union{typeof(CM), typeof(ATE), typeof(IATE)}, - treatments_levels::NamedTuple{names}, - outcome; - confounders=nothing, - outcome_extra_covariates=(), - freq_table=nothing, - positivity_constraint=nothing, - verbosity=1 - ) where names - treatments_settings = get_treatment_settings(constructor, treatments_levels) - components = [] - for combo ∈ Iterators.product(treatments_settings...) - Ψ = constructor( - outcome=outcome, - treatment_values=NamedTuple{names}(get_treatment_setting(combo)), - treatment_confounders = confounders, - outcome_extra_covariates=outcome_extra_covariates - ) - if satisfies_positivity(Ψ, freq_table; positivity_constraint=positivity_constraint) - push!(components, Ψ) - else - verbosity > 0 && @warn("Sub estimand", Ψ, " did not pass the positivity constraint, skipped.") - end - end - return ComposedEstimand(joint_estimand, Tuple(components)) -end - -""" -factorialEstimand( - constructor::Union{typeof(ATE), typeof(IATE)}, - dataset, treatments, outcome; - confounders=nothing, - outcome_extra_covariates=(), - positivity_constraint=nothing, - verbosity=1 - ) - -Identifies `treatment_levels` from `dataset` and construct the -factorialEstimand from it. """ function factorialEstimand( constructor::Union{typeof(CM), typeof(ATE), typeof(IATE)}, - dataset, treatments, outcome; - confounders=nothing, + treatments, outcome; + confounders=nothing, + dataset=nothing, outcome_extra_covariates=(), positivity_constraint=nothing, + freq_table=nothing, verbosity=1 ) - treatments_levels = unique_treatment_values(dataset, treatments) - freq_table = positivity_constraint !== nothing ? frequency_table(dataset, keys(treatments_levels)) : nothing - return factorialEstimand( - constructor, - treatments_levels, - outcome; - confounders=confounders, - outcome_extra_covariates=outcome_extra_covariates, + treatments_levels = make_or_check_treatment_levels(treatments, dataset) + freq_table = freq_table !== nothing ? freq_table : get_frequency_table(positivity_constraint, dataset, keys(treatments_levels)) + treatments_settings = get_treatment_settings(constructor, treatments_levels) + return _factorialEstimand( + constructor, treatments_settings, outcome; + confounders=confounders, + outcome_extra_covariates=outcome_extra_covariates, freq_table=freq_table, positivity_constraint=positivity_constraint, verbosity=verbosity - ) + ) end """ @@ -390,31 +408,30 @@ factorialEstimands( verbosity=1 ) -Identifies `treatment_levels` from `dataset` and a factorialEstimand -for each outcome in `outcomes`. +Generates a `ComposedEstimand` for each outcome in `outcomes`. See `factorialEstimand`. """ function factorialEstimands( constructor::Union{typeof(CM), typeof(ATE), typeof(IATE)}, - dataset, treatments, outcomes; + treatments, outcomes; + dataset=nothing, confounders=nothing, outcome_extra_covariates=(), positivity_constraint=nothing, verbosity=1 ) + treatments_levels = make_or_check_treatment_levels(treatments, dataset) + freq_table = get_frequency_table(positivity_constraint, dataset, keys(treatments_levels)) + treatments_settings = get_treatment_settings(constructor, treatments_levels) estimands = [] - treatments_levels = unique_treatment_values(dataset, treatments) - freq_table = positivity_constraint !== nothing ? frequency_table(dataset, keys(treatments_levels)) : nothing for outcome in outcomes - Ψ = factorialEstimand( - constructor, - treatments_levels, - outcome; - confounders=confounders, - outcome_extra_covariates=outcome_extra_covariates, + Ψ = _factorialEstimand( + constructor, treatments_settings, outcome; + confounders=confounders, + outcome_extra_covariates=outcome_extra_covariates, freq_table=freq_table, positivity_constraint=positivity_constraint, verbosity=verbosity-1 - ) + ) if length(Ψ.args) > 0 push!(estimands, Ψ) else diff --git a/src/utils.jl b/src/utils.jl index c23d9aa..339f4a2 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -122,7 +122,16 @@ end satisfies_positivity(Ψ, freq_table::Nothing; positivity_constraint=nothing) = true -function frequency_table(dataset, colnames) +get_frequency_table(positivity_constraint::Nothing, dataset::Nothing, colnames) = nothing + +get_frequency_table(positivity_constraint::Nothing, dataset, colnames) = nothing + +get_frequency_table(positivity_constraint, dataset::Nothing, colnames) = + throw(ArgumentError("A dataset should be provided to enforce a positivity constraint.")) + +get_frequency_table(positivity_constraint, dataset, colnames) = get_frequency_table(dataset, colnames) + +function get_frequency_table(dataset, colnames) iterator = zip((Tables.getcolumn(dataset, colname) for colname in sort(collect(colnames)))...) counts = groupcount(x -> x, iterator) for key in keys(counts) diff --git a/test/counterfactual_mean_based/estimands.jl b/test/counterfactual_mean_based/estimands.jl index bbabe48..9422051 100644 --- a/test/counterfactual_mean_based/estimands.jl +++ b/test/counterfactual_mean_based/estimands.jl @@ -203,13 +203,13 @@ end @testset "Test control_case_settings" begin treatments_unique_values = (T₁=(1, 0, 2),) - @test TMLE.get_treatment_settings(ATE, treatments_unique_values) == [[(1, 0), (0, 2)]] - @test TMLE.get_treatment_settings(IATE, treatments_unique_values) == [[(1, 0), (0, 2)]] - @test TMLE.get_treatment_settings(CM, treatments_unique_values) == ((1, 0, 2), ) + @test TMLE.get_treatment_settings(ATE, treatments_unique_values) == (T₁=[(1, 0), (0, 2)],) + @test TMLE.get_treatment_settings(IATE, treatments_unique_values) == (T₁=[(1, 0), (0, 2)],) + @test TMLE.get_treatment_settings(CM, treatments_unique_values) == (T₁=(1, 0, 2), ) treatments_unique_values = (T₁=(1, 0, 2), T₂=["AC", "CC"]) - @test TMLE.get_treatment_settings(ATE, treatments_unique_values) == [[(1, 0), (0, 2)], [("AC", "CC")]] - @test TMLE.get_treatment_settings(IATE, treatments_unique_values) == [[(1, 0), (0, 2)], [("AC", "CC")]] - @test TMLE.get_treatment_settings(CM, treatments_unique_values) == ((1, 0, 2), ["AC", "CC"]) + @test TMLE.get_treatment_settings(ATE, treatments_unique_values) == (T₁=[(1, 0), (0, 2)], T₂=[("AC", "CC")]) + @test TMLE.get_treatment_settings(IATE, treatments_unique_values) == (T₁=[(1, 0), (0, 2)], T₂=[("AC", "CC")]) + @test TMLE.get_treatment_settings(CM, treatments_unique_values) == (T₁=(1, 0, 2), T₂=["AC", "CC"]) end @testset "Test unique_treatment_values" begin @@ -234,7 +234,7 @@ end Y₁ = [1, 2, 3, 4], Y₂ = [1, 2, 3, 4] ) - composedCM = factorialEstimand(CM, dataset, [:T₁], :Y₁, verbosity=0) + composedCM = factorialEstimand(CM, [:T₁], :Y₁, dataset=dataset, verbosity=0) @test composedCM == TMLE.ComposedEstimand( TMLE.joint_estimand, ( @@ -244,7 +244,7 @@ end ) ) - composedCM = factorialEstimand(CM, dataset, [:T₁, :T₂], :Y₁, verbosity=0) + composedCM = factorialEstimand(CM, [:T₁, :T₂], :Y₁, dataset=dataset, verbosity=0) @test composedCM == TMLE.ComposedEstimand( TMLE.joint_estimand, ( @@ -272,7 +272,7 @@ end Y₂ = [1, 2, 3, 4] ) # No confounders, 1 treatment, no extra covariate: 3 causal ATEs - composedATE = factorialEstimand(ATE, dataset, [:T₁], :Y₁, verbosity=0) + composedATE = factorialEstimand(ATE, [:T₁], :Y₁, dataset=dataset, verbosity=0) @test composedATE == ComposedEstimand( TMLE.joint_estimand, ( @@ -281,7 +281,8 @@ end ) ) # 2 treatments - composedATE = factorialEstimand(ATE, dataset, [:T₁, :T₂], :Y₁; + composedATE = factorialEstimand(ATE, [:T₁, :T₂], :Y₁; + dataset=dataset, confounders=[:W₁, :W₂], outcome_extra_covariates=[:C], verbosity=0 @@ -317,7 +318,8 @@ end ) ) # positivity constraint - composedATE = factorialEstimand(ATE, dataset, [:T₁, :T₂], :Y₁; + composedATE = factorialEstimand(ATE, [:T₁, :T₂], :Y₁; + dataset=dataset, confounders=[:W₁, :W₂], outcome_extra_covariates=[:C], positivity_constraint=0.1, @@ -337,7 +339,8 @@ end Y₂ = [1, 2, 3, 4] ) # From dataset - composedIATE = factorialEstimand(IATE, dataset, [:T₁, :T₂], :Y₁, + composedIATE = factorialEstimand(IATE, [:T₁, :T₂], :Y₁; + dataset=dataset, confounders=[:W₁], outcome_extra_covariates=[:C], verbosity=0 @@ -396,7 +399,8 @@ end ) # positivity constraint - composedIATE = factorialEstimand(IATE, dataset, [:T₁, :T₂], :Y₁, + composedIATE = factorialEstimand(IATE, [:T₁, :T₂], :Y₁; + dataset=dataset, confounders=[:W₁], outcome_extra_covariates=[:C], positivity_constraint=0.1, @@ -415,7 +419,8 @@ end Y₁ = [1, 2, 3, 4], Y₂ = [1, 2, 3, 4] ) - factorial_ates = factorialEstimands(ATE, dataset, [:T₁, :T₂], [:Y₁, :Y₂], + factorial_ates = factorialEstimands(ATE, [:T₁, :T₂], [:Y₁, :Y₂], + dataset=dataset, confounders=[:W₁, :W₂], outcome_extra_covariates=[:C], positivity_constraint=0.1, @@ -423,7 +428,8 @@ end ) @test length(factorial_ates) == 2 # Nothing passes the threshold - factorial_ates = factorialEstimands(ATE, dataset, [:T₁, :T₂], [:Y₁, :Y₂], + factorial_ates = factorialEstimands(ATE, [:T₁, :T₂], [:Y₁, :Y₂], + dataset=dataset, confounders=[:W₁, :W₂], outcome_extra_covariates=[:C], positivity_constraint=0.3, diff --git a/test/utils.jl b/test/utils.jl index db4caa2..151c792 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -57,7 +57,7 @@ end B = ["AC", "CC", "AA", "AA", "AA", "AA", "AA", "AA"] ) # One variable - frequency_table = TMLE.frequency_table(dataset, [:A]) + frequency_table = TMLE.get_frequency_table(dataset, [:A]) @test frequency_table[(0,)] == 0.25 @test frequency_table[(1,)] == 0.5 @test frequency_table[(2,)] == 0.25 @@ -82,7 +82,7 @@ end # Two variables ## Treatments are sorted: [:B, :A] -> [:A, :B] - frequency_table = TMLE.frequency_table(dataset, [:B, :A]) + frequency_table = TMLE.get_frequency_table(dataset, [:B, :A]) @test frequency_table[(1, "CC")] == 0.125 @test frequency_table[(1, "AA")] == 0.25 @test frequency_table[(0, "AA")] == 0.25