Skip to content

Commit

Permalink
feat(python, rust): Added pickle support.
Browse files Browse the repository at this point in the history
chore(python): Fixed formatting.

chore: Added changelog for rust package
  • Loading branch information
schmidmt committed Jul 25, 2023
1 parent fa284c8 commit 50908d9
Show file tree
Hide file tree
Showing 8 changed files with 169 additions and 12 deletions.
12 changes: 12 additions & 0 deletions changepoint/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Changelog

All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.3.1] -

### Added

- (#7) Added `PartialEq` derive macro to `Bocpd` and `Argpcpd` structs.
2 changes: 1 addition & 1 deletion changepoint/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "changepoint"
version = "0.14.0"
version = "0.14.1"
authors = [
"Mike Schmidt <[email protected]>",
"Baxter Eaves <[email protected]>",
Expand Down
2 changes: 1 addition & 1 deletion changepoint/src/bocpd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use serde::{Deserialize, Serialize};

/// Online Bayesian Change Point Detection state container.
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq)]
pub struct Bocpd<X, Fx, Pr>
where
Fx: Rv<X> + HasSuffStat<X>,
Expand Down
8 changes: 7 additions & 1 deletion pychangepoint/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

## v0.3.1
All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.3.1] -
- Added Pickle support for `ArgpCpd` and `Bocpd`
- Fix type error that occurred converting numpy `bool_` and `int` types
- Added default arguments `NormalInvChiSquared` constructor (`m=0, k=1, v=1, s2=1`)
4 changes: 3 additions & 1 deletion pychangepoint/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ name = "changepoint"
crate-type = ["cdylib"]

[dependencies]
changepoint = { path = "../changepoint" }
bincode = "1.3.3"
changepoint = { path = "../changepoint", features = ["serde1"] }
nalgebra = { version = "0.32" }
numpy = "0.19"
pyo3 = { version ="0.19", features = ["extension-module"] }
rand = { version = "0.8", features = ["small_rng"] }
rv = { version = "0.16", features = ["process", "arraydist"] }
serde = "1.0.175"

[features]
extension-module = ["pyo3/extension-module"]
Expand Down
34 changes: 33 additions & 1 deletion pychangepoint/src/argpcpd.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use bincode::{deserialize, serialize};
use changepoint::gp::Argpcp;
use changepoint::BocpdLike;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::pyclass::CompareOp;
use pyo3::types::PyBytes;
use rv::process::gaussian::kernel::{
AddKernel, ConstantKernel, ProductKernel, RBFKernel, WhiteKernel,
};
Expand Down Expand Up @@ -30,7 +33,7 @@ use rv::process::gaussian::kernel::{
/// Roughly the slope of the logistic hazard function
/// logistic_hazard_b: float
/// The offset of the logistic hazard function.
#[pyclass]
#[pyclass(module = "changepoint")]
pub struct ArgpCpd {
argpcpd: Argpcp<
AddKernel<ProductKernel<ConstantKernel, RBFKernel>, WhiteKernel>,
Expand Down Expand Up @@ -105,4 +108,33 @@ impl ArgpCpd {
pub fn step(&mut self, datum: f64) -> Vec<f64> {
self.argpcpd.step(&datum).to_vec()
}

pub fn __setstate__(
&mut self,
py: Python,
state: PyObject,
) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
Ok(s) => {
self.argpcpd = deserialize(s.as_bytes()).unwrap();
Ok(())
}
Err(e) => Err(e),
}
}

pub fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
Ok(PyBytes::new(py, &serialize(&self.argpcpd).unwrap()).to_object(py))
}

fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult<bool> {
match op {
CompareOp::Lt => Ok(false),
CompareOp::Le => Ok(false),
CompareOp::Eq => Ok(self.argpcpd == other.argpcpd),
CompareOp::Ne => Ok(self.argpcpd != other.argpcpd),
CompareOp::Gt => Ok(false),
CompareOp::Ge => Ok(false),
}
}
}
103 changes: 96 additions & 7 deletions pychangepoint/src/bocpd.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
use crate::convert;
use changepoint::BocpdLike;
use nalgebra::DVector;
use pyo3::exceptions::PyValueError;
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::pyclass::CompareOp;
use pyo3::types::PyTuple;
use rv::dist::{
Bernoulli, Beta, Gamma, Gaussian, MvGaussian, NormalGamma,
NormalInvChiSquared, NormalInvGamma, NormalInvWishart, Poisson,
};

use bincode::{deserialize, serialize};
use serde::{Deserialize, Serialize};

