polars_compute/
find_validity_mismatch.rs

1use arrow::array::{Array, FixedSizeListArray, ListArray, StructArray};
2use arrow::datatypes::ArrowDataType;
3use arrow::types::Offset;
4use polars_utils::IdxSize;
5use polars_utils::itertools::Itertools;
6
7use crate::cast::CastOptionsImpl;
8
9/// Find the indices of the values where the validity mismatches.
10///
11/// This is done recursively, meaning that a validity mismatch at a deeper level will result as at
12/// the level above at the corresponding index.
13///
14/// This procedure requires that
15/// - Nulls are propagated recursively
16/// - Lists to be
17///     - trimmed to normalized offsets
18///     - have the same number of child elements below each element (even nulls)
19pub fn find_validity_mismatch(left: &dyn Array, right: &dyn Array, idxs: &mut Vec<IdxSize>) {
20    assert_eq!(left.len(), right.len());
21
22    // Handle the top-level.
23    //
24    // NOTE: This is done always, even if left and right have different nestings. This is
25    // intentional and needed.
26    let original_idxs_length = idxs.len();
27    match (left.validity(), right.validity()) {
28        (None, None) => {},
29        (Some(l), Some(r)) => {
30            if l != r {
31                let mismatches = arrow::bitmap::xor(l, r);
32                idxs.extend(mismatches.true_idx_iter().map(|i| i as IdxSize));
33            }
34        },
35        (Some(v), _) | (_, Some(v)) => {
36            if v.unset_bits() > 0 {
37                let mismatches = !v;
38                idxs.extend(mismatches.true_idx_iter().map(|i| i as IdxSize));
39            }
40        },
41    }
42
43    let left = left.as_any();
44    let right = right.as_any();
45
46    let pre_nesting_length = idxs.len();
47    // (Struct, Struct)
48    if let (Some(left), Some(right)) = (
49        left.downcast_ref::<StructArray>(),
50        right.downcast_ref::<StructArray>(),
51    ) {
52        assert_eq!(left.fields().len(), right.fields().len());
53        for (l, r) in left.values().iter().zip(right.values().iter()) {
54            find_validity_mismatch(l.as_ref(), r.as_ref(), idxs);
55        }
56    }
57
58    // (List, List)
59    if let (Some(left), Some(right)) = (
60        left.downcast_ref::<ListArray<i32>>(),
61        right.downcast_ref::<ListArray<i32>>(),
62    ) {
63        find_validity_mismatch_list_list_nested(left, right, idxs);
64    }
65    if let (Some(left), Some(right)) = (
66        left.downcast_ref::<ListArray<i64>>(),
67        right.downcast_ref::<ListArray<i64>>(),
68    ) {
69        find_validity_mismatch_list_list_nested(left, right, idxs);
70    }
71
72    // (FixedSizeList, FixedSizeList)
73    if let (Some(left), Some(right)) = (
74        left.downcast_ref::<FixedSizeListArray>(),
75        right.downcast_ref::<FixedSizeListArray>(),
76    ) {
77        assert_eq!(left.size(), right.size());
78        find_validity_mismatch_fsl_fsl_nested(
79            left.values().as_ref(),
80            right.values().as_ref(),
81            left.size(),
82            idxs,
83        )
84    }
85
86    // (List, Array) / (Array, List)
87    if let (Some(left), Some(right)) = (
88        left.downcast_ref::<ListArray<i32>>(),
89        right.downcast_ref::<FixedSizeListArray>(),
90    ) {
91        find_validity_mismatch_list_fsl_impl(left, right, idxs);
92    }
93    if let (Some(left), Some(right)) = (
94        left.downcast_ref::<ListArray<i64>>(),
95        right.downcast_ref::<FixedSizeListArray>(),
96    ) {
97        find_validity_mismatch_list_fsl_impl(left, right, idxs);
98    }
99    if let (Some(right), Some(left)) = (
100        left.downcast_ref::<FixedSizeListArray>(),
101        right.downcast_ref::<ListArray<i32>>(),
102    ) {
103        find_validity_mismatch_list_fsl_impl(left, right, idxs);
104    }
105    if let (Some(right), Some(left)) = (
106        left.downcast_ref::<FixedSizeListArray>(),
107        right.downcast_ref::<ListArray<i64>>(),
108    ) {
109        find_validity_mismatch_list_fsl_impl(left, right, idxs);
110    }
111
112    if pre_nesting_length == idxs.len() {
113        return;
114    }
115    idxs[original_idxs_length..].sort_unstable();
116}
117
118fn find_validity_mismatch_fsl_fsl_nested(
119    left: &dyn Array,
120    right: &dyn Array,
121    size: usize,
122    idxs: &mut Vec<IdxSize>,
123) {
124    assert_eq!(left.len(), right.len());
125    let start_length = idxs.len();
126    find_validity_mismatch(left, right, idxs);
127    if idxs.len() > start_length {
128        let mut offset = 0;
129        idxs[start_length] /= size as IdxSize;
130        for i in start_length + 1..idxs.len() {
131            idxs[i - offset] = idxs[i] / size as IdxSize;
132
133            if idxs[i - offset] == idxs[i - offset - 1] {
134                offset += 1;
135            }
136        }
137        idxs.truncate(idxs.len() - offset);
138    }
139}
140
141fn find_validity_mismatch_list_list_nested<O: Offset>(
142    left: &ListArray<O>,
143    right: &ListArray<O>,
144    idxs: &mut Vec<IdxSize>,
145) {
146    let mut nested_idxs = Vec::new();
147    find_validity_mismatch(
148        left.values().as_ref(),
149        right.values().as_ref(),
150        &mut nested_idxs,
151    );
152
153    if nested_idxs.is_empty() {
154        return;
155    }
156
157    assert_eq!(left.offsets().first().to_usize(), 0);
158    assert_eq!(left.offsets().range().to_usize(), left.values().len());
159
160    // @TODO: Optimize. This is only used on the error path so it is find, right?
161    let mut j = 0;
162    for (i, (start, length)) in left.offsets().offset_and_length_iter().enumerate_idx() {
163        if j < nested_idxs.len() && (nested_idxs[j] as usize) < start + length {
164            idxs.push(i);
165            j += 1;
166
167            // Loop over remaining items in same element.
168            while j < nested_idxs.len() && (nested_idxs[j] as usize) < start + length {
169                j += 1;
170            }
171        }
172
173        if j == nested_idxs.len() {
174            break;
175        }
176    }
177}
178
179fn find_validity_mismatch_list_fsl_impl<O: Offset>(
180    left: &ListArray<O>,
181    right: &FixedSizeListArray,
182    idxs: &mut Vec<IdxSize>,
183) {
184    if left.validity().is_none() && right.validity().is_none() {
185        find_validity_mismatch_fsl_fsl_nested(
186            left.values().as_ref(),
187            right.values().as_ref(),
188            right.size(),
189            idxs,
190        );
191        return;
192    }
193
194    let (ArrowDataType::List(f) | ArrowDataType::LargeList(f)) = left.dtype() else {
195        unreachable!();
196    };
197    let left = crate::cast::cast_list_to_fixed_size_list(
198        left,
199        f,
200        right.size(),
201        CastOptionsImpl::default(),
202    )
203    .unwrap();
204    find_validity_mismatch_fsl_fsl_nested(
205        left.values().as_ref(),
206        right.values().as_ref(),
207        right.size(),
208        idxs,
209    )
210}