polars_compute/rolling/no_nulls/
quantile.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2use arrow::legacy::utils::CustomIterTools;
3use num_traits::ToPrimitive;
4use polars_error::polars_ensure;
5
6use super::QuantileMethod::*;
7use super::*;
8use crate::rolling::quantile_filter::SealedRolling;
9
10pub struct QuantileWindow<'a, T: NativeType> {
11    sorted: SortedBuf<'a, T>,
12    prob: f64,
13    method: QuantileMethod,
14}
15
16impl<
17    'a,
18    T: NativeType
19        + Float
20        + std::iter::Sum
21        + AddAssign
22        + SubAssign
23        + Div<Output = T>
24        + NumCast
25        + One
26        + Zero
27        + SealedRolling
28        + Sub<Output = T>,
29> RollingAggWindowNoNulls<'a, T> for QuantileWindow<'a, T>
30{
31    fn new(
32        slice: &'a [T],
33        start: usize,
34        end: usize,
35        params: Option<RollingFnParams>,
36        window_size: Option<usize>,
37    ) -> Self {
38        let params = params.unwrap();
39        let RollingFnParams::Quantile(params) = params else {
40            unreachable!("expected Quantile params");
41        };
42
43        Self {
44            sorted: SortedBuf::new(slice, start, end, window_size),
45            prob: params.prob,
46            method: params.method,
47        }
48    }
49
50    unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
51        self.sorted.update(start, end);
52        let length = self.sorted.len();
53
54        let idx = match self.method {
55            Linear => {
56                // Maybe add a fast path for median case? They could branch depending on odd/even.
57                let length_f = length as f64;
58                let idx = ((length_f - 1.0) * self.prob).floor() as usize;
59
60                let float_idx_top = (length_f - 1.0) * self.prob;
61                let top_idx = float_idx_top.ceil() as usize;
62                return if idx == top_idx {
63                    Some(self.sorted.get(idx))
64                } else {
65                    let proportion = T::from(float_idx_top - idx as f64).unwrap();
66                    let mut vals = self.sorted.index_range(idx..top_idx + 1);
67                    let vi = *vals.next().unwrap();
68                    let vj = *vals.next().unwrap();
69
70                    Some(proportion * (vj - vi) + vi)
71                };
72            },
73            Midpoint => {
74                let length_f = length as f64;
75                let idx = (length_f * self.prob) as usize;
76                let idx = std::cmp::min(idx, length - 1);
77
78                let top_idx = ((length_f - 1.0) * self.prob).ceil() as usize;
79                return if top_idx == idx {
80                    Some(self.sorted.get(idx))
81                } else {
82                    let top_idx = idx + 1;
83                    let mut vals = self.sorted.index_range(idx..top_idx + 1);
84                    let mid = *vals.next().unwrap();
85                    let mid_plus_1 = *vals.next().unwrap();
86
87                    Some((mid + mid_plus_1) / (T::one() + T::one()))
88                };
89            },
90            Nearest => {
91                let idx = (((length as f64) - 1.0) * self.prob).round() as usize;
92                std::cmp::min(idx, length - 1)
93            },
94            Lower => ((length as f64 - 1.0) * self.prob).floor() as usize,
95            Higher => {
96                let idx = ((length as f64 - 1.0) * self.prob).ceil() as usize;
97                std::cmp::min(idx, length - 1)
98            },
99            Equiprobable => ((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize,
100        };
101
102        Some(self.sorted.get(idx))
103    }
104}
105
106pub fn rolling_quantile<T>(
107    values: &[T],
108    window_size: usize,
109    min_periods: usize,
110    center: bool,
111    weights: Option<&[f64]>,
112    params: Option<RollingFnParams>,
113) -> PolarsResult<ArrayRef>
114where
115    T: NativeType
116        + IsFloat
117        + Float
118        + std::iter::Sum
119        + AddAssign
120        + SubAssign
121        + Div<Output = T>
122        + NumCast
123        + One
124        + Zero
125        + SealedRolling
126        + PartialOrd
127        + Sub<Output = T>,
128{
129    let offset_fn = match center {
130        true => det_offsets_center,
131        false => det_offsets,
132    };
133    match weights {
134        None => {
135            if !center {
136                let params = params.as_ref().unwrap();
137                let RollingFnParams::Quantile(params) = params else {
138                    unreachable!("expected Quantile params");
139                };
140                let out = super::quantile_filter::rolling_quantile::<_, Vec<_>>(
141                    params.method,
142                    min_periods,
143                    window_size,
144                    values,
145                    params.prob,
146                );
147                let validity = create_validity(min_periods, values.len(), window_size, offset_fn);
148                return Ok(Box::new(PrimitiveArray::new(
149                    T::PRIMITIVE.into(),
150                    out.into(),
151                    validity.map(|b| b.into()),
152                )));
153            }
154
155            rolling_apply_agg_window::<QuantileWindow<_>, _, _>(
156                values,
157                window_size,
158                min_periods,
159                offset_fn,
160                params,
161            )
162        },
163        Some(weights) => {
164            let wsum = weights.iter().sum();
165            polars_ensure!(
166                wsum != 0.0,
167                ComputeError: "Weighted quantile is undefined if weights sum to 0"
168            );
169            let params = params.unwrap();
170            let RollingFnParams::Quantile(params) = params else {
171                unreachable!("expected Quantile params");
172            };
173
174            Ok(rolling_apply_weighted_quantile(
175                values,
176                params.prob,
177                params.method,
178                window_size,
179                min_periods,
180                offset_fn,
181                weights,
182                wsum,
183            ))
184        },
185    }
186}
187
188#[inline]
189fn compute_wq<T>(buf: &[(T, f64)], p: f64, wsum: f64, method: QuantileMethod) -> T
190where
191    T: Debug + NativeType + Mul<Output = T> + Sub<Output = T> + NumCast + ToPrimitive + Zero,
192{
193    // There are a few ways to compute a weighted quantile but no "canonical" way.
194    // This is mostly taken from the Julia implementation which was readable and reasonable
195    // https://juliastats.org/StatsBase.jl/stable/scalarstats/#Quantile-and-Related-Functions-1
196    let (mut s, mut s_old, mut vk, mut v_old) = (0.0, 0.0, T::zero(), T::zero());
197
198    // Once the cumulative weight crosses h, we've found our ind{ex/ices}. The definition may look
199    // odd but it's the equivalent of taking h = p * (n - 1) + 1 if your data is indexed from 1.
200    let h: f64 = p * (wsum - buf[0].1) + buf[0].1;
201    for &(v, w) in buf.iter() {
202        if s > h {
203            break;
204        }
205        (s_old, v_old, vk) = (s, vk, v);
206        s += w;
207    }
208    match (h == s_old, method) {
209        (true, _) => v_old, // If we hit the break exactly interpolation shouldn't matter
210        (_, Lower) => v_old,
211        (_, Higher) => vk,
212        (_, Nearest) => {
213            if s - h > h - s_old {
214                v_old
215            } else {
216                vk
217            }
218        },
219        (_, Equiprobable) => {
220            let threshold = (wsum * p).ceil() - 1.0;
221            if s > threshold { vk } else { v_old }
222        },
223        (_, Midpoint) => (vk + v_old) * NumCast::from(0.5).unwrap(),
224        // This is seemingly the canonical way to do it.
225        (_, Linear) => {
226            v_old + <T as NumCast>::from((h - s_old) / (s - s_old)).unwrap() * (vk - v_old)
227        },
228    }
229}
230
231#[allow(clippy::too_many_arguments)]
232fn rolling_apply_weighted_quantile<T, Fo>(
233    values: &[T],
234    p: f64,
235    method: QuantileMethod,
236    window_size: usize,
237    min_periods: usize,
238    det_offsets_fn: Fo,
239    weights: &[f64],
240    wsum: f64,
241) -> ArrayRef
242where
243    Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
244    T: Debug + NativeType + Mul<Output = T> + Sub<Output = T> + NumCast + ToPrimitive + Zero,
245{
246    assert_eq!(weights.len(), window_size);
247    // Keep nonzero weights and their indices to know which values we need each iteration.
248    let nz_idx_wts: Vec<_> = weights.iter().enumerate().filter(|x| x.1 != &0.0).collect();
249    let mut buf = vec![(T::zero(), 0.0); nz_idx_wts.len()];
250    let len = values.len();
251    let out = (0..len)
252        .map(|idx| {
253            // Don't need end. Window size is constant and we computed offsets from start above.
254            let (start, _) = det_offsets_fn(idx, window_size, len);
255
256            // Sorting is not ideal, see https://github.com/tobiasschoch/wquantile for something faster
257            unsafe {
258                buf.iter_mut()
259                    .zip(nz_idx_wts.iter())
260                    .for_each(|(b, (i, w))| *b = (*values.get_unchecked(i + start), **w));
261            }
262            buf.sort_unstable_by(|&a, &b| a.0.tot_cmp(&b.0));
263            compute_wq(&buf, p, wsum, method)
264        })
265        .collect_trusted::<Vec<T>>();
266
267    let validity = create_validity(min_periods, len, window_size, det_offsets_fn);
268    Box::new(PrimitiveArray::new(
269        T::PRIMITIVE.into(),
270        out.into(),
271        validity.map(|b| b.into()),
272    ))
273}
274
275#[cfg(test)]
276mod test {
277    use super::*;
278
279    #[test]
280    fn test_rolling_median() {
281        let values = &[1.0, 2.0, 3.0, 4.0];
282        let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
283            prob: 0.5,
284            method: Linear,
285        }));
286        let out = rolling_quantile(values, 2, 2, false, None, med_pars).unwrap();
287        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
288        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
289        assert_eq!(out, &[None, Some(1.5), Some(2.5), Some(3.5)]);
290
291        let out = rolling_quantile(values, 2, 1, false, None, med_pars).unwrap();
292        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
293        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
294        assert_eq!(out, &[Some(1.0), Some(1.5), Some(2.5), Some(3.5)]);
295
296        let out = rolling_quantile(values, 4, 1, false, None, med_pars).unwrap();
297        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
298        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
299        assert_eq!(out, &[Some(1.0), Some(1.5), Some(2.0), Some(2.5)]);
300
301        let out = rolling_quantile(values, 4, 1, true, None, med_pars).unwrap();
302        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
303        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
304        assert_eq!(out, &[Some(1.5), Some(2.0), Some(2.5), Some(3.0)]);
305
306        let out = rolling_quantile(values, 4, 4, true, None, med_pars).unwrap();
307        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
308        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
309        assert_eq!(out, &[None, None, Some(2.5), None]);
310    }
311
312    #[test]
313    fn test_rolling_quantile_limits() {
314        let values = &[1.0f64, 2.0, 3.0, 4.0];
315
316        let methods = vec![
317            QuantileMethod::Lower,
318            QuantileMethod::Higher,
319            QuantileMethod::Nearest,
320            QuantileMethod::Midpoint,
321            QuantileMethod::Linear,
322            QuantileMethod::Equiprobable,
323        ];
324
325        for method in methods {
326            let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
327                prob: 0.0,
328                method,
329            }));
330            let out1 = rolling_min(values, 2, 2, false, None, None).unwrap();
331            let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
332            let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
333            let out2 = rolling_quantile(values, 2, 2, false, None, min_pars).unwrap();
334            let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
335            let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
336            assert_eq!(out1, out2);
337
338            let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
339                prob: 1.0,
340                method,
341            }));
342            let out1 = rolling_max(values, 2, 2, false, None, None).unwrap();
343            let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
344            let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
345            let out2 = rolling_quantile(values, 2, 2, false, None, max_pars).unwrap();
346            let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
347            let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
348            assert_eq!(out1, out2);
349        }
350    }
351}