polars_compute/rolling/nulls/
mean.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2use super::*;
3
4pub struct MeanWindow<'a, T> {
5    sum: SumWindow<'a, T, f64>,
6}
7
8impl<
9    'a,
10    T: NativeType
11        + IsFloat
12        + Add<Output = T>
13        + Sub<Output = T>
14        + NumCast
15        + Div<Output = T>
16        + AddAssign
17        + SubAssign
18        + PartialOrd,
19> RollingAggWindowNulls<'a, T> for MeanWindow<'a, T>
20{
21    unsafe fn new(
22        slice: &'a [T],
23        validity: &'a Bitmap,
24        start: usize,
25        end: usize,
26        params: Option<RollingFnParams>,
27        window_size: Option<usize>,
28    ) -> Self {
29        Self {
30            sum: SumWindow::new(slice, validity, start, end, params, window_size),
31        }
32    }
33
34    unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
35        let sum = self.sum.update(start, end);
36        let len = end - start;
37        if self.sum.null_count == len {
38            None
39        } else {
40            sum.map(|sum| sum / NumCast::from(end - start - self.sum.null_count).unwrap())
41        }
42    }
43    fn is_valid(&self, min_periods: usize) -> bool {
44        self.sum.is_valid(min_periods)
45    }
46}
47
48pub fn rolling_mean<T>(
49    arr: &PrimitiveArray<T>,
50    window_size: usize,
51    min_periods: usize,
52    center: bool,
53    weights: Option<&[f64]>,
54    _params: Option<RollingFnParams>,
55) -> ArrayRef
56where
57    T: NativeType
58        + IsFloat
59        + PartialOrd
60        + Add<Output = T>
61        + Sub<Output = T>
62        + NumCast
63        + AddAssign
64        + SubAssign
65        + Div<Output = T>,
66{
67    if weights.is_some() {
68        panic!("weights not yet supported on array with null values")
69    }
70    if center {
71        rolling_apply_agg_window::<MeanWindow<_>, _, _>(
72            arr.values().as_slice(),
73            arr.validity().as_ref().unwrap(),
74            window_size,
75            min_periods,
76            det_offsets_center,
77            None,
78        )
79    } else {
80        rolling_apply_agg_window::<MeanWindow<_>, _, _>(
81            arr.values().as_slice(),
82            arr.validity().as_ref().unwrap(),
83            window_size,
84            min_periods,
85            det_offsets,
86            None,
87        )
88    }
89}