Skip to content

Commit

Permalink
up docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierlabayle committed Dec 18, 2023
1 parent 7f5cd62 commit fad1a38
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 43 deletions.
110 changes: 74 additions & 36 deletions src/counterfactual_mean_based/estimands.jl
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,61 @@ function generateComposedEstimandFromContrasts(
return ComposedEstimand(joint_estimand, Tuple(components))
end

GENERATE_DOCSTRING = """
The components of this estimand are generated from the treatment variables contrasts.
For example, consider two treatment variables T₁ and T₂ each taking three possible values (0, 1, 2).
For each treatment variable, the marginal contrasts are defined by (0 → 1, 1 → 2, 0 → 2), there are thus
3 x 3 = 9 joint contrasts to be generated:
- (T₁: 0 → 1, T₂: 0 → 1)
- (T₁: 0 → 1, T₂: 1 → 2)
- (T₁: 0 → 1, T₂: 0 → 2)
- (T₁: 1 → 2, T₂: 0 → 1)
- (T₁: 1 → 2, T₂: 1 → 2)
- (T₁: 1 → 2, T₂: 0 → 2)
- (T₁: 0 → 2, T₂: 0 → 1)
- (T₁: 0 → 2, T₂: 1 → 2)
- (T₁: 0 → 2, T₂: 0 → 2)
# Return
A `ComposedEstimand` with causal or statistical components.
# Args
- `treatments_levels`: A NamedTuple providing the unique levels each treatment variable can take.
- `outcome`: The outcome variable.
- `confounders=nothing`: The generated components will inherit these confounding variables.
If `nothing`, causal estimands are generated.
- `outcome_extra_covariates=()`: The generated components will inherit these `outcome_extra_covariates`.
- `positivity_constraint=nothing`: Only components that pass the positivity constraint are added to the `ComposedEstimand`
"""
generateATEs(treatments_unique_values, outcome; confounders=nothing, outcome_extra_covariates=())

Generate all possible ATEs from the `treatments_unique_values`.
"""
generateATEs(
treatments_levels::NamedTuple{names}, outcome;
confounders=nothing,
outcome_extra_covariates=(),
freq_table=nothing,
positivity_constraint=nothing
) where names
Generate a `ComposedEstimand` of ATEs from the `treatments_levels`. $GENERATE_DOCSTRING
# Example:
To generate a causal composed estimand with 3 components:
```@example
generateATEs((T₁ = (0, 1), T₂=(0, 1, 2)), :Y₁)
```
To generate a statistical composed estimand with 9 components:
```@example
generateATEs((T₁ = (0, 1, 2), T₂=(0, 1, 2)), :Y₁, confounders=[:W₁, :W₂])
```
"""
function generateATEs(
treatments_levels::NamedTuple{names}, outcome;
Expand All @@ -283,7 +334,11 @@ function generateATEs(
end

"""
generateATEs(dataset, treatments, outcome; confounders=nothing, outcome_extra_covariates=())
generateATEs(dataset, treatments, outcome;
confounders=nothing,
outcome_extra_covariates=(),
positivity_constraint=nothing
)
Find all unique values for each treatment variable in the dataset and generate all possible ATEs from these values.
"""
Expand All @@ -305,39 +360,15 @@ function generateATEs(dataset, treatments, outcome;
end

"""
generateIATEs(treatments_levels, outcome;
generateIATEs(
treatments_levels::NamedTuple{names}, outcome;
confounders=nothing,
outcome_extra_covariates=()
)
Generates a `ComposedEstimand` of Average Interation Effects from `treatments_levels`.
The components of this estimand are generated from the treatment variables contrasts.
For example, consider two treatment variables T₁ and T₂ each taking three possible values (0, 1, 2).
For each treatment variable, the marginal contrasts are defined by (0 → 1, 1 → 2, 0 → 2), there are thus
3 x 3 = 9 joint contrasts to be generated:
- (T₁: 0 → 1, T₂: 0 → 1)
- (T₁: 0 → 1, T₂: 1 → 2)
- (T₁: 0 → 1, T₂: 0 → 2)
- (T₁: 1 → 2, T₂: 0 → 1)
- (T₁: 1 → 2, T₂: 1 → 2)
- (T₁: 1 → 2, T₂: 0 → 2)
- (T₁: 0 → 2, T₂: 0 → 1)
- (T₁: 0 → 2, T₂: 1 → 2)
- (T₁: 0 → 2, T₂: 0 → 2)
# Args
- `treatments_levels`: A NamedTuple providing the unique levels each treatment variable can take.
- `outcome`: The outcome variable.
- `confounders=nothing`: The generated interaction components will inherit these confounding variables.
If `nothing`, `CausalIATE`s are generated, otherwise `StatisticalIATE`s are generated.
- `outcome_extra_covariates=()`: The generated interaction components will inherit these `outcome_extra_covariates`.
- `positivity_constraint=nothing`: Only components that pass the positivity constraint are added to the `ComposedEstimand`
# Return
outcome_extra_covariates=(),
freq_table=nothing,
positivity_constraint=nothing
) where names
A `ComposedEstimand` with causal or statistical interaction components.
Generates a `ComposedEstimand` of Average Interation Effects from `treatments_levels`. $GENERATE_DOCSTRING
# Example:
Expand Down Expand Up @@ -378,7 +409,7 @@ end
positivity_constraint=nothing
)
Finds treatments levels from the dataset and generates a `ComposedEstimand` of Average Interation Effects from these
Finds treatments levels from the dataset and generates a `ComposedEstimand` of Average Interation Effects from them
(see [`generateIATEs(treatments_levels, outcome; confounders=nothing, outcome_extra_covariates=())`](@ref)).
"""
function generateIATEs(dataset, treatments, outcome;
Expand All @@ -396,4 +427,11 @@ function generateIATEs(dataset, treatments, outcome;
freq_table=freq_table,
positivity_constraint=positivity_constraint
)
end
end

joint_levels::StatisticalIATE) = Iterators.product(values.treatment_values)...)

joint_levels::StatisticalATE) =
(Tuple.treatment_values[T][c] for T keys.treatment_values)) for c in (:case, :control))

joint_levels::StatisticalCM) = (values.treatment_values),)
7 changes: 0 additions & 7 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,6 @@ default_models(;Q_binary=LinearBinaryClassifier(), Q_continuous=LinearRegressor(

is_binary(dataset, columnname) = Set(skipmissing(Tables.getcolumn(dataset, columnname))) == Set([0, 1])

joint_levels::TMLE.StatisticalIATE) = Iterators.product(values.treatment_values)...)

joint_levels::TMLE.StatisticalATE) =
(Tuple.treatment_values[T][c] for T keys.treatment_values)) for c in (:case, :control))

joint_levels::TMLE.StatisticalCM) = (values.treatment_values),)

function satisfies_positivity(Ψ, freq_table; positivity_constraint=0.01)
for jointlevel in joint_levels(Ψ)
if !haskey(freq_table, jointlevel) || freq_table[jointlevel] < positivity_constraint
Expand Down

0 comments on commit fad1a38

Please sign in to comment.