Skip to content

Commit

Permalink
Merge branch 'main' into loco-gwas
Browse files Browse the repository at this point in the history
Conflicts:
	Manifest.toml
  • Loading branch information
joshua-slaughter committed Jun 6, 2024
2 parents 0832103 + 58148df commit c1c303f
Show file tree
Hide file tree
Showing 7 changed files with 197 additions and 177 deletions.
35 changes: 35 additions & 0 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,9 @@ weakdeps = ["JSON", "RecipesBase", "SentinelArrays", "StructTypes"]
[[deps.CategoricalDistributions]]
deps = ["CategoricalArrays", "Distributions", "Missings", "OrderedCollections", "Random", "ScientificTypes"]
git-tree-sha1 = "926862f549a82d6c3a7145bc7f1adff2a91a39f0"
git-tree-sha1 = "926862f549a82d6c3a7145bc7f1adff2a91a39f0"
uuid = "af321ab8-2d2e-40a6-b165-3d674595d28e"
version = "0.1.15"
version = "0.1.15"

[deps.CategoricalDistributions.extensions]
Expand Down Expand Up @@ -492,8 +494,10 @@ version = "0.17.6"
[[deps.ConstructionBase]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "260fd2400ed2dab602a7c15cf10c1933c59930a2"
git-tree-sha1 = "260fd2400ed2dab602a7c15cf10c1933c59930a2"
uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
version = "1.5.5"
version = "1.5.5"
weakdeps = ["IntervalSets", "StaticArrays"]

[deps.ConstructionBase.extensions]
Expand All @@ -508,8 +512,10 @@ version = "0.1.3"

[[deps.Contour]]
git-tree-sha1 = "439e35b0b36e2e5881738abc8857bd92ad6ff9a8"
git-tree-sha1 = "439e35b0b36e2e5881738abc8857bd92ad6ff9a8"
uuid = "d38c429a-6771-53c6-b99e-75d170b6e991"
version = "0.6.3"
version = "0.6.3"

[[deps.CpuId]]
deps = ["Markdown"]
Expand Down Expand Up @@ -656,7 +662,9 @@ version = "1.0.4"
[[deps.EvoTrees]]
deps = ["BSON", "CategoricalArrays", "Distributions", "MLJModelInterface", "NetworkLayout", "Random", "RecipesBase", "Statistics", "StatsBase", "Tables"]
git-tree-sha1 = "92d1f78f95f4794bf29bd972dacfa37ea1fec9f4"
git-tree-sha1 = "92d1f78f95f4794bf29bd972dacfa37ea1fec9f4"
uuid = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5"
version = "0.16.7"
version = "0.16.7"

[deps.EvoTrees.extensions]
Expand Down Expand Up @@ -731,8 +739,10 @@ version = "0.1.1"
[[deps.FileIO]]
deps = ["Pkg", "Requires", "UUIDs"]
git-tree-sha1 = "82d8afa92ecf4b52d78d869f038ebfb881267322"
git-tree-sha1 = "82d8afa92ecf4b52d78d869f038ebfb881267322"
uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
version = "1.16.3"
version = "1.16.3"

[[deps.FilePaths]]
deps = ["FilePathsBase", "MacroTools", "Reexport", "Requires"]
Expand Down Expand Up @@ -791,8 +801,10 @@ version = "2.13.96+0"

[[deps.Format]]
git-tree-sha1 = "9c68794ef81b08086aeb32eeaf33531668d5f5fc"
git-tree-sha1 = "9c68794ef81b08086aeb32eeaf33531668d5f5fc"
uuid = "1fa38f19-a742-5d3f-a2b9-30dd87b9d5f8"
version = "1.3.7"
version = "1.3.7"

[[deps.ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"]
Expand Down Expand Up @@ -912,8 +924,10 @@ version = "1.11.0"
[[deps.GridLayoutBase]]
deps = ["GeometryBasics", "InteractiveUtils", "Observables"]
git-tree-sha1 = "6f93a83ca11346771a93bbde2bdad2f65b61498f"
git-tree-sha1 = "6f93a83ca11346771a93bbde2bdad2f65b61498f"
uuid = "3955a311-db13-416c-9275-1d80ed98e5e9"
version = "0.10.2"
version = "0.10.2"

[[deps.Grisu]]
git-tree-sha1 = "53bb909d1151e57e2484c3d1b53e19552b887fb2"
Expand Down Expand Up @@ -1219,7 +1233,9 @@ version = "3.100.2+0"
[[deps.LLVM]]
deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"]
git-tree-sha1 = "839c82932db86740ae729779e610f07a1640be9a"
git-tree-sha1 = "839c82932db86740ae729779e610f07a1640be9a"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
version = "6.6.3"
version = "6.6.3"

[deps.LLVM.extensions]
Expand Down Expand Up @@ -1484,8 +1500,10 @@ version = "0.10.0"
[[deps.MLJModelInterface]]
deps = ["Random", "ScientificTypesBase", "StatisticalTraits"]
git-tree-sha1 = "d2a45e1b5998ba3fdfb6cfe0c81096d4c7fb40e7"
git-tree-sha1 = "d2a45e1b5998ba3fdfb6cfe0c81096d4c7fb40e7"
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
version = "1.9.6"
version = "1.9.6"

[[deps.MLJModels]]
deps = ["CategoricalArrays", "CategoricalDistributions", "Combinatorics", "Dates", "Distances", "Distributions", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "Markdown", "OrderedCollections", "Parameters", "Pkg", "PrettyPrinting", "REPL", "Random", "RelocatableFolders", "ScientificTypes", "StatisticalTraits", "Statistics", "StatsBase", "Tables"]
Expand Down Expand Up @@ -1751,8 +1769,10 @@ version = "1.4.3"
[[deps.OpenSSL_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
git-tree-sha1 = "3da7367955dcc5c54c1ba4d402ccdc09a1a3e046"
git-tree-sha1 = "3da7367955dcc5c54c1ba4d402ccdc09a1a3e046"
uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95"
version = "3.0.13+1"
version = "3.0.13+1"

[[deps.OpenSpecFun_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
Expand All @@ -1763,12 +1783,21 @@ version = "0.5.5+0"
[[deps.Optim]]
deps = ["Compat", "FillArrays", "ForwardDiff", "LineSearches", "LinearAlgebra", "NLSolversBase", "NaNMath", "Parameters", "PositiveFactorizations", "Printf", "SparseArrays", "StatsBase"]
git-tree-sha1 = "d9b79c4eed437421ac4285148fcadf42e0700e89"
deps = ["Compat", "FillArrays", "ForwardDiff", "LineSearches", "LinearAlgebra", "NLSolversBase", "NaNMath", "Parameters", "PositiveFactorizations", "Printf", "SparseArrays", "StatsBase"]
git-tree-sha1 = "d9b79c4eed437421ac4285148fcadf42e0700e89"
uuid = "429524aa-4258-5aef-a3af-852621145aeb"
version = "1.9.4"

[deps.Optim.extensions]
OptimMOIExt = "MathOptInterface"

[deps.Optim.weakdeps]
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
version = "1.9.4"

[deps.Optim.extensions]
OptimMOIExt = "MathOptInterface"

[deps.Optim.weakdeps]
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"

Expand Down Expand Up @@ -2283,8 +2312,10 @@ version = "1.7.0"
[[deps.StatsBase]]
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"]
git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21"
git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.34.3"
version = "0.34.3"

[[deps.StatsFuns]]
deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"]
Expand Down Expand Up @@ -2564,8 +2595,10 @@ version = "1.1.34+0"
[[deps.XZ_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
git-tree-sha1 = "ac88fb95ae6447c8dda6a5503f3bafd496ae8632"
git-tree-sha1 = "ac88fb95ae6447c8dda6a5503f3bafd496ae8632"
uuid = "ffd25f8a-64ca-5728-b0f7-c24cf3aae800"
version = "5.4.6+0"
version = "5.4.6+0"

[[deps.Xorg_libX11_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libxcb_jll", "Xorg_xtrans_jll"]
Expand Down Expand Up @@ -2629,8 +2662,10 @@ version = "1.2.13+1"
[[deps.Zstd_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl"]
git-tree-sha1 = "e678132f07ddb5bfa46857f0d7620fb9be675d3b"
git-tree-sha1 = "e678132f07ddb5bfa46857f0d7620fb9be675d3b"
uuid = "3161d3a3-bdf6-5164-811a-617609db77b4"
version = "1.5.6+0"
version = "1.5.6+0"

[[deps.Zygote]]
deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"]
Expand Down
6 changes: 3 additions & 3 deletions src/tl_inputs/from_actors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ function control_case_settings(::Type{TMLE.StatisticalATE}, treatments, data)
end

function addEstimands!(estimands, treatments, variables, data; positivity_constraint=0.)
freqs = TargeneCore.frequency_table(data, treatments)
freqs = TMLE.frequency_table(data, treatments)
# This loop adds all ATE estimands where all other treatments than
# the bQTL are fixed, at the order 1, this is the simple bQTL's ATE
for setting in control_case_settings(TMLE.StatisticalATE, treatments, data)
Expand All @@ -134,7 +134,7 @@ function addEstimands!(estimands, treatments, variables, data; positivity_constr
treatment_confounders = NamedTuple{keys(setting)}([variables.confounders for key in keys(setting)]),
outcome_extra_covariates = variables.covariates
)
if satisfies_positivity(Ψ, freqs; positivity_constraint=positivity_constraint)
if TMLE.satisfies_positivity(Ψ, freqs; positivity_constraint=positivity_constraint)
update_estimands_from_outcomes!(estimands, Ψ, variables.targets)
end
end
Expand All @@ -147,7 +147,7 @@ function addEstimands!(estimands, treatments, variables, data; positivity_constr
treatment_confounders = NamedTuple{keys(setting)}([variables.confounders for key in keys(setting)]),
outcome_extra_covariates = variables.covariates
)
if satisfies_positivity(Ψ, freqs; positivity_constraint=positivity_constraint)
if TMLE.satisfies_positivity(Ψ, freqs; positivity_constraint=positivity_constraint)
update_estimands_from_outcomes!(estimands, Ψ, variables.targets)
end
end
Expand Down
81 changes: 66 additions & 15 deletions src/tl_inputs/from_param_files.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ MismatchedCaseControlEncodingError() =

NoRemainingParamsError(positivity_constraint) = ArgumentError(string("No parameter passed the given positivity constraint: ", positivity_constraint))

MismatchedVariableError(variable) = ArgumentError(string("Each component of a ComposedEstimand should contain the same ", variable, " variables."))

function check_genotypes_encoding(val::NamedTuple, type)
if !(typeof(val.case) <: type && typeof(val.control) <: type)
Expand All @@ -27,17 +28,66 @@ check_genotypes_encoding(val::T, type) where T =
T <: type || throw(MismatchedCaseControlEncodingError())


get_treatments(Ψ) = keys.treatment_values)

function get_treatments::ComposedEstimand)
treatments = get_treatments(first.args))
if length.args) > 1
for arg in Ψ.args[2:end]
get_treatments(arg) == treatments || throw(MismatchedVariableError("treatments"))
end
end
return treatments
end

get_confounders(Ψ) = Tuple(Iterators.flatten((Tconf for Tconf Ψ.treatment_confounders)))

function get_confounders::ComposedEstimand)
confounders = get_confounders(first.args))
if length.args) > 1
for arg in Ψ.args[2:end]
get_confounders(arg) == confounders || throw(MismatchedVariableError("confounders"))
end
end
return confounders
end

get_outcome_extra_covariates(Ψ) = Ψ.outcome_extra_covariates

function get_outcome_extra_covariates::ComposedEstimand)
outcome_extra_covariates = get_outcome_extra_covariates(first.args))
if length.args) > 1
for arg in Ψ.args[2:end]
get_outcome_extra_covariates(arg) == outcome_extra_covariates || throw(MismatchedVariableError("outcome extra covariates"))
end
end
return outcome_extra_covariates
end

get_outcome(Ψ) = Ψ.outcome

function get_outcome::ComposedEstimand)
outcome = get_outcome(first.args))
if length.args) > 1
for arg in Ψ.args[2:end]
get_outcome(arg) == outcome || throw(MismatchedVariableError("outcome"))
end
end
return outcome
end