/// The variant of the prior distribution
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum PriorVariant {
NormalGamma(NormalGamma),
NormalInvGamma(NormalInvGamma),
Expand All @@ -21,14 +26,61 @@ pub enum PriorVariant {

/// Prior distribution, which also describes the liklihood distribution of the
/// change point detector.
#[pyclass]
#[derive(Clone, Debug)]
#[pyclass(module = "changepoint")]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Prior {
pub dist: PriorVariant,
}

macro_rules! count {
() => (0usize);
( $x:tt $($xs:tt)* ) => (1usize + count!($($xs)*));
}

macro_rules! handle_kind {
($name: ident, $args: ident, $($idx:tt $n:tt),+) => {{
if $args.len() == count!($($idx)*) {
Self::$name(
$(
$args.get_item($idx)?.extract()?,
)+
)
} else {
Err(PyTypeError::new_err(format!("Prior kind '{}' requires the following arguments: {}", stringify!($name), stringify!($($n),+))))
}
}}
}

#[pymethods]
impl Prior {
#[new]
#[pyo3(signature = (kind, *args))]
pub fn new(kind: &str, args: &PyTuple) -> PyResult<Self> {
match kind {
"normal_gamma" => {
handle_kind!(normal_gamma, args, 0 m, 1 r, 2 s, 3 v)
}
"normal_inv_gamma" => {
handle_kind!(normal_inv_gamma, args, 0 m, 1 v, 2 a, 3 b)
}
"normal_inv_chi_squared" => {
handle_kind!(normal_inv_chi_squared, args, 0 m, 1 k, 2 v, 3 s2)
}
"normal_inv_wishart" => {
handle_kind!(normal_inv_wishart, args, 0 m, 1 k, 2 df, 3 scale)
}
"beta_bernoulli" => {
handle_kind!(beta_bernoulli, args, 0 alpha, 1 beta)
}
"poisson_gamma" => {
handle_kind!(poisson_gamma, args, 0 shape, 1 rate)
}
unknown_kind => Err(PyTypeError::new_err(format!(
"Unknown prior kind '{unknown_kind}'"
))),
}
}

#[staticmethod]
#[pyo3(signature = (m = 0.0, r = 1.0, s = 1.0, v = 1.0))]
pub fn normal_gamma(m: f64, r: f64, s: f64, v: f64) -> PyResult<Self> {
Expand Down Expand Up @@ -143,6 +195,19 @@ impl Prior {
),
}
}

pub fn __setstate__(&mut self, state: Vec<u8>) -> PyResult<()> {
self.dist = deserialize(&state).unwrap();
Ok(())
}

pub fn __getstate__(&self) -> PyResult<Vec<u8>> {
Ok(serialize(&self.dist).unwrap())
}

pub fn __getnewargs__(&self) -> PyResult<(String, f64, f64, f64, f64)> {
Ok(("normal_gamma".to_string(), 0.0, 1.0, 1.0, 1.0))
}
}

/// Normal Gamma prior on univariate Normal random variable.
Expand Down Expand Up @@ -321,7 +386,7 @@ fn dist_to_bocpd(dist: Prior, lambda: f64) -> BocpdVariant {

/// The variant of the `Bocpd`. Describes the prior, likelihood, and the input
/// data type.
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub enum BocpdVariant {
NormalGamma(changepoint::Bocpd<f64, Gaussian, NormalGamma>),
NormalInvGamma(changepoint::Bocpd<f64, Gaussian, NormalInvGamma>),
Expand Down Expand Up @@ -379,8 +444,8 @@ impl BocpdVariant {
}

/// Online Bayesian Change Point Detection state container
#[derive(Clone, Debug)]
#[pyclass]
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
#[pyclass(module = "changepoint")]
/// Create a new BOCPD
///
/// Parameters
Expand Down Expand Up @@ -423,4 +488,28 @@ impl Bocpd {
pub fn step(&mut self, datum: &PyAny) -> PyResult<Vec<f64>> {
self.bocpd.step(datum).map(|rs| rs.to_vec())
}

pub fn __setstate__(&mut self, state: Vec<u8>) -> PyResult<()> {
self.bocpd = deserialize(&state).unwrap();
Ok(())
}

pub fn __getstate__(&self) -> PyResult<Vec<u8>> {
Ok(serialize(&self.bocpd).unwrap())
}

pub fn __getnewargs__(&self) -> PyResult<(Prior, f64)> {
Ok((Prior::beta_bernoulli(0.5, 0.5).unwrap(), 1.0))
}

fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult<bool> {
match op {
CompareOp::Lt => Ok(false),
CompareOp::Le => Ok(false),
CompareOp::Eq => Ok(self.bocpd == other.bocpd),
CompareOp::Ne => Ok(self.bocpd != other.bocpd),
CompareOp::Gt => Ok(false),
CompareOp::Ge => Ok(false),
}
}
}
16 changes: 16 additions & 0 deletions pychangepoint/tests/test_pickle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from changepoint import Bocpd, ArgpCpd, NormalGamma
import pickle


def test_pickle_bocpd():
cpd = Bocpd(NormalGamma(), 12)
s = pickle.dumps(cpd)
cpd_b = pickle.loads(s)
assert cpd == cpd_b


def test_pickle_argpcpd():
cpd = ArgpCpd(logistic_hazard_h=-2, scale=3, noise_level=0.01)
s = pickle.dumps(cpd)
cpd_b = pickle.loads(s)
assert cpd == cpd_b

0 comments on commit 50908d9

Please sign in to comment.