polars_compute/gather/
mod.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2// Licensed to the Apache Software Foundation (ASF) under one
3// or more contributor license agreements.  See the NOTICE file
4// distributed with this work for additional information
5// regarding copyright ownership.  The ASF licenses this file
6// to you under the Apache License, Version 2.0 (the
7// "License"); you may not use this file except in compliance
8// with the License.  You may obtain a copy of the License at
9//
10//   http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing,
13// software distributed under the License is distributed on an
14// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15// KIND, either express or implied.  See the License for the
16// specific language governing permissions and limitations
17// under the License.
18
19//! Defines take kernel for [`Array`]
20
21use arrow::array::{
22    self, Array, ArrayCollectIterExt, ArrayFromIterDtype, BinaryViewArray, NullArray, StaticArray,
23    Utf8ViewArray, new_empty_array,
24};
25use arrow::datatypes::{ArrowDataType, IdxArr};
26use arrow::types::Index;
27
28pub mod binary;
29pub mod binview;
30pub mod bitmap;
31pub mod boolean;
32pub mod fixed_size_list;
33pub mod generic_binary;
34pub mod list;
35pub mod primitive;
36pub mod structure;
37pub mod sublist;
38
39use arrow::with_match_primitive_type_full;
40
41/// Returns a new [`Array`] with only indices at `indices`. Null indices are taken as nulls.
42/// The returned array has a length equal to `indices.len()`.
43/// # Safety
44/// Doesn't do bound checks
45pub unsafe fn take_unchecked(values: &dyn Array, indices: &IdxArr) -> Box<dyn Array> {
46    if indices.len() == 0 {
47        return new_empty_array(values.dtype().clone());
48    }
49
50    use arrow::datatypes::PhysicalType::*;
51    match values.dtype().to_physical_type() {
52        Null => Box::new(NullArray::new(values.dtype().clone(), indices.len())),
53        Boolean => {
54            let values = values.as_any().downcast_ref().unwrap();
55            Box::new(boolean::take_unchecked(values, indices))
56        },
57        Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| {
58            let values = values.as_any().downcast_ref().unwrap();
59            Box::new(primitive::take_primitive_unchecked::<$T>(&values, indices))
60        }),
61        LargeBinary => {
62            let values = values.as_any().downcast_ref().unwrap();
63            Box::new(binary::take_unchecked::<i64, _>(values, indices))
64        },
65        Struct => {
66            let array = values.as_any().downcast_ref().unwrap();
67            structure::take_unchecked(array, indices).boxed()
68        },
69        LargeList => {
70            let array = values.as_any().downcast_ref().unwrap();
71            Box::new(list::take_unchecked::<i64>(array, indices))
72        },
73        FixedSizeList => {
74            let array = values.as_any().downcast_ref().unwrap();
75            fixed_size_list::take_unchecked(array, indices)
76        },
77        BinaryView => {
78            let array: &BinaryViewArray = values.as_any().downcast_ref().unwrap();
79            binview::take_binview_unchecked(array, indices).boxed()
80        },
81        Utf8View => {
82            let array: &Utf8ViewArray = values.as_any().downcast_ref().unwrap();
83            binview::take_binview_unchecked(array, indices).boxed()
84        },
85        t => unimplemented!("Take not supported for data type {:?}", t),
86    }
87}
88
89/// Naive default implementation
90unsafe fn take_unchecked_impl_generic<T>(
91    values: &T,
92    indices: &IdxArr,
93    new_null_func: &dyn Fn(ArrowDataType, usize) -> T,
94) -> T
95where
96    T: StaticArray + ArrayFromIterDtype<std::option::Option<Box<dyn array::Array>>>,
97{
98    if values.null_count() == values.len() || indices.null_count() == indices.len() {
99        return new_null_func(values.dtype().clone(), indices.len());
100    }
101
102    match (indices.has_nulls(), values.has_nulls()) {
103        (true, true) => {
104            let values_validity = values.validity().unwrap();
105
106            indices
107                .iter()
108                .map(|i| {
109                    if let Some(i) = i {
110                        let i = *i as usize;
111                        if values_validity.get_bit_unchecked(i) {
112                            return Some(values.value_unchecked(i));
113                        }
114                    }
115                    None
116                })
117                .collect_arr_trusted_with_dtype(values.dtype().clone())
118        },
119        (true, false) => indices
120            .iter()
121            .map(|i| {
122                if let Some(i) = i {
123                    let i = *i as usize;
124                    return Some(values.value_unchecked(i));
125                }
126                None
127            })
128            .collect_arr_trusted_with_dtype(values.dtype().clone()),
129        (false, true) => {
130            let values_validity = values.validity().unwrap();
131
132            indices
133                .values_iter()
134                .map(|i| {
135                    let i = *i as usize;
136                    if values_validity.get_bit_unchecked(i) {
137                        return Some(values.value_unchecked(i));
138                    }
139                    None
140                })
141                .collect_arr_trusted_with_dtype(values.dtype().clone())
142        },
143        (false, false) => indices
144            .values_iter()
145            .map(|i| {
146                let i = *i as usize;
147                Some(values.value_unchecked(i))
148            })
149            .collect_arr_trusted_with_dtype(values.dtype().clone()),
150    }
151}