From cac043ebbffe367c5c0e0a76708638abc69e59b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?mutlu=20=C5=9Fim=C5=9Fek?= Date: Thu, 28 Nov 2024 18:58:18 +0300 Subject: [PATCH] gen threshold fixed --- Cargo.toml | 2 +- python-package/Cargo.toml | 4 ++-- python-package/examples/benchmark_perpetual.py | 2 ++ python-package/pyproject.toml | 2 +- src/booster.rs | 6 +++--- src/constants.rs | 3 ++- 6 files changed, 11 insertions(+), 8 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6ba8f23..ad10ab9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "perpetual" -version = "0.7.7" +version = "0.7.8" edition = "2021" authors = ["Mutlu Simsek "] homepage = "https://perpetual-ml.com" diff --git a/python-package/Cargo.toml b/python-package/Cargo.toml index 459420d..8125ef7 100644 --- a/python-package/Cargo.toml +++ b/python-package/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "py-perpetual" -version = "0.7.7" +version = "0.7.8" edition = "2021" authors = ["Mutlu Simsek "] homepage = "https://perpetual-ml.com" @@ -19,7 +19,7 @@ crate-type = ["cdylib", "rlib"] [dependencies] pyo3 = { version = "0.22.6", features = ["extension-module"] } -perpetual_rs = {package="perpetual", version = "0.7.7", path = "../" } +perpetual_rs = {package="perpetual", version = "0.7.8", path = "../" } numpy = "0.22.1" ndarray = "0.16.1" serde_plain = { version = "1.0" } diff --git a/python-package/examples/benchmark_perpetual.py b/python-package/examples/benchmark_perpetual.py index f69146f..4ac1856 100644 --- a/python-package/examples/benchmark_perpetual.py +++ b/python-package/examples/benchmark_perpetual.py @@ -4,6 +4,7 @@ from sklearn.model_selection import train_test_split from sklearn.metrics import mean_squared_error, log_loss from sklearn.datasets import fetch_covtype, fetch_california_housing +from importlib.metadata import version def prepare_data(cal_housing, seed): @@ -24,6 +25,7 @@ def prepare_data(cal_housing, seed): if __name__ == "__main__": + print(f"perpetual: {version('perpetual')}") budget = 1.0 num_threads = 2 cal_housing = True # True -> California Housing, False -> Cover Types diff --git a/python-package/pyproject.toml b/python-package/pyproject.toml index 641747b..e6b0a43 100644 --- a/python-package/pyproject.toml +++ b/python-package/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "maturin" [project] name = "perpetual" -version = "0.7.7" +version = "0.7.8" description = "A self-generalizing gradient boosting machine which doesn't need hyperparameter optimization" license = { file = "LICENSE" } keywords = [ diff --git a/src/booster.rs b/src/booster.rs index 703ee94..274ae4d 100644 --- a/src/booster.rs +++ b/src/booster.rs @@ -1,8 +1,8 @@ use crate::bin::Bin; use crate::binning::bin_matrix; use crate::constants::{ - FREE_MEM_ALLOC_FACTOR, GENERALIZATION_THRESHOLD, ITER_LIMIT, MIN_COL_AMOUNT, N_NODES_ALLOC_LIMIT, STOPPING_ROUNDS, - TIMEOUT_FACTOR, + FREE_MEM_ALLOC_FACTOR, GENERALIZATION_THRESHOLD_FLEX, ITER_LIMIT, MIN_COL_AMOUNT, N_NODES_ALLOC_LIMIT, + STOPPING_ROUNDS, TIMEOUT_FACTOR, }; use crate::constraints::ConstraintMap; use crate::data::Matrix; @@ -524,7 +524,7 @@ impl PerpetualBooster { .map(|n| n.generalization.unwrap_or(0.0)) .max_by(|a, b| a.total_cmp(b)) .unwrap_or(0.0); - if generalization < GENERALIZATION_THRESHOLD && tree.stopper != TreeStopper::LossDecrement { + if generalization < GENERALIZATION_THRESHOLD_FLEX && tree.stopper != TreeStopper::LossDecrement { stopping += 1; // If root node cannot be split due to no positive split gain, stop boosting. if tree.nodes.len() == 1 { diff --git a/src/constants.rs b/src/constants.rs index c224d19..285eb1c 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -2,7 +2,8 @@ pub const STOPPING_ROUNDS: usize = 3; pub const FREE_MEM_ALLOC_FACTOR: f32 = 0.9; pub const N_NODES_ALLOC_LIMIT: usize = 3000; pub const ITER_LIMIT: usize = 1000; -pub const GENERALIZATION_THRESHOLD: f32 = 0.99; +pub const GENERALIZATION_THRESHOLD: f32 = 1.0; +pub const GENERALIZATION_THRESHOLD_FLEX: f32 = 0.99; pub const MIN_COL_AMOUNT: usize = 40; pub const HESSIAN_EPS: f32 = 1e-3; pub const TIMEOUT_FACTOR: f32 = 0.95;