polars_compute/horizontal_flatten/
mod.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2use arrow::array::{
3    Array, ArrayCollectIterExt, BinaryArray, BinaryViewArray, BooleanArray, FixedSizeListArray,
4    ListArray, NullArray, PrimitiveArray, StaticArray, StructArray, Utf8ViewArray,
5};
6use arrow::bitmap::Bitmap;
7use arrow::datatypes::{ArrowDataType, PhysicalType};
8use arrow::with_match_primitive_type_full;
9use strength_reduce::StrengthReducedUsize;
10mod struct_;
11
12/// Low-level operation used by `concat_arr`. This should be called with the inner values array of
13/// every FixedSizeList array.
14///
15/// # Safety
16/// * `arrays` is non-empty
17/// * `arrays` and `widths` have equal length
18/// * All widths in `widths` are non-zero
19/// * Every array `arrays[i]` has a length of either
20///   * `widths[i] * output_height`
21///   * `widths[i]` (this would be broadcasted)
22/// * All arrays in `arrays` have the same type
23pub unsafe fn horizontal_flatten_unchecked(
24    arrays: &[Box<dyn Array>],
25    widths: &[usize],
26    output_height: usize,
27) -> Box<dyn Array> {
28    use PhysicalType::*;
29
30    let dtype = arrays[0].dtype();
31
32    match dtype.to_physical_type() {
33        Null => Box::new(NullArray::new(
34            dtype.clone(),
35            output_height * widths.iter().copied().sum::<usize>(),
36        )),
37        Boolean => Box::new(horizontal_flatten_unchecked_impl_generic(
38            &arrays
39                .iter()
40                .map(|x| x.as_any().downcast_ref::<BooleanArray>().unwrap().clone())
41                .collect::<Vec<_>>(),
42            widths,
43            output_height,
44            dtype,
45        )),
46        Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| {
47            Box::new(horizontal_flatten_unchecked_impl_generic(
48                &arrays
49                    .iter()
50                    .map(|x| x.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap().clone())
51                    .collect::<Vec<_>>(),
52                widths,
53                output_height,
54                dtype
55            ))
56        }),
57        LargeBinary => Box::new(horizontal_flatten_unchecked_impl_generic(
58            &arrays
59                .iter()
60                .map(|x| {
61                    x.as_any()
62                        .downcast_ref::<BinaryArray<i64>>()
63                        .unwrap()
64                        .clone()
65                })
66                .collect::<Vec<_>>(),
67            widths,
68            output_height,
69            dtype,
70        )),
71        Struct => Box::new(struct_::horizontal_flatten_unchecked(
72            &arrays
73                .iter()
74                .map(|x| x.as_any().downcast_ref::<StructArray>().unwrap().clone())
75                .collect::<Vec<_>>(),
76            widths,
77            output_height,
78        )),
79        LargeList => Box::new(horizontal_flatten_unchecked_impl_generic(
80            &arrays
81                .iter()
82                .map(|x| x.as_any().downcast_ref::<ListArray<i64>>().unwrap().clone())
83                .collect::<Vec<_>>(),
84            widths,
85            output_height,
86            dtype,
87        )),
88        FixedSizeList => Box::new(horizontal_flatten_unchecked_impl_generic(
89            &arrays
90                .iter()
91                .map(|x| {
92                    x.as_any()
93                        .downcast_ref::<FixedSizeListArray>()
94                        .unwrap()
95                        .clone()
96                })
97                .collect::<Vec<_>>(),
98            widths,
99            output_height,
100            dtype,
101        )),
102        BinaryView => Box::new(horizontal_flatten_unchecked_impl_generic(
103            &arrays
104                .iter()
105                .map(|x| {
106                    x.as_any()
107                        .downcast_ref::<BinaryViewArray>()
108                        .unwrap()
109                        .clone()
110                })
111                .collect::<Vec<_>>(),
112            widths,
113            output_height,
114            dtype,
115        )),
116        Utf8View => Box::new(horizontal_flatten_unchecked_impl_generic(
117            &arrays
118                .iter()
119                .map(|x| x.as_any().downcast_ref::<Utf8ViewArray>().unwrap().clone())
120                .collect::<Vec<_>>(),
121            widths,
122            output_height,
123            dtype,
124        )),
125        t => unimplemented!("horizontal_flatten not supported for data type {:?}", t),
126    }
127}
128
129unsafe fn horizontal_flatten_unchecked_impl_generic<T>(
130    arrays: &[T],
131    widths: &[usize],
132    output_height: usize,
133    dtype: &ArrowDataType,
134) -> T
135where
136    T: StaticArray,
137{
138    assert!(!arrays.is_empty());
139    assert_eq!(widths.len(), arrays.len());
140
141    debug_assert!(widths.iter().all(|x| *x > 0));
142    debug_assert!(
143        arrays
144            .iter()
145            .zip(widths)
146            .all(|(arr, width)| arr.len() == output_height * *width || arr.len() == *width)
147    );
148
149    // We modulo the array length to support broadcasting.
150    let lengths = arrays
151        .iter()
152        .map(|x| StrengthReducedUsize::new(x.len()))
153        .collect::<Vec<_>>();
154    let out_row_width: usize = widths.iter().cloned().sum();
155    let out_len = out_row_width.checked_mul(output_height).unwrap();
156
157    let mut col_idx = 0;
158    let mut row_idx = 0;
159    let mut until = widths[0];
160    let mut outer_row_idx = 0;
161
162    // We do `0..out_len` to get an `ExactSizeIterator`.
163    (0..out_len)
164        .map(|_| {
165            let arr = arrays.get_unchecked(col_idx);
166            let out = arr.get_unchecked(row_idx % *lengths.get_unchecked(col_idx));
167
168            row_idx += 1;
169
170            if row_idx == until {
171                // Safety: All widths are non-zero so we only need to increment once.
172                col_idx = if 1 + col_idx == widths.len() {
173                    outer_row_idx += 1;
174                    0
175                } else {
176                    1 + col_idx
177                };
178                row_idx = outer_row_idx * *widths.get_unchecked(col_idx);
179                until = (1 + outer_row_idx) * *widths.get_unchecked(col_idx)
180            }
181
182            out
183        })
184        .collect_arr_trusted_with_dtype(dtype.clone())
185}