Skip to content

Commit

Permalink
Merge pull request #98 from TARGENE/generate_ates
Browse files Browse the repository at this point in the history
add generateATEs
  • Loading branch information
olivierlabayle authored Dec 11, 2023
2 parents 230d124 + a54d72f commit 0769421
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 7 deletions.
24 changes: 19 additions & 5 deletions docs/src/user_guide/estimands.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ statisticalΨ = ATE(
)
```

- generating all ``ATEs``

It is possible to generate all possible ATEs from a set of treatment values or from a dataset. For that purpose, use the `generateATEs` function.

## The Interaction Average Treatment Effect

- Causal Question:
Expand Down Expand Up @@ -176,12 +180,22 @@ statisticalΨ = IATE(
)
```

## Any function of the previous Estimands
## Composed Estimands

As a result of Julia's automatic differentiation facilities, given a set of predefined estimands ``(\Psi_1, ..., \Psi_k)``, we can automatically compute an estimator for $f(\Psi_1, ..., \Psi_k)$. This is done via the `ComposedEstimand` type.

As a result of Julia's automatic differentiation facilities, given a set of already estimated estimands ``(\Psi_1, ..., \Psi_k)``, we can automatically compute an estimator for $f(\Psi_1, ..., \Psi_k)$. This is done via the `compose` function:
For example, the difference in ATE for a treatment with 3 levels (0, 1, 2) can be defined as follows:

```julia
compose(f, args...)
ATE₁ = ATE(
outcome = :Y,
treatment_values = (T = (control = 0, case = 1),),
treatment_confounders = [:W]
)
ATE₂ = ATE(
outcome = :Y,
treatment_values = (T = (control = 1, case = 2),),
treatment_confounders = [:W]
)
ATEdiff = ComposedEstimand(-, (ATE₁, ATE₂))
```

