1#![allow(unsafe_op_in_unsafe_fn)]
2use arrow::legacy::utils::CustomIterTools;
3use num_traits::ToPrimitive;
4use polars_error::polars_ensure;
5
6use super::QuantileMethod::*;
7use super::*;
8use crate::rolling::quantile_filter::SealedRolling;
9
10pub struct QuantileWindow<'a, T: NativeType> {
11 sorted: SortedBuf<'a, T>,
12 prob: f64,
13 method: QuantileMethod,
14}
15
16impl<
17 'a,
18 T: NativeType
19 + Float
20 + std::iter::Sum
21 + AddAssign
22 + SubAssign
23 + Div<Output = T>
24 + NumCast
25 + One
26 + Zero
27 + SealedRolling
28 + Sub<Output = T>,
29> RollingAggWindowNoNulls<'a, T> for QuantileWindow<'a, T>
30{
31 fn new(
32 slice: &'a [T],
33 start: usize,
34 end: usize,
35 params: Option<RollingFnParams>,
36 window_size: Option<usize>,
37 ) -> Self {
38 let params = params.unwrap();
39 let RollingFnParams::Quantile(params) = params else {
40 unreachable!("expected Quantile params");
41 };
42
43 Self {
44 sorted: SortedBuf::new(slice, start, end, window_size),
45 prob: params.prob,
46 method: params.method,
47 }
48 }
49
50 unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
51 self.sorted.update(start, end);
52 let length = self.sorted.len();
53
54 let idx = match self.method {
55 Linear => {
56 let length_f = length as f64;
58 let idx = ((length_f - 1.0) * self.prob).floor() as usize;
59
60 let float_idx_top = (length_f - 1.0) * self.prob;
61 let top_idx = float_idx_top.ceil() as usize;
62 return if idx == top_idx {
63 Some(self.sorted.get(idx))
64 } else {
65 let proportion = T::from(float_idx_top - idx as f64).unwrap();
66 let mut vals = self.sorted.index_range(idx..top_idx + 1);
67 let vi = *vals.next().unwrap();
68 let vj = *vals.next().unwrap();
69
70 Some(proportion * (vj - vi) + vi)
71 };
72 },
73 Midpoint => {
74 let length_f = length as f64;
75 let idx = (length_f * self.prob) as usize;
76 let idx = std::cmp::min(idx, length - 1);
77
78 let top_idx = ((length_f - 1.0) * self.prob).ceil() as usize;
79 return if top_idx == idx {
80 Some(self.sorted.get(idx))
81 } else {
82 let top_idx = idx + 1;
83 let mut vals = self.sorted.index_range(idx..top_idx + 1);
84 let mid = *vals.next().unwrap();
85 let mid_plus_1 = *vals.next().unwrap();
86
87 Some((mid + mid_plus_1) / (T::one() + T::one()))
88 };
89 },
90 Nearest => {
91 let idx = (((length as f64) - 1.0) * self.prob).round() as usize;
92 std::cmp::min(idx, length - 1)
93 },
94 Lower => ((length as f64 - 1.0) * self.prob).floor() as usize,
95 Higher => {
96 let idx = ((length as f64 - 1.0) * self.prob).ceil() as usize;
97 std::cmp::min(idx, length - 1)
98 },
99 Equiprobable => ((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize,
100 };
101
102 Some(self.sorted.get(idx))
103 }
104}
105
106pub fn rolling_quantile<T>(
107 values: &[T],
108 window_size: usize,
109 min_periods: usize,
110 center: bool,
111 weights: Option<&[f64]>,
112 params: Option<RollingFnParams>,
113) -> PolarsResult<ArrayRef>
114where
115 T: NativeType
116 + IsFloat
117 + Float
118 + std::iter::Sum
119 + AddAssign
120 + SubAssign
121 + Div<Output = T>
122 + NumCast
123 + One
124 + Zero
125 + SealedRolling
126 + PartialOrd
127 + Sub<Output = T>,
128{
129 let offset_fn = match center {
130 true => det_offsets_center,
131 false => det_offsets,
132 };
133 match weights {
134 None => {
135 if !center {
136 let params = params.as_ref().unwrap();
137 let RollingFnParams::Quantile(params) = params else {
138 unreachable!("expected Quantile params");
139 };
140 let out = super::quantile_filter::rolling_quantile::<_, Vec<_>>(
141 params.method,
142 min_periods,
143 window_size,
144 values,
145 params.prob,
146 );
147 let validity = create_validity(min_periods, values.len(), window_size, offset_fn);
148 return Ok(Box::new(PrimitiveArray::new(
149 T::PRIMITIVE.into(),
150 out.into(),
151 validity.map(|b| b.into()),
152 )));
153 }
154
155 rolling_apply_agg_window::<QuantileWindow<_>, _, _>(
156 values,
157 window_size,
158 min_periods,
159 offset_fn,
160 params,
161 )
162 },
163 Some(weights) => {
164 let wsum = weights.iter().sum();
165 polars_ensure!(
166 wsum != 0.0,
167 ComputeError: "Weighted quantile is undefined if weights sum to 0"
168 );
169 let params = params.unwrap();
170 let RollingFnParams::Quantile(params) = params else {
171 unreachable!("expected Quantile params");
172 };
173
174 Ok(rolling_apply_weighted_quantile(
175 values,
176 params.prob,
177 params.method,
178 window_size,
179 min_periods,
180 offset_fn,
181 weights,
182 wsum,
183 ))
184 },
185 }
186}
187
188#[inline]
189fn compute_wq<T>(buf: &[(T, f64)], p: f64, wsum: f64, method: QuantileMethod) -> T
190where
191 T: Debug + NativeType + Mul<Output = T> + Sub<Output = T> + NumCast + ToPrimitive + Zero,
192{
193 let (mut s, mut s_old, mut vk, mut v_old) = (0.0, 0.0, T::zero(), T::zero());
197
198 let h: f64 = p * (wsum - buf[0].1) + buf[0].1;
201 for &(v, w) in buf.iter() {
202 if s > h {
203 break;
204 }
205 (s_old, v_old, vk) = (s, vk, v);
206 s += w;
207 }
208 match (h == s_old, method) {
209 (true, _) => v_old, (_, Lower) => v_old,
211 (_, Higher) => vk,
212 (_, Nearest) => {
213 if s - h > h - s_old {
214 v_old
215 } else {
216 vk
217 }
218 },
219 (_, Equiprobable) => {
220 let threshold = (wsum * p).ceil() - 1.0;
221 if s > threshold { vk } else { v_old }
222 },
223 (_, Midpoint) => (vk + v_old) * NumCast::from(0.5).unwrap(),
224 (_, Linear) => {
226 v_old + <T as NumCast>::from((h - s_old) / (s - s_old)).unwrap() * (vk - v_old)
227 },
228 }
229}
230
231#[allow(clippy::too_many_arguments)]
232fn rolling_apply_weighted_quantile<T, Fo>(
233 values: &[T],
234 p: f64,
235 method: QuantileMethod,
236 window_size: usize,
237 min_periods: usize,
238 det_offsets_fn: Fo,
239 weights: &[f64],
240 wsum: f64,
241) -> ArrayRef
242where
243 Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
244 T: Debug + NativeType + Mul<Output = T> + Sub<Output = T> + NumCast + ToPrimitive + Zero,
245{
246 assert_eq!(weights.len(), window_size);
247 let nz_idx_wts: Vec<_> = weights.iter().enumerate().filter(|x| x.1 != &0.0).collect();
249 let mut buf = vec![(T::zero(), 0.0); nz_idx_wts.len()];
250 let len = values.len();
251 let out = (0..len)
252 .map(|idx| {
253 let (start, _) = det_offsets_fn(idx, window_size, len);
255
256 unsafe {
258 buf.iter_mut()
259 .zip(nz_idx_wts.iter())
260 .for_each(|(b, (i, w))| *b = (*values.get_unchecked(i + start), **w));
261 }
262 buf.sort_unstable_by(|&a, &b| a.0.tot_cmp(&b.0));
263 compute_wq(&buf, p, wsum, method)
264 })
265 .collect_trusted::<Vec<T>>();
266
267 let validity = create_validity(min_periods, len, window_size, det_offsets_fn);
268 Box::new(PrimitiveArray::new(
269 T::PRIMITIVE.into(),
270 out.into(),
271 validity.map(|b| b.into()),
272 ))
273}
274
275#[cfg(test)]
276mod test {
277 use super::*;
278
279 #[test]
280 fn test_rolling_median() {
281 let values = &[1.0, 2.0, 3.0, 4.0];
282 let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
283 prob: 0.5,
284 method: Linear,
285 }));
286 let out = rolling_quantile(values, 2, 2, false, None, med_pars).unwrap();
287 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
288 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
289 assert_eq!(out, &[None, Some(1.5), Some(2.5), Some(3.5)]);
290
291 let out = rolling_quantile(values, 2, 1, false, None, med_pars).unwrap();
292 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
293 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
294 assert_eq!(out, &[Some(1.0), Some(1.5), Some(2.5), Some(3.5)]);
295
296 let out = rolling_quantile(values, 4, 1, false, None, med_pars).unwrap();
297 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
298 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
299 assert_eq!(out, &[Some(1.0), Some(1.5), Some(2.0), Some(2.5)]);
300
301 let out = rolling_quantile(values, 4, 1, true, None, med_pars).unwrap();
302 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
303 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
304 assert_eq!(out, &[Some(1.5), Some(2.0), Some(2.5), Some(3.0)]);
305
306 let out = rolling_quantile(values, 4, 4, true, None, med_pars).unwrap();
307 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
308 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
309 assert_eq!(out, &[None, None, Some(2.5), None]);
310 }
311
312 #[test]
313 fn test_rolling_quantile_limits() {
314 let values = &[1.0f64, 2.0, 3.0, 4.0];
315
316 let methods = vec![
317 QuantileMethod::Lower,
318 QuantileMethod::Higher,
319 QuantileMethod::Nearest,
320 QuantileMethod::Midpoint,
321 QuantileMethod::Linear,
322 QuantileMethod::Equiprobable,
323 ];
324
325 for method in methods {
326 let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
327 prob: 0.0,
328 method,
329 }));
330 let out1 = rolling_min(values, 2, 2, false, None, None).unwrap();
331 let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
332 let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
333 let out2 = rolling_quantile(values, 2, 2, false, None, min_pars).unwrap();
334 let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
335 let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
336 assert_eq!(out1, out2);
337
338 let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
339 prob: 1.0,
340 method,
341 }));
342 let out1 = rolling_max(values, 2, 2, false, None, None).unwrap();
343 let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
344 let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
345 let out2 = rolling_quantile(values, 2, 2, false, None, max_pars).unwrap();
346 let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
347 let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
348 assert_eq!(out1, out2);
349 }
350 }
351}