polars_compute/rolling/no_nulls/
mean.rs1#![allow(unsafe_op_in_unsafe_fn)]
2
3use polars_error::polars_ensure;
4
5use super::*;
6
7pub struct MeanWindow<'a, T> {
8 sum: SumWindow<'a, T, f64>,
9}
10
11impl<'a, T> RollingAggWindowNoNulls<'a, T> for MeanWindow<'a, T>
12where
13 T: NativeType
14 + IsFloat
15 + std::iter::Sum
16 + AddAssign
17 + SubAssign
18 + Div<Output = T>
19 + NumCast
20 + Add<Output = T>
21 + Sub<Output = T>
22 + PartialOrd,
23{
24 fn new(
25 slice: &'a [T],
26 start: usize,
27 end: usize,
28 params: Option<RollingFnParams>,
29 window_size: Option<usize>,
30 ) -> Self {
31 Self {
32 sum: SumWindow::<T, f64>::new(slice, start, end, params, window_size),
33 }
34 }
35
36 unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
37 let sum = self.sum.update(start, end).unwrap_unchecked();
38 Some(sum / NumCast::from(end - start).unwrap())
39 }
40}
41
42pub fn rolling_mean<T>(
43 values: &[T],
44 window_size: usize,
45 min_periods: usize,
46 center: bool,
47 weights: Option<&[f64]>,
48 _params: Option<RollingFnParams>,
49) -> PolarsResult<ArrayRef>
50where
51 T: NativeType + Float + std::iter::Sum<T> + SubAssign + AddAssign + IsFloat,
52{
53 let offset_fn = match center {
54 true => det_offsets_center,
55 false => det_offsets,
56 };
57 match weights {
58 None => rolling_apply_agg_window::<MeanWindow<_>, _, _>(
59 values,
60 window_size,
61 min_periods,
62 offset_fn,
63 None,
64 ),
65 Some(weights) => {
66 let mut wts = no_nulls::coerce_weights(weights);
68 let wsum = wts.iter().fold(T::zero(), |acc, x| acc + *x);
69 polars_ensure!(
70 wsum != T::zero(),
71 ComputeError: "Weighted mean is undefined if weights sum to 0"
72 );
73 wts.iter_mut().for_each(|w| *w = *w / wsum);
74 no_nulls::rolling_apply_weights(
75 values,
76 window_size,
77 min_periods,
78 offset_fn,
79 no_nulls::compute_sum_weights,
80 &wts,
81 )
82 },
83 }
84}