polars_compute/rolling/nulls/
sum.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2use super::*;
3
4pub struct SumWindow<'a, T, S> {
5    slice: &'a [T],
6    validity: &'a Bitmap,
7    sum: S,
8    err: S,
9    non_finite_count: usize, // NaN or infinity.
10    pos_inf_count: usize,
11    neg_inf_count: usize,
12    pub(super) null_count: usize,
13    last_start: usize,
14    last_end: usize,
15}
16
17impl<T, S> SumWindow<'_, T, S>
18where
19    T: NativeType + IsFloat + Sub<Output = T> + NumCast + PartialOrd,
20    S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,
21{
22    fn add_finite_kahan(&mut self, val: T) {
23        let val: S = NumCast::from(val).unwrap();
24        let y = val - self.err;
25        let new_sum = self.sum + y;
26        self.err = (new_sum - self.sum) - y;
27        self.sum = new_sum;
28    }
29
30    fn add(&mut self, val: T) {
31        if T::is_float() {
32            if val.is_finite() {
33                self.add_finite_kahan(val);
34            } else {
35                self.non_finite_count += 1;
36                self.pos_inf_count += (val > T::zeroed()) as usize;
37                self.neg_inf_count += (val < T::zeroed()) as usize;
38            }
39        } else {
40            let val: S = NumCast::from(val).unwrap();
41            self.sum += val;
42        }
43    }
44
45    fn sub(&mut self, val: T) {
46        if T::is_float() {
47            if val.is_finite() {
48                self.add_finite_kahan(T::zeroed() - val);
49            } else {
50                self.non_finite_count -= 1;
51                self.pos_inf_count -= (val > T::zeroed()) as usize;
52                self.neg_inf_count -= (val < T::zeroed()) as usize;
53            }
54        } else {
55            let val: S = NumCast::from(val).unwrap();
56            self.sum -= val;
57        }
58    }
59}
60
61impl<'a, T, S> RollingAggWindowNulls<'a, T> for SumWindow<'a, T, S>
62where
63    T: NativeType + IsFloat + Sub<Output = T> + NumCast + PartialOrd,
64    S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,
65{
66    unsafe fn new(
67        slice: &'a [T],
68        validity: &'a Bitmap,
69        start: usize,
70        end: usize,
71        _params: Option<RollingFnParams>,
72        _window_size: Option<usize>,
73    ) -> Self {
74        let mut out = Self {
75            slice,
76            validity,
77            sum: S::zeroed(),
78            err: S::zeroed(),
79            non_finite_count: 0,
80            pos_inf_count: 0,
81            neg_inf_count: 0,
82            last_start: 0,
83            last_end: 0,
84            null_count: 0,
85        };
86        out.update(start, end);
87        out
88    }
89
90    // # Safety
91    // The start, end range must be in-bounds.
92    unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
93        if start >= self.last_end {
94            self.sum = S::zeroed();
95            self.err = S::zeroed();
96            self.non_finite_count = 0;
97            self.pos_inf_count = 0;
98            self.neg_inf_count = 0;
99            self.null_count = 0;
100            self.last_start = start;
101            self.last_end = start;
102        }
103
104        for idx in self.last_start..start {
105            let valid = self.validity.get_bit_unchecked(idx);
106            if valid {
107                self.sub(unsafe { *self.slice.get_unchecked(idx) });
108            } else {
109                self.null_count -= 1;
110            }
111        }
112
113        for idx in self.last_end..end {
114            let valid = self.validity.get_bit_unchecked(idx);
115            if valid {
116                self.add(unsafe { *self.slice.get_unchecked(idx) });
117            } else {
118                self.null_count += 1;
119            }
120        }
121
122        self.last_start = start;
123        self.last_end = end;
124        if self.non_finite_count == 0 {
125            NumCast::from(self.sum)
126        } else if self.non_finite_count == self.pos_inf_count {
127            Some(T::pos_inf_value())
128        } else if self.non_finite_count == self.neg_inf_count {
129            Some(T::neg_inf_value())
130        } else {
131            Some(T::nan_value())
132        }
133    }
134
135    fn is_valid(&self, min_periods: usize) -> bool {
136        ((self.last_end - self.last_start) - self.null_count) >= min_periods
137    }
138}
139
140pub fn rolling_sum<T>(
141    arr: &PrimitiveArray<T>,
142    window_size: usize,
143    min_periods: usize,
144    center: bool,
145    weights: Option<&[f64]>,
146    _params: Option<RollingFnParams>,
147) -> ArrayRef
148where
149    T: NativeType
150        + IsFloat
151        + PartialOrd
152        + Add<Output = T>
153        + Sub<Output = T>
154        + SubAssign
155        + AddAssign
156        + NumCast,
157{
158    if weights.is_some() {
159        panic!("weights not yet supported on array with null values")
160    }
161    if center {
162        rolling_apply_agg_window::<SumWindow<T, T>, _, _>(
163            arr.values().as_slice(),
164            arr.validity().as_ref().unwrap(),
165            window_size,
166            min_periods,
167            det_offsets_center,
168            None,
169        )
170    } else {
171        rolling_apply_agg_window::<SumWindow<T, T>, _, _>(
172            arr.values().as_slice(),
173            arr.validity().as_ref().unwrap(),
174            window_size,
175            min_periods,
176            det_offsets,
177            None,
178        )
179    }
180}