Skip to content

Commit

Permalink
Merge pull request #161 from MilesCranmer/snoop-compile2
Browse files Browse the repository at this point in the history
Proper SnoopCompilation
  • Loading branch information
MilesCranmer committed Nov 28, 2022
2 parents 9857863 + 5e2afab commit a093714
Show file tree
Hide file tree
Showing 17 changed files with 166 additions and 86 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ build
*.csv.out*
pysr_recorder.json
docs/src/index.md
*.code-workspace
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SymbolicRegression"
uuid = "8254be44-1295-4e6a-a16d-46603ac705cb"
authors = ["MilesCranmer <[email protected]>"]
version = "0.14.4"
version = "0.14.5"

[deps]
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Expand All @@ -15,20 +15,22 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SnoopPrecompile = "66db9d55-30c0-4569-8b51-7e840670fc0c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"

[compat]
DynamicExpressions = "0.4"
DynamicExpressions = "0.4.2"
JSON3 = "1"
LineSearches = "7"
LossFunctions = "0.6, 0.7, 0.8"
Optim = "0.19, 1.1"
Pkg = "1"
ProgressBars = "1.4"
Reexport = "1"
SnoopPrecompile = "1"
SpecialFunctions = "0.10.1, 1, 2"
StatsBase = "0.33"
SymbolicUtils = "0.19"
Expand Down
2 changes: 1 addition & 1 deletion src/AdaptiveParsimony.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ for an equation at size `size`.
@inline function update_frequencies!(
running_search_statistics::RunningSearchStatistics; size=nothing
)
if size <= length(running_search_statistics.frequencies)
if 0 < size <= length(running_search_statistics.frequencies)
running_search_statistics.frequencies[size] += 1
end
return nothing
Expand Down
3 changes: 2 additions & 1 deletion src/CheckConstraints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ end

"""Check if user-passed constraints are violated or not"""
function check_constraints(tree::Node, options::Options, maxsize::Int)::Bool
if compute_complexity(tree, options) > maxsize
size = compute_complexity(tree, options)
if 0 > size > maxsize
return false
end
for i in 1:(options.nbin)
Expand Down
4 changes: 1 addition & 3 deletions src/Complexity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ function compute_complexity(tree::Node, options::Options)::Int
end
end

function _compute_complexity(
tree::Node, options::Options{C,complexity_type}
)::complexity_type where {C,complexity_type<:Real}
function _compute_complexity(tree::Node, options::Options{CT})::CT where {CT<:Real}
if tree.degree == 0
if tree.constant
return options.complexity_mapping.constant_complexity
Expand Down
10 changes: 1 addition & 9 deletions src/Core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,7 @@ include("OptionsStruct.jl")
include("Operators.jl")
include("Options.jl")

