polars_compute/
propagate_nulls.rs

1use arrow::array::{Array, FixedSizeListArray, ListArray, StructArray};
2use arrow::bitmap::BitmapBuilder;
3use arrow::bitmap::bitmask::BitMask;
4use arrow::types::Offset;
5
6/// Propagate nulls down to masked-out values in lower nesting levels.
7pub fn propagate_nulls(arr: &dyn Array) -> Option<Box<dyn Array>> {
8    let arr = arr.as_any();
9    if let Some(arr) = arr.downcast_ref::<ListArray<i32>>() {
10        return propagate_nulls_list(arr).map(|arr| Box::new(arr) as _);
11    }
12    if let Some(arr) = arr.downcast_ref::<ListArray<i64>>() {
13        return propagate_nulls_list(arr).map(|arr| Box::new(arr) as _);
14    }
15    if let Some(arr) = arr.downcast_ref::<FixedSizeListArray>() {
16        return propagate_nulls_fsl(arr).map(|arr| Box::new(arr) as _);
17    }
18    if let Some(arr) = arr.downcast_ref::<StructArray>() {
19        return propagate_nulls_struct(arr).map(|arr| Box::new(arr) as _);
20    }
21
22    None
23}
24
25pub fn propagate_nulls_list<O: Offset>(arr: &ListArray<O>) -> Option<ListArray<O>> {
26    let Some(validity) = arr.validity() else {
27        return propagate_nulls(arr.values().as_ref()).map(|values| {
28            ListArray::new(arr.dtype().clone(), arr.offsets().clone(), values, None)
29        });
30    };
31
32    let mut last_idx = 0;
33    let old_child_validity = arr.values().validity();
34    let mut new_child_validity = BitmapBuilder::new();
35
36    let mut new_values = None;
37
38    // Find the first element that does not have propagated nulls.
39    let null_mask = !validity;
40    for i in null_mask.true_idx_iter() {
41        last_idx = i;
42        let (start, end) = arr.offsets().start_end(i);
43        if end == start {
44            continue;
45        }
46
47        if old_child_validity.is_none_or(|v| {
48            BitMask::from_bitmap(v)
49                .sliced(start, end - start)
50                .set_bits()
51                > 0
52        }) {
53            new_child_validity.subslice_extend_from_opt_validity(old_child_validity, 0, start);
54            new_child_validity.extend_constant(end - start, false);
55            break;
56        }
57    }
58
59    if !new_child_validity.is_empty() {
60        // If nulls need to be propagated, create a new validity mask for the child array.
61        let null_mask = null_mask.sliced(last_idx + 1, arr.len() - last_idx - 1);
62
63        for i in null_mask.true_idx_iter() {
64            let i = i + last_idx + 1;
65            let (start, end) = arr.offsets().start_end(i);
66            if end == start {
67                continue;
68            }
69
70            new_child_validity.subslice_extend_from_opt_validity(
71                old_child_validity,
72                new_child_validity.len(),
73                start - new_child_validity.len(),
74            );
75            new_child_validity.extend_constant(end - start, false);
76        }
77
78        new_child_validity.subslice_extend_from_opt_validity(
79            old_child_validity,
80            new_child_validity.len(),
81            arr.values().len() - new_child_validity.len(),
82        );
83
84        let new_child_validity = new_child_validity.freeze();
85        new_values = Some(arr.values().with_validity(Some(new_child_validity)));
86    }
87
88    let Some(values) = new_values
89        .as_ref()
90        .and_then(|v| propagate_nulls(v.as_ref()))
91        .or(new_values)
92    else {
93        // Nothing was changed. Return the original array.
94        return None;
95    };
96
97    Some(ListArray::new(
98        arr.dtype().clone(),
99        arr.offsets().clone(),
100        values,
101        Some(validity.clone()),
102    ))
103}
104
105pub fn propagate_nulls_fsl(arr: &FixedSizeListArray) -> Option<FixedSizeListArray> {
106    let Some(validity) = arr.validity() else {
107        return propagate_nulls(arr.values().as_ref())
108            .map(|values| FixedSizeListArray::new(arr.dtype().clone(), arr.len(), values, None));
109    };
110
111    if arr.size() == 0 || validity.unset_bits() == 0 {
112        return None;
113    }
114
115    let start_point = match arr.values().validity() {
116        None => Some(validity.leading_ones()),
117        Some(old_child_validity) => {
118            // Find the first element that does not have propagated nulls.
119            let null_mask = !validity;
120            null_mask.true_idx_iter().find(|i| {
121                BitMask::from_bitmap(old_child_validity)
122                    .sliced(i * arr.size(), arr.size())
123                    .set_bits()
124                    > 0
125            })
126        },
127    };
128
129    let mut new_values = None;
130    if let Some(start_point) = start_point {
131        // Nulls need to be propagated, create a new validity mask.
132        let mut new_child_validity = BitmapBuilder::with_capacity(arr.size() * arr.len());
133
134        let mut validity = validity.clone();
135        validity.slice(start_point, validity.len() - start_point);
136        match arr.values().validity() {
137            None => {
138                new_child_validity.extend_constant(start_point * arr.size(), true);
139
140                while !validity.is_empty() {
141                    let num_zeroes = validity.take_leading_zeros();
142                    new_child_validity.extend_constant(num_zeroes * arr.size(), false);
143
144                    let num_ones = validity.take_leading_ones();
145                    new_child_validity.extend_constant(num_ones * arr.size(), true);
146                }
147            },
148
149            Some(old_child_validity) => {
150                new_child_validity.subslice_extend_from_bitmap(
151                    old_child_validity,
152                    0,
153                    start_point * arr.size(),
154                );
155                while !validity.is_empty() {
156                    let num_zeroes = validity.take_leading_zeros();
157                    new_child_validity.extend_constant(num_zeroes * arr.size(), false);
158
159                    let num_ones = validity.take_leading_ones();
160                    new_child_validity.subslice_extend_from_bitmap(
161                        old_child_validity,
162                        new_child_validity.len(),
163                        num_ones * arr.size(),
164                    );
165                }
166            },
167        }
168
169        let new_child_validity = new_child_validity.freeze();
170        new_values = Some(arr.values().with_validity(Some(new_child_validity)));
171    }
172
173    let Some(values) = new_values
174        .as_ref()
175        .and_then(|v| propagate_nulls(v.as_ref()))
176        .or(new_values)
177    else {
178        // Nothing was changed. Return the original array.
179        return None;
180    };
181
182    // The child array was changed.
183    Some(FixedSizeListArray::new(
184        arr.dtype().clone(),
185        arr.len(),
186        values,
187        Some(validity.clone()),
188    ))
189}
190
191pub fn propagate_nulls_struct(arr: &StructArray) -> Option<StructArray> {
192    let Some(validity) = arr.validity() else {
193        let mut new_values = Vec::new();
194        for (i, field_array) in arr.values().iter().enumerate() {
195            if let Some(field_array) = propagate_nulls(field_array.as_ref()) {
196                new_values.reserve(arr.values().len());
197                new_values.extend(arr.values()[..i].iter().cloned());
198                new_values.push(field_array);
199                break;
200            }
201        }
202
203        if new_values.is_empty() {
204            return None;
205        }
206
207        new_values.extend(arr.values()[new_values.len()..].iter().map(|field_array| {
208            propagate_nulls(field_array.as_ref()).unwrap_or_else(|| field_array.to_boxed())
209        }));
210        return Some(StructArray::new(
211            arr.dtype().clone(),
212            arr.len(),
213            new_values,
214            None,
215        ));
216    };
217
218    if arr.values().is_empty() || validity.unset_bits() == 0 {
219        return None;
220    }
221
222    let mut new_values = Vec::new();
223    for (i, field_array) in arr.values().iter().enumerate() {
224        let new_field_array = match field_array.validity() {
225            None => Some(field_array.with_validity(Some(validity.clone()))),
226            Some(v) if v.num_intersections_with(validity) == validity.set_bits() => None,
227            Some(v) => Some(field_array.with_validity(Some(v & validity))),
228        };
229
230        let Some(new_field_array) = new_field_array
231            .as_ref()
232            .and_then(|v| propagate_nulls(v.as_ref()))
233            .or(new_field_array)
234        else {
235            // Nothing was changed. Return the original array.
236            continue;
237        };
238
239        new_values.reserve(arr.values().len());
240        new_values.extend(arr.values()[..i].iter().cloned());
241        new_values.push(new_field_array);
242        break;
243    }
244
245    if new_values.is_empty() {
246        return None;
247    }
248
249    new_values.extend(arr.values()[new_values.len()..].iter().map(|field_array| {
250        let new_field_array = match field_array.validity() {
251            None => Some(field_array.with_validity(Some(validity.clone()))),
252            Some(v) if v.num_intersections_with(validity) == validity.set_bits() => None,
253            Some(v) => Some(field_array.with_validity(Some(v & validity))),
254        };
255
256        new_field_array
257            .as_ref()
258            .and_then(|v| propagate_nulls(v.as_ref()))
259            .or(new_field_array)
260            .unwrap_or_else(|| field_array.clone())
261    }));
262
263    Some(StructArray::new(
264        arr.dtype().clone(),
265        arr.len(),
266        new_values,
267        Some(validity.clone()),
268    ))
269}
270
271#[cfg(test)]
272mod tests {
273    use arrow::array::proptest::array;
274    use proptest::proptest;
275
276    use crate::propagate_nulls::propagate_nulls;
277
278    proptest! {
279        #[test]
280        fn test_proptest(array in array(0..100)) {
281            if let Some(p_arr) = propagate_nulls(array.as_ref()) {
282                proptest::prop_assert_eq!(array, p_arr);
283            }
284        }
285    }
286}