1use std::sync::Arc;
2
3use arrow_format::ipc::planus::ReadAsRoot;
4use arrow_format::ipc::{FieldRef, FixedSizeListRef, MapRef, TimeRef, TimestampRef, UnionRef};
5use polars_error::{PolarsResult, polars_bail, polars_err};
6use polars_utils::pl_str::PlSmallStr;
7
8use super::super::{IpcField, IpcSchema};
9use super::{OutOfSpecKind, StreamMetadata};
10use crate::datatypes::{
11 ArrowDataType, ArrowSchema, Extension, ExtensionType, Field, IntegerType, IntervalUnit,
12 Metadata, TimeUnit, UnionMode, UnionType, get_extension,
13};
14
15fn try_unzip_vec<A, B, I: Iterator<Item = PolarsResult<(A, B)>>>(
16 iter: I,
17) -> PolarsResult<(Vec<A>, Vec<B>)> {
18 let mut a = vec![];
19 let mut b = vec![];
20 for maybe_item in iter {
21 let (a_i, b_i) = maybe_item?;
22 a.push(a_i);
23 b.push(b_i);
24 }
25
26 Ok((a, b))
27}
28
29fn deserialize_field(ipc_field: arrow_format::ipc::FieldRef) -> PolarsResult<(Field, IpcField)> {
30 let metadata = read_metadata(&ipc_field)?;
31
32 let extension = metadata.as_ref().and_then(get_extension);
33
34 let (dtype, ipc_field_) = get_dtype(ipc_field, extension, true)?;
35
36 let field = Field {
37 name: PlSmallStr::from_str(
38 ipc_field
39 .name()?
40 .ok_or_else(|| polars_err!(oos = "Every field in IPC must have a name"))?,
41 ),
42 dtype,
43 is_nullable: ipc_field.nullable()?,
44 metadata: metadata.map(Arc::new),
45 };
46
47 Ok((field, ipc_field_))
48}
49
50fn read_metadata(field: &arrow_format::ipc::FieldRef) -> PolarsResult<Option<Metadata>> {
51 Ok(if let Some(list) = field.custom_metadata()? {
52 let mut metadata_map = Metadata::new();
53 for kv in list {
54 let kv = kv?;
55 if let (Some(k), Some(v)) = (kv.key()?, kv.value()?) {
56 metadata_map.insert(PlSmallStr::from_str(k), PlSmallStr::from_str(v));
57 }
58 }
59 Some(metadata_map)
60 } else {
61 None
62 })
63}
64
65fn deserialize_integer(int: arrow_format::ipc::IntRef) -> PolarsResult<IntegerType> {
66 Ok(match (int.bit_width()?, int.is_signed()?) {
67 (8, true) => IntegerType::Int8,
68 (8, false) => IntegerType::UInt8,
69 (16, true) => IntegerType::Int16,
70 (16, false) => IntegerType::UInt16,
71 (32, true) => IntegerType::Int32,
72 (32, false) => IntegerType::UInt32,
73 (64, true) => IntegerType::Int64,
74 (64, false) => IntegerType::UInt64,
75 (128, true) => IntegerType::Int128,
76 _ => polars_bail!(oos = "IPC: indexType can only be 8, 16, 32, 64 or 128."),
77 })
78}
79
80fn deserialize_timeunit(time_unit: arrow_format::ipc::TimeUnit) -> PolarsResult<TimeUnit> {
81 use arrow_format::ipc::TimeUnit::*;
82 Ok(match time_unit {
83 Second => TimeUnit::Second,
84 Millisecond => TimeUnit::Millisecond,
85 Microsecond => TimeUnit::Microsecond,
86 Nanosecond => TimeUnit::Nanosecond,
87 })
88}
89
90fn deserialize_time(time: TimeRef) -> PolarsResult<(ArrowDataType, IpcField)> {
91 let unit = deserialize_timeunit(time.unit()?)?;
92
93 let dtype = match (time.bit_width()?, unit) {
94 (32, TimeUnit::Second) => ArrowDataType::Time32(TimeUnit::Second),
95 (32, TimeUnit::Millisecond) => ArrowDataType::Time32(TimeUnit::Millisecond),
96 (64, TimeUnit::Microsecond) => ArrowDataType::Time64(TimeUnit::Microsecond),
97 (64, TimeUnit::Nanosecond) => ArrowDataType::Time64(TimeUnit::Nanosecond),
98 (bits, precision) => {
99 polars_bail!(ComputeError:
100 "Time type with bit width of {bits} and unit of {precision:?}"
101 )
102 },
103 };
104 Ok((dtype, IpcField::default()))
105}
106
107fn deserialize_timestamp(timestamp: TimestampRef) -> PolarsResult<(ArrowDataType, IpcField)> {
108 let timezone = timestamp.timezone()?;
109 let time_unit = deserialize_timeunit(timestamp.unit()?)?;
110 Ok((
111 ArrowDataType::Timestamp(time_unit, timezone.map(PlSmallStr::from_str)),
112 IpcField::default(),
113 ))
114}
115
116fn deserialize_union(union_: UnionRef, field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {
117 let mode = UnionMode::sparse(union_.mode()? == arrow_format::ipc::UnionMode::Sparse);
118 let ids = union_.type_ids()?.map(|x| x.iter().collect());
119
120 let fields = field
121 .children()?
122 .ok_or_else(|| polars_err!(oos = "IPC: Union must contain children"))?;
123 if fields.is_empty() {
124 polars_bail!(oos = "IPC: Union must contain at least one child");
125 }
126
127 let (fields, ipc_fields) = try_unzip_vec(fields.iter().map(|field| {
128 let (field, fields) = deserialize_field(field?)?;
129 Ok((field, fields))
130 }))?;
131 let ipc_field = IpcField {
132 fields: ipc_fields,
133 dictionary_id: None,
134 };
135 Ok((
136 ArrowDataType::Union(Box::new(UnionType { fields, ids, mode })),
137 ipc_field,
138 ))
139}
140
141fn deserialize_map(map: MapRef, field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {
142 let is_sorted = map.keys_sorted()?;
143
144 let children = field
145 .children()?
146 .ok_or_else(|| polars_err!(oos = "IPC: Map must contain children"))?;
147 let inner = children
148 .get(0)
149 .ok_or_else(|| polars_err!(oos = "IPC: Map must contain one child"))??;
150 let (field, ipc_field) = deserialize_field(inner)?;
151
152 let dtype = ArrowDataType::Map(Box::new(field), is_sorted);
153 Ok((
154 dtype,
155 IpcField {
156 fields: vec![ipc_field],
157 dictionary_id: None,
158 },
159 ))
160}
161
162fn deserialize_struct(field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {
163 let fields = field
164 .children()?
165 .ok_or_else(|| polars_err!(oos = "IPC: Struct must contain children"))?;
166 let (fields, ipc_fields) = try_unzip_vec(fields.iter().map(|field| {
167 let (field, fields) = deserialize_field(field?)?;
168 Ok((field, fields))
169 }))?;
170 let ipc_field = IpcField {
171 fields: ipc_fields,
172 dictionary_id: None,
173 };
174 Ok((ArrowDataType::Struct(fields), ipc_field))
175}
176
177fn deserialize_list(field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {
178 let children = field
179 .children()?
180 .ok_or_else(|| polars_err!(oos = "IPC: List must contain children"))?;
181 let inner = children
182 .get(0)
183 .ok_or_else(|| polars_err!(oos = "IPC: List must contain one child"))??;
184 let (field, ipc_field) = deserialize_field(inner)?;
185
186 Ok((
187 ArrowDataType::List(Box::new(field)),
188 IpcField {
189 fields: vec![ipc_field],
190 dictionary_id: None,
191 },
192 ))
193}
194
195fn deserialize_large_list(field: FieldRef) -> PolarsResult<(ArrowDataType, IpcField)> {
196 let children = field
197 .children()?
198 .ok_or_else(|| polars_err!(oos = "IPC: List must contain children"))?;
199 let inner = children
200 .get(0)
201 .ok_or_else(|| polars_err!(oos = "IPC: List must contain one child"))??;
202 let (field, ipc_field) = deserialize_field(inner)?;
203
204 Ok((
205 ArrowDataType::LargeList(Box::new(field)),
206 IpcField {
207 fields: vec![ipc_field],
208 dictionary_id: None,
209 },
210 ))
211}
212
213fn deserialize_fixed_size_list(
214 list: FixedSizeListRef,
215 field: FieldRef,
216) -> PolarsResult<(ArrowDataType, IpcField)> {
217 let children = field
218 .children()?
219 .ok_or_else(|| polars_err!(oos = "IPC: FixedSizeList must contain children"))?;
220 let inner = children
221 .get(0)
222 .ok_or_else(|| polars_err!(oos = "IPC: FixedSizeList must contain one child"))??;
223 let (field, ipc_field) = deserialize_field(inner)?;
224
225 let size = list
226 .list_size()?
227 .try_into()
228 .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
229
230 Ok((
231 ArrowDataType::FixedSizeList(Box::new(field), size),
232 IpcField {
233 fields: vec![ipc_field],
234 dictionary_id: None,
235 },
236 ))
237}
238
239fn get_dtype(
241 field: arrow_format::ipc::FieldRef,
242 extension: Extension,
243 may_be_dictionary: bool,
244) -> PolarsResult<(ArrowDataType, IpcField)> {
245 if let Some(dictionary) = field.dictionary()? {
246 if may_be_dictionary {
247 let int = dictionary
248 .index_type()?
249 .ok_or_else(|| polars_err!(oos = "indexType is mandatory in Dictionary."))?;
250 let index_type = deserialize_integer(int)?;
251 let (inner, mut ipc_field) = get_dtype(field, extension, false)?;
252 ipc_field.dictionary_id = Some(dictionary.id()?);
253 return Ok((
254 ArrowDataType::Dictionary(index_type, Box::new(inner), dictionary.is_ordered()?),
255 ipc_field,
256 ));
257 }
258 }
259
260 if let Some(extension) = extension {
261 let (name, metadata) = extension;
262 let (dtype, fields) = get_dtype(field, None, false)?;
263 return Ok((
264 ArrowDataType::Extension(Box::new(ExtensionType {
265 name,
266 inner: dtype,
267 metadata,
268 })),
269 fields,
270 ));
271 }
272
273 let type_ = field
274 .type_()?
275 .ok_or_else(|| polars_err!(oos = "IPC: field type is mandatory"))?;
276
277 use arrow_format::ipc::TypeRef::*;
278 Ok(match type_ {
279 Null(_) => (ArrowDataType::Null, IpcField::default()),
280 Bool(_) => (ArrowDataType::Boolean, IpcField::default()),
281 Int(int) => {
282 let dtype = deserialize_integer(int)?.into();
283 (dtype, IpcField::default())
284 },
285 Binary(_) => (ArrowDataType::Binary, IpcField::default()),
286 LargeBinary(_) => (ArrowDataType::LargeBinary, IpcField::default()),
287 Utf8(_) => (ArrowDataType::Utf8, IpcField::default()),
288 LargeUtf8(_) => (ArrowDataType::LargeUtf8, IpcField::default()),
289 BinaryView(_) => (ArrowDataType::BinaryView, IpcField::default()),
290 Utf8View(_) => (ArrowDataType::Utf8View, IpcField::default()),
291 FixedSizeBinary(fixed) => (
292 ArrowDataType::FixedSizeBinary(
293 fixed
294 .byte_width()?
295 .try_into()
296 .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?,
297 ),
298 IpcField::default(),
299 ),
300 FloatingPoint(float) => {
301 let dtype = match float.precision()? {
302 arrow_format::ipc::Precision::Half => ArrowDataType::Float16,
303 arrow_format::ipc::Precision::Single => ArrowDataType::Float32,
304 arrow_format::ipc::Precision::Double => ArrowDataType::Float64,
305 };
306 (dtype, IpcField::default())
307 },
308 Date(date) => {
309 let dtype = match date.unit()? {
310 arrow_format::ipc::DateUnit::Day => ArrowDataType::Date32,
311 arrow_format::ipc::DateUnit::Millisecond => ArrowDataType::Date64,
312 };
313 (dtype, IpcField::default())
314 },
315 Time(time) => deserialize_time(time)?,
316 Timestamp(timestamp) => deserialize_timestamp(timestamp)?,
317 Interval(interval) => {
318 let dtype = match interval.unit()? {
319 arrow_format::ipc::IntervalUnit::YearMonth => {
320 ArrowDataType::Interval(IntervalUnit::YearMonth)
321 },
322 arrow_format::ipc::IntervalUnit::DayTime => {
323 ArrowDataType::Interval(IntervalUnit::DayTime)
324 },
325 arrow_format::ipc::IntervalUnit::MonthDayNano => {
326 ArrowDataType::Interval(IntervalUnit::MonthDayNano)
327 },
328 };
329 (dtype, IpcField::default())
330 },
331 Duration(duration) => {
332 let time_unit = deserialize_timeunit(duration.unit()?)?;
333 (ArrowDataType::Duration(time_unit), IpcField::default())
334 },
335 Decimal(decimal) => {
336 let bit_width: usize = decimal
337 .bit_width()?
338 .try_into()
339 .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
340 let precision: usize = decimal
341 .precision()?
342 .try_into()
343 .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
344 let scale: usize = decimal
345 .scale()?
346 .try_into()
347 .map_err(|_| polars_err!(oos = OutOfSpecKind::NegativeFooterLength))?;
348
349 let dtype = match bit_width {
350 32 => ArrowDataType::Decimal32(precision, scale),
351 64 => ArrowDataType::Decimal64(precision, scale),
352 128 => ArrowDataType::Decimal(precision, scale),
353 256 => ArrowDataType::Decimal256(precision, scale),
354 _ => return Err(polars_err!(oos = OutOfSpecKind::NegativeFooterLength)),
355 };
356
357 (dtype, IpcField::default())
358 },
359 List(_) => deserialize_list(field)?,
360 LargeList(_) => deserialize_large_list(field)?,
361 FixedSizeList(list) => deserialize_fixed_size_list(list, field)?,
362 Struct(_) => deserialize_struct(field)?,
363 Union(union_) => deserialize_union(union_, field)?,
364 Map(map) => deserialize_map(map, field)?,
365 RunEndEncoded(_) => todo!(),
366 LargeListView(_) | ListView(_) => todo!(),
367 })
368}
369
370pub fn deserialize_schema(
372 message: &[u8],
373) -> PolarsResult<(ArrowSchema, IpcSchema, Option<Metadata>)> {
374 let message = arrow_format::ipc::MessageRef::read_as_root(message)
375 .map_err(|err| polars_err!(oos = format!("Unable deserialize message: {err:?}")))?;
376
377 let schema = match message
378 .header()?
379 .ok_or_else(|| polars_err!(oos = "Unable to convert header to a schema".to_string()))?
380 {
381 arrow_format::ipc::MessageHeaderRef::Schema(schema) => PolarsResult::Ok(schema),
382 _ => polars_bail!(ComputeError: "The message is expected to be a Schema message"),
383 }?;
384
385 fb_to_schema(schema)
386}
387
388pub(super) fn fb_to_schema(
390 schema: arrow_format::ipc::SchemaRef,
391) -> PolarsResult<(ArrowSchema, IpcSchema, Option<Metadata>)> {
392 let fields = schema
393 .fields()?
394 .ok_or_else(|| polars_err!(oos = OutOfSpecKind::MissingFields))?;
395
396 let mut arrow_schema = ArrowSchema::with_capacity(fields.len());
397 let mut ipc_fields = Vec::with_capacity(fields.len());
398
399 for field in fields {
400 let (field, ipc_field) = deserialize_field(field?)?;
401 arrow_schema.insert(field.name.clone(), field);
402 ipc_fields.push(ipc_field);
403 }
404
405 let is_little_endian = match schema.endianness()? {
406 arrow_format::ipc::Endianness::Little => true,
407 arrow_format::ipc::Endianness::Big => false,
408 };
409
410 let custom_schema_metadata = match schema.custom_metadata()? {
411 None => None,
412 Some(metadata) => {
413 let metadata: Metadata = metadata
414 .into_iter()
415 .filter_map(|kv_result| {
416 let kv_ref = kv_result.ok()?;
418 Some((kv_ref.key().ok()??.into(), kv_ref.value().ok()??.into()))
419 })
420 .collect();
421
422 if metadata.is_empty() {
423 None
424 } else {
425 Some(metadata)
426 }
427 },
428 };
429
430 Ok((
431 arrow_schema,
432 IpcSchema {
433 fields: ipc_fields,
434 is_little_endian,
435 },
436 custom_schema_metadata,
437 ))
438}
439
440pub(super) fn deserialize_stream_metadata(meta: &[u8]) -> PolarsResult<StreamMetadata> {
441 let message = arrow_format::ipc::MessageRef::read_as_root(meta)
442 .map_err(|err| polars_err!(oos = format!("Unable to get root as message: {err:?}")))?;
443 let version = message.version()?;
444 let header = message
446 .header()?
447 .ok_or_else(|| polars_err!(oos = "Unable to read the first IPC message"))?;
448 let schema = if let arrow_format::ipc::MessageHeaderRef::Schema(schema) = header {
449 schema
450 } else {
451 polars_bail!(oos = "The first IPC message of the stream must be a schema")
452 };
453 let (schema, ipc_schema, custom_schema_metadata) = fb_to_schema(schema)?;
454
455 Ok(StreamMetadata {
456 schema,
457 version,
458 ipc_schema,
459 custom_schema_metadata,
460 })
461}