Skip to content

Commit

Permalink
Merge pull request #7 from promised-ai/feature/pickle
Browse files Browse the repository at this point in the history
Pickle Support
  • Loading branch information
schmidmt authored Jul 26, 2023
2 parents 1fa82c9 + 50908d9 commit 84c77a2
Show file tree
Hide file tree
Showing 17 changed files with 376 additions and 82 deletions.
12 changes: 12 additions & 0 deletions changepoint/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Changelog

All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.3.1] -

### Added

- (#7) Added `PartialEq` derive macro to `Bocpd` and `Argpcpd` structs.
2 changes: 1 addition & 1 deletion changepoint/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "changepoint"
version = "0.14.0"
version = "0.14.1"
authors = [
"Mike Schmidt <[email protected]>",
"Baxter Eaves <[email protected]>",
Expand Down
2 changes: 1 addition & 1 deletion changepoint/src/bocpd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use serde::{Deserialize, Serialize};

/// Online Bayesian Change Point Detection state container.
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq)]
pub struct Bocpd<X, Fx, Pr>
where
Fx: Rv<X> + HasSuffStat<X>,
Expand Down
11 changes: 11 additions & 0 deletions pychangepoint/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Changelog

All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.3.1] -
- Added Pickle support for `ArgpCpd` and `Bocpd`
- Fix type error that occurred converting numpy `bool_` and `int` types
- Added default arguments `NormalInvChiSquared` constructor (`m=0, k=1, v=1, s2=1`)
6 changes: 4 additions & 2 deletions pychangepoint/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "pychangepoint"
description = "A Python module for changepoint"
version = "0.3.0"
version = "0.3.1"
repository = "https://github.com/promised-ai/changepoint"
authors = ["Redpoll <[email protected]>"]
edition = "2021"
Expand All @@ -16,12 +16,14 @@ name = "changepoint"
crate-type = ["cdylib"]

[dependencies]
changepoint = { path = "../changepoint" }
bincode = "1.3.3"
changepoint = { path = "../changepoint", features = ["serde1"] }
nalgebra = { version = "0.32" }
numpy = "0.19"
pyo3 = { version ="0.19", features = ["extension-module"] }
rand = { version = "0.8", features = ["small_rng"] }
rv = { version = "0.16", features = ["process", "arraydist"] }
serde = "1.0.175"

