polars_compute/rolling/nulls/
quantile.rs1#![allow(unsafe_op_in_unsafe_fn)]
2use super::*;
3use crate::rolling::quantile_filter::SealedRolling;
4
5pub struct QuantileWindow<'a, T: NativeType + IsFloat + PartialOrd> {
6 sorted: SortedBufNulls<'a, T>,
7 prob: f64,
8 method: QuantileMethod,
9}
10
11impl<
12 'a,
13 T: NativeType
14 + IsFloat
15 + Float
16 + std::iter::Sum
17 + AddAssign
18 + SubAssign
19 + Div<Output = T>
20 + NumCast
21 + One
22 + Zero
23 + SealedRolling
24 + PartialOrd
25 + Sub<Output = T>,
26> RollingAggWindowNulls<'a, T> for QuantileWindow<'a, T>
27{
28 unsafe fn new(
29 slice: &'a [T],
30 validity: &'a Bitmap,
31 start: usize,
32 end: usize,
33 params: Option<RollingFnParams>,
34 window_size: Option<usize>,
35 ) -> Self {
36 let params = params.unwrap();
37 let RollingFnParams::Quantile(params) = params else {
38 unreachable!("expected Quantile params");
39 };
40 Self {
41 sorted: SortedBufNulls::new(slice, validity, start, end, window_size),
42 prob: params.prob,
43 method: params.method,
44 }
45 }
46
47 unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
48 let null_count = self.sorted.update(start, end);
49 let mut length = self.sorted.len();
50 if null_count == length {
52 return None;
53 }
54 length -= null_count;
56 let mut idx = match self.method {
57 QuantileMethod::Nearest => (((length as f64) - 1.0) * self.prob).round() as usize,
58 QuantileMethod::Lower | QuantileMethod::Midpoint | QuantileMethod::Linear => {
59 ((length as f64 - 1.0) * self.prob).floor() as usize
60 },
61 QuantileMethod::Higher => ((length as f64 - 1.0) * self.prob).ceil() as usize,
62 QuantileMethod::Equiprobable => {
63 ((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize
64 },
65 };
66
67 idx = std::cmp::min(idx, length - 1);
68
69 match self.method {
71 QuantileMethod::Midpoint => {
72 let top_idx = ((length as f64 - 1.0) * self.prob).ceil() as usize;
73
74 debug_assert!(idx <= top_idx);
75 let v = if idx != top_idx {
76 let mut vals = self
77 .sorted
78 .index_range(idx + null_count..top_idx + null_count + 1);
79 let low = vals.next().unwrap().unwrap();
80 let high = vals.next().unwrap().unwrap();
81 (low + high) / T::from::<f64>(2.0f64).unwrap()
82 } else {
83 self.sorted.get(idx + null_count).unwrap()
84 };
85
86 Some(v)
87 },
88 QuantileMethod::Linear => {
89 let float_idx = (length as f64 - 1.0) * self.prob;
90 let top_idx = f64::ceil(float_idx) as usize;
91
92 if top_idx == idx {
93 Some(self.sorted.get(idx + null_count).unwrap())
94 } else {
95 let mut vals = self
96 .sorted
97 .index_range(idx + null_count..top_idx + null_count + 1);
98 let low = vals.next().unwrap().unwrap();
99 let high = vals.next().unwrap().unwrap();
100
101 let proportion = T::from(float_idx - idx as f64).unwrap();
102 Some(proportion * (high - low) + low)
103 }
104 },
105 _ => Some(self.sorted.get(idx + null_count).unwrap()),
106 }
107 }
108
109 fn is_valid(&self, min_periods: usize) -> bool {
110 self.sorted.is_valid(min_periods)
111 }
112}
113
114pub fn rolling_quantile<T>(
115 arr: &PrimitiveArray<T>,
116 window_size: usize,
117 min_periods: usize,
118 center: bool,
119 weights: Option<&[f64]>,
120 params: Option<RollingFnParams>,
121) -> ArrayRef
122where
123 T: NativeType
124 + IsFloat
125 + Float
126 + std::iter::Sum
127 + AddAssign
128 + SubAssign
129 + Div<Output = T>
130 + NumCast
131 + One
132 + Zero
133 + SealedRolling
134 + PartialOrd
135 + Sub<Output = T>,
136{
137 if weights.is_some() {
138 panic!("weights not yet supported on array with null values")
139 }
140 let offset_fn = match center {
141 true => det_offsets_center,
142 false => det_offsets,
143 };
144 rolling_apply_agg_window::<QuantileWindow<_>, _, _>(
165 arr.values().as_slice(),
166 arr.validity().as_ref().unwrap(),
167 window_size,
168 min_periods,
169 offset_fn,
170 params,
171 )
172}
173
174#[cfg(test)]
175mod test {
176 use arrow::buffer::Buffer;
177 use arrow::datatypes::ArrowDataType;
178
179 use super::*;
180
181 #[test]
182 fn test_rolling_median_nulls() {
183 let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]);
184 let arr = &PrimitiveArray::new(
185 ArrowDataType::Float64,
186 buf,
187 Some(Bitmap::from(&[true, false, true, true])),
188 );
189 let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
190 prob: 0.5,
191 method: QuantileMethod::Linear,
192 }));
193
194 let out = rolling_quantile(arr, 2, 2, false, None, med_pars);
195 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
196 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
197 assert_eq!(out, &[None, None, None, Some(3.5)]);
198
199 let out = rolling_quantile(arr, 2, 1, false, None, med_pars);
200 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
201 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
202 assert_eq!(out, &[Some(1.0), Some(1.0), Some(3.0), Some(3.5)]);
203
204 let out = rolling_quantile(arr, 4, 1, false, None, med_pars);
205 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
206 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
207 assert_eq!(out, &[Some(1.0), Some(1.0), Some(2.0), Some(3.0)]);
208
209 let out = rolling_quantile(arr, 4, 1, true, None, med_pars);
210 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
211 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
212 assert_eq!(out, &[Some(1.0), Some(2.0), Some(3.0), Some(3.5)]);
213
214 let out = rolling_quantile(arr, 4, 4, true, None, med_pars);
215 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
216 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
217 assert_eq!(out, &[None, None, None, None]);
218 }
219
220 #[test]
221 fn test_rolling_quantile_nulls_limits() {
222 let buf = Buffer::<f64>::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
224 let values = &PrimitiveArray::new(
225 ArrowDataType::Float64,
226 buf,
227 Some(Bitmap::from(&[true, false, false, true, true])),
228 );
229
230 let methods = vec![
231 QuantileMethod::Lower,
232 QuantileMethod::Higher,
233 QuantileMethod::Nearest,
234 QuantileMethod::Midpoint,
235 QuantileMethod::Linear,
236 QuantileMethod::Equiprobable,
237 ];
238
239 for method in methods {
240 let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
241 prob: 0.0,
242 method,
243 }));
244 let out1 = rolling_min(values, 2, 1, false, None, None);
245 let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
246 let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
247 let out2 = rolling_quantile(values, 2, 1, false, None, min_pars);
248 let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
249 let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
250 assert_eq!(out1, out2);
251
252 let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
253 prob: 1.0,
254 method,
255 }));
256 let out1 = rolling_max(values, 2, 1, false, None, None);
257 let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
258 let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
259 let out2 = rolling_quantile(values, 2, 1, false, None, max_pars);
260 let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
261 let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
262 assert_eq!(out1, out2);
263 }
264 }
265}