polars_compute/rolling/no_nulls/
min_max.rs1use polars_utils::min_max::{MaxPropagateNan, MinMaxPolicy, MinPropagateNan};
2
3use super::super::min_max::MinMaxWindow;
4use super::*;
5
6pub type MinWindow<'a, T> = MinMaxWindow<'a, T, MinPropagateNan>;
7pub type MaxWindow<'a, T> = MinMaxWindow<'a, T, MaxPropagateNan>;
8
9fn weighted_min_max<T, P>(values: &[T], weights: &[T]) -> T
10where
11 T: NativeType + std::ops::Mul<Output = T>,
12 P: MinMaxPolicy,
13{
14 values
15 .iter()
16 .zip(weights)
17 .map(|(v, w)| *v * *w)
18 .reduce(P::best)
19 .unwrap()
20}
21
22macro_rules! rolling_minmax_func {
23 ($rolling_m:ident, $policy:ident) => {
24 pub fn $rolling_m<T>(
25 values: &[T],
26 window_size: usize,
27 min_periods: usize,
28 center: bool,
29 weights: Option<&[f64]>,
30 _params: Option<RollingFnParams>,
31 ) -> PolarsResult<ArrayRef>
32 where
33 T: NativeType + PartialOrd + IsFloat + Bounded + NumCast + Mul<Output = T> + Num,
34 {
35 let offset_fn = match center {
36 true => det_offsets_center,
37 false => det_offsets,
38 };
39 match weights {
40 None => rolling_apply_agg_window::<MinMaxWindow<T, $policy>, _, _>(
41 values,
42 window_size,
43 min_periods,
44 offset_fn,
45 None,
46 ),
47 Some(weights) => {
48 assert!(
49 T::is_float(),
50 "implementation error, should only be reachable by float types"
51 );
52 let weights = weights
53 .iter()
54 .map(|v| NumCast::from(*v).unwrap())
55 .collect::<Vec<_>>();
56 no_nulls::rolling_apply_weights(
57 values,
58 window_size,
59 min_periods,
60 offset_fn,
61 weighted_min_max::<T, $policy>,
62 &weights,
63 )
64 },
65 }
66 }
67 };
68}
69
70rolling_minmax_func!(rolling_min, MinPropagateNan);
71rolling_minmax_func!(rolling_max, MaxPropagateNan);
72
73#[cfg(test)]
74mod test {
75 use super::*;
76
77 #[test]
78 fn test_rolling_min_max() {
79 let values = &[1.0f64, 5.0, 3.0, 4.0];
80
81 let out = rolling_min(values, 2, 2, false, None, None).unwrap();
82 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
83 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
84 assert_eq!(out, &[None, Some(1.0), Some(3.0), Some(3.0)]);
85 let out = rolling_max(values, 2, 2, false, None, None).unwrap();
86 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
87 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
88 assert_eq!(out, &[None, Some(5.0), Some(5.0), Some(4.0)]);
89
90 let out = rolling_min(values, 2, 1, false, None, None).unwrap();
91 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
92 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
93 assert_eq!(out, &[Some(1.0), Some(1.0), Some(3.0), Some(3.0)]);
94 let out = rolling_max(values, 2, 1, false, None, None).unwrap();
95 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
96 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
97 assert_eq!(out, &[Some(1.0), Some(5.0), Some(5.0), Some(4.0)]);
98
99 let out = rolling_max(values, 3, 1, false, None, None).unwrap();
100 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
101 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
102 assert_eq!(out, &[Some(1.0), Some(5.0), Some(5.0), Some(5.0)]);
103
104 let values = &[1.0, 2.0, 3.0, f64::nan(), 5.0, 6.0, 7.0];
106 let out = rolling_min(values, 3, 3, false, None, None).unwrap();
107 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
108 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
109 assert_eq!(
111 format!("{:?}", out.as_slice()),
112 format!(
113 "{:?}",
114 &[
115 None,
116 None,
117 Some(1.0),
118 Some(f64::nan()),
119 Some(f64::nan()),
120 Some(f64::nan()),
121 Some(5.0)
122 ]
123 )
124 );
125
126 let out = rolling_max(values, 3, 3, false, None, None).unwrap();
127 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
128 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
129 assert_eq!(
130 format!("{:?}", out.as_slice()),
131 format!(
132 "{:?}",
133 &[
134 None,
135 None,
136 Some(3.0),
137 Some(f64::nan()),
138 Some(f64::nan()),
139 Some(f64::nan()),
140 Some(7.0)
141 ]
142 )
143 );
144 }
145}