polars_compute/rolling/no_nulls/
mod.rs1mod mean;
2mod min_max;
3mod moment;
4mod quantile;
5mod sum;
6use std::fmt::Debug;
7
8use arrow::array::PrimitiveArray;
9use arrow::datatypes::ArrowDataType;
10use arrow::legacy::error::PolarsResult;
11use arrow::legacy::utils::CustomIterTools;
12use arrow::types::NativeType;
13pub use mean::*;
14pub use min_max::*;
15pub use moment::*;
16use num_traits::{Float, Num, NumCast};
17pub use quantile::*;
18pub use sum::*;
19
20use super::*;
21
22pub trait RollingAggWindowNoNulls<'a, T: NativeType> {
23 fn new(
24 slice: &'a [T],
25 start: usize,
26 end: usize,
27 params: Option<RollingFnParams>,
28 window_size: Option<usize>,
29 ) -> Self;
30
31 unsafe fn update(&mut self, start: usize, end: usize) -> Option<T>;
36}
37
38pub(super) fn rolling_apply_agg_window<'a, Agg, T, Fo>(
40 values: &'a [T],
41 window_size: usize,
42 min_periods: usize,
43 det_offsets_fn: Fo,
44 params: Option<RollingFnParams>,
45) -> PolarsResult<ArrayRef>
46where
47 Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
48 Agg: RollingAggWindowNoNulls<'a, T>,
49 T: Debug + NativeType + Num,
50{
51 let len = values.len();
52 let (start, end) = det_offsets_fn(0, window_size, len);
53 let mut agg_window = Agg::new(values, start, end, params, Some(window_size));
54 if let Some(validity) = create_validity(min_periods, len, window_size, &det_offsets_fn) {
55 if validity.iter().all(|x| !x) {
56 return Ok(Box::new(PrimitiveArray::<T>::new_null(
57 T::PRIMITIVE.into(),
58 len,
59 )));
60 }
61 }
62
63 let out = (0..len).map(|idx| {
64 let (start, end) = det_offsets_fn(idx, window_size, len);
65 if end - start < min_periods {
66 None
67 } else {
68 unsafe { agg_window.update(start, end) }
71 }
72 });
73 let arr = PrimitiveArray::from_trusted_len_iter(out);
74 Ok(Box::new(arr))
75}
76
77pub(super) fn rolling_apply_weights<T, Fo, Fa>(
78 values: &[T],
79 window_size: usize,
80 min_periods: usize,
81 det_offsets_fn: Fo,
82 aggregator: Fa,
83 weights: &[T],
84) -> PolarsResult<ArrayRef>
85where
86 T: NativeType,
87 Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
88 Fa: Fn(&[T], &[T]) -> T,
89{
90 assert_eq!(weights.len(), window_size);
91 let len = values.len();
92 let out = (0..len)
93 .map(|idx| {
94 let (start, end) = det_offsets_fn(idx, window_size, len);
95 let vals = unsafe { values.get_unchecked(start..end) };
96
97 aggregator(vals, weights)
98 })
99 .collect_trusted::<Vec<T>>();
100
101 let validity = create_validity(min_periods, len, window_size, det_offsets_fn);
102 Ok(Box::new(PrimitiveArray::new(
103 ArrowDataType::from(T::PRIMITIVE),
104 out.into(),
105 validity.map(|b| b.into()),
106 )))
107}
108
109fn compute_var_weights<T>(vals: &[T], weights: &[T]) -> T
110where
111 T: Float + std::ops::AddAssign,
112{
113 debug_assert!(
115 weights.iter().fold(T::zero(), |acc, x| acc + *x) == T::one(),
116 "Rolling weighted variance Weights don't sum to 1"
117 );
118 let (wssq, wmean) = vals
119 .iter()
120 .zip(weights)
121 .fold((T::zero(), T::zero()), |(wssq, wsum), (&v, &w)| {
122 (wssq + v * v * w, wsum + v * w)
123 });
124
125 wssq - wmean * wmean
126}
127
128pub(crate) fn compute_sum_weights<T>(values: &[T], weights: &[T]) -> T
129where
130 T: std::iter::Sum<T> + Copy + std::ops::Mul<Output = T>,
131{
132 values.iter().zip(weights).map(|(v, w)| *v * *w).sum()
133}
134
135pub(super) fn coerce_weights<T: NumCast>(weights: &[f64]) -> Vec<T>
136where
137{
138 weights
139 .iter()
140 .map(|v| NumCast::from(*v).unwrap())
141 .collect::<Vec<_>>()
142}