Skip to content

Commit

Permalink
update factorialEstimand generation
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierlabayle committed May 24, 2024
1 parent 5d4a8e9 commit 56ca45b
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 115 deletions.
211 changes: 114 additions & 97 deletions src/counterfactual_mean_based/estimands.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,22 +245,88 @@ unique_treatment_values(dataset, colnames) =
"""
Generated from transitive treatment switches to create independent estimands.
"""
get_treatment_settings(::Union{typeof(ATE), typeof(IATE)}, treatments_unique_values) =
[collect(zip(vals[1:end-1], vals[2:end])) for vals in values(treatments_unique_values)]
get_treatment_settings(::Union{typeof(ATE), typeof(IATE)}, treatments_unique_values::NamedTuple{names}) where names =
NamedTuple{names}([collect(zip(vals[1:end-1], vals[2:end])) for vals in values(treatments_unique_values)])

get_treatment_settings(::typeof(CM), treatments_unique_values) =
values(treatments_unique_values)
get_treatment_settings(::typeof(CM), treatments_unique_values) = treatments_unique_values

get_treatment_setting(combo::Tuple{Vararg{Tuple}}) = [NamedTuple{(:control, :case)}(treatment_control_case) for treatment_control_case combo]

get_treatment_setting(combo) = collect(combo)

