polars_compute/rolling/no_nulls/
sum.rs1#![allow(unsafe_op_in_unsafe_fn)]
2use super::*;
3
4pub struct SumWindow<'a, T, S> {
5 slice: &'a [T],
6 sum: S,
7 err: S,
8 non_finite_count: usize, pos_inf_count: usize,
10 neg_inf_count: usize,
11 last_start: usize,
12 last_end: usize,
13}
14
15impl<T, S> SumWindow<'_, T, S>
16where
17 T: NativeType + IsFloat + Sub<Output = T> + NumCast + PartialOrd,
18 S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,
19{
20 fn add_finite_kahan(&mut self, val: T) {
21 let val: S = NumCast::from(val).unwrap();
22 let y = val - self.err;
23 let new_sum = self.sum + y;
24 self.err = (new_sum - self.sum) - y;
25 self.sum = new_sum;
26 }
27
28 fn add(&mut self, val: T) {
29 if T::is_float() {
30 if val.is_finite() {
31 self.add_finite_kahan(val);
32 } else {
33 self.non_finite_count += 1;
34 self.pos_inf_count += (val > T::zeroed()) as usize;
35 self.neg_inf_count += (val < T::zeroed()) as usize;
36 }
37 } else {
38 let val: S = NumCast::from(val).unwrap();
39 self.sum += val;
40 }
41 }
42
43 fn sub(&mut self, val: T) {
44 if T::is_float() {
45 if val.is_finite() {
46 self.add_finite_kahan(T::zeroed() - val);
47 } else {
48 self.non_finite_count -= 1;
49 self.pos_inf_count -= (val > T::zeroed()) as usize;
50 self.neg_inf_count -= (val < T::zeroed()) as usize;
51 }
52 } else {
53 let val: S = NumCast::from(val).unwrap();
54 self.sum -= val;
55 }
56 }
57}
58
59impl<'a, T, S> RollingAggWindowNoNulls<'a, T> for SumWindow<'a, T, S>
60where
61 T: NativeType + IsFloat + Sub<Output = T> + NumCast + PartialOrd,
62 S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,
63{
64 fn new(
65 slice: &'a [T],
66 start: usize,
67 end: usize,
68 _params: Option<RollingFnParams>,
69 _window_size: Option<usize>,
70 ) -> Self {
71 let mut out = Self {
72 slice,
73 sum: S::zeroed(),
74 err: S::zeroed(),
75 non_finite_count: 0,
76 pos_inf_count: 0,
77 neg_inf_count: 0,
78 last_start: 0,
79 last_end: 0,
80 };
81 unsafe { out.update(start, end) };
82 out
83 }
84
85 unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
88 if start >= self.last_end {
89 self.sum = S::zeroed();
90 self.err = S::zeroed();
91 self.non_finite_count = 0;
92 self.pos_inf_count = 0;
93 self.neg_inf_count = 0;
94 self.last_start = start;
95 self.last_end = start;
96 }
97
98 for val in &self.slice[self.last_start..start] {
99 self.sub(*val);
100 }
101
102 for val in &self.slice[self.last_end..end] {
103 self.add(*val);
104 }
105
106 self.last_start = start;
107 self.last_end = end;
108 if self.non_finite_count == 0 {
109 NumCast::from(self.sum)
110 } else if self.non_finite_count == self.pos_inf_count {
111 Some(T::pos_inf_value())
112 } else if self.non_finite_count == self.neg_inf_count {
113 Some(T::neg_inf_value())
114 } else {
115 Some(T::nan_value())
116 }
117 }
118}
119
120pub fn rolling_sum<T>(
121 values: &[T],
122 window_size: usize,
123 min_periods: usize,
124 center: bool,
125 weights: Option<&[f64]>,
126 _params: Option<RollingFnParams>,
127) -> PolarsResult<ArrayRef>
128where
129 T: NativeType
130 + std::iter::Sum
131 + NumCast
132 + Mul<Output = T>
133 + AddAssign
134 + SubAssign
135 + IsFloat
136 + Num
137 + PartialOrd,
138{
139 match (center, weights) {
140 (true, None) => rolling_apply_agg_window::<SumWindow<T, T>, _, _>(
141 values,
142 window_size,
143 min_periods,
144 det_offsets_center,
145 None,
146 ),
147 (false, None) => rolling_apply_agg_window::<SumWindow<T, T>, _, _>(
148 values,
149 window_size,
150 min_periods,
151 det_offsets,
152 None,
153 ),
154 (true, Some(weights)) => {
155 let weights = no_nulls::coerce_weights(weights);
156 no_nulls::rolling_apply_weights(
157 values,
158 window_size,
159 min_periods,
160 det_offsets_center,
161 no_nulls::compute_sum_weights,
162 &weights,
163 )
164 },
165 (false, Some(weights)) => {
166 let weights = no_nulls::coerce_weights(weights);
167 no_nulls::rolling_apply_weights(
168 values,
169 window_size,
170 min_periods,
171 det_offsets,
172 no_nulls::compute_sum_weights,
173 &weights,
174 )
175 },
176 }
177}
178
179#[cfg(test)]
180mod test {
181 use super::*;
182 #[test]
183 fn test_rolling_sum() {
184 let values = &[1.0f64, 2.0, 3.0, 4.0];
185
186 let out = rolling_sum(values, 2, 2, false, None, None).unwrap();
187 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
188 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
189 assert_eq!(out, &[None, Some(3.0), Some(5.0), Some(7.0)]);
190
191 let out = rolling_sum(values, 2, 1, false, None, None).unwrap();
192 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
193 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
194 assert_eq!(out, &[Some(1.0), Some(3.0), Some(5.0), Some(7.0)]);
195
196 let out = rolling_sum(values, 4, 1, false, None, None).unwrap();
197 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
198 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
199 assert_eq!(out, &[Some(1.0), Some(3.0), Some(6.0), Some(10.0)]);
200
201 let out = rolling_sum(values, 4, 1, true, None, None).unwrap();
202 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
203 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
204 assert_eq!(out, &[Some(3.0), Some(6.0), Some(10.0), Some(9.0)]);
205
206 let out = rolling_sum(values, 4, 4, true, None, None).unwrap();
207 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
208 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
209 assert_eq!(out, &[None, None, Some(10.0), None]);
210
211 let values = &[1.0, 2.0, 3.0, f64::nan(), 5.0, 6.0, 7.0];
213 let out = rolling_sum(values, 3, 3, false, None, None).unwrap();
214 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
215 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
216
217 assert_eq!(
218 format!("{:?}", out.as_slice()),
219 format!(
220 "{:?}",
221 &[
222 None,
223 None,
224 Some(6.0),
225 Some(f64::nan()),
226 Some(f64::nan()),
227 Some(f64::nan()),
228 Some(18.0)
229 ]
230 )
231 );
232 }
233}