function get_variables(estimands, traits, pcs)
genetic_variants = Set{Symbol}()
others = Set{Symbol}()
pcs = Set{Symbol}(filter(x -> x != :SAMPLE_ID, propertynames(pcs)))
alltraits = Set{Symbol}(filter(x -> x != :SAMPLE_ID, propertynames(traits)))
for Ψ in estimands
treatments = keys.treatment_values)
confounders = Iterators.flatten((Tconf for Tconf Ψ.treatment_confounders))
treatments = get_treatments(Ψ)
confounders = get_confounders(Ψ)
outcome_extra_covariates = get_outcome_extra_covariates(Ψ)
push!(
others,
Ψ.outcome_extra_covariates...,
outcome_extra_covariates...,
confounders...,
treatments...
)
Expand Down Expand Up @@ -123,6 +173,8 @@ function adjust_parameter_sections(Ψ::T, variants_alleles, pcs) where T<:TMLE.E
return T(outcome=Ψ.outcome, treatment_values=treatments, treatment_confounders=confounders, outcome_extra_covariates=Ψ.outcome_extra_covariates)
end

adjust_parameter_sections::ComposedEstimand, variants_alleles, pcs) =
ComposedEstimand.f, Tuple(adjust_parameter_sections(arg, variants_alleles, pcs) for arg in Ψ.args))

