diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 3cb51e8afbca19139b9a669e411c1bd929733ddb..32e4b5c0d7093488587ff93bc4fe9e587635a584 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -8,8 +8,8 @@ stages: name: ghcr.io/pyo3/maturin:v1.7.1 entrypoint: [""] script: - - sed -i "s/##VERSION##/$CI_COMMIT_TAG/g" Cargo.toml - - sed -i "s/##VERSION##/$CI_COMMIT_TAG/g" pyproject.toml + - sed -i "s/0.0.0/$CI_COMMIT_TAG/g" Cargo.toml + - sed -i "s/0.0.0/$CI_COMMIT_TAG/g" pyproject.toml - maturin publish rules: - if: $CI_COMMIT_TAG diff --git a/Cargo.lock b/Cargo.lock index 2715b3819dabe0040db206d169de5465b1adf318..79de3f043930a9e61507c2ca63014f9a5a8d8cee 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -676,7 +676,7 @@ dependencies = [ [[package]] name = "polars-loess" -version = "0.1.0" +version = "0.0.0" dependencies = [ "jemallocator", "nalgebra", diff --git a/Cargo.toml b/Cargo.toml index cf049f36cfbeef9779af11a0b9f1882c7c410557..3684764de3e2826bce8a13d920e24db0b2758071 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "polars-loess" -version = "##VERSION##" +version = "0.0.0" edition = "2021" [lib] diff --git a/polars_loess/__init__.py b/polars_loess/__init__.py index fc6f76738dd1ef7284071056d051a88ea296d1d1..7f82cd5bbba00935b404d8920c64a2b5f06f9685 100644 --- a/polars_loess/__init__.py +++ b/polars_loess/__init__.py @@ -19,11 +19,12 @@ else: lib = Path(__file__).parent -def loess(expr: IntoExpr, expr2: IntoExpr, expr3: IntoExpr) -> pl.Expr: +def loess(x: IntoExpr, y: IntoExpr, newx: IntoExpr, *, frac: float = None, points: int = None) -> pl.Expr: return register_plugin( - args=[expr, expr2, expr3], + args=[x, y, newx], symbol="loess", is_elementwise=True, lib=lib, + kwargs={"frac": frac, "points": points}, ) diff --git a/pyproject.toml b/pyproject.toml index 2a90baae60bdf92b8aa1dc9955d57d2463e921d9..53e5f26002d814081f3a39cf0df24351873311e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = ["maturin>=1.0,<2.0", "polars>=0.20.6"] build-backend = "maturin" readme = "README.md" license = "MIT" -version = "##VERSION##" +version = "0.0.0" [project] name = "polars-loess" diff --git a/run.py b/run.py index 94a15b6225de1954243e438a3d37ea426b07bfaa..00258620c91fa3b8bf22ea746b2ac3e10e2fadb5 100644 --- a/run.py +++ b/run.py @@ -15,6 +15,6 @@ df = pl.DataFrame({ 160.78742, 168.55567, 152.42658, 221.70702, 222.69040, 243.18828] }) -result = df.with_columns(loess = loess('time', 'price', 'time')) +result = df.with_columns(loess = loess('time', 'price', 'time', frac=0.5)) print(result) diff --git a/src/expressions.rs b/src/expressions.rs index fc31cb8037648ec4c4388fb4541f903f70a85f75..3c3a627f70d84bebcf6b01ae9644e60291fb9cf6 100644 --- a/src/expressions.rs +++ b/src/expressions.rs @@ -2,6 +2,7 @@ use polars::prelude::*; use pyo3_polars::derive::polars_expr; use nalgebra::{convert, DMatrix, DVector}; +use serde::Deserialize; fn select_indices(values: &DVector<f64>, indices: &DVector<usize>) -> DVector<f64> { @@ -165,20 +166,30 @@ impl Loess { self.denormalize_y(y) } } +#[derive(Deserialize)] +struct LoessParams { + frac: Option<f64>, + points: Option<usize>, +} #[polars_expr(output_type=Float64)] -fn loess(inputs: &[Series]) -> PolarsResult<Series> { +fn loess(inputs: &[Series], kwargs: LoessParams) -> PolarsResult<Series> { let x: &Float64Chunked = inputs[0].f64()?; let y: &Float64Chunked = inputs[1].f64()?; let newx: &Float64Chunked = inputs[2].f64()?; + let window = match (kwargs.frac, kwargs.points) { + (Some(frac), None) => (frac * x.len() as f64).round() as usize, + (None, Some(points)) => points, + _ => polars_bail!(InvalidOperation:format!("Either `frac` or `points` must be provided")), + }; let xx = DVector::<f64>::from_iterator(x.len(), x.to_vec().iter().filter(|v| v.is_some()).map(|v| v.unwrap())); let yy = DVector::<f64>::from_iterator(y.len(), y.to_vec().iter().filter(|v| v.is_some()).map(|v| v.unwrap())); let loess: Loess = Loess::new(&xx, &yy); - let result = newx.apply(|optv| optv.map(|v| loess.estimate(v, 7, true, 1))); + let result = newx.apply(|optv| optv.map(|v| loess.estimate(v, window, true, 1))); Ok(result.into_series()) }