where args are asymptotically linear estimates (see [Composing Estimands](@ref)).
6 changes: 5 additions & 1 deletion docs/src/user_guide/misc.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,8 @@ where `encoder` is a [OneHotEncoder](https://alan-turing-institute.github.io/MLJ

The `with_encoder(model; encoder=TreatmentTransformer())` provides a shorthand to combine a `TreatmentTransformer` with another MLJ model in a pipeline.

Of course you are also free to define your own strategy!
Of course you are also free to define your own strategy!

## Serialization

Many objects from TMLE.jl can be serialized to various file formats. This is achieved by transforming these structures to dictionaries that can then be serialized to classic JSON or YAML format. For that purpose you can use the `TMLE.read_json`, `TMLE.write_json`, `TMLE.read_yaml` and `TMLE.write_yaml` functions.
1 change: 1 addition & 0 deletions src/TMLE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ using Combinatorics
export SCM, StaticSCM, add_equations!, add_equation!, parents, vertices
export CM, ATE, IATE
export AVAILABLE_ESTIMANDS
export generateATEs
export TMLEE, OSE, NAIVE
export ComposedEstimand
export var, estimate, OneSampleTTest, OneSampleZTest, OneSampleHotellingT2Test,pvalue, confint, emptyIC
Expand Down
6 changes: 6 additions & 0 deletions src/configuration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ struct Configuration
adjustment::Union{Nothing, <:AdjustmentMethod}
end

"""
Configuration(;estimands, scm=nothing, adjustment=nothing) = Configuration(estimands, scm, adjustment)
A Configuration is a set of estimands to be estimated. If the set of estimands contains causal (identifiable) estimands,
these will be identified using the provided `scm` and `adjustment` method.
"""
Configuration(;estimands, scm=nothing, adjustment=nothing) = Configuration(estimands, scm, adjustment)

function to_dict(configuration::Configuration)
Expand Down
45 changes: 44 additions & 1 deletion src/counterfactual_mean_based/estimands.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ for (estimand, (formula,)) ∈ ESTIMANDS_DOCS
outcome = Symbol(outcome)
treatment_values = get_treatment_specs(treatment_values)
treatment_variables = Tuple(keys(treatment_values))
treatment_confounders = NamedTuple{treatment_variables}([unique_sorted_tuple(treatment_confounders[T]) for T treatment_variables])
treatment_confounders = NamedTuple{treatment_variables}([confounders_values(treatment_confounders, T) for T treatment_variables])
outcome_extra_covariates = unique_sorted_tuple(outcome_extra_covariates)
return new(outcome, treatment_values, treatment_confounders, outcome_extra_covariates)
end
Expand Down Expand Up @@ -157,6 +157,10 @@ treatment_specs_to_dict(treatment_values::NamedTuple) = Dict(pairs(treatment_val
treatment_values(d::AbstractDict) = (;d...)
treatment_values(d) = d

confounders_values(key_value_iterable::Union{NamedTuple, Dict}, T) = unique_sorted_tuple(key_value_iterable[T])

confounders_values(iterable, T) = unique_sorted_tuple(iterable)

get_treatment_specs(treatment_specs::NamedTuple{names, }) where names =
NamedTuple{Tuple(sort(collect(names)))}(treatment_specs)

Expand Down Expand Up @@ -221,4 +225,43 @@ function identify(method::BackdoorAdjustment, causal_estimand::T, scm::SCM) wher
treatment_confounders = treatment_confounders,
outcome_extra_covariates = method.outcome_extra_covariates
)
end

unique_non_missing(dataset, colname) = unique(skipmissing(Tables.getcolumn(dataset, colname)))

unique_treatment_values(dataset, colnames) =(;(colname => unique_non_missing(dataset, colname) for colname in colnames)...)

"""
generateATEs(dataset, treatments, outcome; confounders=nothing, outcome_extra_covariates=())
Find all unique values for each treatment variable in the dataset and generate all possible ATEs from these values.
"""
function generateATEs(dataset, treatments, outcome; confounders=nothing, outcome_extra_covariates=())
treatments_unique_values = unique_treatment_values(dataset, treatments)
return generateATEs(treatments_unique_values, outcome; confounders=confounders, outcome_extra_covariates=outcome_extra_covariates)
end

"""
generateATEs(treatments_unique_values, outcome; confounders=nothing, outcome_extra_covariates=())
Generate all possible ATEs from the `treatments_unique_values`.
"""
function generateATEs(treatments_unique_values, outcome; confounders=nothing, outcome_extra_covariates=())
treatments = Tuple(Symbol.(keys(treatments_unique_values)))
treatments_control_case = [collect(Combinatorics.combinations(treatments_unique_values[T], 2)) for T in treatments]

ATEs = []
for combo Iterators.product(treatments_control_case...)
treatments_control_case = [NamedTuple{(:control, :case)}(treatment_control_case) for treatment_control_case combo]
push!(
ATEs,
ATE(
outcome=outcome,
treatment_values=NamedTuple{treatments}(treatments_control_case),
treatment_confounders = confounders,
outcome_extra_covariates=outcome_extra_covariates
)
)
end
return ATEs
end
32 changes: 32 additions & 0 deletions test/counterfactual_mean_based/estimands.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,38 @@ end
Ψreconstructed = TMLE.from_dict!(d)
@test Ψreconstructed == Ψ
end

@testset "Test generateATEs" begin
dataset = (
T₁=[0, 1, 2, missing],
T₂ = ["AC", "CC", missing, "AA"],
W₁ = [1, 2, 3, 4],
W₂ = [1, 2, 3, 4],
C = [1, 2, 3, 4],
Y₁ = [1, 2, 3, 4],
Y₂ = [1, 2, 3, 4]
)
# No confounders, 1 treatment, no extra covariate: 3 causal ATEs
ATEs = generateATEs(dataset, [:T₁], :Y₁)
@test ATEs == [
TMLE.CausalATE(:Y₁, (T₁ = (case = 1, control = 0),)),
TMLE.CausalATE(:Y₁, (T₁ = (case = 2, control = 0),)),
TMLE.CausalATE(:Y₁, (T₁ = (case = 2, control = 1),))
]
# 2 treatments
ATEs = generateATEs(dataset, [:T₁, :T₂], :Y₁;
confounders=[:W₁, :W₂],
outcome_extra_covariates=[:C]
)
# 9 expected different treatment settings
@test length(ATEs) == 9
@test length(unique([Ψ.treatment_values for Ψ in ATEs])) == 9
# Covariates and confounders
@test all.outcome_extra_covariates == (:C,) for Ψ in ATEs)
@test all.treatment_confounders == (T₁ = (:W₁, :W₂), T₂ = (:W₁, :W₂)) for Ψ in ATEs)
end


end

true

0 comments on commit 0769421

Please sign in to comment.