From e99bd4a81a48aed04d4d48e54845f90d242222fe Mon Sep 17 00:00:00 2001 From: Frank Sauerburger <frank@sauerburger.com> Date: Mon, 16 Sep 2024 22:30:56 +0200 Subject: [PATCH] Add frac and points argument --- .gitlab-ci.yml | 4 ++-- Cargo.lock | 2 +- Cargo.toml | 2 +- polars_loess/__init__.py | 5 +++-- pyproject.toml | 2 +- run.py | 2 +- src/expressions.rs | 15 +++++++++++++-- 7 files changed, 22 insertions(+), 10 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 3cb51e8..32e4b5c 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -8,8 +8,8 @@ stages: name: ghcr.io/pyo3/maturin:v1.7.1 entrypoint: [""] script: - - sed -i "s/##VERSION##/$CI_COMMIT_TAG/g" Cargo.toml - - sed -i "s/##VERSION##/$CI_COMMIT_TAG/g" pyproject.toml + - sed -i "s/0.0.0/$CI_COMMIT_TAG/g" Cargo.toml + - sed -i "s/0.0.0/$CI_COMMIT_TAG/g" pyproject.toml - maturin publish rules: - if: $CI_COMMIT_TAG diff --git a/Cargo.lock b/Cargo.lock index 2715b38..79de3f0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -676,7 +676,7 @@ dependencies = [ [[package]] name = "polars-loess" -version = "0.1.0" +version = "0.0.0" dependencies = [ "jemallocator", "nalgebra", diff --git a/Cargo.toml b/Cargo.toml index cf049f3..3684764 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "polars-loess" -version = "##VERSION##" +version = "0.0.0" edition = "2021" [lib] diff --git a/polars_loess/__init__.py b/polars_loess/__init__.py index fc6f767..7f82cd5 100644 --- a/polars_loess/__init__.py +++ b/polars_loess/__init__.py @@ -19,11 +19,12 @@ else: lib = Path(__file__).parent -def loess(expr: IntoExpr, expr2: IntoExpr, expr3: IntoExpr) -> pl.Expr: +def loess(x: IntoExpr, y: IntoExpr, newx: IntoExpr, *, frac: float = None, points: int = None) -> pl.Expr: return register_plugin( - args=[expr, expr2, expr3], + args=[x, y, newx], symbol="loess", is_elementwise=True, lib=lib, + kwargs={"frac": frac, "points": points}, ) diff --git a/pyproject.toml b/pyproject.toml index 2a90baa..53e5f26 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = ["maturin>=1.0,<2.0", "polars>=0.20.6"] build-backend = "maturin" readme = "README.md" license = "MIT" -version = "##VERSION##" +version = "0.0.0" [project] name = "polars-loess" diff --git a/run.py b/run.py index 94a15b6..0025862 100644 --- a/run.py +++ b/run.py @@ -15,6 +15,6 @@ df = pl.DataFrame({ 160.78742, 168.55567, 152.42658, 221.70702, 222.69040, 243.18828] }) -result = df.with_columns(loess = loess('time', 'price', 'time')) +result = df.with_columns(loess = loess('time', 'price', 'time', frac=0.5)) print(result) diff --git a/src/expressions.rs b/src/expressions.rs index fc31cb8..3c3a627 100644 --- a/src/expressions.rs +++ b/src/expressions.rs @@ -2,6 +2,7 @@ use polars::prelude::*; use pyo3_polars::derive::polars_expr; use nalgebra::{convert, DMatrix, DVector}; +use serde::Deserialize; fn select_indices(values: &DVector<f64>, indices: &DVector<usize>) -> DVector<f64> { @@ -165,20 +166,30 @@ impl Loess { self.denormalize_y(y) } } +#[derive(Deserialize)] +struct LoessParams { + frac: Option<f64>, + points: Option<usize>, +} #[polars_expr(output_type=Float64)] -fn loess(inputs: &[Series]) -> PolarsResult<Series> { +fn loess(inputs: &[Series], kwargs: LoessParams) -> PolarsResult<Series> { let x: &Float64Chunked = inputs[0].f64()?; let y: &Float64Chunked = inputs[1].f64()?; let newx: &Float64Chunked = inputs[2].f64()?; + let window = match (kwargs.frac, kwargs.points) { + (Some(frac), None) => (frac * x.len() as f64).round() as usize, + (None, Some(points)) => points, + _ => polars_bail!(InvalidOperation:format!("Either `frac` or `points` must be provided")), + }; let xx = DVector::<f64>::from_iterator(x.len(), x.to_vec().iter().filter(|v| v.is_some()).map(|v| v.unwrap())); let yy = DVector::<f64>::from_iterator(y.len(), y.to_vec().iter().filter(|v| v.is_some()).map(|v| v.unwrap())); let loess: Loess = Loess::new(&xx, &yy); - let result = newx.apply(|optv| optv.map(|v| loess.estimate(v, 7, true, 1))); + let result = newx.apply(|optv| optv.map(|v| loess.estimate(v, window, true, 1))); Ok(result.into_series()) } -- GitLab