1use std::convert::identity;
2
3use arrow::array::{Array, BooleanArray, PrimitiveArray};
4use arrow::bitmap::{binary_fold, intersects_with};
5use arrow::datatypes::ArrowDataType;
6use arrow::legacy::utils::CustomIterTools;
7
8pub trait BitwiseKernel {
9 type Scalar;
10
11 fn count_ones(&self) -> PrimitiveArray<u32>;
12 fn count_zeros(&self) -> PrimitiveArray<u32>;
13
14 fn leading_ones(&self) -> PrimitiveArray<u32>;
15 fn leading_zeros(&self) -> PrimitiveArray<u32>;
16
17 fn trailing_ones(&self) -> PrimitiveArray<u32>;
18 fn trailing_zeros(&self) -> PrimitiveArray<u32>;
19
20 fn reduce_and(&self) -> Option<Self::Scalar>;
21 fn reduce_or(&self) -> Option<Self::Scalar>;
22 fn reduce_xor(&self) -> Option<Self::Scalar>;
23
24 fn bit_and(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar;
25 fn bit_or(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar;
26 fn bit_xor(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar;
27}
28
29macro_rules! impl_bitwise_kernel {
30 ($(($T:ty, $to_bits:expr, $from_bits:expr)),+ $(,)?) => {
31 $(
32 impl BitwiseKernel for PrimitiveArray<$T> {
33 type Scalar = $T;
34
35 #[inline(never)]
36 fn count_ones(&self) -> PrimitiveArray<u32> {
37 PrimitiveArray::new(
38 ArrowDataType::UInt32,
39 self.values_iter()
40 .map(|&v| $to_bits(v).count_ones())
41 .collect_trusted::<Vec<_>>()
42 .into(),
43 self.validity().cloned(),
44 )
45 }
46
47 #[inline(never)]
48 fn count_zeros(&self) -> PrimitiveArray<u32> {
49 PrimitiveArray::new(
50 ArrowDataType::UInt32,
51 self.values_iter()
52 .map(|&v| $to_bits(v).count_zeros())
53 .collect_trusted::<Vec<_>>()
54 .into(),
55 self.validity().cloned(),
56 )
57 }
58
59 #[inline(never)]
60 fn leading_ones(&self) -> PrimitiveArray<u32> {
61 PrimitiveArray::new(
62 ArrowDataType::UInt32,
63 self.values_iter()
64 .map(|&v| $to_bits(v).leading_ones())
65 .collect_trusted::<Vec<_>>()
66 .into(),
67 self.validity().cloned(),
68 )
69 }
70
71 #[inline(never)]
72 fn leading_zeros(&self) -> PrimitiveArray<u32> {
73 PrimitiveArray::new(
74 ArrowDataType::UInt32,
75 self.values_iter()
76 .map(|&v| $to_bits(v).leading_zeros())
77 .collect_trusted::<Vec<_>>()
78 .into(),
79 self.validity().cloned(),
80 )
81 }
82
83 #[inline(never)]
84 fn trailing_ones(&self) -> PrimitiveArray<u32> {
85 PrimitiveArray::new(
86 ArrowDataType::UInt32,
87 self.values_iter()
88 .map(|&v| $to_bits(v).trailing_ones())
89 .collect_trusted::<Vec<_>>()
90 .into(),
91 self.validity().cloned(),
92 )
93 }
94
95 #[inline(never)]
96 fn trailing_zeros(&self) -> PrimitiveArray<u32> {
97 PrimitiveArray::new(
98 ArrowDataType::UInt32,
99 self.values().iter()
100 .map(|&v| $to_bits(v).trailing_zeros())
101 .collect_trusted::<Vec<_>>()
102 .into(),
103 self.validity().cloned(),
104 )
105 }
106
107 #[inline(never)]
108 fn reduce_and(&self) -> Option<Self::Scalar> {
109 if !self.has_nulls() {
110 self.values_iter().copied().map($to_bits).reduce(|a, b| a & b).map($from_bits)
111 } else {
112 self.non_null_values_iter().map($to_bits).reduce(|a, b| a & b).map($from_bits)
113 }
114 }
115
116 #[inline(never)]
117 fn reduce_or(&self) -> Option<Self::Scalar> {
118 if !self.has_nulls() {
119 self.values_iter().copied().map($to_bits).reduce(|a, b| a | b).map($from_bits)
120 } else {
121 self.non_null_values_iter().map($to_bits).reduce(|a, b| a | b).map($from_bits)
122 }
123 }
124
125 #[inline(never)]
126 fn reduce_xor(&self) -> Option<Self::Scalar> {
127 if !self.has_nulls() {
128 self.values_iter().copied().map($to_bits).reduce(|a, b| a ^ b).map($from_bits)
129 } else {
130 self.non_null_values_iter().map($to_bits).reduce(|a, b| a ^ b).map($from_bits)
131 }
132 }
133
134 fn bit_and(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
135 $from_bits($to_bits(lhs) & $to_bits(rhs))
136 }
137 fn bit_or(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
138 $from_bits($to_bits(lhs) | $to_bits(rhs))
139 }
140 fn bit_xor(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
141 $from_bits($to_bits(lhs) ^ $to_bits(rhs))
142 }
143 }
144 )+
145 };
146}
147
148impl_bitwise_kernel! {
149 (i8, identity, identity),
150 (i16, identity, identity),
151 (i32, identity, identity),
152 (i64, identity, identity),
153 (u8, identity, identity),
154 (u16, identity, identity),
155 (u32, identity, identity),
156 (u64, identity, identity),
157 (f32, f32::to_bits, f32::from_bits),
158 (f64, f64::to_bits, f64::from_bits),
159}
160
161#[cfg(feature = "dtype-i128")]
162impl_bitwise_kernel! {
163 (i128, identity, identity),
164}
165
166impl BitwiseKernel for BooleanArray {
167 type Scalar = bool;
168
169 #[inline(never)]
170 fn count_ones(&self) -> PrimitiveArray<u32> {
171 PrimitiveArray::new(
172 ArrowDataType::UInt32,
173 self.values_iter()
174 .map(u32::from)
175 .collect_trusted::<Vec<_>>()
176 .into(),
177 self.validity().cloned(),
178 )
179 }
180
181 #[inline(never)]
182 fn count_zeros(&self) -> PrimitiveArray<u32> {
183 PrimitiveArray::new(
184 ArrowDataType::UInt32,
185 self.values_iter()
186 .map(|v| u32::from(!v))
187 .collect_trusted::<Vec<_>>()
188 .into(),
189 self.validity().cloned(),
190 )
191 }
192
193 #[inline(always)]
194 fn leading_ones(&self) -> PrimitiveArray<u32> {
195 self.count_ones()
196 }
197
198 #[inline(always)]
199 fn leading_zeros(&self) -> PrimitiveArray<u32> {
200 self.count_zeros()
201 }
202
203 #[inline(always)]
204 fn trailing_ones(&self) -> PrimitiveArray<u32> {
205 self.count_ones()
206 }
207
208 #[inline(always)]
209 fn trailing_zeros(&self) -> PrimitiveArray<u32> {
210 self.count_zeros()
211 }
212
213 fn reduce_and(&self) -> Option<Self::Scalar> {
214 if self.len() == self.null_count() {
215 None
216 } else if !self.has_nulls() {
217 Some(self.values().unset_bits() == 0)
218 } else {
219 let false_found = binary_fold(
220 self.values(),
221 self.validity().unwrap(),
222 |lhs, rhs| (!lhs & rhs) != 0,
223 false,
224 |a, b| a || b,
225 );
226 Some(!false_found)
227 }
228 }
229
230 fn reduce_or(&self) -> Option<Self::Scalar> {
231 if self.len() == self.null_count() {
232 None
233 } else if !self.has_nulls() {
234 Some(self.values().set_bits() > 0)
235 } else {
236 Some(intersects_with(self.values(), self.validity().unwrap()))
237 }
238 }
239
240 fn reduce_xor(&self) -> Option<Self::Scalar> {
241 if self.len() == self.null_count() {
242 None
243 } else if !self.has_nulls() {
244 Some(self.values().set_bits() % 2 == 1)
245 } else {
246 let nonnull_parity = binary_fold(
247 self.values(),
248 self.validity().unwrap(),
249 |lhs, rhs| lhs & rhs,
250 0,
251 |a, b| a ^ b,
252 );
253 Some(nonnull_parity.count_ones() % 2 == 1)
254 }
255 }
256
257 fn bit_and(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
258 lhs & rhs
259 }
260 fn bit_or(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
261 lhs | rhs
262 }
263 fn bit_xor(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
264 lhs ^ rhs
265 }
266}