polars_compute/rolling/no_nulls/
mean.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2
3use polars_error::polars_ensure;
4
5use super::*;
6
7pub struct MeanWindow<'a, T> {
8    sum: SumWindow<'a, T, f64>,
9}
10
11impl<'a, T> RollingAggWindowNoNulls<'a, T> for MeanWindow<'a, T>
12where
13    T: NativeType
14        + IsFloat
15        + std::iter::Sum
16        + AddAssign
17        + SubAssign
18        + Div<Output = T>
19        + NumCast
20        + Add<Output = T>
21        + Sub<Output = T>
22        + PartialOrd,
23{
24    fn new(
25        slice: &'a [T],
26        start: usize,
27        end: usize,
28        params: Option<RollingFnParams>,
29        window_size: Option<usize>,
30    ) -> Self {
31        Self {
32            sum: SumWindow::<T, f64>::new(slice, start, end, params, window_size),
33        }
34    }
35
36    unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
37        let sum = self.sum.update(start, end).unwrap_unchecked();
38        Some(sum / NumCast::from(end - start).unwrap())
39    }
40}
41
42pub fn rolling_mean<T>(
43    values: &[T],
44    window_size: usize,
45    min_periods: usize,
46    center: bool,
47    weights: Option<&[f64]>,
48    _params: Option<RollingFnParams>,
49) -> PolarsResult<ArrayRef>
50where
51    T: NativeType + Float + std::iter::Sum<T> + SubAssign + AddAssign + IsFloat,
52{
53    let offset_fn = match center {
54        true => det_offsets_center,
55        false => det_offsets,
56    };
57    match weights {
58        None => rolling_apply_agg_window::<MeanWindow<_>, _, _>(
59            values,
60            window_size,
61            min_periods,
62            offset_fn,
63            None,
64        ),
65        Some(weights) => {
66            // A weighted mean is a weighted sum with normalized weights
67            let mut wts = no_nulls::coerce_weights(weights);
68            let wsum = wts.iter().fold(T::zero(), |acc, x| acc + *x);
69            polars_ensure!(
70                wsum != T::zero(),
71                ComputeError: "Weighted mean is undefined if weights sum to 0"
72            );
73            wts.iter_mut().for_each(|w| *w = *w / wsum);
74            no_nulls::rolling_apply_weights(
75                values,
76                window_size,
77                min_periods,
78                offset_fn,
79                no_nulls::compute_sum_weights,
80                &wts,
81            )
82        },
83    }
84}