polars_compute/
arity.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2use arrow::array::PrimitiveArray;
3use arrow::compute::utils::combine_validities_and;
4use arrow::types::NativeType;
5
6/// To reduce codegen we use these helpers where the input and output arrays
7/// may overlap. These are marked to never be inlined, this way only a single
8/// unrolled kernel gets generated, even if we call it in multiple ways.
9///
10/// # Safety
11///  - arr must point to a readable slice of length len.
12///  - out must point to a writable slice of length len.
13#[inline(never)]
14unsafe fn ptr_apply_unary_kernel<I: Copy, O, F: Fn(I) -> O>(
15    arr: *const I,
16    out: *mut O,
17    len: usize,
18    op: F,
19) {
20    for i in 0..len {
21        let ret = op(arr.add(i).read());
22        out.add(i).write(ret);
23    }
24}
25
26/// # Safety
27///  - left must point to a readable slice of length len.
28///  - right must point to a readable slice of length len.
29///  - out must point to a writable slice of length len.
30#[inline(never)]
31unsafe fn ptr_apply_binary_kernel<L: Copy, R: Copy, O, F: Fn(L, R) -> O>(
32    left: *const L,
33    right: *const R,
34    out: *mut O,
35    len: usize,
36    op: F,
37) {
38    for i in 0..len {
39        let ret = op(left.add(i).read(), right.add(i).read());
40        out.add(i).write(ret);
41    }
42}
43
44/// Applies a function to all the values (regardless of nullability).
45///
46/// May reuse the memory of the array if possible.
47pub fn prim_unary_values<I, O, F>(mut arr: PrimitiveArray<I>, op: F) -> PrimitiveArray<O>
48where
49    I: NativeType,
50    O: NativeType,
51    F: Fn(I) -> O,
52{
53    let len = arr.len();
54
55    // Reuse memory if possible.
56    if size_of::<I>() == size_of::<O>() && align_of::<I>() == align_of::<O>() {
57        if let Some(values) = arr.get_mut_values() {
58            let ptr = values.as_mut_ptr();
59            // SAFETY: checked same size & alignment I/O, NativeType is always Pod.
60            unsafe { ptr_apply_unary_kernel(ptr, ptr as *mut O, len, op) }
61            return arr.transmute::<O>();
62        }
63    }
64
65    let mut out = Vec::with_capacity(len);
66    unsafe {
67        // SAFETY: checked pointers point to slices of length len.
68        ptr_apply_unary_kernel(arr.values().as_ptr(), out.as_mut_ptr(), len, op);
69        out.set_len(len);
70    }
71    PrimitiveArray::from_vec(out).with_validity(arr.take_validity())
72}
73
74/// Apply a binary function to all the values (regardless of nullability)
75/// in (lhs, rhs). Combines the validities with a bitand.
76///
77/// May reuse the memory of one of its arguments if possible.
78pub fn prim_binary_values<L, R, O, F>(
79    mut lhs: PrimitiveArray<L>,
80    mut rhs: PrimitiveArray<R>,
81    op: F,
82) -> PrimitiveArray<O>
83where
84    L: NativeType,
85    R: NativeType,
86    O: NativeType,
87    F: Fn(L, R) -> O,
88{
89    assert_eq!(lhs.len(), rhs.len());
90    let len = lhs.len();
91
92    let validity = combine_validities_and(lhs.validity(), rhs.validity());
93
94    // Reuse memory if possible.
95    if size_of::<L>() == size_of::<O>() && align_of::<L>() == align_of::<O>() {
96        if let Some(lv) = lhs.get_mut_values() {
97            let lp = lv.as_mut_ptr();
98            let rp = rhs.values().as_ptr();
99            unsafe {
100                // SAFETY: checked same size & alignment L/O, NativeType is always Pod.
101                ptr_apply_binary_kernel(lp, rp, lp as *mut O, len, op);
102            }
103            return lhs.transmute::<O>().with_validity(validity);
104        }
105    }
106    if size_of::<R>() == size_of::<O>() && align_of::<R>() == align_of::<O>() {
107        if let Some(rv) = rhs.get_mut_values() {
108            let lp = lhs.values().as_ptr();
109            let rp = rv.as_mut_ptr();
110            unsafe {
111                // SAFETY: checked same size & alignment R/O, NativeType is always Pod.
112                ptr_apply_binary_kernel(lp, rp, rp as *mut O, len, op);
113            }
114            return rhs.transmute::<O>().with_validity(validity);
115        }
116    }
117
118    let mut out = Vec::with_capacity(len);
119    unsafe {
120        // SAFETY: checked pointers point to slices of length len.
121        let lp = lhs.values().as_ptr();
122        let rp = rhs.values().as_ptr();
123        ptr_apply_binary_kernel(lp, rp, out.as_mut_ptr(), len, op);
124        out.set_len(len);
125    }
126    PrimitiveArray::from_vec(out).with_validity(validity)
127}