solana_borsh/
macros.rs

1//! Macros for implementing functions across multiple versions of Borsh
2
3macro_rules! impl_get_packed_len_v1 {
4    ($borsh:ident $(,#[$meta:meta])?) => {
5        /// Get the worst-case packed length for the given BorshSchema
6        ///
7        /// Note: due to the serializer currently used by Borsh, this function cannot
8        /// be used on-chain in the Solana SBF execution environment.
9        $(#[$meta])?
10        pub fn get_packed_len<S: $borsh::BorshSchema>() -> usize {
11            let container = $borsh::schema_container_of::<S>();
12            get_declaration_packed_len(container.declaration(), &container)
13        }
14
15        /// Get packed length for the given BorshSchema Declaration
16        fn get_declaration_packed_len(
17            declaration: &str,
18            container: &$borsh::schema::BorshSchemaContainer,
19        ) -> usize {
20            match container.get_definition(declaration) {
21                Some($borsh::schema::Definition::Sequence { length_width, length_range, elements }) if *length_width == 0 => {
22                    *length_range.end() as usize * get_declaration_packed_len(elements, container)
23                }
24                Some($borsh::schema::Definition::Enum { tag_width, variants }) => {
25                    (*tag_width as usize) + variants
26                        .iter()
27                        .map(|(_, _, declaration)| get_declaration_packed_len(declaration, container))
28                        .max()
29                        .unwrap_or(0)
30                }
31                Some($borsh::schema::Definition::Struct { fields }) => match fields {
32                    $borsh::schema::Fields::NamedFields(named_fields) => named_fields
33                        .iter()
34                        .map(|(_, declaration)| get_declaration_packed_len(declaration, container))
35                        .sum(),
36                    $borsh::schema::Fields::UnnamedFields(declarations) => declarations
37                        .iter()
38                        .map(|declaration| get_declaration_packed_len(declaration, container))
39                        .sum(),
40                    $borsh::schema::Fields::Empty => 0,
41                },
42                Some($borsh::schema::Definition::Sequence {
43                    ..
44                }) => panic!("Missing support for Definition::Sequence"),
45                Some($borsh::schema::Definition::Tuple { elements }) => elements
46                    .iter()
47                    .map(|element| get_declaration_packed_len(element, container))
48                    .sum(),
49                Some($borsh::schema::Definition::Primitive(size)) => *size as usize,
50                None => match declaration {
51                    "bool" | "u8" | "i8" => 1,
52                    "u16" | "i16" => 2,
53                    "u32" | "i32" => 4,
54                    "u64" | "i64" => 8,
55                    "u128" | "i128" => 16,
56                    "nil" => 0,
57                    _ => panic!("Missing primitive type: {declaration}"),
58                },
59            }
60        }
61    }
62}
63pub(crate) use impl_get_packed_len_v1;
64
65macro_rules! impl_try_from_slice_unchecked {
66    ($borsh:ident, $borsh_io:ident $(,#[$meta:meta])?) => {
67        /// Deserializes without checking that the entire slice has been consumed
68        ///
69        /// Normally, `try_from_slice` checks the length of the final slice to ensure
70        /// that the deserialization uses up all of the bytes in the slice.
71        ///
72        /// Note that there is a potential issue with this function. Any buffer greater than
73        /// or equal to the expected size will properly deserialize. For example, if the
74        /// user passes a buffer destined for a different type, the error won't get caught
75        /// as easily.
76        $(#[$meta])?
77        pub fn try_from_slice_unchecked<T: $borsh::BorshDeserialize>(data: &[u8]) -> Result<T, $borsh_io::Error> {
78            let mut data_mut = data;
79            let result = T::deserialize(&mut data_mut)?;
80            Ok(result)
81        }
82    }
83}
84pub(crate) use impl_try_from_slice_unchecked;
85
86macro_rules! impl_get_instance_packed_len {
87    ($borsh:ident, $borsh_io:ident $(,#[$meta:meta])?) => {
88        /// Helper struct which to count how much data would be written during serialization
89        #[derive(Default)]
90        struct WriteCounter {
91            count: usize,
92        }
93
94        impl $borsh_io::Write for WriteCounter {
95            fn write(&mut self, data: &[u8]) -> Result<usize, $borsh_io::Error> {
96                let amount = data.len();
97                self.count += amount;
98                Ok(amount)
99            }
100
101            fn flush(&mut self) -> Result<(), $borsh_io::Error> {
102                Ok(())
103            }
104        }
105
106        /// Get the packed length for the serialized form of this object instance.
107        ///
108        /// Useful when working with instances of types that contain a variable-length
109        /// sequence, such as a Vec or HashMap.  Since it is impossible to know the packed
110        /// length only from the type's schema, this can be used when an instance already
111        /// exists, to figure out how much space to allocate in an account.
112        $(#[$meta])?
113        pub fn get_instance_packed_len<T: $borsh::BorshSerialize>(instance: &T) -> Result<usize, $borsh_io::Error> {
114            let mut counter = WriteCounter::default();
115            instance.serialize(&mut counter)?;
116            Ok(counter.count)
117        }
118    }
119}
120pub(crate) use impl_get_instance_packed_len;
121
122#[cfg(test)]
123macro_rules! impl_tests {
124    ($borsh:ident, $borsh_io:ident) => {
125        extern crate alloc;
126        use {
127            super::*,
128            std::{collections::HashMap, mem::size_of},
129            $borsh::{BorshDeserialize, BorshSerialize},
130            $borsh_io::ErrorKind,
131        };
132
133        type Child = [u8; 64];
134        type Parent = Vec<Child>;
135
136        #[test]
137        fn unchecked_deserialization() {
138            let parent = vec![[0u8; 64], [1u8; 64], [2u8; 64]];
139
140            // exact size, both work
141            let mut byte_vec = vec![0u8; 4 + get_packed_len::<Child>() * 3];
142            let mut bytes = byte_vec.as_mut_slice();
143            parent.serialize(&mut bytes).unwrap();
144            let deserialized = Parent::try_from_slice(&byte_vec).unwrap();
145            assert_eq!(deserialized, parent);
146            let deserialized = try_from_slice_unchecked::<Parent>(&byte_vec).unwrap();
147            assert_eq!(deserialized, parent);
148
149            // too big, only unchecked works
150            let mut byte_vec = vec![0u8; 4 + get_packed_len::<Child>() * 10];
151            let mut bytes = byte_vec.as_mut_slice();
152            parent.serialize(&mut bytes).unwrap();
153            let err = Parent::try_from_slice(&byte_vec).unwrap_err();
154            assert_eq!(err.kind(), ErrorKind::InvalidData);
155            let deserialized = try_from_slice_unchecked::<Parent>(&byte_vec).unwrap();
156            assert_eq!(deserialized, parent);
157        }
158
159        #[test]
160        fn packed_len() {
161            assert_eq!(get_packed_len::<u64>(), size_of::<u64>());
162            assert_eq!(get_packed_len::<Child>(), size_of::<u8>() * 64);
163        }
164
165        #[test]
166        fn instance_packed_len_matches_packed_len() {
167            let child = [0u8; 64];
168            assert_eq!(
169                get_packed_len::<Child>(),
170                get_instance_packed_len(&child).unwrap(),
171            );
172            assert_eq!(
173                get_packed_len::<u8>(),
174                get_instance_packed_len(&0u8).unwrap(),
175            );
176            assert_eq!(
177                get_packed_len::<u16>(),
178                get_instance_packed_len(&0u16).unwrap(),
179            );
180            assert_eq!(
181                get_packed_len::<u32>(),
182                get_instance_packed_len(&0u32).unwrap(),
183            );
184            assert_eq!(
185                get_packed_len::<u64>(),
186                get_instance_packed_len(&0u64).unwrap(),
187            );
188            assert_eq!(
189                get_packed_len::<u128>(),
190                get_instance_packed_len(&0u128).unwrap(),
191            );
192            assert_eq!(
193                get_packed_len::<[u8; 10]>(),
194                get_instance_packed_len(&[0u8; 10]).unwrap(),
195            );
196            assert_eq!(
197                get_packed_len::<(i8, i16, i32, i64, i128)>(),
198                get_instance_packed_len(&(i8::MAX, i16::MAX, i32::MAX, i64::MAX, i128::MAX))
199                    .unwrap(),
200            );
201        }
202
203        #[test]
204        fn instance_packed_len_with_vec() {
205            let parent = vec![
206                [0u8; 64], [1u8; 64], [2u8; 64], [3u8; 64], [4u8; 64], [5u8; 64],
207            ];
208            assert_eq!(
209                get_instance_packed_len(&parent).unwrap(),
210                4 + parent.len() * get_packed_len::<Child>()
211            );
212        }
213
214        #[test]
215        fn instance_packed_len_with_varying_sizes_in_hashmap() {
216            let mut data = HashMap::new();
217            let key1 = "the first string, it's actually really really long".to_string();
218            let value1 = "".to_string();
219            let key2 = "second string, shorter".to_string();
220            let value2 = "a real value".to_string();
221            let key3 = "third".to_string();
222            let value3 = "an even longer value".to_string();
223            data.insert(key1.clone(), value1.clone());
224            data.insert(key2.clone(), value2.clone());
225            data.insert(key3.clone(), value3.clone());
226            assert_eq!(
227                get_instance_packed_len(&data).unwrap(),
228                4 + get_instance_packed_len(&key1).unwrap()
229                    + get_instance_packed_len(&value1).unwrap()
230                    + get_instance_packed_len(&key2).unwrap()
231                    + get_instance_packed_len(&value2).unwrap()
232                    + get_instance_packed_len(&key3).unwrap()
233                    + get_instance_packed_len(&value3).unwrap()
234            );
235        }
236    };
237}
238#[cfg(test)]
239pub(crate) use impl_tests;