Skip to content

Commit

Permalink
feat: pbv_topn_v
Browse files Browse the repository at this point in the history
  • Loading branch information
Yvictor committed May 29, 2024
1 parent cf3701c commit 8173ca5
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 10 deletions.
28 changes: 28 additions & 0 deletions polars_pbv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,5 +88,33 @@ def pbv_topn_vp(
"n": n,
"center_label": center,
"round": round,
"pct": False,
},
)

def pbv_topn_v(
price: IntoExpr,
volume: IntoExpr,
window_size: int,
bins: int,
n: int,
center: bool = True,
round: int = -1,
pct: bool = False,
) -> pl.Expr:
price = parse_into_expr(price)
volume = parse_into_expr(volume)
return register_plugin(
args=[price, volume],
symbol="pbv_topn_v",
is_elementwise=False,
lib=lib,
kwargs={
"window_size": window_size,
"bins": bins,
"n": n,
"center_label": center,
"round": round,
"pct": pct,
},
)
100 changes: 99 additions & 1 deletion src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub struct PriceByVolumeTopNKwargs {
n: usize,
center_label: bool,
round: i32,
pct: bool,
}

// fn price_by_volume_dtype(input_fields: &[Field]) -> PolarsResult<Field> {
Expand Down Expand Up @@ -229,5 +230,102 @@ fn pbv_topn_vp(inputs: &[Series], kwargs: PriceByVolumeTopNKwargs) -> PolarsResu
// label.push(Some(price_label_s));
}
}
Ok(Series::new("pbv_topn", pbv_topn))
Ok(Series::new("pbv_topn_vp", pbv_topn))
}


fn price_by_volume_topn_volume_dtype(_input_fields: &[Field]) -> PolarsResult<Field> {
let field = Field::new(
"pbv_topn_v",
DataType::List(Box::new(Float64Type::get_dtype())),
);
Ok(field)
}

#[polars_expr(output_type_func=price_by_volume_topn_volume_dtype)]
fn pbv_topn_v(inputs: &[Series], kwargs: PriceByVolumeTopNKwargs) -> PolarsResult<Series> {
let price = &inputs[0].to_float()?;
let volume = &inputs[1].to_float()?;
let window_size = kwargs.window_size as usize;
let mut pbv_topn = vec![];
for i in 1..(price.len() + 1) {
if i < window_size {
pbv_topn.push(None);
} else {
let mut volume_at_price = vec![];
let mut price_label = vec![];
let start = (i - window_size) as i64;
let window_price = price.slice(start, window_size);
let window_volume = volume.slice(start, window_size);
let max_price: f64 = window_price.max()?.unwrap();
let min_price: f64 = window_price.min()?.unwrap();
let range = max_price - min_price;
let interval = range / kwargs.bins as f64;
for n in 0..kwargs.bins {
let lower_bound = min_price + n as f64 * interval;
let upper_bound = min_price + (n + 1) as f64 * interval;
let center = (lower_bound + upper_bound) / 2.0;
if n == kwargs.bins - 1 {
let v: f64 = window_volume
.filter(&window_price.gt_eq(lower_bound)?)?
.sum()?;
volume_at_price.push(v);
} else {
let mask = window_price.gt_eq(lower_bound)? & window_price.lt(upper_bound)?;
let v = window_volume.filter(&mask)?.sum()?;
volume_at_price.push(v);
}
let label = if kwargs.center_label {
center
} else {
lower_bound
};
price_label.push(label);
}
// let price_label_s = Series::new("price", &price_label);
// let price_label_s_round = price_label_s.round(kwargs.round as u32)?;
// let price_label_s = if kwargs.round < 0 {
// price_label_s.f64()?
// } else {
// price_label_s_round.f64()?
// };
let pbv_s = Series::new("volume", &volume_at_price);
let total_v: f64 = pbv_s.sum()?;
let mut pbv_s_pct;
let pbv_s_f64 = if kwargs.pct {
pbv_s_pct = pbv_s.clone() / total_v;
pbv_s_pct = if kwargs.round < 0 {
pbv_s_pct
} else {
pbv_s_pct.round(kwargs.round as u32)?
};
pbv_s_pct.f64()?
} else {
pbv_s_pct = if kwargs.round < 0 {
pbv_s.clone()
} else {
pbv_s.round(kwargs.round as u32)?
};
pbv_s_pct.f64()?
};
let pbv_s_idx_sort =
pbv_s.arg_sort(SortOptions::default().with_order_descending(true));

let pbv_topn_s = pbv_s_idx_sort.slice(0, kwargs.n as usize)
.iter()
.map(|opt_idx| {
match opt_idx {
Some(idx) => Some(pbv_s_f64.get(idx as usize).unwrap()),
None => None, // Return None for out-of-bounds indices
}
})
.collect::<Vec<Option<f64>>>();

pbv_topn.push(Some(Series::new("pbv_topn", &pbv_topn_s)));
// pbv.push(Some(pbv_s));

// label.push(Some(price_label_s));
}
}
Ok(Series::new("pbv_topn_v", pbv_topn))
}
51 changes: 42 additions & 9 deletions tests/test_pbv.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import polars as pl
from polars_pbv import pbv, pbv_pct, pbv_topn_vp
from polars_pbv import pbv, pbv_pct, pbv_topn_vp, pbv_topn_v


def test_pbv():
Expand Down Expand Up @@ -88,7 +88,6 @@ def test_pbv_pct():
assert result.equals(expected_df)



def test_pbv_top_vp():
price_col = [100, 101, 102, 103, 104, 105, 106]
volume_col = [200, 220, 250, 240, 260, 300, 280]
Expand All @@ -97,19 +96,53 @@ def test_pbv_top_vp():
n = 2

df = pl.DataFrame({"price": price_col, "volume": volume_col})
expected_df = pl.DataFrame({
"pbv_top_vp": [
None, None, None, None, None,
[104.17, 102.5], [105.17, 103.5]
]
})
expected_df = pl.DataFrame(
{"pbv_top_vp": [None, None, None, None, None, [104.17, 102.5], [105.17, 103.5]]}
)

result_df = df.select(
pbv_topn_vp(
"price", "volume", window_size=window_size, bins=bins, n=n, center=True, round=2
"price",
"volume",
window_size=window_size,
bins=bins,
n=n,
center=True,
round=2,
).alias("pbv_top_vp")
)

print(expected_df)
print(result_df)
assert result_df.equals(expected_df)


# pbv_topn_v test case
def test_pbv_top_v():
price_col = [100, 101, 102, 103, 104, 105, 106]
volume_col = [200, 220, 250, 240, 260, 300, 280]
window_size = 6
bins = 3
n = 2

df = pl.DataFrame({"price": price_col, "volume": volume_col})
expected_df = pl.DataFrame(
{"pbv_top_v": [None, None, None, None, None, [560.0, 490.0], [580.0, 500.0]]}
)

result_df = df.select(
pbv_topn_v(
"price",
"volume",
window_size=window_size,
bins=bins,
n=n,
center=False,
round=-1,
pct=False,
).alias("pbv_top_v")
)

print(expected_df)
print(result_df)
assert result_df.equals(expected_df)

0 comments on commit 8173ca5

Please sign in to comment.