polars_compute/rolling/nulls/
quantile.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2use super::*;
3use crate::rolling::quantile_filter::SealedRolling;
4
5pub struct QuantileWindow<'a, T: NativeType + IsFloat + PartialOrd> {
6    sorted: SortedBufNulls<'a, T>,
7    prob: f64,
8    method: QuantileMethod,
9}
10
11impl<
12    'a,
13    T: NativeType
14        + IsFloat
15        + Float
16        + std::iter::Sum
17        + AddAssign
18        + SubAssign
19        + Div<Output = T>
20        + NumCast
21        + One
22        + Zero
23        + SealedRolling
24        + PartialOrd
25        + Sub<Output = T>,
26> RollingAggWindowNulls<'a, T> for QuantileWindow<'a, T>
27{
28    unsafe fn new(
29        slice: &'a [T],
30        validity: &'a Bitmap,
31        start: usize,
32        end: usize,
33        params: Option<RollingFnParams>,
34        window_size: Option<usize>,
35    ) -> Self {
36        let params = params.unwrap();
37        let RollingFnParams::Quantile(params) = params else {
38            unreachable!("expected Quantile params");
39        };
40        Self {
41            sorted: SortedBufNulls::new(slice, validity, start, end, window_size),
42            prob: params.prob,
43            method: params.method,
44        }
45    }
46
47    unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
48        let null_count = self.sorted.update(start, end);
49        let mut length = self.sorted.len();
50        // The min periods_issue will be taken care of when actually rolling
51        if null_count == length {
52            return None;
53        }
54        // Nulls are guaranteed to be at the front
55        length -= null_count;
56        let mut idx = match self.method {
57            QuantileMethod::Nearest => (((length as f64) - 1.0) * self.prob).round() as usize,
58            QuantileMethod::Lower | QuantileMethod::Midpoint | QuantileMethod::Linear => {
59                ((length as f64 - 1.0) * self.prob).floor() as usize
60            },
61            QuantileMethod::Higher => ((length as f64 - 1.0) * self.prob).ceil() as usize,
62            QuantileMethod::Equiprobable => {
63                ((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize
64            },
65        };
66
67        idx = std::cmp::min(idx, length - 1);
68
69        // we can unwrap because we sliced of the nulls
70        match self.method {
71            QuantileMethod::Midpoint => {
72                let top_idx = ((length as f64 - 1.0) * self.prob).ceil() as usize;
73
74                debug_assert!(idx <= top_idx);
75                let v = if idx != top_idx {
76                    let mut vals = self
77                        .sorted
78                        .index_range(idx + null_count..top_idx + null_count + 1);
79                    let low = vals.next().unwrap().unwrap();
80                    let high = vals.next().unwrap().unwrap();
81                    (low + high) / T::from::<f64>(2.0f64).unwrap()
82                } else {
83                    self.sorted.get(idx + null_count).unwrap()
84                };
85
86                Some(v)
87            },
88            QuantileMethod::Linear => {
89                let float_idx = (length as f64 - 1.0) * self.prob;
90                let top_idx = f64::ceil(float_idx) as usize;
91
92                if top_idx == idx {
93                    Some(self.sorted.get(idx + null_count).unwrap())
94                } else {
95                    let mut vals = self
96                        .sorted
97                        .index_range(idx + null_count..top_idx + null_count + 1);
98                    let low = vals.next().unwrap().unwrap();
99                    let high = vals.next().unwrap().unwrap();
100
101                    let proportion = T::from(float_idx - idx as f64).unwrap();
102                    Some(proportion * (high - low) + low)
103                }
104            },
105            _ => Some(self.sorted.get(idx + null_count).unwrap()),
106        }
107    }
108
109    fn is_valid(&self, min_periods: usize) -> bool {
110        self.sorted.is_valid(min_periods)
111    }
112}
113
114pub fn rolling_quantile<T>(
115    arr: &PrimitiveArray<T>,
116    window_size: usize,
117    min_periods: usize,
118    center: bool,
119    weights: Option<&[f64]>,
120    params: Option<RollingFnParams>,
121) -> ArrayRef
122where
123    T: NativeType
124        + IsFloat
125        + Float
126        + std::iter::Sum
127        + AddAssign
128        + SubAssign
129        + Div<Output = T>
130        + NumCast
131        + One
132        + Zero
133        + SealedRolling
134        + PartialOrd
135        + Sub<Output = T>,
136{
137    if weights.is_some() {
138        panic!("weights not yet supported on array with null values")
139    }
140    let offset_fn = match center {
141        true => det_offsets_center,
142        false => det_offsets,
143    };
144    /*
145    TODO: fix or remove the dancing links based rolling implementation
146    see https://github.com/pola-rs/polars/issues/23480
147    if !center {
148        let params = params.as_ref().unwrap();
149        let RollingFnParams::Quantile(params) = params else {
150            unreachable!("expected Quantile params");
151        };
152
153        let out = super::quantile_filter::rolling_quantile::<_, MutablePrimitiveArray<_>>(
154            params.method,
155            min_periods,
156            window_size,
157            arr.clone(),
158            params.prob,
159        );
160        let out: PrimitiveArray<T> = out.into();
161        return Box::new(out);
162    }
163    */
164    rolling_apply_agg_window::<QuantileWindow<_>, _, _>(
165        arr.values().as_slice(),
166        arr.validity().as_ref().unwrap(),
167        window_size,
168        min_periods,
169        offset_fn,
170        params,
171    )
172}
173
174#[cfg(test)]
175mod test {
176    use arrow::buffer::Buffer;
177    use arrow::datatypes::ArrowDataType;
178
179    use super::*;
180
181    #[test]
182    fn test_rolling_median_nulls() {
183        let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]);
184        let arr = &PrimitiveArray::new(
185            ArrowDataType::Float64,
186            buf,
187            Some(Bitmap::from(&[true, false, true, true])),
188        );
189        let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
190            prob: 0.5,
191            method: QuantileMethod::Linear,
192        }));
193
194        let out = rolling_quantile(arr, 2, 2, false, None, med_pars);
195        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
196        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
197        assert_eq!(out, &[None, None, None, Some(3.5)]);
198
199        let out = rolling_quantile(arr, 2, 1, false, None, med_pars);
200        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
201        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
202        assert_eq!(out, &[Some(1.0), Some(1.0), Some(3.0), Some(3.5)]);
203
204        let out = rolling_quantile(arr, 4, 1, false, None, med_pars);
205        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
206        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
207        assert_eq!(out, &[Some(1.0), Some(1.0), Some(2.0), Some(3.0)]);
208
209        let out = rolling_quantile(arr, 4, 1, true, None, med_pars);
210        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
211        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
212        assert_eq!(out, &[Some(1.0), Some(2.0), Some(3.0), Some(3.5)]);
213
214        let out = rolling_quantile(arr, 4, 4, true, None, med_pars);
215        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
216        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
217        assert_eq!(out, &[None, None, None, None]);
218    }
219
220    #[test]
221    fn test_rolling_quantile_nulls_limits() {
222        // compare quantiles to corresponding min/max/median values
223        let buf = Buffer::<f64>::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
224        let values = &PrimitiveArray::new(
225            ArrowDataType::Float64,
226            buf,
227            Some(Bitmap::from(&[true, false, false, true, true])),
228        );
229
230        let methods = vec![
231            QuantileMethod::Lower,
232            QuantileMethod::Higher,
233            QuantileMethod::Nearest,
234            QuantileMethod::Midpoint,
235            QuantileMethod::Linear,
236            QuantileMethod::Equiprobable,
237        ];
238
239        for method in methods {
240            let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
241                prob: 0.0,
242                method,
243            }));
244            let out1 = rolling_min(values, 2, 1, false, None, None);
245            let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
246            let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
247            let out2 = rolling_quantile(values, 2, 1, false, None, min_pars);
248            let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
249            let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
250            assert_eq!(out1, out2);
251
252            let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
253                prob: 1.0,
254                method,
255            }));
256            let out1 = rolling_max(values, 2, 1, false, None, None);
257            let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
258            let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
259            let out2 = rolling_quantile(values, 2, 1, false, None, max_pars);
260            let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
261            let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
262            assert_eq!(out1, out2);
263        }
264    }
265}