Skip to content

Commit

Permalink
Log all complexities over time
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Jan 6, 2024
1 parent 65f413b commit 8fb1ab6
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 10 deletions.
4 changes: 4 additions & 0 deletions src/MLJInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ function modelexpr(model_name::Symbol)
addprocs_function::Union{Function,Nothing} = nothing
heap_size_hint_in_bytes::Union{Integer,Nothing} = nothing
logger::Union{AbstractLogger,Nothing} = nothing
logging_callback::Union{Function,Nothing} = nothing
log_every_n::Int = 1
runtests::Bool = true
loss_type::L = Nothing
selection_method::Function = choose_best
Expand Down Expand Up @@ -177,6 +179,8 @@ function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options)
y_units=y_units_clean,
verbosity=verbosity,
logger=m.logger,
logging_callback=m.logging_callback,
log_every_n=m.log_every_n,
# Help out with inference:
v_dim_out=isa(m, SRRegressor) ? Val(1) : Val(2),
)
Expand Down
20 changes: 11 additions & 9 deletions src/SearchUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ end
function default_logging_callback(logger; options, num_evals, hall_of_fame, datasets, _...)
L = typeof(first(datasets).baseline_loss)
with_logger(logger) do
@info("search_state", num_evals = sum(sum, num_evals))
d = Dict()
for (i, (hof, dataset)) in enumerate(zip(hall_of_fame, datasets))
dominating = calculate_pareto_frontier(hof)
best_loss = length(dominating) > 0 ? dominating[end].loss : L(Inf)
Expand All @@ -417,15 +417,17 @@ function default_logging_callback(logger; options, num_evals, hall_of_fame, data
string_tree(member.tree, options; variable_names=dataset.variable_names) for
member in dominating
]
@info(
"search_state_$(i)",
best_loss = best_loss,
equations = equations,
losses = losses,
complexities = complexities,
log_step_increment = 0,
)
d[string(i)] = Dict()
d[string(i)]["best_loss"] = best_loss
d[string(i)]["equations"] = Dict()
for (complexity, loss, equation) in zip(complexities, losses, equations)
d[string(i)]["equations"][string(complexity)] = Dict(
"loss" => loss, "equation" => equation
)
end
end
d["num_evals"] = sum(sum, num_evals)
@info("search_state", data = d)
end
end

Expand Down
8 changes: 7 additions & 1 deletion src/SymbolicRegression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ function equation_search(
verbosity::Union{Integer,Nothing}=nothing,
logger::Union{AbstractLogger,Nothing}=nothing,
logging_callback::Union{Function,Nothing}=nothing,
log_every_n::Int=1,
progress::Union{Bool,Nothing}=nothing,
X_units::Union{AbstractVector,Nothing}=nothing,
y_units=nothing,
Expand Down Expand Up @@ -403,6 +404,7 @@ function equation_search(
verbosity=verbosity,
logger=logger,
logging_callback=logging_callback,
log_every_n=log_every_n,
progress=progress,
v_dim_out=Val(DIM_OUT),
)
Expand Down Expand Up @@ -574,6 +576,7 @@ function equation_search(
saved_state,
_verbosity,
_logging_callback,
log_every_n,
_progress,
Val(_return_state),
)
Expand All @@ -593,6 +596,7 @@ function _equation_search(
saved_state,
verbosity,
logging_callback,
log_every_n,
progress,
::Val{RETURN_STATE},
) where {T<:DATA_TYPE,L<:LOSS_TYPE,D<:Dataset{T,L},PARALLELISM,RETURN_STATE,DIM_OUT}
Expand Down Expand Up @@ -800,6 +804,7 @@ function _equation_search(
)
end

log_step = 0
last_print_time = time()
last_speed_recording_time = time()
num_evals_last = sum(sum, num_evals)
Expand Down Expand Up @@ -971,7 +976,7 @@ function _equation_search(
PARALLELISM,
)
end
if logging_callback !== nothing
if logging_callback !== nothing && log_step % log_every_n == 0
logging_callback(;
options,
num_evals,
Expand All @@ -982,6 +987,7 @@ function _equation_search(
datasets=datasets,
)
end
log_step += 1
end
sleep(1e-6)

Expand Down

0 comments on commit 8fb1ab6

Please sign in to comment.