polars_arrow/ffi/
schema.rs

1use std::collections::BTreeMap;
2use std::ffi::{CStr, CString};
3use std::ptr;
4
5use polars_error::{PolarsResult, polars_bail, polars_err};
6use polars_utils::pl_str::PlSmallStr;
7
8use super::ArrowSchema;
9use crate::datatypes::{
10    ArrowDataType, Extension, ExtensionType, Field, IntegerType, IntervalUnit, Metadata, TimeUnit,
11    UnionMode, UnionType,
12};
13
14#[allow(dead_code)]
15struct SchemaPrivateData {
16    name: CString,
17    format: CString,
18    metadata: Option<Vec<u8>>,
19    children_ptr: Box<[*mut ArrowSchema]>,
20    dictionary: Option<*mut ArrowSchema>,
21}
22
23// callback used to drop [ArrowSchema] when it is exported.
24unsafe extern "C" fn c_release_schema(schema: *mut ArrowSchema) {
25    if schema.is_null() {
26        return;
27    }
28    let schema = &mut *schema;
29
30    let private = Box::from_raw(schema.private_data as *mut SchemaPrivateData);
31    for child in private.children_ptr.iter() {
32        let _ = Box::from_raw(*child);
33    }
34
35    if let Some(ptr) = private.dictionary {
36        let _ = Box::from_raw(ptr);
37    }
38
39    schema.release = None;
40}
41
42/// allocate (and hold) the children
43fn schema_children(dtype: &ArrowDataType, flags: &mut i64) -> Box<[*mut ArrowSchema]> {
44    match dtype {
45        ArrowDataType::List(field)
46        | ArrowDataType::FixedSizeList(field, _)
47        | ArrowDataType::LargeList(field) => {
48            Box::new([Box::into_raw(Box::new(ArrowSchema::new(field.as_ref())))])
49        },
50        ArrowDataType::Map(field, is_sorted) => {
51            *flags += (*is_sorted as i64) * 4;
52            Box::new([Box::into_raw(Box::new(ArrowSchema::new(field.as_ref())))])
53        },
54        ArrowDataType::Struct(fields) => fields
55            .iter()
56            .map(|field| Box::into_raw(Box::new(ArrowSchema::new(field))))
57            .collect::<Box<[_]>>(),
58        ArrowDataType::Union(u) => u
59            .fields
60            .iter()
61            .map(|field| Box::into_raw(Box::new(ArrowSchema::new(field))))
62            .collect::<Box<[_]>>(),
63        ArrowDataType::Extension(ext) => schema_children(&ext.inner, flags),
64        _ => Box::new([]),
65    }
66}
67
68impl ArrowSchema {
69    /// creates a new [ArrowSchema]
70    pub(crate) fn new(field: &Field) -> Self {
71        let format = to_format(field.dtype());
72        let name = field.name.clone();
73
74        let mut flags = field.is_nullable as i64 * 2;
75
76        // note: this cannot be done along with the above because the above is fallible and this op leaks.
77        let children_ptr = schema_children(field.dtype(), &mut flags);
78        let n_children = children_ptr.len() as i64;
79
80        let dictionary = if let ArrowDataType::Dictionary(_, values, is_ordered) = field.dtype() {
81            flags += *is_ordered as i64;
82            // we do not store field info in the dict values, so can't recover it all :(
83            let field = Field::new(PlSmallStr::EMPTY, values.as_ref().clone(), true);
84            Some(Box::new(ArrowSchema::new(&field)))
85        } else {
86            None
87        };
88
89        let metadata = field
90            .metadata
91            .as_ref()
92            .map(|inner| (**inner).clone())
93            .unwrap_or_default();
94
95        let metadata = if let ArrowDataType::Extension(ext) = field.dtype() {
96            // append extension information.
97            let mut metadata = metadata;
98
99            // metadata
100            if let Some(extension_metadata) = &ext.metadata {
101                metadata.insert(
102                    PlSmallStr::from_static("ARROW:extension:metadata"),
103                    extension_metadata.clone(),
104                );
105            }
106
107            metadata.insert(
108                PlSmallStr::from_static("ARROW:extension:name"),
109                ext.name.clone(),
110            );
111
112            Some(metadata_to_bytes(&metadata))
113        } else if !metadata.is_empty() {
114            Some(metadata_to_bytes(&metadata))
115        } else {
116            None
117        };
118
119        let name = CString::new(name.as_bytes()).unwrap();
120        let format = CString::new(format).unwrap();
121
122        let mut private = Box::new(SchemaPrivateData {
123            name,
124            format,
125            metadata,
126            children_ptr,
127            dictionary: dictionary.map(Box::into_raw),
128        });
129
130        // <https://arrow.apache.org/docs/format/CDataInterface.html#c.ArrowSchema>
131        Self {
132            format: private.format.as_ptr(),
133            name: private.name.as_ptr(),
134            metadata: private
135                .metadata
136                .as_ref()
137                .map(|x| x.as_ptr())
138                .unwrap_or(std::ptr::null()) as *const ::std::os::raw::c_char,
139            flags,
140            n_children,
141            children: private.children_ptr.as_mut_ptr(),
142            dictionary: private.dictionary.unwrap_or(std::ptr::null_mut()),
143            release: Some(c_release_schema),
144            private_data: Box::into_raw(private) as *mut ::std::os::raw::c_void,
145        }
146    }
147
148    /// create an empty [ArrowSchema]
149    pub fn empty() -> Self {
150        Self {
151            format: std::ptr::null_mut(),
152            name: std::ptr::null_mut(),
153            metadata: std::ptr::null_mut(),
154            flags: 0,
155            n_children: 0,
156            children: ptr::null_mut(),
157            dictionary: std::ptr::null_mut(),
158            release: None,
159            private_data: std::ptr::null_mut(),
160        }
161    }
162
163    pub fn is_null(&self) -> bool {
164        self.private_data.is_null()
165    }
166
167    /// returns the format of this schema.
168    pub(crate) fn format(&self) -> &str {
169        assert!(!self.format.is_null());
170        // safe because the lifetime of `self.format` equals `self`
171        unsafe { CStr::from_ptr(self.format) }
172            .to_str()
173            .expect("The external API has a non-utf8 as format")
174    }
175
176    /// returns the name of this schema.
177    ///
178    /// Since this field is optional, `""` is returned if it is not set (as per the spec).
179    pub(crate) fn name(&self) -> &str {
180        if self.name.is_null() {
181            return "";
182        }
183        // safe because the lifetime of `self.name` equals `self`
184        unsafe { CStr::from_ptr(self.name) }.to_str().unwrap()
185    }
186
187    pub(crate) fn child(&self, index: usize) -> &'static Self {
188        assert!(index < self.n_children as usize);
189        unsafe { self.children.add(index).as_ref().unwrap().as_ref().unwrap() }
190    }
191
192    pub(crate) fn dictionary(&self) -> Option<&'static Self> {
193        if self.dictionary.is_null() {
194            return None;
195        };
196        Some(unsafe { self.dictionary.as_ref().unwrap() })
197    }
198
199    pub(crate) fn nullable(&self) -> bool {
200        (self.flags / 2) & 1 == 1
201    }
202}
203
204impl Drop for ArrowSchema {
205    fn drop(&mut self) {
206        match self.release {
207            None => (),
208            Some(release) => unsafe { release(self) },
209        };
210    }
211}
212
213pub(crate) unsafe fn to_field(schema: &ArrowSchema) -> PolarsResult<Field> {
214    let dictionary = schema.dictionary();
215    let dtype = if let Some(dictionary) = dictionary {
216        let indices = to_integer_type(schema.format())?;
217        let values = to_field(dictionary)?;
218        let is_ordered = schema.flags & 1 == 1;
219        ArrowDataType::Dictionary(indices, Box::new(values.dtype().clone()), is_ordered)
220    } else {
221        to_dtype(schema)?
222    };
223    let (metadata, extension) = unsafe { metadata_from_bytes(schema.metadata) };
224
225    let dtype = if let Some((name, extension_metadata)) = extension {
226        ArrowDataType::Extension(Box::new(ExtensionType {
227            name,
228            inner: dtype,
229            metadata: extension_metadata,
230        }))
231    } else {
232        dtype
233    };
234
235    Ok(Field::new(
236        PlSmallStr::from_str(schema.name()),
237        dtype,
238        schema.nullable(),
239    )
240    .with_metadata(metadata))
241}
242
243fn to_integer_type(format: &str) -> PolarsResult<IntegerType> {
244    use IntegerType::*;
245    Ok(match format {
246        "c" => Int8,
247        "C" => UInt8,
248        "s" => Int16,
249        "S" => UInt16,
250        "i" => Int32,
251        "I" => UInt32,
252        "l" => Int64,
253        "L" => UInt64,
254        _ => {
255            polars_bail!(
256                ComputeError:
257                "dictionary indices can only be integers"
258            )
259        },
260    })
261}
262
263unsafe fn to_dtype(schema: &ArrowSchema) -> PolarsResult<ArrowDataType> {
264    Ok(match schema.format() {
265        "n" => ArrowDataType::Null,
266        "b" => ArrowDataType::Boolean,
267        "c" => ArrowDataType::Int8,
268        "C" => ArrowDataType::UInt8,
269        "s" => ArrowDataType::Int16,
270        "S" => ArrowDataType::UInt16,
271        "i" => ArrowDataType::Int32,
272        "I" => ArrowDataType::UInt32,
273        "l" => ArrowDataType::Int64,
274        "L" => ArrowDataType::UInt64,
275        "_pli128" => ArrowDataType::Int128,
276        "e" => ArrowDataType::Float16,
277        "f" => ArrowDataType::Float32,
278        "g" => ArrowDataType::Float64,
279        "z" => ArrowDataType::Binary,
280        "Z" => ArrowDataType::LargeBinary,
281        "u" => ArrowDataType::Utf8,
282        "U" => ArrowDataType::LargeUtf8,
283        "tdD" => ArrowDataType::Date32,
284        "tdm" => ArrowDataType::Date64,
285        "tts" => ArrowDataType::Time32(TimeUnit::Second),
286        "ttm" => ArrowDataType::Time32(TimeUnit::Millisecond),
287        "ttu" => ArrowDataType::Time64(TimeUnit::Microsecond),
288        "ttn" => ArrowDataType::Time64(TimeUnit::Nanosecond),
289        "tDs" => ArrowDataType::Duration(TimeUnit::Second),
290        "tDm" => ArrowDataType::Duration(TimeUnit::Millisecond),
291        "tDu" => ArrowDataType::Duration(TimeUnit::Microsecond),
292        "tDn" => ArrowDataType::Duration(TimeUnit::Nanosecond),
293        "tiM" => ArrowDataType::Interval(IntervalUnit::YearMonth),
294        "tiD" => ArrowDataType::Interval(IntervalUnit::DayTime),
295        "vu" => ArrowDataType::Utf8View,
296        "vz" => ArrowDataType::BinaryView,
297        "+l" => {
298            let child = schema.child(0);
299            ArrowDataType::List(Box::new(to_field(child)?))
300        },
301        "+L" => {
302            let child = schema.child(0);
303            ArrowDataType::LargeList(Box::new(to_field(child)?))
304        },
305        "+m" => {
306            let child = schema.child(0);
307
308            let is_sorted = (schema.flags & 4) != 0;
309            ArrowDataType::Map(Box::new(to_field(child)?), is_sorted)
310        },
311        "+s" => {
312            let children = (0..schema.n_children as usize)
313                .map(|x| to_field(schema.child(x)))
314                .collect::<PolarsResult<Vec<_>>>()?;
315            ArrowDataType::Struct(children)
316        },
317        other => {
318            match other.splitn(2, ':').collect::<Vec<_>>()[..] {
319                // Timestamps with no timezone
320                ["tss", ""] => ArrowDataType::Timestamp(TimeUnit::Second, None),
321                ["tsm", ""] => ArrowDataType::Timestamp(TimeUnit::Millisecond, None),
322                ["tsu", ""] => ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
323                ["tsn", ""] => ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
324
325                // Timestamps with timezone
326                ["tss", tz] => {
327                    ArrowDataType::Timestamp(TimeUnit::Second, Some(PlSmallStr::from_str(tz)))
328                },
329                ["tsm", tz] => {
330                    ArrowDataType::Timestamp(TimeUnit::Millisecond, Some(PlSmallStr::from_str(tz)))
331                },
332                ["tsu", tz] => {
333                    ArrowDataType::Timestamp(TimeUnit::Microsecond, Some(PlSmallStr::from_str(tz)))
334                },
335                ["tsn", tz] => {
336                    ArrowDataType::Timestamp(TimeUnit::Nanosecond, Some(PlSmallStr::from_str(tz)))
337                },
338
339                ["w", size_raw] => {
340                    // Example: "w:42" fixed-width binary [42 bytes]
341                    let size = size_raw
342                        .parse::<usize>()
343                        .map_err(|_| polars_err!(ComputeError: "size is not a valid integer"))?;
344                    ArrowDataType::FixedSizeBinary(size)
345                },
346                ["+w", size_raw] => {
347                    // Example: "+w:123" fixed-sized list [123 items]
348                    let size = size_raw
349                        .parse::<usize>()
350                        .map_err(|_| polars_err!(ComputeError: "size is not a valid integer"))?;
351                    let child = to_field(schema.child(0))?;
352                    ArrowDataType::FixedSizeList(Box::new(child), size)
353                },
354                ["d", raw] => {
355                    // Decimal
356                    let (precision, scale) = match raw.split(',').collect::<Vec<_>>()[..] {
357                        [precision_raw, scale_raw] => {
358                            // Example: "d:19,10" decimal128 [precision 19, scale 10]
359                            (precision_raw, scale_raw)
360                        },
361                        [precision_raw, scale_raw, width_raw] => {
362                            // Example: "d:19,10,NNN" decimal bitwidth = NNN [precision 19, scale 10]
363                            // Only bitwdth of 128 currently supported
364                            let bit_width = width_raw.parse::<usize>().map_err(|_| {
365                                polars_err!(ComputeError: "Decimal bit width is not a valid integer")
366                            })?;
367                            match bit_width {
368                                32 => return Ok(ArrowDataType::Decimal32(
369                                    precision_raw.parse::<usize>().map_err(|_| {
370                                        polars_err!(ComputeError: "Decimal precision is not a valid integer")
371                                    })?,
372                                    scale_raw.parse::<usize>().map_err(|_| {
373                                        polars_err!(ComputeError: "Decimal scale is not a valid integer")
374                                    })?,
375                                )),
376                                64 => return Ok(ArrowDataType::Decimal64(
377                                    precision_raw.parse::<usize>().map_err(|_| {
378                                        polars_err!(ComputeError: "Decimal precision is not a valid integer")
379                                    })?,
380                                    scale_raw.parse::<usize>().map_err(|_| {
381                                        polars_err!(ComputeError: "Decimal scale is not a valid integer")
382                                    })?,
383                                )),
384                                256 => return Ok(ArrowDataType::Decimal256(
385                                    precision_raw.parse::<usize>().map_err(|_| {
386                                        polars_err!(ComputeError: "Decimal precision is not a valid integer")
387                                    })?,
388                                    scale_raw.parse::<usize>().map_err(|_| {
389                                        polars_err!(ComputeError: "Decimal scale is not a valid integer")
390                                    })?,
391                                )),
392                                _ => {},
393                            }
394                            (precision_raw, scale_raw)
395                        },
396                        _ => {
397                            polars_bail!(ComputeError:
398                                "Decimal must contain 2 or 3 comma-separated values"
399                            )
400                        },
401                    };
402
403                    ArrowDataType::Decimal(
404                        precision.parse::<usize>().map_err(|_| {
405                            polars_err!(ComputeError:
406                            "Decimal precision is not a valid integer"
407                            )
408                        })?,
409                        scale.parse::<usize>().map_err(|_| {
410                            polars_err!(ComputeError:
411                            "Decimal scale is not a valid integer"
412                            )
413                        })?,
414                    )
415                },
416                [union_type @ "+us", union_parts] | [union_type @ "+ud", union_parts] => {
417                    // union, sparse
418                    // Example "+us:I,J,..." sparse union with type ids I,J...
419                    // Example: "+ud:I,J,..." dense union with type ids I,J...
420                    let mode = UnionMode::sparse(union_type == "+us");
421                    let type_ids = union_parts
422                        .split(',')
423                        .map(|x| {
424                            x.parse::<i32>().map_err(|_| {
425                                polars_err!(ComputeError:
426                                "Union type id is not a valid integer"
427                                )
428                            })
429                        })
430                        .collect::<PolarsResult<Vec<_>>>()?;
431                    let fields = (0..schema.n_children as usize)
432                        .map(|x| to_field(schema.child(x)))
433                        .collect::<PolarsResult<Vec<_>>>()?;
434                    ArrowDataType::Union(Box::new(UnionType {
435                        fields,
436                        ids: Some(type_ids),
437                        mode,
438                    }))
439                },
440                _ => {
441                    polars_bail!(ComputeError:
442                    "The datatype \"{other}\" is still not supported in Rust implementation",
443                        )
444                },
445            }
446        },
447    })
448}
449
450/// the inverse of [to_field]
451fn to_format(dtype: &ArrowDataType) -> String {
452    match dtype {
453        ArrowDataType::Null => "n".to_string(),
454        ArrowDataType::Boolean => "b".to_string(),
455        ArrowDataType::Int8 => "c".to_string(),
456        ArrowDataType::UInt8 => "C".to_string(),
457        ArrowDataType::Int16 => "s".to_string(),
458        ArrowDataType::UInt16 => "S".to_string(),
459        ArrowDataType::Int32 => "i".to_string(),
460        ArrowDataType::UInt32 => "I".to_string(),
461        ArrowDataType::Int64 => "l".to_string(),
462        ArrowDataType::UInt64 => "L".to_string(),
463        // Doesn't exist in arrow, '_pl' prefixed is Polars specific
464        ArrowDataType::Int128 => "_pli128".to_string(),
465        ArrowDataType::Float16 => "e".to_string(),
466        ArrowDataType::Float32 => "f".to_string(),
467        ArrowDataType::Float64 => "g".to_string(),
468        ArrowDataType::Binary => "z".to_string(),
469        ArrowDataType::LargeBinary => "Z".to_string(),
470        ArrowDataType::Utf8 => "u".to_string(),
471        ArrowDataType::LargeUtf8 => "U".to_string(),
472        ArrowDataType::Date32 => "tdD".to_string(),
473        ArrowDataType::Date64 => "tdm".to_string(),
474        ArrowDataType::Time32(TimeUnit::Second) => "tts".to_string(),
475        ArrowDataType::Time32(TimeUnit::Millisecond) => "ttm".to_string(),
476        ArrowDataType::Time32(_) => {
477            unreachable!("Time32 is only supported for seconds and milliseconds")
478        },
479        ArrowDataType::Time64(TimeUnit::Microsecond) => "ttu".to_string(),
480        ArrowDataType::Time64(TimeUnit::Nanosecond) => "ttn".to_string(),
481        ArrowDataType::Time64(_) => {
482            unreachable!("Time64 is only supported for micro and nanoseconds")
483        },
484        ArrowDataType::Duration(TimeUnit::Second) => "tDs".to_string(),
485        ArrowDataType::Duration(TimeUnit::Millisecond) => "tDm".to_string(),
486        ArrowDataType::Duration(TimeUnit::Microsecond) => "tDu".to_string(),
487        ArrowDataType::Duration(TimeUnit::Nanosecond) => "tDn".to_string(),
488        ArrowDataType::Interval(IntervalUnit::YearMonth) => "tiM".to_string(),
489        ArrowDataType::Interval(IntervalUnit::DayTime) => "tiD".to_string(),
490        ArrowDataType::Interval(IntervalUnit::MonthDayNano) => {
491            todo!("Spec for FFI for MonthDayNano still not defined.")
492        },
493        ArrowDataType::Timestamp(unit, tz) => {
494            let unit = match unit {
495                TimeUnit::Second => "s",
496                TimeUnit::Millisecond => "m",
497                TimeUnit::Microsecond => "u",
498                TimeUnit::Nanosecond => "n",
499            };
500            format!(
501                "ts{}:{}",
502                unit,
503                tz.as_ref().map(|x| x.as_str()).unwrap_or("")
504            )
505        },
506        ArrowDataType::Utf8View => "vu".to_string(),
507        ArrowDataType::BinaryView => "vz".to_string(),
508        ArrowDataType::Decimal(precision, scale) => format!("d:{precision},{scale}"),
509        ArrowDataType::Decimal32(precision, scale) => format!("d:{precision},{scale},32"),
510        ArrowDataType::Decimal64(precision, scale) => format!("d:{precision},{scale},64"),
511        ArrowDataType::Decimal256(precision, scale) => format!("d:{precision},{scale},256"),
512        ArrowDataType::List(_) => "+l".to_string(),
513        ArrowDataType::LargeList(_) => "+L".to_string(),
514        ArrowDataType::Struct(_) => "+s".to_string(),
515        ArrowDataType::FixedSizeBinary(size) => format!("w:{size}"),
516        ArrowDataType::FixedSizeList(_, size) => format!("+w:{size}"),
517        ArrowDataType::Union(u) => {
518            let sparsness = if u.mode.is_sparse() { 's' } else { 'd' };
519            let mut r = format!("+u{sparsness}:");
520            let ids = if let Some(ids) = &u.ids {
521                ids.iter()
522                    .fold(String::new(), |a, b| a + b.to_string().as_str() + ",")
523            } else {
524                (0..u.fields.len()).fold(String::new(), |a, b| a + b.to_string().as_str() + ",")
525            };
526            let ids = &ids[..ids.len() - 1]; // take away last ","
527            r.push_str(ids);
528            r
529        },
530        ArrowDataType::Map(_, _) => "+m".to_string(),
531        ArrowDataType::Dictionary(index, _, _) => to_format(&(*index).into()),
532        ArrowDataType::Extension(ext) => to_format(&ext.inner),
533        ArrowDataType::Unknown => unimplemented!(),
534    }
535}
536
537pub(super) fn get_child(dtype: &ArrowDataType, index: usize) -> PolarsResult<ArrowDataType> {
538    match (index, dtype) {
539        (0, ArrowDataType::List(field)) => Ok(field.dtype().clone()),
540        (0, ArrowDataType::FixedSizeList(field, _)) => Ok(field.dtype().clone()),
541        (0, ArrowDataType::LargeList(field)) => Ok(field.dtype().clone()),
542        (0, ArrowDataType::Map(field, _)) => Ok(field.dtype().clone()),
543        (index, ArrowDataType::Struct(fields)) => Ok(fields[index].dtype().clone()),
544        (index, ArrowDataType::Union(u)) => Ok(u.fields[index].dtype().clone()),
545        (index, ArrowDataType::Extension(ext)) => get_child(&ext.inner, index),
546        (child, dtype) => polars_bail!(ComputeError:
547            "Requested child {child} to type {dtype:?} that has no such child",
548        ),
549    }
550}
551
552fn metadata_to_bytes(metadata: &BTreeMap<PlSmallStr, PlSmallStr>) -> Vec<u8> {
553    let a = (metadata.len() as i32).to_ne_bytes().to_vec();
554    metadata.iter().fold(a, |mut acc, (key, value)| {
555        acc.extend((key.len() as i32).to_ne_bytes());
556        acc.extend(key.as_bytes());
557        acc.extend((value.len() as i32).to_ne_bytes());
558        acc.extend(value.as_bytes());
559        acc
560    })
561}
562
563unsafe fn read_ne_i32(ptr: *const u8) -> i32 {
564    let slice = std::slice::from_raw_parts(ptr, 4);
565    i32::from_ne_bytes(slice.try_into().unwrap())
566}
567
568unsafe fn read_bytes(ptr: *const u8, len: usize) -> &'static str {
569    let slice = std::slice::from_raw_parts(ptr, len);
570    simdutf8::basic::from_utf8(slice).unwrap()
571}
572
573unsafe fn metadata_from_bytes(data: *const ::std::os::raw::c_char) -> (Metadata, Extension) {
574    let mut data = data as *const u8; // u8 = i8
575    if data.is_null() {
576        return (Metadata::default(), None);
577    };
578    let len = read_ne_i32(data);
579    data = data.add(4);
580
581    let mut result = BTreeMap::new();
582    let mut extension_name = None;
583    let mut extension_metadata = None;
584    for _ in 0..len {
585        let key_len = read_ne_i32(data) as usize;
586        data = data.add(4);
587        let key = read_bytes(data, key_len);
588        data = data.add(key_len);
589        let value_len = read_ne_i32(data) as usize;
590        data = data.add(4);
591        let value = read_bytes(data, value_len);
592        data = data.add(value_len);
593        match key {
594            "ARROW:extension:name" => {
595                extension_name = Some(PlSmallStr::from_str(value));
596            },
597            "ARROW:extension:metadata" => {
598                extension_metadata = Some(PlSmallStr::from_str(value));
599            },
600            _ => {
601                result.insert(PlSmallStr::from_str(key), PlSmallStr::from_str(value));
602            },
603        };
604    }
605    let extension = extension_name.map(|name| (name, extension_metadata));
606    (result, extension)
607}
608
609#[cfg(test)]
610mod tests {
611    use super::*;
612    use crate::array::LIST_VALUES_NAME;
613
614    #[test]
615    fn test_all() {
616        let mut dts = vec![
617            ArrowDataType::Null,
618            ArrowDataType::Boolean,
619            ArrowDataType::UInt8,
620            ArrowDataType::UInt16,
621            ArrowDataType::UInt32,
622            ArrowDataType::UInt64,
623            ArrowDataType::Int8,
624            ArrowDataType::Int16,
625            ArrowDataType::Int32,
626            ArrowDataType::Int64,
627            ArrowDataType::Float32,
628            ArrowDataType::Float64,
629            ArrowDataType::Date32,
630            ArrowDataType::Date64,
631            ArrowDataType::Time32(TimeUnit::Second),
632            ArrowDataType::Time32(TimeUnit::Millisecond),
633            ArrowDataType::Time64(TimeUnit::Microsecond),
634            ArrowDataType::Time64(TimeUnit::Nanosecond),
635            ArrowDataType::Decimal(5, 5),
636            ArrowDataType::Utf8,
637            ArrowDataType::LargeUtf8,
638            ArrowDataType::Binary,
639            ArrowDataType::LargeBinary,
640            ArrowDataType::FixedSizeBinary(2),
641            ArrowDataType::List(Box::new(Field::new(
642                PlSmallStr::from_static("example"),
643                ArrowDataType::Boolean,
644                false,
645            ))),
646            ArrowDataType::FixedSizeList(
647                Box::new(Field::new(
648                    PlSmallStr::from_static("example"),
649                    ArrowDataType::Boolean,
650                    false,
651                )),
652                2,
653            ),
654            ArrowDataType::LargeList(Box::new(Field::new(
655                PlSmallStr::from_static("example"),
656                ArrowDataType::Boolean,
657                false,
658            ))),
659            ArrowDataType::Struct(vec![
660                Field::new(PlSmallStr::from_static("a"), ArrowDataType::Int64, true),
661                Field::new(
662                    PlSmallStr::from_static("b"),
663                    ArrowDataType::List(Box::new(Field::new(
664                        LIST_VALUES_NAME,
665                        ArrowDataType::Int32,
666                        true,
667                    ))),
668                    true,
669                ),
670            ]),
671            ArrowDataType::Map(
672                Box::new(Field::new(
673                    PlSmallStr::from_static("a"),
674                    ArrowDataType::Int64,
675                    true,
676                )),
677                true,
678            ),
679            ArrowDataType::Union(Box::new(UnionType {
680                fields: vec![
681                    Field::new(PlSmallStr::from_static("a"), ArrowDataType::Int64, true),
682                    Field::new(
683                        PlSmallStr::from_static("b"),
684                        ArrowDataType::List(Box::new(Field::new(
685                            LIST_VALUES_NAME,
686                            ArrowDataType::Int32,
687                            true,
688                        ))),
689                        true,
690                    ),
691                ],
692                ids: Some(vec![1, 2]),
693                mode: UnionMode::Dense,
694            })),
695            ArrowDataType::Union(Box::new(UnionType {
696                fields: vec![
697                    Field::new(PlSmallStr::from_static("a"), ArrowDataType::Int64, true),
698                    Field::new(
699                        PlSmallStr::from_static("b"),
700                        ArrowDataType::List(Box::new(Field::new(
701                            LIST_VALUES_NAME,
702                            ArrowDataType::Int32,
703                            true,
704                        ))),
705                        true,
706                    ),
707                ],
708                ids: Some(vec![0, 1]),
709                mode: UnionMode::Sparse,
710            })),
711        ];
712        for time_unit in [
713            TimeUnit::Second,
714            TimeUnit::Millisecond,
715            TimeUnit::Microsecond,
716            TimeUnit::Nanosecond,
717        ] {
718            dts.push(ArrowDataType::Timestamp(time_unit, None));
719            dts.push(ArrowDataType::Timestamp(
720                time_unit,
721                Some(PlSmallStr::from_static("00:00")),
722            ));
723            dts.push(ArrowDataType::Duration(time_unit));
724        }
725        for interval_type in [
726            IntervalUnit::DayTime,
727            IntervalUnit::YearMonth,
728            //IntervalUnit::MonthDayNano, // not yet defined on the C data interface
729        ] {
730            dts.push(ArrowDataType::Interval(interval_type));
731        }
732
733        for expected in dts {
734            let field = Field::new(PlSmallStr::from_static("a"), expected.clone(), true);
735            let schema = ArrowSchema::new(&field);
736            let result = unsafe { super::to_dtype(&schema).unwrap() };
737            assert_eq!(result, expected);
738        }
739    }
740}