polars_compute/rolling/no_nulls/
sum.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2use super::*;
3
4pub struct SumWindow<'a, T, S> {
5    slice: &'a [T],
6    sum: S,
7    err: S,
8    non_finite_count: usize, // NaN or infinity.
9    pos_inf_count: usize,
10    neg_inf_count: usize,
11    last_start: usize,
12    last_end: usize,
13}
14
15impl<T, S> SumWindow<'_, T, S>
16where
17    T: NativeType + IsFloat + Sub<Output = T> + NumCast + PartialOrd,
18    S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,
19{
20    fn add_finite_kahan(&mut self, val: T) {
21        let val: S = NumCast::from(val).unwrap();
22        let y = val - self.err;
23        let new_sum = self.sum + y;
24        self.err = (new_sum - self.sum) - y;
25        self.sum = new_sum;
26    }
27
28    fn add(&mut self, val: T) {
29        if T::is_float() {
30            if val.is_finite() {
31                self.add_finite_kahan(val);
32            } else {
33                self.non_finite_count += 1;
34                self.pos_inf_count += (val > T::zeroed()) as usize;
35                self.neg_inf_count += (val < T::zeroed()) as usize;
36            }
37        } else {
38            let val: S = NumCast::from(val).unwrap();
39            self.sum += val;
40        }
41    }
42
43    fn sub(&mut self, val: T) {
44        if T::is_float() {
45            if val.is_finite() {
46                self.add_finite_kahan(T::zeroed() - val);
47            } else {
48                self.non_finite_count -= 1;
49                self.pos_inf_count -= (val > T::zeroed()) as usize;
50                self.neg_inf_count -= (val < T::zeroed()) as usize;
51            }
52        } else {
53            let val: S = NumCast::from(val).unwrap();
54            self.sum -= val;
55        }
56    }
57}
58
59impl<'a, T, S> RollingAggWindowNoNulls<'a, T> for SumWindow<'a, T, S>
60where
61    T: NativeType + IsFloat + Sub<Output = T> + NumCast + PartialOrd,
62    S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,
63{
64    fn new(
65        slice: &'a [T],
66        start: usize,
67        end: usize,
68        _params: Option<RollingFnParams>,
69        _window_size: Option<usize>,
70    ) -> Self {
71        let mut out = Self {
72            slice,
73            sum: S::zeroed(),
74            err: S::zeroed(),
75            non_finite_count: 0,
76            pos_inf_count: 0,
77            neg_inf_count: 0,
78            last_start: 0,
79            last_end: 0,
80        };
81        unsafe { out.update(start, end) };
82        out
83    }
84
85    // # Safety
86    // The start, end range must be in-bounds.
87    unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
88        if start >= self.last_end {
89            self.sum = S::zeroed();
90            self.err = S::zeroed();
91            self.non_finite_count = 0;
92            self.pos_inf_count = 0;
93            self.neg_inf_count = 0;
94            self.last_start = start;
95            self.last_end = start;
96        }
97
98        for val in &self.slice[self.last_start..start] {
99            self.sub(*val);
100        }
101
102        for val in &self.slice[self.last_end..end] {
103            self.add(*val);
104        }
105
106        self.last_start = start;
107        self.last_end = end;
108        if self.non_finite_count == 0 {
109            NumCast::from(self.sum)
110        } else if self.non_finite_count == self.pos_inf_count {
111            Some(T::pos_inf_value())
112        } else if self.non_finite_count == self.neg_inf_count {
113            Some(T::neg_inf_value())
114        } else {
115            Some(T::nan_value())
116        }
117    }
118}
119
120pub fn rolling_sum<T>(
121    values: &[T],
122    window_size: usize,
123    min_periods: usize,
124    center: bool,
125    weights: Option<&[f64]>,
126    _params: Option<RollingFnParams>,
127) -> PolarsResult<ArrayRef>
128where
129    T: NativeType
130        + std::iter::Sum
131        + NumCast
132        + Mul<Output = T>
133        + AddAssign
134        + SubAssign
135        + IsFloat
136        + Num
137        + PartialOrd,
138{
139    match (center, weights) {
140        (true, None) => rolling_apply_agg_window::<SumWindow<T, T>, _, _>(
141            values,
142            window_size,
143            min_periods,
144            det_offsets_center,
145            None,
146        ),
147        (false, None) => rolling_apply_agg_window::<SumWindow<T, T>, _, _>(
148            values,
149            window_size,
150            min_periods,
151            det_offsets,
152            None,
153        ),
154        (true, Some(weights)) => {
155            let weights = no_nulls::coerce_weights(weights);
156            no_nulls::rolling_apply_weights(
157                values,
158                window_size,
159                min_periods,
160                det_offsets_center,
161                no_nulls::compute_sum_weights,
162                &weights,
163            )
164        },
165        (false, Some(weights)) => {
166            let weights = no_nulls::coerce_weights(weights);
167            no_nulls::rolling_apply_weights(
168                values,
169                window_size,
170                min_periods,
171                det_offsets,
172                no_nulls::compute_sum_weights,
173                &weights,
174            )
175        },
176    }
177}
178
179#[cfg(test)]
180mod test {
181    use super::*;
182    #[test]
183    fn test_rolling_sum() {
184        let values = &[1.0f64, 2.0, 3.0, 4.0];
185
186        let out = rolling_sum(values, 2, 2, false, None, None).unwrap();
187        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
188        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
189        assert_eq!(out, &[None, Some(3.0), Some(5.0), Some(7.0)]);
190
191        let out = rolling_sum(values, 2, 1, false, None, None).unwrap();
192        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
193        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
194        assert_eq!(out, &[Some(1.0), Some(3.0), Some(5.0), Some(7.0)]);
195
196        let out = rolling_sum(values, 4, 1, false, None, None).unwrap();
197        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
198        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
199        assert_eq!(out, &[Some(1.0), Some(3.0), Some(6.0), Some(10.0)]);
200
201        let out = rolling_sum(values, 4, 1, true, None, None).unwrap();
202        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
203        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
204        assert_eq!(out, &[Some(3.0), Some(6.0), Some(10.0), Some(9.0)]);
205
206        let out = rolling_sum(values, 4, 4, true, None, None).unwrap();
207        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
208        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
209        assert_eq!(out, &[None, None, Some(10.0), None]);
210
211        // test nan handling.
212        let values = &[1.0, 2.0, 3.0, f64::nan(), 5.0, 6.0, 7.0];
213        let out = rolling_sum(values, 3, 3, false, None, None).unwrap();
214        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
215        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
216
217        assert_eq!(
218            format!("{:?}", out.as_slice()),
219            format!(
220                "{:?}",
221                &[
222                    None,
223                    None,
224                    Some(6.0),
225                    Some(f64::nan()),
226                    Some(f64::nan()),
227                    Some(f64::nan()),
228                    Some(18.0)
229                ]
230            )
231        );
232    }
233}