polars_compute/horizontal_flatten/
mod.rs1#![allow(unsafe_op_in_unsafe_fn)]
2use arrow::array::{
3 Array, ArrayCollectIterExt, BinaryArray, BinaryViewArray, BooleanArray, FixedSizeListArray,
4 ListArray, NullArray, PrimitiveArray, StaticArray, StructArray, Utf8ViewArray,
5};
6use arrow::bitmap::Bitmap;
7use arrow::datatypes::{ArrowDataType, PhysicalType};
8use arrow::with_match_primitive_type_full;
9use strength_reduce::StrengthReducedUsize;
10mod struct_;
11
12pub unsafe fn horizontal_flatten_unchecked(
24 arrays: &[Box<dyn Array>],
25 widths: &[usize],
26 output_height: usize,
27) -> Box<dyn Array> {
28 use PhysicalType::*;
29
30 let dtype = arrays[0].dtype();
31
32 match dtype.to_physical_type() {
33 Null => Box::new(NullArray::new(
34 dtype.clone(),
35 output_height * widths.iter().copied().sum::<usize>(),
36 )),
37 Boolean => Box::new(horizontal_flatten_unchecked_impl_generic(
38 &arrays
39 .iter()
40 .map(|x| x.as_any().downcast_ref::<BooleanArray>().unwrap().clone())
41 .collect::<Vec<_>>(),
42 widths,
43 output_height,
44 dtype,
45 )),
46 Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| {
47 Box::new(horizontal_flatten_unchecked_impl_generic(
48 &arrays
49 .iter()
50 .map(|x| x.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap().clone())
51 .collect::<Vec<_>>(),
52 widths,
53 output_height,
54 dtype
55 ))
56 }),
57 LargeBinary => Box::new(horizontal_flatten_unchecked_impl_generic(
58 &arrays
59 .iter()
60 .map(|x| {
61 x.as_any()
62 .downcast_ref::<BinaryArray<i64>>()
63 .unwrap()
64 .clone()
65 })
66 .collect::<Vec<_>>(),
67 widths,
68 output_height,
69 dtype,
70 )),
71 Struct => Box::new(struct_::horizontal_flatten_unchecked(
72 &arrays
73 .iter()
74 .map(|x| x.as_any().downcast_ref::<StructArray>().unwrap().clone())
75 .collect::<Vec<_>>(),
76 widths,
77 output_height,
78 )),
79 LargeList => Box::new(horizontal_flatten_unchecked_impl_generic(
80 &arrays
81 .iter()
82 .map(|x| x.as_any().downcast_ref::<ListArray<i64>>().unwrap().clone())
83 .collect::<Vec<_>>(),
84 widths,
85 output_height,
86 dtype,
87 )),
88 FixedSizeList => Box::new(horizontal_flatten_unchecked_impl_generic(
89 &arrays
90 .iter()
91 .map(|x| {
92 x.as_any()
93 .downcast_ref::<FixedSizeListArray>()
94 .unwrap()
95 .clone()
96 })
97 .collect::<Vec<_>>(),
98 widths,
99 output_height,
100 dtype,
101 )),
102 BinaryView => Box::new(horizontal_flatten_unchecked_impl_generic(
103 &arrays
104 .iter()
105 .map(|x| {
106 x.as_any()
107 .downcast_ref::<BinaryViewArray>()
108 .unwrap()
109 .clone()
110 })
111 .collect::<Vec<_>>(),
112 widths,
113 output_height,
114 dtype,
115 )),
116 Utf8View => Box::new(horizontal_flatten_unchecked_impl_generic(
117 &arrays
118 .iter()
119 .map(|x| x.as_any().downcast_ref::<Utf8ViewArray>().unwrap().clone())
120 .collect::<Vec<_>>(),
121 widths,
122 output_height,
123 dtype,
124 )),
125 t => unimplemented!("horizontal_flatten not supported for data type {:?}", t),
126 }
127}
128
129unsafe fn horizontal_flatten_unchecked_impl_generic<T>(
130 arrays: &[T],
131 widths: &[usize],
132 output_height: usize,
133 dtype: &ArrowDataType,
134) -> T
135where
136 T: StaticArray,
137{
138 assert!(!arrays.is_empty());
139 assert_eq!(widths.len(), arrays.len());
140
141 debug_assert!(widths.iter().all(|x| *x > 0));
142 debug_assert!(
143 arrays
144 .iter()
145 .zip(widths)
146 .all(|(arr, width)| arr.len() == output_height * *width || arr.len() == *width)
147 );
148
149 let lengths = arrays
151 .iter()
152 .map(|x| StrengthReducedUsize::new(x.len()))
153 .collect::<Vec<_>>();
154 let out_row_width: usize = widths.iter().cloned().sum();
155 let out_len = out_row_width.checked_mul(output_height).unwrap();
156
157 let mut col_idx = 0;
158 let mut row_idx = 0;
159 let mut until = widths[0];
160 let mut outer_row_idx = 0;
161
162 (0..out_len)
164 .map(|_| {
165 let arr = arrays.get_unchecked(col_idx);
166 let out = arr.get_unchecked(row_idx % *lengths.get_unchecked(col_idx));
167
168 row_idx += 1;
169
170 if row_idx == until {
171 col_idx = if 1 + col_idx == widths.len() {
173 outer_row_idx += 1;
174 0
175 } else {
176 1 + col_idx
177 };
178 row_idx = outer_row_idx * *widths.get_unchecked(col_idx);
179 until = (1 + outer_row_idx) * *widths.get_unchecked(col_idx)
180 }
181
182 out
183 })
184 .collect_arr_trusted_with_dtype(dtype.clone())
185}