polars_compute/bitwise/
mod.rs

1use std::convert::identity;
2
3use arrow::array::{Array, BooleanArray, PrimitiveArray};
4use arrow::bitmap::{binary_fold, intersects_with};
5use arrow::datatypes::ArrowDataType;
6use arrow::legacy::utils::CustomIterTools;
7
8pub trait BitwiseKernel {
9    type Scalar;
10
11    fn count_ones(&self) -> PrimitiveArray<u32>;
12    fn count_zeros(&self) -> PrimitiveArray<u32>;
13
14    fn leading_ones(&self) -> PrimitiveArray<u32>;
15    fn leading_zeros(&self) -> PrimitiveArray<u32>;
16
17    fn trailing_ones(&self) -> PrimitiveArray<u32>;
18    fn trailing_zeros(&self) -> PrimitiveArray<u32>;
19
20    fn reduce_and(&self) -> Option<Self::Scalar>;
21    fn reduce_or(&self) -> Option<Self::Scalar>;
22    fn reduce_xor(&self) -> Option<Self::Scalar>;
23
24    fn bit_and(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar;
25    fn bit_or(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar;
26    fn bit_xor(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar;
27}
28
29macro_rules! impl_bitwise_kernel {
30    ($(($T:ty, $to_bits:expr, $from_bits:expr)),+ $(,)?) => {
31        $(
32        impl BitwiseKernel for PrimitiveArray<$T> {
33            type Scalar = $T;
34
35            #[inline(never)]
36            fn count_ones(&self) -> PrimitiveArray<u32> {
37                PrimitiveArray::new(
38                    ArrowDataType::UInt32,
39                    self.values_iter()
40                        .map(|&v| $to_bits(v).count_ones())
41                        .collect_trusted::<Vec<_>>()
42                        .into(),
43                    self.validity().cloned(),
44                )
45            }
46
47            #[inline(never)]
48            fn count_zeros(&self) -> PrimitiveArray<u32> {
49                PrimitiveArray::new(
50                    ArrowDataType::UInt32,
51                    self.values_iter()
52                        .map(|&v| $to_bits(v).count_zeros())
53                        .collect_trusted::<Vec<_>>()
54                        .into(),
55                    self.validity().cloned(),
56                )
57            }
58
59            #[inline(never)]
60            fn leading_ones(&self) -> PrimitiveArray<u32> {
61                PrimitiveArray::new(
62                    ArrowDataType::UInt32,
63                    self.values_iter()
64                        .map(|&v| $to_bits(v).leading_ones())
65                        .collect_trusted::<Vec<_>>()
66                        .into(),
67                    self.validity().cloned(),
68                )
69            }
70
71            #[inline(never)]
72            fn leading_zeros(&self) -> PrimitiveArray<u32> {
73                PrimitiveArray::new(
74                    ArrowDataType::UInt32,
75                    self.values_iter()
76                        .map(|&v| $to_bits(v).leading_zeros())
77                        .collect_trusted::<Vec<_>>()
78                        .into(),
79                    self.validity().cloned(),
80                )
81            }
82
83            #[inline(never)]
84            fn trailing_ones(&self) -> PrimitiveArray<u32> {
85                PrimitiveArray::new(
86                    ArrowDataType::UInt32,
87                    self.values_iter()
88                        .map(|&v| $to_bits(v).trailing_ones())
89                        .collect_trusted::<Vec<_>>()
90                        .into(),
91                    self.validity().cloned(),
92                )
93            }
94
95            #[inline(never)]
96            fn trailing_zeros(&self) -> PrimitiveArray<u32> {
97                PrimitiveArray::new(
98                    ArrowDataType::UInt32,
99                    self.values().iter()
100                        .map(|&v| $to_bits(v).trailing_zeros())
101                        .collect_trusted::<Vec<_>>()
102                        .into(),
103                    self.validity().cloned(),
104                )
105            }
106
107            #[inline(never)]
108            fn reduce_and(&self) -> Option<Self::Scalar> {
109                if !self.has_nulls() {
110                    self.values_iter().copied().map($to_bits).reduce(|a, b| a & b).map($from_bits)
111                } else {
112                    self.non_null_values_iter().map($to_bits).reduce(|a, b| a & b).map($from_bits)
113                }
114            }
115
116            #[inline(never)]
117            fn reduce_or(&self) -> Option<Self::Scalar> {
118                if !self.has_nulls() {
119                    self.values_iter().copied().map($to_bits).reduce(|a, b| a | b).map($from_bits)
120                } else {
121                    self.non_null_values_iter().map($to_bits).reduce(|a, b| a | b).map($from_bits)
122                }
123            }
124
125            #[inline(never)]
126            fn reduce_xor(&self) -> Option<Self::Scalar> {
127                if !self.has_nulls() {
128                    self.values_iter().copied().map($to_bits).reduce(|a, b| a ^ b).map($from_bits)
129                } else {
130                    self.non_null_values_iter().map($to_bits).reduce(|a, b| a ^ b).map($from_bits)
131                }
132            }
133
134            fn bit_and(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
135                $from_bits($to_bits(lhs) & $to_bits(rhs))
136            }
137            fn bit_or(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
138                $from_bits($to_bits(lhs) | $to_bits(rhs))
139            }
140            fn bit_xor(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
141                $from_bits($to_bits(lhs) ^ $to_bits(rhs))
142            }
143        }
144        )+
145    };
146}
147
148impl_bitwise_kernel! {
149    (i8, identity, identity),
150    (i16, identity, identity),
151    (i32, identity, identity),
152    (i64, identity, identity),
153    (u8, identity, identity),
154    (u16, identity, identity),
155    (u32, identity, identity),
156    (u64, identity, identity),
157    (f32, f32::to_bits, f32::from_bits),
158    (f64, f64::to_bits, f64::from_bits),
159}
160
161#[cfg(feature = "dtype-i128")]
162impl_bitwise_kernel! {
163    (i128, identity, identity),
164}
165
166impl BitwiseKernel for BooleanArray {
167    type Scalar = bool;
168
169    #[inline(never)]
170    fn count_ones(&self) -> PrimitiveArray<u32> {
171        PrimitiveArray::new(
172            ArrowDataType::UInt32,
173            self.values_iter()
174                .map(u32::from)
175                .collect_trusted::<Vec<_>>()
176                .into(),
177            self.validity().cloned(),
178        )
179    }
180
181    #[inline(never)]
182    fn count_zeros(&self) -> PrimitiveArray<u32> {
183        PrimitiveArray::new(
184            ArrowDataType::UInt32,
185            self.values_iter()
186                .map(|v| u32::from(!v))
187                .collect_trusted::<Vec<_>>()
188                .into(),
189            self.validity().cloned(),
190        )
191    }
192
193    #[inline(always)]
194    fn leading_ones(&self) -> PrimitiveArray<u32> {
195        self.count_ones()
196    }
197
198    #[inline(always)]
199    fn leading_zeros(&self) -> PrimitiveArray<u32> {
200        self.count_zeros()
201    }
202
203    #[inline(always)]
204    fn trailing_ones(&self) -> PrimitiveArray<u32> {
205        self.count_ones()
206    }
207
208    #[inline(always)]
209    fn trailing_zeros(&self) -> PrimitiveArray<u32> {
210        self.count_zeros()
211    }
212
213    fn reduce_and(&self) -> Option<Self::Scalar> {
214        if self.len() == self.null_count() {
215            None
216        } else if !self.has_nulls() {
217            Some(self.values().unset_bits() == 0)
218        } else {
219            let false_found = binary_fold(
220                self.values(),
221                self.validity().unwrap(),
222                |lhs, rhs| (!lhs & rhs) != 0,
223                false,
224                |a, b| a || b,
225            );
226            Some(!false_found)
227        }
228    }
229
230    fn reduce_or(&self) -> Option<Self::Scalar> {
231        if self.len() == self.null_count() {
232            None
233        } else if !self.has_nulls() {
234            Some(self.values().set_bits() > 0)
235        } else {
236            Some(intersects_with(self.values(), self.validity().unwrap()))
237        }
238    }
239
240    fn reduce_xor(&self) -> Option<Self::Scalar> {
241        if self.len() == self.null_count() {
242            None
243        } else if !self.has_nulls() {
244            Some(self.values().set_bits() % 2 == 1)
245        } else {
246            let nonnull_parity = binary_fold(
247                self.values(),
248                self.validity().unwrap(),
249                |lhs, rhs| lhs & rhs,
250                0,
251                |a, b| a ^ b,
252            );
253            Some(nonnull_parity.count_ones() % 2 == 1)
254        }
255    }
256
257    fn bit_and(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
258        lhs & rhs
259    }
260    fn bit_or(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
261        lhs | rhs
262    }
263    fn bit_xor(lhs: Self::Scalar, rhs: Self::Scalar) -> Self::Scalar {
264        lhs ^ rhs
265    }
266}