FACTORIAL_DOCSTRING = """
The components of this estimand are generated from the treatment variables contrasts.
"""
If there is no dataset and the treatments_levels are a NamedTuple, then they are assumed correct.
"""
make_or_check_treatment_levels(treatments_levels::NamedTuple, dataset::Nothing) = treatments_levels

"""
If no dataset is provided, then a NamedTuple precising treatment levels is expected
"""
make_or_check_treatment_levels(treatments, dataset::Nothing) =
throw(ArgumenError("No dataset from which to infer treatment levels was provided. Either provide a `dataset` or a NamedTuple `treatments` e.g. (T=[0, 1, 2],)"))

"""
If a list of treatments is provided as well as a dataset then the treatment_levels are infered from it.
"""
make_or_check_treatment_levels(treatments, dataset) = unique_treatment_values(dataset, treatments)

"""
If a NamedTuple of treatments_levels is provided as well as a dataset then the treatment_levels are checked from the dataset.
"""
function make_or_check_treatment_levels(treatments_levels::NamedTuple, dataset)
for (treatment, treatment_levels) in zip(keys(treatments_levels), values(treatments_levels))
dataset_treatment_levels = Set(skipmissing(Tables.getcolumn(dataset, treatment)))
missing_levels = setdiff(treatment_levels, dataset_treatment_levels)
length(missing_levels) == 0 ||
throw(ArgumentError(string("Not all levels provided for treatment ", treatment, " were found in the dataset: ", missing_levels)))
end
return treatments_levels
end

function _factorialEstimand(
constructor,
treatments_settings::NamedTuple{names}, outcome;
confounders=nothing,
outcome_extra_covariates=nothing,
freq_table=nothing,
positivity_constraint=nothing,
verbosity=1
) where names
components = []
for combo Iterators.product(values(treatments_settings)...)
Ψ = constructor(
outcome=outcome,
treatment_values=NamedTuple{names}(get_treatment_setting(combo)),
treatment_confounders = confounders,
outcome_extra_covariates=outcome_extra_covariates
)
if satisfies_positivity(Ψ, freq_table; positivity_constraint=positivity_constraint)
push!(components, Ψ)
else
verbosity > 0 && @warn("Sub estimand", Ψ, " did not pass the positivity constraint, skipped.")
end
end
return ComposedEstimand(joint_estimand, Tuple(components))
end

"""
factorialEstimand(
constructor::Union{typeof(CM), typeof(ATE), typeof(IATE)},
treatments, outcome;
confounders=nothing,
dataset=nothing,
outcome_extra_covariates=(),
positivity_constraint=nothing,
freq_table=nothing,
verbosity=1
)
Generates a factorial `ComposedEstimand` with components of type `constructor` (CM, ATE, IATE).
For the ATE and the IATE, the generated components are restricted to the Cartesian Product of single treatment levels transitions.
For example, consider two treatment variables T₁ and T₂ each taking three possible values (0, 1, 2).
For each treatment variable, the marginal transitive contrasts are defined by (0 → 1, 1 → 2). Note that (0 → 2) or (1 → 0) need not
be considered because they are linearly dependent on the other contrasts. Then, the cartesian product of treatment contrasts is taken,
resulting in a 2 x 2 = 4 dimensional joint estimand:
For each treatment variable, the single treatment levels transitions are defined by (0 → 1, 1 → 2).
Then, the Cartesian Product of these transitions is taken, resulting in a 2 x 2 = 4 dimensional joint estimand:
- (T₁: 0 → 1, T₂: 0 → 1)
- (T₁: 0 → 1, T₂: 1 → 2)
Expand All @@ -273,111 +339,63 @@ A `ComposedEstimand` with causal or statistical components.
# Args
- `treatments_levels`: A NamedTuple providing the unique levels each treatment variable can take.
- `constructor`: CM, ATE or IATE.
- `treatments`: A NamedTuple of treatment levels (e.g. `(T=(0, 1, 2),)`) or a treatment iterator, then a dataset must be provided to infer the levels from it.
- `outcome`: The outcome variable.
- `confounders=nothing`: The generated components will inherit these confounding variables.
If `nothing`, causal estimands are generated.
- `confounders=nothing`: The generated components will inherit these confounding variables. If `nothing`, causal estimands are generated.
- `outcome_extra_covariates=()`: The generated components will inherit these `outcome_extra_covariates`.
- `positivity_constraint=nothing`: Only components that pass the positivity constraint are added to the `ComposedEstimand`
- `dataset`: An optional dataset to enforce a positivity constraint and infer treatment levels.
- `positivity_constraint=nothing`: Only components that pass the positivity constraint are added to the `ComposedEstimand`. A `dataset` must then be provided.
- `freq_table`: This is only to be used by `factorialEstimands` to avoid unecessary computations.
- `verbosity=1`: Verbosity level.
"""

"""
factorialEstimand(
constructor::Union{typeof(ATE), typeof(IATE)},
treatments_levels::NamedTuple{names},
outcome;
confounders=nothing,
outcome_extra_covariates=(),
freq_table=nothing,
positivity_constraint=nothing,
verbosity=1
) where names
Generate a `ComposedEstimand` from `treatments_levels`. $FACTORIAL_DOCSTRING
# Examples:
Average Treatment Effects:
- An Average Treatment Effect with causal components:
```@example
factorialEstimand(ATE, (T₁ = (0, 1), T₂=(0, 1, 2)), :Y₁)
```
- An Average Interaction Effect with statistical components:
```@example
factorial(ATE, (T₁ = (0, 1, 2), T₂=(0, 1, 2)), :Y₁, confounders=[:W₁, :W₂])
factorial(IATE, (T₁ = (0, 1, 2), T₂=(0, 1, 2)), :Y₁, confounders=[:W₁, :W₂])
```
- With a dataset, the treatment levels can be infered and a positivity constraint enforced:
Interactions:
```@example
factorialEstimand(IATE, (T₁ = (0, 1), T₂=(0, 1, 2)), :Y₁)
factorialEstimand(ATE, [:T₁, :T₂], :Y₁,
confounders=[:W₁, :W₂],
dataset=dataset,
positivity_constraint=0.1
)
```
```@example
factorialEstimand(IATE, (T₁ = (0, 1, 2), T₂=(0, 1, 2)), :Y₁, confounders=[:W₁, :W₂])
"""
function factorialEstimand(
constructor::Union{typeof(CM), typeof(ATE), typeof(IATE)},
treatments_levels::NamedTuple{names},
outcome;
confounders=nothing,
outcome_extra_covariates=(),
freq_table=nothing,
positivity_constraint=nothing,
verbosity=1
) where names
treatments_settings = get_treatment_settings(constructor, treatments_levels)
components = []
for combo Iterators.product(treatments_settings...)
Ψ = constructor(
outcome=outcome,
treatment_values=NamedTuple{names}(get_treatment_setting(combo)),
treatment_confounders = confounders,
outcome_extra_covariates=outcome_extra_covariates
)
if satisfies_positivity(Ψ, freq_table; positivity_constraint=positivity_constraint)
push!(components, Ψ)
else
verbosity > 0 && @warn("Sub estimand", Ψ, " did not pass the positivity constraint, skipped.")
end
end
return ComposedEstimand(joint_estimand, Tuple(components))
end

"""
factorialEstimand(
constructor::Union{typeof(ATE), typeof(IATE)},
dataset, treatments, outcome;
confounders=nothing,
outcome_extra_covariates=(),
positivity_constraint=nothing,
verbosity=1
)
Identifies `treatment_levels` from `dataset` and construct the
factorialEstimand from it.
"""
function factorialEstimand(
constructor::Union{typeof(CM), typeof(ATE), typeof(IATE)},
dataset, treatments, outcome;
confounders=nothing,
treatments, outcome;
confounders=nothing,
dataset=nothing,
outcome_extra_covariates=(),
positivity_constraint=nothing,
freq_table=nothing,
verbosity=1
)
treatments_levels = unique_treatment_values(dataset, treatments)
freq_table = positivity_constraint !== nothing ? frequency_table(dataset, keys(treatments_levels)) : nothing
return factorialEstimand(
constructor,
treatments_levels,
outcome;
confounders=confounders,
outcome_extra_covariates=outcome_extra_covariates,
treatments_levels = make_or_check_treatment_levels(treatments, dataset)
freq_table = freq_table !== nothing ? freq_table : get_frequency_table(positivity_constraint, dataset, keys(treatments_levels))
treatments_settings = get_treatment_settings(constructor, treatments_levels)
return _factorialEstimand(
constructor, treatments_settings, outcome;
confounders=confounders,
outcome_extra_covariates=outcome_extra_covariates,
freq_table=freq_table,
positivity_constraint=positivity_constraint,
verbosity=verbosity
)
)
end

"""
Expand All @@ -390,31 +408,30 @@ factorialEstimands(
verbosity=1
)
Identifies `treatment_levels` from `dataset` and a factorialEstimand
for each outcome in `outcomes`.
Generates a `ComposedEstimand` for each outcome in `outcomes`. See `factorialEstimand`.
"""
function factorialEstimands(
constructor::Union{typeof(CM), typeof(ATE), typeof(IATE)},
dataset, treatments, outcomes;
treatments, outcomes;
dataset=nothing,
confounders=nothing,
outcome_extra_covariates=(),
positivity_constraint=nothing,
verbosity=1
)
treatments_levels = make_or_check_treatment_levels(treatments, dataset)
freq_table = get_frequency_table(positivity_constraint, dataset, keys(treatments_levels))
treatments_settings = get_treatment_settings(constructor, treatments_levels)
estimands = []
treatments_levels = unique_treatment_values(dataset, treatments)
freq_table = positivity_constraint !== nothing ? frequency_table(dataset, keys(treatments_levels)) : nothing
for outcome in outcomes
Ψ = factorialEstimand(
constructor,
treatments_levels,
outcome;
confounders=confounders,
outcome_extra_covariates=outcome_extra_covariates,
Ψ = _factorialEstimand(
constructor, treatments_settings, outcome;
confounders=confounders,
outcome_extra_covariates=outcome_extra_covariates,
freq_table=freq_table,
positivity_constraint=positivity_constraint,
verbosity=verbosity-1
)
)
if length.args) > 0
push!(estimands, Ψ)
else
Expand Down
11 changes: 10 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,16 @@ end

satisfies_positivity(Ψ, freq_table::Nothing; positivity_constraint=nothing) = true

function frequency_table(dataset, colnames)
get_frequency_table(positivity_constraint::Nothing, dataset::Nothing, colnames) = nothing

get_frequency_table(positivity_constraint::Nothing, dataset, colnames) = nothing

get_frequency_table(positivity_constraint, dataset::Nothing, colnames) =
throw(ArgumentError("A dataset should be provided to enforce a positivity constraint."))

get_frequency_table(positivity_constraint, dataset, colnames) = get_frequency_table(dataset, colnames)

function get_frequency_table(dataset, colnames)
iterator = zip((Tables.getcolumn(dataset, colname) for colname in sort(collect(colnames)))...)
counts = groupcount(x -> x, iterator)
for key in keys(counts)
Expand Down
Loading

0 comments on commit 56ca45b

Please sign in to comment.