polars_row/
decode.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2use arrow::bitmap::{Bitmap, BitmapBuilder};
3use arrow::buffer::Buffer;
4use arrow::datatypes::ArrowDataType;
5use arrow::offset::OffsetsBuffer;
6use arrow::types::NativeType;
7use polars_dtype::categorical::CatNative;
8
9use self::encode::fixed_size;
10use self::row::{RowEncodingCategoricalContext, RowEncodingOptions};
11use self::variable::utf8::decode_str;
12use super::*;
13use crate::fixed::numeric::{FixedLengthEncoding, FromSlice};
14use crate::fixed::{boolean, decimal, numeric};
15use crate::variable::{binary, no_order, utf8};
16
17/// Decode `rows` into a arrow format
18/// # Safety
19/// This will not do any bound checks. Caller must ensure the `rows` are valid
20/// encodings.
21pub unsafe fn decode_rows_from_binary<'a>(
22    arr: &'a BinaryArray<i64>,
23    opts: &[RowEncodingOptions],
24    dicts: &[Option<RowEncodingContext>],
25    dtypes: &[ArrowDataType],
26    rows: &mut Vec<&'a [u8]>,
27) -> Vec<ArrayRef> {
28    assert_eq!(arr.null_count(), 0);
29    rows.clear();
30    rows.extend(arr.values_iter());
31    decode_rows(rows, opts, dicts, dtypes)
32}
33
34/// Decode `rows` into a arrow format
35/// # Safety
36/// This will not do any bound checks. Caller must ensure the `rows` are valid
37/// encodings.
38pub unsafe fn decode_rows(
39    // the rows will be updated while the data is decoded
40    rows: &mut [&[u8]],
41    opts: &[RowEncodingOptions],
42    dicts: &[Option<RowEncodingContext>],
43    dtypes: &[ArrowDataType],
44) -> Vec<ArrayRef> {
45    assert_eq!(opts.len(), dtypes.len());
46    assert_eq!(dicts.len(), dtypes.len());
47
48    dtypes
49        .iter()
50        .zip(opts)
51        .zip(dicts)
52        .map(|((dtype, opt), dict)| decode(rows, *opt, dict.as_ref(), dtype))
53        .collect()
54}
55
56unsafe fn decode_validity(rows: &mut [&[u8]], opt: RowEncodingOptions) -> Option<Bitmap> {
57    // 2 loop system to avoid the overhead of allocating the bitmap if all the elements are valid.
58
59    let null_sentinel = opt.null_sentinel();
60    let first_null = (0..rows.len()).find(|&i| {
61        let v;
62        (v, rows[i]) = rows[i].split_at_unchecked(1);
63        v[0] == null_sentinel
64    });
65
66    // No nulls just return None
67    let first_null = first_null?;
68
69    let mut bm = BitmapBuilder::new();
70    bm.reserve(rows.len());
71    bm.extend_constant(first_null, true);
72    bm.push(false);
73    bm.extend_trusted_len_iter(rows[first_null + 1..].iter_mut().map(|row| {
74        let v;
75        (v, *row) = row.split_at_unchecked(1);
76        v[0] != null_sentinel
77    }));
78    bm.into_opt_validity()
79}
80
81// We inline this in an attempt to avoid the dispatch cost.
82#[inline(always)]
83fn dtype_and_data_to_encoded_item_len(
84    dtype: &ArrowDataType,
85    data: &[u8],
86    opt: RowEncodingOptions,
87    dict: Option<&RowEncodingContext>,
88) -> usize {
89    // Fast path: if the size is fixed, we can just divide.
90    if let Some(size) = fixed_size(dtype, opt, dict) {
91        return size;
92    }
93
94    use ArrowDataType as D;
95    match dtype {
96        D::Binary | D::LargeBinary | D::BinaryView | D::Utf8 | D::LargeUtf8 | D::Utf8View
97            if opt.contains(RowEncodingOptions::NO_ORDER) =>
98        unsafe { no_order::len_from_buffer(data, opt) },
99        D::Binary | D::LargeBinary | D::BinaryView => unsafe {
100            binary::encoded_item_len(data, opt)
101        },
102        D::Utf8 | D::LargeUtf8 | D::Utf8View => unsafe { utf8::len_from_buffer(data, opt) },
103
104        D::List(list_field) | D::LargeList(list_field) => {
105            let mut data = data;
106            let mut item_len = 0;
107
108            let list_continuation_token = opt.list_continuation_token();
109
110            while data[0] == list_continuation_token {
111                data = &data[1..];
112                let len = dtype_and_data_to_encoded_item_len(list_field.dtype(), data, opt, dict);
113                data = &data[len..];
114                item_len += 1 + len;
115            }
116            1 + item_len
117        },
118
119        D::FixedSizeBinary(_) => todo!(),
120        D::FixedSizeList(fsl_field, width) => {
121            let mut data = &data[1..];
122            let mut item_len = 1; // validity byte
123
124            for _ in 0..*width {
125                let len = dtype_and_data_to_encoded_item_len(
126                    fsl_field.dtype(),
127                    data,
128                    opt.into_nested(),
129                    dict,
130                );
131                data = &data[len..];
132                item_len += len;
133            }
134            item_len
135        },
136        D::Struct(struct_fields) => {
137            let mut data = &data[1..];
138            let mut item_len = 1; // validity byte
139
140            for struct_field in struct_fields {
141                let len = dtype_and_data_to_encoded_item_len(
142                    struct_field.dtype(),
143                    data,
144                    opt.into_nested(),
145                    dict,
146                );
147                data = &data[len..];
148                item_len += len;
149            }
150            item_len
151        },
152
153        D::Union(_) => todo!(),
154        D::Map(_, _) => todo!(),
155        D::Decimal32(_, _) => todo!(),
156        D::Decimal64(_, _) => todo!(),
157        D::Decimal256(_, _) => todo!(),
158        D::Extension(_) => todo!(),
159        D::Unknown => todo!(),
160
161        _ => unreachable!(),
162    }
163}
164
165fn rows_for_fixed_size_list<'a>(
166    dtype: &ArrowDataType,
167    opt: RowEncodingOptions,
168    dict: Option<&RowEncodingContext>,
169    width: usize,
170    rows: &mut [&'a [u8]],
171    nested_rows: &mut Vec<&'a [u8]>,
172) {
173    nested_rows.clear();
174    nested_rows.reserve(rows.len() * width);
175
176    // Fast path: if the size is fixed, we can just divide.
177    if let Some(size) = fixed_size(dtype, opt, dict) {
178        for row in rows.iter_mut() {
179            for i in 0..width {
180                nested_rows.push(&row[(i * size)..][..size]);
181            }
182            *row = &row[size * width..];
183        }
184        return;
185    }
186
187    // @TODO: This is quite slow since we need to dispatch for possibly every nested type
188    for row in rows.iter_mut() {
189        for _ in 0..width {
190            let length = dtype_and_data_to_encoded_item_len(dtype, row, opt.into_nested(), dict);
191            let v;
192            (v, *row) = row.split_at(length);
193            nested_rows.push(v);
194        }
195    }
196}
197
198unsafe fn decode_cat<T: NativeType + FixedLengthEncoding + CatNative>(
199    rows: &mut [&[u8]],
200    opt: RowEncodingOptions,
201    ctx: &RowEncodingCategoricalContext,
202) -> PrimitiveArray<T>
203where
204    T::Encoded: FromSlice,
205{
206    if ctx.is_enum || !opt.is_ordered() {
207        numeric::decode_primitive::<T>(rows, opt)
208    } else {
209        variable::utf8::decode_str_as_cat::<T>(rows, opt, &ctx.mapping)
210    }
211}
212
213unsafe fn decode(
214    rows: &mut [&[u8]],
215    opt: RowEncodingOptions,
216    dict: Option<&RowEncodingContext>,
217    dtype: &ArrowDataType,
218) -> ArrayRef {
219    use ArrowDataType as D;
220
221    if let Some(RowEncodingContext::Categorical(ctx)) = dict {
222        return match dtype {
223            D::UInt8 => decode_cat::<u8>(rows, opt, ctx).to_boxed(),
224            D::UInt16 => decode_cat::<u16>(rows, opt, ctx).to_boxed(),
225            D::UInt32 => decode_cat::<u32>(rows, opt, ctx).to_boxed(),
226            _ => unreachable!(),
227        };
228    }
229
230    match dtype {
231        D::Null => NullArray::new(D::Null, rows.len()).to_boxed(),
232        D::Boolean => boolean::decode_bool(rows, opt).to_boxed(),
233        D::Binary | D::LargeBinary | D::BinaryView | D::Utf8 | D::LargeUtf8 | D::Utf8View
234            if opt.contains(RowEncodingOptions::NO_ORDER) =>
235        {
236            let array = no_order::decode_variable_no_order(rows, opt);
237
238            if matches!(dtype, D::Utf8 | D::LargeUtf8 | D::Utf8View) {
239                unsafe { array.to_utf8view_unchecked() }.to_boxed()
240            } else {
241                array.to_boxed()
242            }
243        },
244        D::Binary | D::LargeBinary | D::BinaryView => binary::decode_binview(rows, opt).to_boxed(),
245        D::Utf8 | D::LargeUtf8 | D::Utf8View => decode_str(rows, opt).boxed(),
246
247        D::Struct(fields) => {
248            let validity = decode_validity(rows, opt);
249
250            let values = match dict {
251                None => fields
252                    .iter()
253                    .map(|struct_fld| decode(rows, opt.into_nested(), None, struct_fld.dtype()))
254                    .collect(),
255                Some(RowEncodingContext::Struct(dicts)) => fields
256                    .iter()
257                    .zip(dicts)
258                    .map(|(struct_fld, dict)| {
259                        decode(rows, opt.into_nested(), dict.as_ref(), struct_fld.dtype())
260                    })
261                    .collect(),
262                _ => unreachable!(),
263            };
264            StructArray::new(dtype.clone(), rows.len(), values, validity).to_boxed()
265        },
266        D::FixedSizeList(fsl_field, width) => {
267            let validity = decode_validity(rows, opt);
268
269            // @TODO: we could consider making this into a scratchpad
270            let mut nested_rows = Vec::new();
271            rows_for_fixed_size_list(
272                fsl_field.dtype(),
273                opt.into_nested(),
274                dict,
275                *width,
276                rows,
277                &mut nested_rows,
278            );
279
280            let values = decode(&mut nested_rows, opt.into_nested(), dict, fsl_field.dtype());
281
282            FixedSizeListArray::new(dtype.clone(), rows.len(), values, validity).to_boxed()
283        },
284        D::List(list_field) | D::LargeList(list_field) => {
285            let mut validity = BitmapBuilder::new();
286
287            // @TODO: we could consider making this into a scratchpad
288            let num_rows = rows.len();
289            let mut nested_rows = Vec::new();
290            let mut offsets = Vec::with_capacity(rows.len() + 1);
291            offsets.push(0);
292
293            let list_null_sentinel = opt.list_null_sentinel();
294            let list_continuation_token = opt.list_continuation_token();
295            let list_termination_token = opt.list_termination_token();
296
297            // @TODO: make a specialized loop for fixed size list_field.dtype()
298            for (i, row) in rows.iter_mut().enumerate() {
299                while row[0] == list_continuation_token {
300                    *row = &row[1..];
301                    let len = dtype_and_data_to_encoded_item_len(
302                        list_field.dtype(),
303                        row,
304                        opt.into_nested(),
305                        dict,
306                    );
307                    nested_rows.push(&row[..len]);
308                    *row = &row[len..];
309                }
310
311                offsets.push(nested_rows.len() as i64);
312
313                // @TODO: Might be better to make this a 2-loop system.
314                if row[0] == list_null_sentinel {
315                    *row = &row[1..];
316                    validity.reserve(num_rows);
317                    validity.extend_constant(i - validity.len(), true);
318                    validity.push(false);
319                    continue;
320                }
321
322                assert_eq!(row[0], list_termination_token);
323                *row = &row[1..];
324            }
325
326            let validity = if validity.is_empty() {
327                None
328            } else {
329                validity.extend_constant(num_rows - validity.len(), true);
330                validity.into_opt_validity()
331            };
332            assert_eq!(offsets.len(), rows.len() + 1);
333
334            let values = decode(
335                &mut nested_rows,
336                opt.into_nested(),
337                dict,
338                list_field.dtype(),
339            );
340
341            ListArray::<i64>::new(
342                dtype.clone(),
343                unsafe { OffsetsBuffer::new_unchecked(Buffer::from(offsets)) },
344                values,
345                validity,
346            )
347            .to_boxed()
348        },
349
350        dt => {
351            if matches!(dt, D::Int128) {
352                if let Some(dict) = dict {
353                    return match dict {
354                        RowEncodingContext::Decimal(precision) => {
355                            decimal::decode(rows, opt, *precision).to_boxed()
356                        },
357                        _ => unreachable!(),
358                    };
359                }
360            }
361
362            with_match_arrow_primitive_type!(dt, |$T| {
363                numeric::decode_primitive::<$T>(rows, opt).to_boxed()
364            })
365        },
366    }
367}