spl_pod/list/
list_view_mut.rs

1//! `ListViewMut`, a mutable, compact, zero-copy array wrapper.
2
3use {
4    crate::{
5        error::PodSliceError, list::list_trait::List, pod_length::PodLength, primitives::PodU32,
6    },
7    bytemuck::Pod,
8    solana_program_error::ProgramError,
9    std::ops::{Deref, DerefMut},
10};
11
12#[derive(Debug)]
13pub struct ListViewMut<'data, T: Pod, L: PodLength = PodU32> {
14    pub(crate) length: &'data mut L,
15    pub(crate) data: &'data mut [T],
16    pub(crate) capacity: usize,
17}
18
19impl<T: Pod, L: PodLength> ListViewMut<'_, T, L> {
20    /// Add another item to the slice
21    pub fn push(&mut self, item: T) -> Result<(), ProgramError> {
22        let length = (*self.length).into();
23        if length >= self.capacity {
24            Err(PodSliceError::BufferTooSmall.into())
25        } else {
26            self.data[length] = item;
27            *self.length = L::try_from(length.saturating_add(1))?;
28            Ok(())
29        }
30    }
31
32    /// Remove and return the element at `index`, shifting all later
33    /// elements one position to the left.
34    pub fn remove(&mut self, index: usize) -> Result<T, ProgramError> {
35        let len = (*self.length).into();
36        if index >= len {
37            return Err(ProgramError::InvalidArgument);
38        }
39
40        let removed_item = self.data[index];
41
42        // Move the tail left by one
43        let tail_start = index
44            .checked_add(1)
45            .ok_or(ProgramError::ArithmeticOverflow)?;
46        self.data.copy_within(tail_start..len, index);
47
48        // Store the new length (len - 1)
49        let new_len = len.checked_sub(1).unwrap();
50        *self.length = L::try_from(new_len)?;
51
52        Ok(removed_item)
53    }
54}
55
56impl<T: Pod, L: PodLength> Deref for ListViewMut<'_, T, L> {
57    type Target = [T];
58
59    fn deref(&self) -> &Self::Target {
60        let len = (*self.length).into();
61        &self.data[..len]
62    }
63}
64
65impl<T: Pod, L: PodLength> DerefMut for ListViewMut<'_, T, L> {
66    fn deref_mut(&mut self) -> &mut Self::Target {
67        let len = (*self.length).into();
68        &mut self.data[..len]
69    }
70}
71
72impl<T: Pod, L: PodLength> List for ListViewMut<'_, T, L> {
73    type Item = T;
74    type Length = L;
75
76    fn capacity(&self) -> usize {
77        self.capacity
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use {
84        super::*,
85        crate::{
86            list::{List, ListView},
87            primitives::{PodU16, PodU32, PodU64},
88        },
89        bytemuck_derive::{Pod, Zeroable},
90    };
91
92    #[repr(C)]
93    #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Pod, Zeroable)]
94    struct TestStruct {
95        a: u64,
96        b: u32,
97        _padding: [u8; 4],
98    }
99
100    impl TestStruct {
101        fn new(a: u64, b: u32) -> Self {
102            Self {
103                a,
104                b,
105                _padding: [0; 4],
106            }
107        }
108    }
109
110    fn init_view_mut<T: Pod, L: PodLength>(
111        buffer: &mut Vec<u8>,
112        capacity: usize,
113    ) -> ListViewMut<T, L> {
114        let size = ListView::<T, L>::size_of(capacity).unwrap();
115        buffer.resize(size, 0);
116        ListView::<T, L>::init(buffer).unwrap()
117    }
118
119    #[test]
120    fn test_push() {
121        let mut buffer = vec![];
122        let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 3);
123
124        assert_eq!(view.len(), 0);
125        assert!(view.is_empty());
126        assert_eq!(view.capacity(), 3);
127
128        // Push first item
129        let item1 = TestStruct::new(1, 10);
130        view.push(item1).unwrap();
131        assert_eq!(view.len(), 1);
132        assert!(!view.is_empty());
133        assert_eq!(*view, [item1]);
134
135        // Push second item
136        let item2 = TestStruct::new(2, 20);
137        view.push(item2).unwrap();
138        assert_eq!(view.len(), 2);
139        assert_eq!(*view, [item1, item2]);
140
141        // Push third item to fill capacity
142        let item3 = TestStruct::new(3, 30);
143        view.push(item3).unwrap();
144        assert_eq!(view.len(), 3);
145        assert_eq!(*view, [item1, item2, item3]);
146
147        // Try to push beyond capacity
148        let item4 = TestStruct::new(4, 40);
149        let err = view.push(item4).unwrap_err();
150        assert_eq!(err, PodSliceError::BufferTooSmall.into());
151
152        // Ensure state is unchanged
153        assert_eq!(view.len(), 3);
154        assert_eq!(*view, [item1, item2, item3]);
155    }
156
157    #[test]
158    fn test_remove() {
159        let mut buffer = vec![];
160        let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 4);
161
162        let item1 = TestStruct::new(1, 10);
163        let item2 = TestStruct::new(2, 20);
164        let item3 = TestStruct::new(3, 30);
165        let item4 = TestStruct::new(4, 40);
166        view.push(item1).unwrap();
167        view.push(item2).unwrap();
168        view.push(item3).unwrap();
169        view.push(item4).unwrap();
170
171        assert_eq!(view.len(), 4);
172        assert_eq!(*view, [item1, item2, item3, item4]);
173
174        // Remove from the middle
175        let removed = view.remove(1).unwrap();
176        assert_eq!(removed, item2);
177        assert_eq!(view.len(), 3);
178        assert_eq!(*view, [item1, item3, item4]);
179
180        // Remove from the end
181        let removed = view.remove(2).unwrap();
182        assert_eq!(removed, item4);
183        assert_eq!(view.len(), 2);
184        assert_eq!(*view, [item1, item3]);
185
186        // Remove from the start
187        let removed = view.remove(0).unwrap();
188        assert_eq!(removed, item1);
189        assert_eq!(view.len(), 1);
190        assert_eq!(*view, [item3]);
191
192        // Remove the last element
193        let removed = view.remove(0).unwrap();
194        assert_eq!(removed, item3);
195        assert_eq!(view.len(), 0);
196        assert!(view.is_empty());
197        assert_eq!(*view, []);
198    }
199
200    #[test]
201    fn test_remove_out_of_bounds() {
202        let mut buffer = vec![];
203        let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 2);
204
205        view.push(TestStruct::new(1, 10)).unwrap();
206        view.push(TestStruct::new(2, 20)).unwrap();
207
208        // Try to remove at index == len
209        let err = view.remove(2).unwrap_err();
210        assert_eq!(err, ProgramError::InvalidArgument);
211        assert_eq!(view.len(), 2); // Unchanged
212
213        // Try to remove at index > len
214        let err = view.remove(100).unwrap_err();
215        assert_eq!(err, ProgramError::InvalidArgument);
216        assert_eq!(view.len(), 2); // Unchanged
217
218        // Empty the view
219        view.remove(1).unwrap();
220        view.remove(0).unwrap();
221        assert!(view.is_empty());
222
223        // Try to remove from empty view
224        let err = view.remove(0).unwrap_err();
225        assert_eq!(err, ProgramError::InvalidArgument);
226    }
227
228    #[test]
229    fn test_iter_mut() {
230        let mut buffer = vec![];
231        let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 4);
232
233        let item1 = TestStruct::new(1, 10);
234        let item2 = TestStruct::new(2, 20);
235        let item3 = TestStruct::new(3, 30);
236        view.push(item1).unwrap();
237        view.push(item2).unwrap();
238        view.push(item3).unwrap();
239
240        assert_eq!(view.len(), 3);
241        assert_eq!(view.capacity(), 4);
242
243        // Modify items using iter_mut
244        for item in view.iter_mut() {
245            item.a *= 10;
246        }
247
248        let expected_item1 = TestStruct::new(10, 10);
249        let expected_item2 = TestStruct::new(20, 20);
250        let expected_item3 = TestStruct::new(30, 30);
251
252        // Check that the underlying data is modified
253        assert_eq!(view.len(), 3);
254        assert_eq!(*view, [expected_item1, expected_item2, expected_item3]);
255
256        // Check that iter_mut only iterates over `len` items, not `capacity`
257        assert_eq!(view.iter_mut().count(), 3);
258    }
259
260    #[test]
261    fn test_iter_mut_empty() {
262        let mut buffer = vec![];
263        let mut view = init_view_mut::<TestStruct, PodU64>(&mut buffer, 5);
264
265        let mut count = 0;
266        for _ in view.iter_mut() {
267            count += 1;
268        }
269        assert_eq!(count, 0);
270        assert_eq!(view.iter_mut().next(), None);
271    }
272
273    #[test]
274    fn test_zero_capacity() {
275        let mut buffer = vec![];
276        let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 0);
277
278        assert_eq!(view.len(), 0);
279        assert_eq!(view.capacity(), 0);
280        assert!(view.is_empty());
281
282        let err = view.push(TestStruct::new(1, 1)).unwrap_err();
283        assert_eq!(err, PodSliceError::BufferTooSmall.into());
284
285        let err = view.remove(0).unwrap_err();
286        assert_eq!(err, ProgramError::InvalidArgument);
287    }
288
289    #[test]
290    fn test_default_length_type() {
291        let capacity = 2;
292        let mut buffer = vec![];
293        let size = ListView::<TestStruct, PodU64>::size_of(capacity).unwrap();
294        buffer.resize(size, 0);
295
296        // Initialize the view *without* specifying L. The compiler uses the default.
297        let view = ListView::<TestStruct>::init(&mut buffer).unwrap();
298
299        // Check that the capacity is correct for a PodU64 length.
300        assert_eq!(view.capacity(), capacity);
301        assert_eq!(view.len(), 0);
302
303        // Verify the size of the length field.
304        assert_eq!(size_of_val(view.length), size_of::<PodU32>());
305    }
306
307    #[test]
308    fn test_bytes_used_and_allocated_mut() {
309        // capacity 3, start empty
310        let mut buffer = vec![];
311        let mut view = init_view_mut::<TestStruct, PodU16>(&mut buffer, 3);
312
313        // Empty view
314        assert_eq!(
315            view.bytes_used().unwrap(),
316            ListView::<TestStruct, PodU32>::size_of(0).unwrap()
317        );
318        assert_eq!(
319            view.bytes_allocated().unwrap(),
320            ListView::<TestStruct, PodU32>::size_of(view.capacity()).unwrap()
321        );
322
323        // After pushing elements
324        view.push(TestStruct::new(1, 2)).unwrap();
325        view.push(TestStruct::new(3, 4)).unwrap();
326        view.push(TestStruct::new(5, 6)).unwrap();
327        assert_eq!(
328            view.bytes_used().unwrap(),
329            ListView::<TestStruct, PodU32>::size_of(3).unwrap()
330        );
331        assert_eq!(
332            view.bytes_allocated().unwrap(),
333            ListView::<TestStruct, PodU32>::size_of(view.capacity()).unwrap()
334        );
335    }
336    #[test]
337    fn test_get_and_get_mut() {
338        let mut buffer = vec![];
339        let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 3);
340
341        let item0 = TestStruct::new(1, 10);
342        let item1 = TestStruct::new(2, 20);
343        view.push(item0).unwrap();
344        view.push(item1).unwrap();
345
346        // Test get()
347        assert_eq!(view.first(), Some(&item0));
348        assert_eq!(view.get(1), Some(&item1));
349        assert_eq!(view.get(2), None); // out of bounds
350        assert_eq!(view.get(100), None); // way out of bounds
351
352        // Test get_mut() to modify an item
353        let modified_item0 = TestStruct::new(111, 110);
354        let item_ref = view.get_mut(0).unwrap();
355        *item_ref = modified_item0;
356
357        // Verify the modification
358        assert_eq!(view.first(), Some(&modified_item0));
359        assert_eq!(*view, [modified_item0, item1]);
360
361        // Test get_mut() out of bounds
362        assert_eq!(view.get_mut(2), None);
363    }
364
365    #[test]
366    fn test_mutable_access_via_indexing() {
367        let mut buffer = vec![];
368        let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 3);
369
370        let item0 = TestStruct::new(1, 10);
371        let item1 = TestStruct::new(2, 20);
372        view.push(item0).unwrap();
373        view.push(item1).unwrap();
374
375        assert_eq!(view.len(), 2);
376
377        // Modify via the mutable slice
378        view[0].a = 99;
379
380        let expected_item0 = TestStruct::new(99, 10);
381        assert_eq!(view.first(), Some(&expected_item0));
382        assert_eq!(*view, [expected_item0, item1]);
383    }
384
385    #[test]
386    fn test_sort_by() {
387        let mut buffer = vec![];
388        let mut view = init_view_mut::<TestStruct, PodU32>(&mut buffer, 5);
389
390        let item0 = TestStruct::new(5, 1);
391        let item1 = TestStruct::new(2, 2);
392        let item2 = TestStruct::new(5, 3);
393        let item3 = TestStruct::new(1, 4);
394        let item4 = TestStruct::new(2, 5);
395
396        view.push(item0).unwrap();
397        view.push(item1).unwrap();
398        view.push(item2).unwrap();
399        view.push(item3).unwrap();
400        view.push(item4).unwrap();
401
402        // Sort by `b` field in descending order.
403        view.sort_by(|a, b| b.b.cmp(&a.b));
404        let expected_order_by_b_desc = [
405            item4, // b: 5
406            item3, // b: 4
407            item2, // b: 3
408            item1, // b: 2
409            item0, // b: 1
410        ];
411        assert_eq!(*view, expected_order_by_b_desc);
412
413        // Now, sort by `a` in ascending order. A stable sort preserves the relative
414        // order of equal elements from the previous state of the list.
415        view.sort_by(|x, y| x.a.cmp(&y.a));
416
417        let expected_order_by_a_stable = [
418            item3, // a: 1
419            item4, // a: 2 (was before item1 in the previous state)
420            item1, // a: 2
421            item2, // a: 5 (was before item0 in the previous state)
422            item0, // a: 5
423        ];
424        assert_eq!(*view, expected_order_by_a_stable);
425    }
426}