[features]
extension-module = ["pyo3/extension-module"]
Expand Down
21 changes: 16 additions & 5 deletions pychangepoint/ChangePointExample.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,19 @@
"\n",
" data_lower, data_upper = 0.975 * np.min(data), 1.05 * np.max(data)\n",
"\n",
" regime_bot, regime_top = data_lower - 0.1 * (data_upper - data_lower), data_lower\n",
" regime_bot, regime_top = (\n",
" data_lower - 0.1 * (data_upper - data_lower),\n",
" data_lower,\n",
" )\n",
" data_ax.set_ylim(regime_bot, data_upper)\n",
"\n",
" map_change_points = np.hstack([chp.map_changepoints(change_point_history), n])\n",
" map_change_points = np.hstack(\n",
" [chp.map_changepoints(change_point_history), n]\n",
" )\n",
"\n",
" for i, (a, b) in enumerate(zip(map_change_points[:-1], map_change_points[1:])):\n",
" for i, (a, b) in enumerate(\n",
" zip(map_change_points[:-1], map_change_points[1:])\n",
" ):\n",
" data_ax.fill_between(\n",
" [a, b],\n",
" [regime_bot, regime_bot],\n",
Expand Down Expand Up @@ -496,7 +503,9 @@
}
],
"source": [
"_, (data_ax, _data_ax), change_points = change_point_plot(cpi, change_point_history)\n",
"_, (data_ax, _data_ax), change_points = change_point_plot(\n",
" cpi, change_point_history\n",
")\n",
"\n",
"data_ax.set_xticks([5 * 12 * i for i in range(cpi.shape[0] // (5 * 12))])\n",
"data_ax.set_xticklabels([cpi.index[i][:4] for i in data_ax.get_xticks()])\n",
Expand Down Expand Up @@ -826,7 +835,9 @@
}
],
"source": [
"_, (data_ax, _data_ax), change_points = change_point_plot(cpi, change_point_history)\n",
"_, (data_ax, _data_ax), change_points = change_point_plot(\n",
" cpi, change_point_history\n",
")\n",
"\n",
"data_ax.set_xticks([5 * 12 * i for i in range(cpi.shape[0] // (5 * 12))])\n",
"data_ax.set_xticklabels([cpi.index[i][:4] for i in data_ax.get_xticks()])\n",
Expand Down
24 changes: 12 additions & 12 deletions pychangepoint/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,26 @@ See [rustup.rs](https://rustup.rs/) for instructions on installing Rust.
## Quick Docs
By convention in these docs and examples,
```python
import changepoint as chp
import changepoint as cpt
```

### Models
#### Bocpd

The Bayesian change point detector, `Bocpd`, takes a prior distribution, aka one of
```python,ignore
chp.BetaBernoulli
chp.NormalGamma
chp.NormalInvChiSquared
chp.NormalInvGamma
chp.NormalInvWishart
chp.PoissonGamma
cpt.BetaBernoulli
cpt.NormalGamma
cpt.NormalInvChiSquared
cpt.NormalInvGamma
cpt.NormalInvWishart
cpt.PoissonGamma
```

Then, a `Bocpd` may be created:
```python
cpd = chp.Bocpd(
prior=chp.NormalGamma(),
cpd = cpt.Bocpd(
prior=cpt.NormalGamma(),
lam=12,
)
```
Expand All @@ -58,7 +58,7 @@ change_point_history = np.zeros((n, n))
for i, x in enumerate(data):
change_point_history[i, : i + 1] = cpd.step(x)

print(chp.map_changepoints(change_point_history))
print(cpt.map_changepoints(change_point_history))
```


Expand All @@ -69,7 +69,7 @@ where `c` is the scale, `X_{i-l-1, ..., i-1}` is the previous vales in the seque
It behaves similarity to the `Bocpd` class; for example,

```python
argp = chp.ArgpCpd(logistic_hazard_h=-2, scale=3, noise_level=0.01)
argp = cpt.ArgpCpd(logistic_hazard_h=-2, scale=3, noise_level=0.01)
n = len(data)
change_point_history = np.zeros((n + 1, n + 1))
xs = []
Expand All @@ -78,7 +78,7 @@ for i, x in enumerate(data):
cps = argp.step(x)
change_point_history[i, : len(cps)] = cps

print(chp.map_changepoints(change_point_history))
print(cpt.map_changepoints(change_point_history))
```

## Example
Expand Down
5 changes: 4 additions & 1 deletion pychangepoint/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "maturin"

[project]
name = "changepoint"
version = "0.3.0"
version = "0.3.1"
authors = [
{name = "Redpoll", email = "[email protected]" },
]
Expand All @@ -28,3 +28,6 @@ Repository = "https://github.com/promised-ai/changepoint"

[tool.maturin]
bindings = "pyo3"

[tool.black]
line-length = 80
34 changes: 33 additions & 1 deletion pychangepoint/src/argpcpd.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use bincode::{deserialize, serialize};
use changepoint::gp::Argpcp;
use changepoint::BocpdLike;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::pyclass::CompareOp;
use pyo3::types::PyBytes;
use rv::process::gaussian::kernel::{
AddKernel, ConstantKernel, ProductKernel, RBFKernel, WhiteKernel,
};
Expand Down Expand Up @@ -30,7 +33,7 @@ use rv::process::gaussian::kernel::{
/// Roughly the slope of the logistic hazard function
/// logistic_hazard_b: float
/// The offset of the logistic hazard function.
#[pyclass]
#[pyclass(module = "changepoint")]
pub struct ArgpCpd {
argpcpd: Argpcp<
AddKernel<ProductKernel<ConstantKernel, RBFKernel>, WhiteKernel>,
Expand Down Expand Up @@ -105,4 +108,33 @@ impl ArgpCpd {
pub fn step(&mut self, datum: f64) -> Vec<f64> {
self.argpcpd.step(&datum).to_vec()
}

pub fn __setstate__(
&mut self,
py: Python,
state: PyObject,
) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
Ok(s) => {
self.argpcpd = deserialize(s.as_bytes()).unwrap();
Ok(())
}
Err(e) => Err(e),
}
}

pub fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
Ok(PyBytes::new(py, &serialize(&self.argpcpd).unwrap()).to_object(py))
}

fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult<bool> {
match op {
CompareOp::Lt => Ok(false),
CompareOp::Le => Ok(false),
CompareOp::Eq => Ok(self.argpcpd == other.argpcpd),
CompareOp::Ne => Ok(self.argpcpd != other.argpcpd),
CompareOp::Gt => Ok(false),
CompareOp::Ge => Ok(false),
}
}
}
Loading

0 comments on commit 84c77a2

Please sign in to comment.