polars_arrow/legacy/kernels/ewm/
mod.rs

1mod average;
2mod variance;
3
4use std::hash::{Hash, Hasher};
5
6pub use average::*;
7#[cfg(feature = "serde")]
8use serde::{Deserialize, Serialize};
9pub use variance::*;
10
11#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
12#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
13#[derive(Debug, Copy, Clone, PartialEq)]
14#[must_use]
15pub struct EWMOptions {
16    pub alpha: f64,
17    pub adjust: bool,
18    pub bias: bool,
19    pub min_periods: usize,
20    pub ignore_nulls: bool,
21}
22
23impl Default for EWMOptions {
24    fn default() -> Self {
25        Self {
26            alpha: 0.5,
27            adjust: true,
28            bias: false,
29            min_periods: 1,
30            ignore_nulls: true,
31        }
32    }
33}
34
35impl Hash for EWMOptions {
36    fn hash<H: Hasher>(&self, state: &mut H) {
37        self.alpha.to_bits().hash(state);
38        self.adjust.hash(state);
39        self.bias.hash(state);
40        self.min_periods.hash(state);
41        self.ignore_nulls.hash(state);
42    }
43}
44
45impl EWMOptions {
46    pub fn and_min_periods(mut self, min_periods: usize) -> Self {
47        self.min_periods = min_periods;
48        self
49    }
50    pub fn and_adjust(mut self, adjust: bool) -> Self {
51        self.adjust = adjust;
52        self
53    }
54    pub fn and_span(mut self, span: usize) -> Self {
55        assert!(span >= 1);
56        self.alpha = 2.0 / (span as f64 + 1.0);
57        self
58    }
59    pub fn and_half_life(mut self, half_life: f64) -> Self {
60        assert!(half_life > 0.0);
61        self.alpha = 1.0 - (-(2.0f64.ln()) / half_life).exp();
62        self
63    }
64    pub fn and_com(mut self, com: f64) -> Self {
65        assert!(com > 0.0);
66        self.alpha = 1.0 / (1.0 + com);
67        self
68    }
69    pub fn and_ignore_nulls(mut self, ignore_nulls: bool) -> Self {
70        self.ignore_nulls = ignore_nulls;
71        self
72    }
73}
74
75#[cfg(test)]
76macro_rules! assert_allclose {
77    ($xs:expr, $ys:expr, $tol:expr) => {
78        assert!(
79            $xs.iter()
80                .zip($ys.iter())
81                .map(|(x, z)| {
82                    match (x, z) {
83                        (Some(a), Some(b)) => (a - b).abs() < $tol,
84                        (None, None) => true,
85                        _ => false,
86                    }
87                })
88                .fold(true, |acc, b| acc && b)
89        );
90    };
91}
92#[cfg(test)]
93pub(crate) use assert_allclose;