polars_compute/rolling/nulls/
sum.rs1#![allow(unsafe_op_in_unsafe_fn)]
2use super::*;
3
4pub struct SumWindow<'a, T, S> {
5 slice: &'a [T],
6 validity: &'a Bitmap,
7 sum: S,
8 err: S,
9 non_finite_count: usize, pos_inf_count: usize,
11 neg_inf_count: usize,
12 pub(super) null_count: usize,
13 last_start: usize,
14 last_end: usize,
15}
16
17impl<T, S> SumWindow<'_, T, S>
18where
19 T: NativeType + IsFloat + Sub<Output = T> + NumCast + PartialOrd,
20 S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,
21{
22 fn add_finite_kahan(&mut self, val: T) {
23 let val: S = NumCast::from(val).unwrap();
24 let y = val - self.err;
25 let new_sum = self.sum + y;
26 self.err = (new_sum - self.sum) - y;
27 self.sum = new_sum;
28 }
29
30 fn add(&mut self, val: T) {
31 if T::is_float() {
32 if val.is_finite() {
33 self.add_finite_kahan(val);
34 } else {
35 self.non_finite_count += 1;
36 self.pos_inf_count += (val > T::zeroed()) as usize;
37 self.neg_inf_count += (val < T::zeroed()) as usize;
38 }
39 } else {
40 let val: S = NumCast::from(val).unwrap();
41 self.sum += val;
42 }
43 }
44
45 fn sub(&mut self, val: T) {
46 if T::is_float() {
47 if val.is_finite() {
48 self.add_finite_kahan(T::zeroed() - val);
49 } else {
50 self.non_finite_count -= 1;
51 self.pos_inf_count -= (val > T::zeroed()) as usize;
52 self.neg_inf_count -= (val < T::zeroed()) as usize;
53 }
54 } else {
55 let val: S = NumCast::from(val).unwrap();
56 self.sum -= val;
57 }
58 }
59}
60
61impl<'a, T, S> RollingAggWindowNulls<'a, T> for SumWindow<'a, T, S>
62where
63 T: NativeType + IsFloat + Sub<Output = T> + NumCast + PartialOrd,
64 S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,
65{
66 unsafe fn new(
67 slice: &'a [T],
68 validity: &'a Bitmap,
69 start: usize,
70 end: usize,
71 _params: Option<RollingFnParams>,
72 _window_size: Option<usize>,
73 ) -> Self {
74 let mut out = Self {
75 slice,
76 validity,
77 sum: S::zeroed(),
78 err: S::zeroed(),
79 non_finite_count: 0,
80 pos_inf_count: 0,
81 neg_inf_count: 0,
82 last_start: 0,
83 last_end: 0,
84 null_count: 0,
85 };
86 out.update(start, end);
87 out
88 }
89
90 unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
93 if start >= self.last_end {
94 self.sum = S::zeroed();
95 self.err = S::zeroed();
96 self.non_finite_count = 0;
97 self.pos_inf_count = 0;
98 self.neg_inf_count = 0;
99 self.null_count = 0;
100 self.last_start = start;
101 self.last_end = start;
102 }
103
104 for idx in self.last_start..start {
105 let valid = self.validity.get_bit_unchecked(idx);
106 if valid {
107 self.sub(unsafe { *self.slice.get_unchecked(idx) });
108 } else {
109 self.null_count -= 1;
110 }
111 }
112
113 for idx in self.last_end..end {
114 let valid = self.validity.get_bit_unchecked(idx);
115 if valid {
116 self.add(unsafe { *self.slice.get_unchecked(idx) });
117 } else {
118 self.null_count += 1;
119 }
120 }
121
122 self.last_start = start;
123 self.last_end = end;
124 if self.non_finite_count == 0 {
125 NumCast::from(self.sum)
126 } else if self.non_finite_count == self.pos_inf_count {
127 Some(T::pos_inf_value())
128 } else if self.non_finite_count == self.neg_inf_count {
129 Some(T::neg_inf_value())
130 } else {
131 Some(T::nan_value())
132 }
133 }
134
135 fn is_valid(&self, min_periods: usize) -> bool {
136 ((self.last_end - self.last_start) - self.null_count) >= min_periods
137 }
138}
139
140pub fn rolling_sum<T>(
141 arr: &PrimitiveArray<T>,
142 window_size: usize,
143 min_periods: usize,
144 center: bool,
145 weights: Option<&[f64]>,
146 _params: Option<RollingFnParams>,
147) -> ArrayRef
148where
149 T: NativeType
150 + IsFloat
151 + PartialOrd
152 + Add<Output = T>
153 + Sub<Output = T>
154 + SubAssign
155 + AddAssign
156 + NumCast,
157{
158 if weights.is_some() {
159 panic!("weights not yet supported on array with null values")
160 }
161 if center {
162 rolling_apply_agg_window::<SumWindow<T, T>, _, _>(
163 arr.values().as_slice(),
164 arr.validity().as_ref().unwrap(),
165 window_size,
166 min_periods,
167 det_offsets_center,
168 None,
169 )
170 } else {
171 rolling_apply_agg_window::<SumWindow<T, T>, _, _>(
172 arr.values().as_slice(),
173 arr.validity().as_ref().unwrap(),
174 window_size,
175 min_periods,
176 det_offsets,
177 None,
178 )
179 }
180}