Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gen threshold fixed #33

Merged
merged 1 commit into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "perpetual"
version = "0.7.7"
version = "0.7.8"
edition = "2021"
authors = ["Mutlu Simsek <[email protected]>"]
homepage = "https://perpetual-ml.com"
Expand Down
4 changes: 2 additions & 2 deletions python-package/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "py-perpetual"
version = "0.7.7"
version = "0.7.8"
edition = "2021"
authors = ["Mutlu Simsek <[email protected]>"]
homepage = "https://perpetual-ml.com"
Expand All @@ -19,7 +19,7 @@ crate-type = ["cdylib", "rlib"]

[dependencies]
pyo3 = { version = "0.22.6", features = ["extension-module"] }
perpetual_rs = {package="perpetual", version = "0.7.7", path = "../" }
perpetual_rs = {package="perpetual", version = "0.7.8", path = "../" }
numpy = "0.22.1"
ndarray = "0.16.1"
serde_plain = { version = "1.0" }
Expand Down
2 changes: 2 additions & 0 deletions python-package/examples/benchmark_perpetual.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, log_loss
from sklearn.datasets import fetch_covtype, fetch_california_housing
from importlib.metadata import version


def prepare_data(cal_housing, seed):
Expand All @@ -24,6 +25,7 @@ def prepare_data(cal_housing, seed):


if __name__ == "__main__":
print(f"perpetual: {version('perpetual')}")
budget = 1.0
num_threads = 2
cal_housing = True # True -> California Housing, False -> Cover Types
Expand Down
2 changes: 1 addition & 1 deletion python-package/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "maturin"

[project]
name = "perpetual"
version = "0.7.7"
version = "0.7.8"
description = "A self-generalizing gradient boosting machine which doesn't need hyperparameter optimization"
license = { file = "LICENSE" }
keywords = [
Expand Down
6 changes: 3 additions & 3 deletions src/booster.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::bin::Bin;
use crate::binning::bin_matrix;
use crate::constants::{
FREE_MEM_ALLOC_FACTOR, GENERALIZATION_THRESHOLD, ITER_LIMIT, MIN_COL_AMOUNT, N_NODES_ALLOC_LIMIT, STOPPING_ROUNDS,
TIMEOUT_FACTOR,
FREE_MEM_ALLOC_FACTOR, GENERALIZATION_THRESHOLD_FLEX, ITER_LIMIT, MIN_COL_AMOUNT, N_NODES_ALLOC_LIMIT,
STOPPING_ROUNDS, TIMEOUT_FACTOR,
};
use crate::constraints::ConstraintMap;
use crate::data::Matrix;
Expand Down Expand Up @@ -524,7 +524,7 @@ impl PerpetualBooster {
.map(|n| n.generalization.unwrap_or(0.0))
.max_by(|a, b| a.total_cmp(b))
.unwrap_or(0.0);
if generalization < GENERALIZATION_THRESHOLD && tree.stopper != TreeStopper::LossDecrement {
if generalization < GENERALIZATION_THRESHOLD_FLEX && tree.stopper != TreeStopper::LossDecrement {
stopping += 1;
// If root node cannot be split due to no positive split gain, stop boosting.
if tree.nodes.len() == 1 {
Expand Down
3 changes: 2 additions & 1 deletion src/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ pub const STOPPING_ROUNDS: usize = 3;
pub const FREE_MEM_ALLOC_FACTOR: f32 = 0.9;
pub const N_NODES_ALLOC_LIMIT: usize = 3000;
pub const ITER_LIMIT: usize = 1000;
pub const GENERALIZATION_THRESHOLD: f32 = 0.99;
pub const GENERALIZATION_THRESHOLD: f32 = 1.0;
pub const GENERALIZATION_THRESHOLD_FLEX: f32 = 0.99;
pub const MIN_COL_AMOUNT: usize = 40;
pub const HESSIAN_EPS: f32 = 1e-3;
pub const TIMEOUT_FACTOR: f32 = 0.95;
Loading