import .ProgramConstantsModule:
MAX_DEGREE,
BATCH_DIM,
FEATURE_DIM,
RecordType,
SRConcurrency,
SRSerial,
SRThreaded,
SRDistributed
import .ProgramConstantsModule: MAX_DEGREE, BATCH_DIM, FEATURE_DIM, RecordType
import .DatasetModule: Dataset
import .OptionsStructModule: Options, MutationWeights, sample_mutation
import .OptionsModule: Options
Expand Down
4 changes: 2 additions & 2 deletions src/Mutate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,12 @@ function next_generation(
if options.use_frequency
oldSize = compute_complexity(prev, options)
newSize = compute_complexity(tree, options)
old_frequency = if (oldSize <= options.maxsize)
old_frequency = if (0 < oldSize <= options.maxsize)
running_search_statistics.normalized_frequencies[oldSize]
else
1e-6
end
new_frequency = if (newSize <= options.maxsize)
new_frequency = if (0 < newSize <= options.maxsize)
running_search_statistics.normalized_frequencies[newSize]
else
1e-6
Expand Down
9 changes: 3 additions & 6 deletions src/Options.jl
Original file line number Diff line number Diff line change
Expand Up @@ -600,17 +600,13 @@ function Options(;
end
end

options = Options{
typeof(loss),
eltype(complexity_mapping),
tournament_selection_p,
tournament_selection_n,
}(
options = Options{eltype(complexity_mapping)}(
operators,
bin_constraints,
una_constraints,
complexity_mapping,
tournament_selection_n,
tournament_selection_p,
parsimony,
alpha,
maxsize,
Expand Down Expand Up @@ -659,6 +655,7 @@ function Options(;
skip_mutation_failures,
nested_constraints,
deterministic,
define_helper_functions,
)

return options
Expand Down
8 changes: 5 additions & 3 deletions src/OptionsStruct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,13 @@ function ComplexityMapping(;
)
end

struct Options{LossType<:Union{SupervisedLoss,Function},ComplexityType,_prob_pick_first,_ns}
struct Options{CT}
operators::AbstractOperatorEnum
bin_constraints::Vector{Tuple{Int,Int}}
una_constraints::Vector{Int}
complexity_mapping::ComplexityMapping{ComplexityType}
complexity_mapping::ComplexityMapping{CT}
tournament_selection_n::Int
tournament_selection_p::Float32
parsimony::Float32
alpha::Float32
maxsize::Int
Expand Down Expand Up @@ -140,7 +141,7 @@ struct Options{LossType<:Union{SupervisedLoss,Function},ComplexityType,_prob_pic
nuna::Int
nbin::Int
seed::Union{Int,Nothing}
loss::LossType
loss::Union{SupervisedLoss,Function}
progress::Bool
terminal_width::Union{Int,Nothing}
optimizer_algorithm::String
Expand All @@ -157,6 +158,7 @@ struct Options{LossType<:Union{SupervisedLoss,Function},ComplexityType,_prob_pic
skip_mutation_failures::Bool
nested_constraints::Union{Vector{Tuple{Int,Int,Vector{Tuple{Int,Int,Int}}}},Nothing}
deterministic::Bool
define_helper_functions::Bool
end

function Base.print(io::IO, options::Options)
Expand Down
9 changes: 6 additions & 3 deletions src/Population.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,13 @@ end
function best_of_sample(
pop::Population{T},
running_search_statistics::RunningSearchStatistics,
options::Options{A,B,p,tournament_selection_n},
)::PopMember where {T<:Real,A,B,p,tournament_selection_n}
options::Options{CT},
)::PopMember where {T<:Real,CT}
sample = sample_pop(pop, options)

p = options.tournament_selection_p
tournament_selection_n = options.tournament_selection_n

if options.use_frequency_in_tournament
# Score based on frequency of that size occuring.
# In the end, all sizes should be just as common in the population.
Expand All @@ -92,7 +95,7 @@ function best_of_sample(
scores = Vector{T}(undef, tournament_selection_n)
for (i, member) in enumerate(sample.members)
size = compute_complexity(member.tree, options)
frequency = if (size <= options.maxsize)
frequency = if (0 < size <= options.maxsize)
running_search_statistics.normalized_frequencies[size]
else
T(0)
Expand Down
6 changes: 0 additions & 6 deletions src/ProgramConstants.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,4 @@ const BATCH_DIM = 2
const FEATURE_DIM = 1
const RecordType = Dict{String,Any}

"""Enum for concurrency type (to get function specialization)"""
abstract type SRConcurrency end
struct SRSerial <: SRConcurrency end
struct SRThreaded <: SRConcurrency end
struct SRDistributed <: SRConcurrency end

end
20 changes: 10 additions & 10 deletions src/SearchUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import Printf: @printf, @sprintf
using Distributed
import StatsBase: mean

import ..CoreModule: SRThreaded, SRSerial, SRDistributed, Dataset, Options
import ..CoreModule: Dataset, Options
import ..ComplexityModule: compute_complexity
import ..PopulationModule: Population, copy_population
import ..HallOfFameModule:
Expand All @@ -32,11 +32,11 @@ end

macro sr_spawner(parallel, p, expr)
quote
if $(esc(parallel)) == SRSerial
if $(esc(parallel)) == :serial
$(esc(expr))
elseif $(esc(parallel)) == SRDistributed
elseif $(esc(parallel)) == :multiprocessing
@spawnat($(esc(p)), $(esc(expr)))
elseif $(esc(parallel)) == SRThreaded
elseif $(esc(parallel)) == :multithreading
Threads.@spawn($(esc(expr)))
else
error("Invalid parallel type.")
Expand Down Expand Up @@ -197,8 +197,8 @@ function estimate_work_fraction(monitor::ResourceMonitor)::Float64
return mean(work_intervals) / (mean(work_intervals) + mean(rest_intervals))
end

function get_load_string(; head_node_occupation::Float64, ConcurrencyType=SRSerial)
ConcurrencyType == SRSerial && return ""
function get_load_string(; head_node_occupation::Float64, parallelism=:serial)
parallelism == :serial && return ""
out = @sprintf("Head worker occupation: %.1f%%", head_node_occupation * 100)

raise_usage_warning = head_node_occupation > 0.2
Expand All @@ -218,11 +218,11 @@ function update_progress_bar!(
dataset::Dataset{T},
options::Options,
head_node_occupation::Float64,
ConcurrencyType=SRSerial,
parallelism=:serial,
) where {T}
equation_strings = string_dominating_pareto_curve(hall_of_fame, dataset, options)
# TODO - include command about "q" here.
load_string = get_load_string(; head_node_occupation, ConcurrencyType)
load_string = get_load_string(; head_node_occupation, parallelism)
load_string *= @sprintf("Press 'q' and then <enter> to stop execution early.\n")
equation_strings = load_string * equation_strings
set_multiline_postfix!(progress_bar, equation_strings)
Expand All @@ -238,14 +238,14 @@ function print_search_state(
total_cycles::Int,
cycles_remaining::Vector{Int},
head_node_occupation::Float64,
ConcurrencyType=SRSerial,
parallelism=:serial,
) where {T}
nout = length(datasets)
average_speed = sum(equation_speed) / length(equation_speed)

@printf("\n")
@printf("Cycles per second: %.3e\n", round(average_speed, sigdigits=3))
load_string = get_load_string(; head_node_occupation, ConcurrencyType)
load_string = get_load_string(; head_node_occupation, parallelism)
print(load_string)
cycles_elapsed = total_cycles * nout - sum(cycles_remaining)
@printf(
Expand Down
2 changes: 1 addition & 1 deletion src/SingleIteration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ function s_r_cycle(
for member in pop.members
size = compute_complexity(member.tree, options)
score = member.score
if size <= options.maxsize && (
if 0 < size <= options.maxsize && (
!best_examples_seen.exists[size] ||
score < best_examples_seen.members[size].score
)
Expand Down

2 comments on commit a093714

@MilesCranmer
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/73009

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.14.5 -m "<description of version>" a093714e9416b9ccc0678a48c566948908a9708e
git push origin v0.14.5

Please sign in to comment.