function append_from_valid_estimands!(
estimands::Vector{<:TMLE.Estimand},
Expand All @@ -136,29 +188,28 @@ function append_from_valid_estimands!(
# Update treatment's and confounders's sections of Ψ
Ψ = adjust_parameter_sections(Ψ, variants_alleles, variables.pcs)
# Update frequency tables with current treatments
treatments = sorted_treatment_names(Ψ)
treatments = get_treatments(Ψ)
if !haskey(frequency_tables, treatments)
frequency_tables[treatments] = TargeneCore.frequency_table(data, collect(treatments))
frequency_tables[treatments] = TMLE.frequency_table(data, treatments)
end
# Check if parameter satisfies positivity
satisfies_positivity(Ψ, frequency_tables[treatments];
positivity_constraint=positivity_constraint) || return
# Expand wildcard to all outcomes
if Ψ.outcome === :ALL
update_estimands_from_outcomes!(estimands, Ψ, variables.outcomes)
else
# Ψ.target || MissingVariableError(variable)
push!(estimands, Ψ)
if TMLE.satisfies_positivity(Ψ, frequency_tables[treatments]; positivity_constraint=positivity_constraint)
# Expand wildcard to all outcomes
if get_outcome(Ψ) === :ALL
update_estimands_from_outcomes!(estimands, Ψ, variables.outcomes)
else
push!(estimands, Ψ)
end
end
end

function adjusted_estimands(estimands, variables, data; positivity_constraint=0.)
final_estimands = TMLE.Estimand[]
variants_alleles = Dict(v => Set(unique(skipmissing(data[!, v]))) for v in variables.genetic_variants)
freqency_tables = Dict()
frequency_tables = Dict()
for Ψ in estimands
# If the genotypes encoding is a string representation make sure they match the actual genotypes
append_from_valid_estimands!(final_estimands, freqency_tables, Ψ, data, variants_alleles, variables; positivity_constraint=positivity_constraint)
append_from_valid_estimands!(final_estimands, frequency_tables, Ψ, data, variants_alleles, variables; positivity_constraint=positivity_constraint)
end

length(final_estimands) > 0 || throw(NoRemainingParamsError(positivity_constraint))
Expand Down
Loading

0 comments on commit c1c303f

Please sign in to comment.