polars_row/
encode.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2use std::mem::MaybeUninit;
3
4use arrow::array::{
5    Array, BinaryArray, BinaryViewArray, BooleanArray, FixedSizeListArray, ListArray,
6    PrimitiveArray, StructArray, UInt8Array, UInt16Array, UInt32Array, Utf8Array, Utf8ViewArray,
7};
8use arrow::bitmap::Bitmap;
9use arrow::datatypes::ArrowDataType;
10use arrow::types::{NativeType, Offset};
11use polars_dtype::categorical::CatNative;
12
13use crate::fixed::numeric::FixedLengthEncoding;
14use crate::fixed::{boolean, decimal, numeric};
15use crate::row::{RowEncodingOptions, RowsEncoded};
16use crate::variable::{binary, no_order, utf8};
17use crate::widths::RowWidths;
18use crate::{
19    ArrayRef, RowEncodingCategoricalContext, RowEncodingContext, with_match_arrow_primitive_type,
20};
21
22pub fn convert_columns(
23    num_rows: usize,
24    columns: &[ArrayRef],
25    opts: &[RowEncodingOptions],
26    dicts: &[Option<RowEncodingContext>],
27) -> RowsEncoded {
28    let mut rows = RowsEncoded::new(vec![], vec![]);
29    convert_columns_amortized(
30        num_rows,
31        columns,
32        opts.iter().copied().zip(dicts.iter().map(|v| v.as_ref())),
33        &mut rows,
34    );
35    rows
36}
37
38pub fn convert_columns_no_order(
39    num_rows: usize,
40    columns: &[ArrayRef],
41    dicts: &[Option<RowEncodingContext>],
42) -> RowsEncoded {
43    let mut rows = RowsEncoded::new(vec![], vec![]);
44    convert_columns_amortized_no_order(num_rows, columns, dicts, &mut rows);
45    rows
46}
47
48pub fn convert_columns_amortized_no_order(
49    num_rows: usize,
50    columns: &[ArrayRef],
51    dicts: &[Option<RowEncodingContext>],
52    rows: &mut RowsEncoded,
53) {
54    convert_columns_amortized(
55        num_rows,
56        columns,
57        std::iter::repeat_n(RowEncodingOptions::default(), columns.len())
58            .zip(dicts.iter().map(|v| v.as_ref())),
59        rows,
60    );
61}
62
63pub fn convert_columns_amortized<'a>(
64    num_rows: usize,
65    columns: &[ArrayRef],
66    fields: impl IntoIterator<Item = (RowEncodingOptions, Option<&'a RowEncodingContext>)> + Clone,
67    rows: &mut RowsEncoded,
68) {
69    let mut masked_out_max_length = 0;
70    let mut row_widths = RowWidths::new(num_rows);
71    let mut encoders = columns
72        .iter()
73        .zip(fields.clone())
74        .map(|(column, (opt, dicts))| {
75            get_encoder(
76                column.as_ref(),
77                opt,
78                dicts,
79                &mut row_widths,
80                &mut masked_out_max_length,
81            )
82        })
83        .collect::<Vec<_>>();
84
85    // Create an offsets array, we append 0 at the beginning here so it can serve as the final
86    // offset array.
87    let mut offsets = Vec::with_capacity(num_rows + 1);
88    offsets.push(0);
89    row_widths.extend_with_offsets(&mut offsets);
90
91    // Create a buffer without initializing everything to zero.
92    let total_num_bytes = row_widths.sum();
93    let mut out = Vec::<u8>::with_capacity(total_num_bytes + masked_out_max_length);
94    let buffer = &mut out.spare_capacity_mut()[..total_num_bytes + masked_out_max_length];
95
96    let masked_out_write_offset = total_num_bytes;
97    let mut scratches = EncodeScratches::default();
98    for (encoder, (opt, dict)) in encoders.iter_mut().zip(fields) {
99        unsafe {
100            encode_array(
101                buffer,
102                encoder,
103                opt,
104                dict,
105                &mut offsets[1..],
106                masked_out_write_offset,
107                &mut scratches,
108            )
109        };
110    }
111    // SAFETY: All the bytes in out up to total_num_bytes should now be initialized.
112    unsafe {
113        out.set_len(total_num_bytes);
114    }
115
116    *rows = RowsEncoded {
117        values: out,
118        offsets,
119    };
120}
121
122fn list_num_column_bytes<O: Offset>(
123    array: &dyn Array,
124    opt: RowEncodingOptions,
125    dicts: Option<&RowEncodingContext>,
126    row_widths: &mut RowWidths,
127    masked_out_max_width: &mut usize,
128) -> Encoder {
129    let array = array.as_any().downcast_ref::<ListArray<O>>().unwrap();
130    let values = array.values();
131
132    let mut list_row_widths = RowWidths::new(values.len());
133    let encoder = get_encoder(
134        values.as_ref(),
135        opt.into_nested(),
136        dicts,
137        &mut list_row_widths,
138        masked_out_max_width,
139    );
140
141    match array.validity() {
142        None => row_widths.push_iter(array.offsets().offset_and_length_iter().map(
143            |(offset, length)| {
144                let mut sum = 0;
145                for i in offset..offset + length {
146                    sum += list_row_widths.get(i);
147                }
148                1 + length + sum
149            },
150        )),
151        Some(validity) => row_widths.push_iter(
152            array
153                .offsets()
154                .offset_and_length_iter()
155                .zip(validity.iter())
156                .map(|((offset, length), is_valid)| {
157                    if !is_valid {
158                        if length > 0 {
159                            for i in offset..offset + length {
160                                *masked_out_max_width =
161                                    (*masked_out_max_width).max(list_row_widths.get(i));
162                            }
163                        }
164                        return 1;
165                    }
166
167                    let mut sum = 0;
168                    for i in offset..offset + length {
169                        sum += list_row_widths.get(i);
170                    }
171                    1 + length + sum
172                }),
173        ),
174    };
175
176    Encoder {
177        array: array.to_boxed(),
178        state: Some(Box::new(EncoderState::List(
179            Box::new(encoder),
180            list_row_widths,
181        ))),
182    }
183}
184
185fn biniter_num_column_bytes(
186    array: &dyn Array,
187    iter: impl ExactSizeIterator<Item = usize>,
188    validity: Option<&Bitmap>,
189    opt: RowEncodingOptions,
190    row_widths: &mut RowWidths,
191) -> Encoder {
192    if opt.contains(RowEncodingOptions::NO_ORDER) {
193        match validity {
194            None => row_widths.push_iter(iter.map(|v| no_order::len_from_item(Some(v), opt))),
195            Some(validity) => row_widths.push_iter(
196                iter.zip(validity.iter())
197                    .map(|(v, is_valid)| no_order::len_from_item(is_valid.then_some(v), opt)),
198            ),
199        }
200    } else {
201        match validity {
202            None => row_widths.push_iter(
203                iter.map(|v| crate::variable::binary::encoded_len_from_len(Some(v), opt)),
204            ),
205            Some(validity) => row_widths.push_iter(
206                iter.zip(validity.iter())
207                    .map(|(v, is_valid)| binary::encoded_len_from_len(is_valid.then_some(v), opt)),
208            ),
209        }
210    };
211
212    Encoder {
213        array: array.to_boxed(),
214        state: None,
215    }
216}
217
218fn striter_num_column_bytes(
219    array: &dyn Array,
220    iter: impl ExactSizeIterator<Item = usize>,
221    validity: Option<&Bitmap>,
222    opt: RowEncodingOptions,
223    row_widths: &mut RowWidths,
224) -> Encoder {
225    if opt.contains(RowEncodingOptions::NO_ORDER) {
226        match validity {
227            None => row_widths.push_iter(iter.map(|v| no_order::len_from_item(Some(v), opt))),
228            Some(validity) => row_widths.push_iter(
229                iter.zip(validity.iter())
230                    .map(|(v, is_valid)| no_order::len_from_item(is_valid.then_some(v), opt)),
231            ),
232        }
233    } else {
234        match validity {
235            None => row_widths
236                .push_iter(iter.map(|v| crate::variable::utf8::len_from_item(Some(v), opt))),
237            Some(validity) => row_widths.push_iter(
238                iter.zip(validity.iter())
239                    .map(|(v, is_valid)| utf8::len_from_item(is_valid.then_some(v), opt)),
240            ),
241        }
242    };
243
244    Encoder {
245        array: array.to_boxed(),
246        state: None,
247    }
248}
249
250/// Get the encoder for a specific array.
251fn get_encoder(
252    array: &dyn Array,
253    opt: RowEncodingOptions,
254    dict: Option<&RowEncodingContext>,
255    row_widths: &mut RowWidths,
256    masked_out_max_width: &mut usize,
257) -> Encoder {
258    use ArrowDataType as D;
259    let dtype = array.dtype();
260
261    // Fast path: column has a fixed size encoding
262    if let Some(size) = fixed_size(dtype, opt, dict) {
263        row_widths.push_constant(size);
264        let state = match dtype {
265            D::FixedSizeList(_, width) => {
266                let array = array.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
267
268                debug_assert_eq!(array.values().len(), array.len() * width);
269                let mut nested_row_widths = RowWidths::new(array.values().len());
270                let nested_encoder = get_encoder(
271                    array.values().as_ref(),
272                    opt.into_nested(),
273                    dict,
274                    &mut nested_row_widths,
275                    masked_out_max_width,
276                );
277                Some(EncoderState::FixedSizeList(
278                    Box::new(nested_encoder),
279                    *width,
280                    nested_row_widths,
281                ))
282            },
283            D::Struct(_) => {
284                let struct_array = array.as_any().downcast_ref::<StructArray>().unwrap();
285
286                Some(EncoderState::Struct(match dict {
287                    None => struct_array
288                        .values()
289                        .iter()
290                        .map(|array| {
291                            get_encoder(
292                                array.as_ref(),
293                                opt.into_nested(),
294                                None,
295                                &mut RowWidths::new(row_widths.num_rows()),
296                                masked_out_max_width,
297                            )
298                        })
299                        .collect(),
300                    Some(RowEncodingContext::Struct(dicts)) => struct_array
301                        .values()
302                        .iter()
303                        .zip(dicts)
304                        .map(|(array, dict)| {
305                            get_encoder(
306                                array.as_ref(),
307                                opt,
308                                dict.as_ref(),
309                                &mut RowWidths::new(row_widths.num_rows()),
310                                masked_out_max_width,
311                            )
312                        })
313                        .collect(),
314                    _ => unreachable!(),
315                }))
316            },
317            _ => None,
318        };
319
320        let state = state.map(Box::new);
321        return Encoder {
322            array: array.to_boxed(),
323            state,
324        };
325    }
326
327    // Non-fixed-size categorical path.
328    if let Some(RowEncodingContext::Categorical(ctx)) = dict {
329        match dtype {
330            D::UInt8 => {
331                assert!(opt.is_ordered() && !ctx.is_enum);
332                let dc_array = array.as_any().downcast_ref::<UInt8Array>().unwrap();
333                return striter_num_column_bytes(
334                    array,
335                    dc_array.values_iter().map(|cat| {
336                        ctx.mapping
337                            .cat_to_str(cat.as_cat())
338                            .map(|s| s.len())
339                            .unwrap_or(0)
340                    }),
341                    dc_array.validity(),
342                    opt,
343                    row_widths,
344                );
345            },
346            D::UInt16 => {
347                assert!(opt.is_ordered() && !ctx.is_enum);
348                let dc_array = array.as_any().downcast_ref::<UInt16Array>().unwrap();
349                return striter_num_column_bytes(
350                    array,
351                    dc_array.values_iter().map(|cat| {
352                        ctx.mapping
353                            .cat_to_str(cat.as_cat())
354                            .map(|s| s.len())
355                            .unwrap_or(0)
356                    }),
357                    dc_array.validity(),
358                    opt,
359                    row_widths,
360                );
361            },
362            D::UInt32 => {
363                assert!(opt.is_ordered() && !ctx.is_enum);
364                let dc_array = array.as_any().downcast_ref::<UInt32Array>().unwrap();
365                return striter_num_column_bytes(
366                    array,
367                    dc_array.values_iter().map(|cat| {
368                        ctx.mapping
369                            .cat_to_str(cat.as_cat())
370                            .map(|s| s.len())
371                            .unwrap_or(0)
372                    }),
373                    dc_array.validity(),
374                    opt,
375                    row_widths,
376                );
377            },
378            _ => {
379                // Fall through to below, should be nested type containing categorical.
380                debug_assert!(dtype.is_nested())
381            },
382        }
383    }
384
385    match dtype {
386        D::FixedSizeList(_, width) => {
387            let array = array.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
388
389            debug_assert_eq!(array.values().len(), array.len() * width);
390            let mut nested_row_widths = RowWidths::new(array.values().len());
391            let nested_encoder = get_encoder(
392                array.values().as_ref(),
393                opt.into_nested(),
394                dict,
395                &mut nested_row_widths,
396                masked_out_max_width,
397            );
398
399            let mut fsl_row_widths = nested_row_widths.collapse_chunks(*width, array.len());
400            fsl_row_widths.push_constant(1); // validity byte
401
402            row_widths.push(&fsl_row_widths);
403            Encoder {
404                array: array.to_boxed(),
405                state: Some(Box::new(EncoderState::FixedSizeList(
406                    Box::new(nested_encoder),
407                    *width,
408                    nested_row_widths,
409                ))),
410            }
411        },
412        D::Struct(_) => {
413            let array = array.as_any().downcast_ref::<StructArray>().unwrap();
414
415            let mut nested_encoders = Vec::with_capacity(array.values().len());
416            row_widths.push_constant(1); // validity byte
417            match dict {
418                None => {
419                    for array in array.values() {
420                        let encoder = get_encoder(
421                            array.as_ref(),
422                            opt.into_nested(),
423                            None,
424                            row_widths,
425                            masked_out_max_width,
426                        );
427                        nested_encoders.push(encoder);
428                    }
429                },
430                Some(RowEncodingContext::Struct(dicts)) => {
431                    for (array, dict) in array.values().iter().zip(dicts) {
432                        let encoder = get_encoder(
433                            array.as_ref(),
434                            opt.into_nested(),
435                            dict.as_ref(),
436                            row_widths,
437                            masked_out_max_width,
438                        );
439                        nested_encoders.push(encoder);
440                    }
441                },
442                _ => unreachable!(),
443            }
444            Encoder {
445                array: array.to_boxed(),
446                state: Some(Box::new(EncoderState::Struct(nested_encoders))),
447            }
448        },
449
450        D::List(_) => {
451            list_num_column_bytes::<i32>(array, opt, dict, row_widths, masked_out_max_width)
452        },
453        D::LargeList(_) => {
454            list_num_column_bytes::<i64>(array, opt, dict, row_widths, masked_out_max_width)
455        },
456
457        D::BinaryView => {
458            let dc_array = array.as_any().downcast_ref::<BinaryViewArray>().unwrap();
459            biniter_num_column_bytes(
460                array,
461                dc_array.views().iter().map(|v| v.length as usize),
462                dc_array.validity(),
463                opt,
464                row_widths,
465            )
466        },
467        D::Binary => {
468            let dc_array = array.as_any().downcast_ref::<BinaryArray<i32>>().unwrap();
469            biniter_num_column_bytes(
470                array,
471                dc_array.offsets().lengths(),
472                dc_array.validity(),
473                opt,
474                row_widths,
475            )
476        },
477        D::LargeBinary => {
478            let dc_array = array.as_any().downcast_ref::<BinaryArray<i64>>().unwrap();
479            biniter_num_column_bytes(
480                array,
481                dc_array.offsets().lengths(),
482                dc_array.validity(),
483                opt,
484                row_widths,
485            )
486        },
487
488        D::Utf8View => {
489            let dc_array = array.as_any().downcast_ref::<Utf8ViewArray>().unwrap();
490            striter_num_column_bytes(
491                array,
492                dc_array.views().iter().map(|v| v.length as usize),
493                dc_array.validity(),
494                opt,
495                row_widths,
496            )
497        },
498        D::Utf8 => {
499            let dc_array = array.as_any().downcast_ref::<Utf8Array<i32>>().unwrap();
500            striter_num_column_bytes(
501                array,
502                dc_array.offsets().lengths(),
503                dc_array.validity(),
504                opt,
505                row_widths,
506            )
507        },
508        D::LargeUtf8 => {
509            let dc_array = array.as_any().downcast_ref::<Utf8Array<i64>>().unwrap();
510            striter_num_column_bytes(
511                array,
512                dc_array.offsets().lengths(),
513                dc_array.validity(),
514                opt,
515                row_widths,
516            )
517        },
518
519        D::Union(_) => unreachable!(),
520        D::Map(_, _) => unreachable!(),
521        D::Extension(_) => unreachable!(),
522        D::Unknown => unreachable!(),
523
524        // All non-physical types
525        D::Timestamp(_, _)
526        | D::Date32
527        | D::Date64
528        | D::Time32(_)
529        | D::Time64(_)
530        | D::Duration(_)
531        | D::Interval(_)
532        | D::Dictionary(_, _, _)
533        | D::Decimal(_, _)
534        | D::Decimal32(_, _)
535        | D::Decimal64(_, _)
536        | D::Decimal256(_, _) => unreachable!(),
537
538        // Should be fixed size type
539        _ => unreachable!(),
540    }
541}
542
543struct Encoder {
544    array: Box<dyn Array>,
545
546    /// State contains nested encoders and extra information needed to encode.
547    state: Option<Box<EncoderState>>,
548}
549
550enum EncoderState {
551    List(Box<Encoder>, RowWidths),
552    FixedSizeList(Box<Encoder>, usize, RowWidths),
553    Struct(Vec<Encoder>),
554}
555
556unsafe fn encode_strs<'a>(
557    buffer: &mut [MaybeUninit<u8>],
558    iter: impl Iterator<Item = Option<&'a str>>,
559    opt: RowEncodingOptions,
560    offsets: &mut [usize],
561) {
562    if opt.contains(RowEncodingOptions::NO_ORDER) {
563        no_order::encode_variable_no_order(
564            buffer,
565            iter.map(|v| v.map(str::as_bytes)),
566            opt,
567            offsets,
568        );
569    } else {
570        utf8::encode_str(buffer, iter, opt, offsets);
571    }
572}
573
574unsafe fn encode_bins<'a>(
575    buffer: &mut [MaybeUninit<u8>],
576    iter: impl Iterator<Item = Option<&'a [u8]>>,
577    opt: RowEncodingOptions,
578    offsets: &mut [usize],
579) {
580    if opt.contains(RowEncodingOptions::NO_ORDER) {
581        no_order::encode_variable_no_order(buffer, iter, opt, offsets);
582    } else {
583        binary::encode_iter(buffer, iter, opt, offsets);
584    }
585}
586
587unsafe fn encode_cat_array<T: NativeType + FixedLengthEncoding + CatNative>(
588    buffer: &mut [MaybeUninit<u8>],
589    keys: &PrimitiveArray<T>,
590    opt: RowEncodingOptions,
591    ctx: &RowEncodingCategoricalContext,
592    offsets: &mut [usize],
593) {
594    if ctx.is_enum || !opt.is_ordered() {
595        numeric::encode(buffer, keys, opt, offsets);
596    } else {
597        utf8::encode_str(
598            buffer,
599            keys.iter()
600                .map(|k| k.map(|&cat| ctx.mapping.cat_to_str_unchecked(cat.as_cat()))),
601            opt,
602            offsets,
603        );
604    }
605}
606
607unsafe fn encode_flat_array(
608    buffer: &mut [MaybeUninit<u8>],
609    array: &dyn Array,
610    opt: RowEncodingOptions,
611    dict: Option<&RowEncodingContext>,
612    offsets: &mut [usize],
613) {
614    use ArrowDataType as D;
615
616    if let Some(RowEncodingContext::Categorical(ctx)) = dict {
617        match array.dtype() {
618            D::UInt8 => {
619                let keys = array.as_any().downcast_ref::<PrimitiveArray<u8>>().unwrap();
620                encode_cat_array(buffer, keys, opt, ctx, offsets);
621            },
622            D::UInt16 => {
623                let keys = array
624                    .as_any()
625                    .downcast_ref::<PrimitiveArray<u16>>()
626                    .unwrap();
627                encode_cat_array(buffer, keys, opt, ctx, offsets);
628            },
629            D::UInt32 => {
630                let keys = array
631                    .as_any()
632                    .downcast_ref::<PrimitiveArray<u32>>()
633                    .unwrap();
634                encode_cat_array(buffer, keys, opt, ctx, offsets);
635            },
636            _ => unreachable!(),
637        };
638        return;
639    }
640
641    match array.dtype() {
642        D::Null => {},
643        D::Boolean => {
644            let array = array.as_any().downcast_ref::<BooleanArray>().unwrap();
645            boolean::encode_bool(buffer, array.iter(), opt, offsets);
646        },
647
648        dt if dt.is_numeric() => {
649            if matches!(dt, D::Int128) {
650                if let Some(RowEncodingContext::Decimal(precision)) = dict {
651                    decimal::encode(
652                        buffer,
653                        array
654                            .as_any()
655                            .downcast_ref::<PrimitiveArray<i128>>()
656                            .unwrap(),
657                        opt,
658                        offsets,
659                        *precision,
660                    );
661                    return;
662                }
663            }
664
665            with_match_arrow_primitive_type!(dt, |$T| {
666                let array = array.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap();
667                numeric::encode(buffer, array, opt, offsets);
668            })
669        },
670
671        D::Binary => {
672            let array = array.as_any().downcast_ref::<BinaryArray<i32>>().unwrap();
673            encode_bins(buffer, array.iter(), opt, offsets);
674        },
675        D::LargeBinary => {
676            let array = array.as_any().downcast_ref::<BinaryArray<i64>>().unwrap();
677            encode_bins(buffer, array.iter(), opt, offsets);
678        },
679        D::BinaryView => {
680            let array = array.as_any().downcast_ref::<BinaryViewArray>().unwrap();
681            encode_bins(buffer, array.iter(), opt, offsets);
682        },
683        D::Utf8 => {
684            let array = array.as_any().downcast_ref::<Utf8Array<i32>>().unwrap();
685            encode_strs(buffer, array.iter(), opt, offsets);
686        },
687        D::LargeUtf8 => {
688            let array = array.as_any().downcast_ref::<Utf8Array<i64>>().unwrap();
689            encode_strs(buffer, array.iter(), opt, offsets);
690        },
691        D::Utf8View => {
692            let array = array.as_any().downcast_ref::<Utf8ViewArray>().unwrap();
693            encode_strs(buffer, array.iter(), opt, offsets);
694        },
695
696        // Lexical ordered Categorical are cast to PrimitiveArray above.
697        D::Dictionary(_, _, _) => todo!(),
698
699        D::FixedSizeBinary(_) => todo!(),
700        D::Decimal(_, _) => todo!(),
701        D::Decimal32(_, _) => todo!(),
702        D::Decimal64(_, _) => todo!(),
703        D::Decimal256(_, _) => todo!(),
704
705        D::Union(_) => todo!(),
706        D::Map(_, _) => todo!(),
707        D::Extension(_) => todo!(),
708        D::Unknown => todo!(),
709
710        // All are non-physical types.
711        D::Timestamp(_, _)
712        | D::Date32
713        | D::Date64
714        | D::Time32(_)
715        | D::Time64(_)
716        | D::Duration(_)
717        | D::Interval(_) => unreachable!(),
718
719        _ => unreachable!(),
720    }
721}
722
723#[derive(Default)]
724struct EncodeScratches {
725    nested_offsets: Vec<usize>,
726    nested_buffer: Vec<u8>,
727}
728
729impl EncodeScratches {
730    fn clear(&mut self) {
731        self.nested_offsets.clear();
732        self.nested_buffer.clear();
733    }
734}
735
736unsafe fn encode_array(
737    buffer: &mut [MaybeUninit<u8>],
738    encoder: &Encoder,
739    opt: RowEncodingOptions,
740    dict: Option<&RowEncodingContext>,
741    offsets: &mut [usize],
742    masked_out_write_offset: usize, // Masked out values need to be written somewhere. We just
743    // reserved space at the end and tell all values to write
744    // there.
745    scratches: &mut EncodeScratches,
746) {
747    let Some(state) = &encoder.state else {
748        // This is actually the main path.
749        //
750        // If no nested types or special types are needed, this path is taken.
751        return encode_flat_array(buffer, encoder.array.as_ref(), opt, dict, offsets);
752    };
753
754    match state.as_ref() {
755        EncoderState::List(nested_encoder, nested_row_widths) => {
756            // @TODO: make more general.
757            let array = encoder
758                .array
759                .as_any()
760                .downcast_ref::<ListArray<i64>>()
761                .unwrap();
762
763            scratches.clear();
764
765            scratches
766                .nested_offsets
767                .reserve(nested_row_widths.num_rows());
768            let nested_offsets = &mut scratches.nested_offsets;
769
770            let list_null_sentinel = opt.list_null_sentinel();
771            let list_continuation_token = opt.list_continuation_token();
772            let list_termination_token = opt.list_termination_token();
773
774            match array.validity() {
775                None => {
776                    for (i, (offset, length)) in
777                        array.offsets().offset_and_length_iter().enumerate()
778                    {
779                        for j in offset..offset + length {
780                            buffer[offsets[i]] = MaybeUninit::new(list_continuation_token);
781                            offsets[i] += 1;
782
783                            nested_offsets.push(offsets[i]);
784                            offsets[i] += nested_row_widths.get(j);
785                        }
786                        buffer[offsets[i]] = MaybeUninit::new(list_termination_token);
787                        offsets[i] += 1;
788                    }
789                },
790                Some(validity) => {
791                    for (i, ((offset, length), is_valid)) in array
792                        .offsets()
793                        .offset_and_length_iter()
794                        .zip(validity.iter())
795                        .enumerate()
796                    {
797                        if !is_valid {
798                            buffer[offsets[i]] = MaybeUninit::new(list_null_sentinel);
799                            offsets[i] += 1;
800
801                            // Values might have been masked out.
802                            if length > 0 {
803                                nested_offsets
804                                    .extend(std::iter::repeat_n(masked_out_write_offset, length));
805                            }
806
807                            continue;
808                        }
809
810                        for j in offset..offset + length {
811                            buffer[offsets[i]] = MaybeUninit::new(list_continuation_token);
812                            offsets[i] += 1;
813
814                            nested_offsets.push(offsets[i]);
815                            offsets[i] += nested_row_widths.get(j);
816                        }
817                        buffer[offsets[i]] = MaybeUninit::new(list_termination_token);
818                        offsets[i] += 1;
819                    }
820                },
821            }
822
823            unsafe {
824                encode_array(
825                    buffer,
826                    nested_encoder,
827                    opt.into_nested(),
828                    dict,
829                    nested_offsets,
830                    masked_out_write_offset,
831                    &mut EncodeScratches::default(),
832                )
833            };
834        },
835        EncoderState::FixedSizeList(array, width, nested_row_widths) => {
836            encode_validity(buffer, encoder.array.validity(), opt, offsets);
837
838            if *width == 0 {
839                return;
840            }
841
842            let mut child_offsets = Vec::with_capacity(offsets.len() * width);
843            for (i, offset) in offsets.iter_mut().enumerate() {
844                for j in 0..*width {
845                    child_offsets.push(*offset);
846                    *offset += nested_row_widths.get((i * width) + j);
847                }
848            }
849
850            encode_array(
851                buffer,
852                array.as_ref(),
853                opt.into_nested(),
854                dict,
855                &mut child_offsets,
856                masked_out_write_offset,
857                scratches,
858            );
859            for (i, offset) in offsets.iter_mut().enumerate() {
860                *offset = child_offsets[(i + 1) * width - 1];
861            }
862        },
863        EncoderState::Struct(arrays) => {
864            encode_validity(buffer, encoder.array.validity(), opt, offsets);
865
866            match dict {
867                None => {
868                    for array in arrays {
869                        encode_array(
870                            buffer,
871                            array,
872                            opt.into_nested(),
873                            None,
874                            offsets,
875                            masked_out_write_offset,
876                            scratches,
877                        );
878                    }
879                },
880                Some(RowEncodingContext::Struct(dicts)) => {
881                    for (array, dict) in arrays.iter().zip(dicts) {
882                        encode_array(
883                            buffer,
884                            array,
885                            opt.into_nested(),
886                            dict.as_ref(),
887                            offsets,
888                            masked_out_write_offset,
889                            scratches,
890                        );
891                    }
892                },
893                _ => unreachable!(),
894            }
895        },
896    }
897}
898
899unsafe fn encode_validity(
900    buffer: &mut [MaybeUninit<u8>],
901    validity: Option<&Bitmap>,
902    opt: RowEncodingOptions,
903    row_starts: &mut [usize],
904) {
905    let null_sentinel = opt.null_sentinel();
906    match validity {
907        None => {
908            for row_start in row_starts.iter_mut() {
909                buffer[*row_start] = MaybeUninit::new(1);
910                *row_start += 1;
911            }
912        },
913        Some(validity) => {
914            for (row_start, is_valid) in row_starts.iter_mut().zip(validity.iter()) {
915                let v = if is_valid {
916                    MaybeUninit::new(1)
917                } else {
918                    MaybeUninit::new(null_sentinel)
919                };
920                buffer[*row_start] = v;
921                *row_start += 1;
922            }
923        },
924    }
925}
926
927pub fn fixed_size(
928    dtype: &ArrowDataType,
929    opt: RowEncodingOptions,
930    dict: Option<&RowEncodingContext>,
931) -> Option<usize> {
932    use ArrowDataType as D;
933    use numeric::FixedLengthEncoding;
934
935    if let Some(RowEncodingContext::Categorical(ctx)) = dict {
936        // If ordered categorical (non-enum) we encode strings, otherwise physical.
937        if !ctx.is_enum && opt.is_ordered() {
938            return None;
939        }
940    }
941
942    Some(match dtype {
943        D::Null => 0,
944        D::Boolean => 1,
945
946        D::UInt8 => u8::ENCODED_LEN,
947        D::UInt16 => u16::ENCODED_LEN,
948        D::UInt32 => u32::ENCODED_LEN,
949        D::UInt64 => u64::ENCODED_LEN,
950
951        D::Int8 => i8::ENCODED_LEN,
952        D::Int16 => i16::ENCODED_LEN,
953        D::Int32 => i32::ENCODED_LEN,
954        D::Int64 => i64::ENCODED_LEN,
955        D::Int128 => match dict {
956            None => i128::ENCODED_LEN,
957            Some(RowEncodingContext::Decimal(precision)) => decimal::len_from_precision(*precision),
958            _ => unreachable!(),
959        },
960
961        D::Float32 => f32::ENCODED_LEN,
962        D::Float64 => f64::ENCODED_LEN,
963        D::FixedSizeList(f, width) => 1 + width * fixed_size(f.dtype(), opt, dict)?,
964        D::Struct(fs) => match dict {
965            None => {
966                let mut sum = 0;
967                for f in fs {
968                    sum += fixed_size(f.dtype(), opt, None)?;
969                }
970                1 + sum
971            },
972            Some(RowEncodingContext::Struct(dicts)) => {
973                let mut sum = 0;
974                for (f, dict) in fs.iter().zip(dicts) {
975                    sum += fixed_size(f.dtype(), opt, dict.as_ref())?;
976                }
977                1 + sum
978            },
979            _ => unreachable!(),
980        },
981        _ => return None,
982    })
983}
984
985#[cfg(test)]
986mod tests {
987    use arrow::array::proptest::{
988        ArrayArbitraryOptions, ArrowDataTypeArbitraryOptions, ArrowDataTypeArbitrarySelection,
989        array_with_options,
990    };
991
992    use super::*;
993
994    proptest::prop_compose! {
995        fn arrays
996            ()
997            (length in 0..100usize)
998            (arrays in proptest::collection::vec(array_with_options(length, ArrayArbitraryOptions {
999                dtype: ArrowDataTypeArbitraryOptions {
1000                    allowed_dtypes: ArrowDataTypeArbitrarySelection::all() & !ArrowDataTypeArbitrarySelection::BINARY,
1001                    ..Default::default()
1002                }
1003            }), 1..3))
1004        -> Vec<Box<dyn Array>> {
1005            arrays
1006        }
1007    }
1008
1009    proptest::proptest! {
1010        #[test]
1011        fn test_encode_arrays
1012            (arrays in arrays())
1013         {
1014            let dicts: Vec<Option<RowEncodingContext>> = (0..arrays.len()).map(|_| None).collect();
1015            convert_columns_no_order(arrays[0].len(), &arrays, &dicts);
1016        }
1017    }
1018}