polars_compute/rolling/no_nulls/
mod.rs

1mod mean;
2mod min_max;
3mod moment;
4mod quantile;
5mod sum;
6use std::fmt::Debug;
7
8use arrow::array::PrimitiveArray;
9use arrow::datatypes::ArrowDataType;
10use arrow::legacy::error::PolarsResult;
11use arrow::legacy::utils::CustomIterTools;
12use arrow::types::NativeType;
13pub use mean::*;
14pub use min_max::*;
15pub use moment::*;
16use num_traits::{Float, Num, NumCast};
17pub use quantile::*;
18pub use sum::*;
19
20use super::*;
21
22pub trait RollingAggWindowNoNulls<'a, T: NativeType> {
23    fn new(
24        slice: &'a [T],
25        start: usize,
26        end: usize,
27        params: Option<RollingFnParams>,
28        window_size: Option<usize>,
29    ) -> Self;
30
31    /// Update and recompute the window
32    ///
33    /// # Safety
34    /// `start` and `end` must be within the windows bounds
35    unsafe fn update(&mut self, start: usize, end: usize) -> Option<T>;
36}
37
38// Use an aggregation window that maintains the state
39pub(super) fn rolling_apply_agg_window<'a, Agg, T, Fo>(
40    values: &'a [T],
41    window_size: usize,
42    min_periods: usize,
43    det_offsets_fn: Fo,
44    params: Option<RollingFnParams>,
45) -> PolarsResult<ArrayRef>
46where
47    Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
48    Agg: RollingAggWindowNoNulls<'a, T>,
49    T: Debug + NativeType + Num,
50{
51    let len = values.len();
52    let (start, end) = det_offsets_fn(0, window_size, len);
53    let mut agg_window = Agg::new(values, start, end, params, Some(window_size));
54    if let Some(validity) = create_validity(min_periods, len, window_size, &det_offsets_fn) {
55        if validity.iter().all(|x| !x) {
56            return Ok(Box::new(PrimitiveArray::<T>::new_null(
57                T::PRIMITIVE.into(),
58                len,
59            )));
60        }
61    }
62
63    let out = (0..len).map(|idx| {
64        let (start, end) = det_offsets_fn(idx, window_size, len);
65        if end - start < min_periods {
66            None
67        } else {
68            // SAFETY:
69            // we are in bounds
70            unsafe { agg_window.update(start, end) }
71        }
72    });
73    let arr = PrimitiveArray::from_trusted_len_iter(out);
74    Ok(Box::new(arr))
75}
76
77pub(super) fn rolling_apply_weights<T, Fo, Fa>(
78    values: &[T],
79    window_size: usize,
80    min_periods: usize,
81    det_offsets_fn: Fo,
82    aggregator: Fa,
83    weights: &[T],
84) -> PolarsResult<ArrayRef>
85where
86    T: NativeType,
87    Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
88    Fa: Fn(&[T], &[T]) -> T,
89{
90    assert_eq!(weights.len(), window_size);
91    let len = values.len();
92    let out = (0..len)
93        .map(|idx| {
94            let (start, end) = det_offsets_fn(idx, window_size, len);
95            let vals = unsafe { values.get_unchecked(start..end) };
96
97            aggregator(vals, weights)
98        })
99        .collect_trusted::<Vec<T>>();
100
101    let validity = create_validity(min_periods, len, window_size, det_offsets_fn);
102    Ok(Box::new(PrimitiveArray::new(
103        ArrowDataType::from(T::PRIMITIVE),
104        out.into(),
105        validity.map(|b| b.into()),
106    )))
107}
108
109fn compute_var_weights<T>(vals: &[T], weights: &[T]) -> T
110where
111    T: Float + std::ops::AddAssign,
112{
113    // Assumes the weights have already been standardized to 1
114    debug_assert!(
115        weights.iter().fold(T::zero(), |acc, x| acc + *x) == T::one(),
116        "Rolling weighted variance Weights don't sum to 1"
117    );
118    let (wssq, wmean) = vals
119        .iter()
120        .zip(weights)
121        .fold((T::zero(), T::zero()), |(wssq, wsum), (&v, &w)| {
122            (wssq + v * v * w, wsum + v * w)
123        });
124
125    wssq - wmean * wmean
126}
127
128pub(crate) fn compute_sum_weights<T>(values: &[T], weights: &[T]) -> T
129where
130    T: std::iter::Sum<T> + Copy + std::ops::Mul<Output = T>,
131{
132    values.iter().zip(weights).map(|(v, w)| *v * *w).sum()
133}
134
135pub(super) fn coerce_weights<T: NumCast>(weights: &[f64]) -> Vec<T>
136where
137{
138    weights
139        .iter()
140        .map(|v| NumCast::from(*v).unwrap())
141        .collect::<Vec<_>>()
142}