polars_compute/rolling/no_nulls/
min_max.rs

1use polars_utils::min_max::{MaxPropagateNan, MinMaxPolicy, MinPropagateNan};
2
3use super::super::min_max::MinMaxWindow;
4use super::*;
5
6pub type MinWindow<'a, T> = MinMaxWindow<'a, T, MinPropagateNan>;
7pub type MaxWindow<'a, T> = MinMaxWindow<'a, T, MaxPropagateNan>;
8
9fn weighted_min_max<T, P>(values: &[T], weights: &[T]) -> T
10where
11    T: NativeType + std::ops::Mul<Output = T>,
12    P: MinMaxPolicy,
13{
14    values
15        .iter()
16        .zip(weights)
17        .map(|(v, w)| *v * *w)
18        .reduce(P::best)
19        .unwrap()
20}
21
22macro_rules! rolling_minmax_func {
23    ($rolling_m:ident, $policy:ident) => {
24        pub fn $rolling_m<T>(
25            values: &[T],
26            window_size: usize,
27            min_periods: usize,
28            center: bool,
29            weights: Option<&[f64]>,
30            _params: Option<RollingFnParams>,
31        ) -> PolarsResult<ArrayRef>
32        where
33            T: NativeType + PartialOrd + IsFloat + Bounded + NumCast + Mul<Output = T> + Num,
34        {
35            let offset_fn = match center {
36                true => det_offsets_center,
37                false => det_offsets,
38            };
39            match weights {
40                None => rolling_apply_agg_window::<MinMaxWindow<T, $policy>, _, _>(
41                    values,
42                    window_size,
43                    min_periods,
44                    offset_fn,
45                    None,
46                ),
47                Some(weights) => {
48                    assert!(
49                        T::is_float(),
50                        "implementation error, should only be reachable by float types"
51                    );
52                    let weights = weights
53                        .iter()
54                        .map(|v| NumCast::from(*v).unwrap())
55                        .collect::<Vec<_>>();
56                    no_nulls::rolling_apply_weights(
57                        values,
58                        window_size,
59                        min_periods,
60                        offset_fn,
61                        weighted_min_max::<T, $policy>,
62                        &weights,
63                    )
64                },
65            }
66        }
67    };
68}
69
70rolling_minmax_func!(rolling_min, MinPropagateNan);
71rolling_minmax_func!(rolling_max, MaxPropagateNan);
72
73#[cfg(test)]
74mod test {
75    use super::*;
76
77    #[test]
78    fn test_rolling_min_max() {
79        let values = &[1.0f64, 5.0, 3.0, 4.0];
80
81        let out = rolling_min(values, 2, 2, false, None, None).unwrap();
82        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
83        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
84        assert_eq!(out, &[None, Some(1.0), Some(3.0), Some(3.0)]);
85        let out = rolling_max(values, 2, 2, false, None, None).unwrap();
86        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
87        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
88        assert_eq!(out, &[None, Some(5.0), Some(5.0), Some(4.0)]);
89
90        let out = rolling_min(values, 2, 1, false, None, None).unwrap();
91        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
92        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
93        assert_eq!(out, &[Some(1.0), Some(1.0), Some(3.0), Some(3.0)]);
94        let out = rolling_max(values, 2, 1, false, None, None).unwrap();
95        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
96        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
97        assert_eq!(out, &[Some(1.0), Some(5.0), Some(5.0), Some(4.0)]);
98
99        let out = rolling_max(values, 3, 1, false, None, None).unwrap();
100        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
101        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
102        assert_eq!(out, &[Some(1.0), Some(5.0), Some(5.0), Some(5.0)]);
103
104        // test nan handling.
105        let values = &[1.0, 2.0, 3.0, f64::nan(), 5.0, 6.0, 7.0];
106        let out = rolling_min(values, 3, 3, false, None, None).unwrap();
107        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
108        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
109        // we cannot compare nans, so we compare the string values
110        assert_eq!(
111            format!("{:?}", out.as_slice()),
112            format!(
113                "{:?}",
114                &[
115                    None,
116                    None,
117                    Some(1.0),
118                    Some(f64::nan()),
119                    Some(f64::nan()),
120                    Some(f64::nan()),
121                    Some(5.0)
122                ]
123            )
124        );
125
126        let out = rolling_max(values, 3, 3, false, None, None).unwrap();
127        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
128        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
129        assert_eq!(
130            format!("{:?}", out.as_slice()),
131            format!(
132                "{:?}",
133                &[
134                    None,
135                    None,
136                    Some(3.0),
137                    Some(f64::nan()),
138                    Some(f64::nan()),
139                    Some(f64::nan()),
140                    Some(7.0)
141                ]
142            )
143        );
144    }
145}