spl_token_2022_interface/extension/transfer_fee/
mod.rs

1#[cfg(feature = "serde")]
2use serde::{Deserialize, Serialize};
3use {
4    crate::{
5        error::TokenError,
6        extension::{Extension, ExtensionType},
7    },
8    bytemuck::{Pod, Zeroable},
9    solana_program_error::ProgramResult,
10    spl_pod::{
11        optional_keys::OptionalNonZeroPubkey,
12        primitives::{PodU16, PodU64},
13    },
14    std::{
15        cmp,
16        convert::{TryFrom, TryInto},
17    },
18};
19
20/// Transfer fee extension instructions
21pub mod instruction;
22
23/// Maximum possible fee in basis points is `100%`, aka 10,000 basis points
24pub const MAX_FEE_BASIS_POINTS: u16 = 10_000;
25const ONE_IN_BASIS_POINTS: u128 = MAX_FEE_BASIS_POINTS as u128;
26
27/// Transfer fee information
28#[repr(C)]
29#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
30#[cfg_attr(feature = "serde", serde(rename_all = "camelCase"))]
31#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
32pub struct TransferFee {
33    /// First epoch where the transfer fee takes effect
34    pub epoch: PodU64, // Epoch,
35    /// Maximum fee assessed on transfers, expressed as an amount of tokens
36    pub maximum_fee: PodU64,
37    /// Amount of transfer collected as fees, expressed as basis points of the
38    /// transfer amount (increments of `0.01%`)
39    pub transfer_fee_basis_points: PodU16,
40}
41impl TransferFee {
42    /// Calculate ceiling-division
43    ///
44    /// Ceiling-division
45    ///     `ceil[ numerator / denominator ]`
46    /// can be represented as a floor-division
47    ///     `floor[ (numerator + denominator - 1) / denominator]`
48    fn ceil_div(numerator: u128, denominator: u128) -> Option<u128> {
49        numerator
50            .checked_add(denominator)?
51            .checked_sub(1)?
52            .checked_div(denominator)
53    }
54
55    /// Calculate the transfer fee
56    pub fn calculate_fee(&self, pre_fee_amount: u64) -> Option<u64> {
57        let transfer_fee_basis_points = u16::from(self.transfer_fee_basis_points) as u128;
58        if transfer_fee_basis_points == 0 || pre_fee_amount == 0 {
59            Some(0)
60        } else {
61            let numerator = (pre_fee_amount as u128).checked_mul(transfer_fee_basis_points)?;
62            let raw_fee = Self::ceil_div(numerator, ONE_IN_BASIS_POINTS)?
63                .try_into() // guaranteed to be okay
64                .ok()?;
65
66            Some(cmp::min(raw_fee, u64::from(self.maximum_fee)))
67        }
68    }
69
70    /// Calculate the gross transfer amount after deducting fees
71    pub fn calculate_post_fee_amount(&self, pre_fee_amount: u64) -> Option<u64> {
72        pre_fee_amount.checked_sub(self.calculate_fee(pre_fee_amount)?)
73    }
74
75    /// Calculate the transfer amount that will result in a specified net
76    /// transfer amount.
77    ///
78    /// The original transfer amount may not always be unique due to rounding.
79    /// In this case, the smaller amount will be chosen.
80    /// e.g. Both transfer amount 10, 11 with `10%` fee rate results in net
81    /// transfer amount of 9. In this case, 10 will be chosen.
82    /// e.g. Fee rate is `100%`. In this case, 0 will be chosen.
83    ///
84    /// The original transfer amount may not always exist on large net transfer
85    /// amounts due to overflow. In this case, `None` is returned.
86    /// e.g. The net fee amount is `u64::MAX` with a positive fee rate.
87    pub fn calculate_pre_fee_amount(&self, post_fee_amount: u64) -> Option<u64> {
88        let maximum_fee = u64::from(self.maximum_fee);
89        let transfer_fee_basis_points = u16::from(self.transfer_fee_basis_points) as u128;
90        match (transfer_fee_basis_points, post_fee_amount) {
91            // no fee, same amount
92            (0, _) => Some(post_fee_amount),
93            // 0 zero out, 0 in
94            (_, 0) => Some(0),
95            // 100%, cap at max fee
96            (ONE_IN_BASIS_POINTS, _) => maximum_fee.checked_add(post_fee_amount),
97            _ => {
98                let numerator = (post_fee_amount as u128).checked_mul(ONE_IN_BASIS_POINTS)?;
99                let denominator = ONE_IN_BASIS_POINTS.checked_sub(transfer_fee_basis_points)?;
100                let raw_pre_fee_amount = Self::ceil_div(numerator, denominator)?;
101
102                if raw_pre_fee_amount.checked_sub(post_fee_amount as u128)? >= maximum_fee as u128 {
103                    post_fee_amount.checked_add(maximum_fee)
104                } else {
105                    // should return `None` if `pre_fee_amount` overflows
106                    u64::try_from(raw_pre_fee_amount).ok()
107                }
108            }
109        }
110    }
111
112    /// Calculate the fee that would produce the given output
113    ///
114    /// Note: this function is not an exact inverse operation of
115    /// `calculate_fee`. Meaning, it is not the case that:
116    ///
117    /// `calculate_fee(x) == calculate_inverse_fee(x - calculate_fee(x))`
118    ///
119    /// Only the following relationship holds:
120    ///
121    /// `calculate_fee(x) >= calculate_inverse_fee(x - calculate_fee(x))`
122    pub fn calculate_inverse_fee(&self, post_fee_amount: u64) -> Option<u64> {
123        let pre_fee_amount = self.calculate_pre_fee_amount(post_fee_amount)?;
124        self.calculate_fee(pre_fee_amount)
125    }
126}
127
128/// Transfer fee extension data for mints.
129#[repr(C)]
130#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
131#[cfg_attr(feature = "serde", serde(rename_all = "camelCase"))]
132#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
133pub struct TransferFeeConfig {
134    /// Optional authority to set the fee
135    pub transfer_fee_config_authority: OptionalNonZeroPubkey,
136    /// Withdraw from mint instructions must be signed by this key
137    pub withdraw_withheld_authority: OptionalNonZeroPubkey,
138    /// Withheld transfer fee tokens that have been moved to the mint for
139    /// withdrawal
140    pub withheld_amount: PodU64,
141    /// Older transfer fee, used if `current epoch < new_transfer_fee.epoch`
142    pub older_transfer_fee: TransferFee,
143    /// Newer transfer fee, used if `current epoch >= new_transfer_fee.epoch`
144    pub newer_transfer_fee: TransferFee,
145}
146impl TransferFeeConfig {
147    /// Get the fee for the given epoch
148    pub fn get_epoch_fee(&self, epoch: u64) -> &TransferFee {
149        if epoch >= self.newer_transfer_fee.epoch.into() {
150            &self.newer_transfer_fee
151        } else {
152            &self.older_transfer_fee
153        }
154    }
155    /// Calculate the fee for the given epoch and input amount
156    pub fn calculate_epoch_fee(&self, epoch: u64, pre_fee_amount: u64) -> Option<u64> {
157        self.get_epoch_fee(epoch).calculate_fee(pre_fee_amount)
158    }
159    /// Calculate the fee for the given epoch and output amount
160    pub fn calculate_inverse_epoch_fee(&self, epoch: u64, post_fee_amount: u64) -> Option<u64> {
161        self.get_epoch_fee(epoch)
162            .calculate_inverse_fee(post_fee_amount)
163    }
164}
165impl Extension for TransferFeeConfig {
166    const TYPE: ExtensionType = ExtensionType::TransferFeeConfig;
167}
168
169/// Transfer fee extension data for accounts.
170#[repr(C)]
171#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
172#[cfg_attr(feature = "serde", serde(rename_all = "camelCase"))]
173#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
174pub struct TransferFeeAmount {
175    /// Amount withheld during transfers, to be harvested to the mint
176    pub withheld_amount: PodU64,
177}
178impl TransferFeeAmount {
179    /// Check if the extension is in a closable state
180    pub fn closable(&self) -> ProgramResult {
181        if self.withheld_amount == 0.into() {
182            Ok(())
183        } else {
184            Err(TokenError::AccountHasWithheldTransferFees.into())
185        }
186    }
187}
188impl Extension for TransferFeeAmount {
189    const TYPE: ExtensionType = ExtensionType::TransferFeeAmount;
190}
191
192#[cfg(test)]
193pub(crate) mod test {
194    use {super::*, proptest::prelude::*, solana_pubkey::Pubkey, std::convert::TryFrom};
195
196    const NEWER_EPOCH: u64 = 100;
197    const OLDER_EPOCH: u64 = 1;
198
199    pub(crate) fn test_transfer_fee_config() -> TransferFeeConfig {
200        TransferFeeConfig {
201            transfer_fee_config_authority: OptionalNonZeroPubkey::try_from(Some(
202                Pubkey::new_from_array([10; 32]),
203            ))
204            .unwrap(),
205            withdraw_withheld_authority: OptionalNonZeroPubkey::try_from(Some(
206                Pubkey::new_from_array([11; 32]),
207            ))
208            .unwrap(),
209            withheld_amount: PodU64::from(u64::MAX),
210            older_transfer_fee: TransferFee {
211                epoch: PodU64::from(OLDER_EPOCH),
212                maximum_fee: PodU64::from(10),
213                transfer_fee_basis_points: PodU16::from(100),
214            },
215            newer_transfer_fee: TransferFee {
216                epoch: PodU64::from(NEWER_EPOCH),
217                maximum_fee: PodU64::from(5_000),
218                transfer_fee_basis_points: PodU16::from(1),
219            },
220        }
221    }
222
223    #[test]
224    fn epoch_fee() {
225        let transfer_fee_config = test_transfer_fee_config();
226        // during epoch 100 and after, use newer transfer fee
227        assert_eq!(
228            transfer_fee_config.get_epoch_fee(NEWER_EPOCH).epoch,
229            NEWER_EPOCH.into()
230        );
231        assert_eq!(
232            transfer_fee_config.get_epoch_fee(NEWER_EPOCH + 1).epoch,
233            NEWER_EPOCH.into()
234        );
235        assert_eq!(
236            transfer_fee_config.get_epoch_fee(u64::MAX).epoch,
237            NEWER_EPOCH.into()
238        );
239        // before that, use older transfer fee
240        assert_eq!(
241            transfer_fee_config.get_epoch_fee(NEWER_EPOCH - 1).epoch,
242            OLDER_EPOCH.into()
243        );
244        assert_eq!(
245            transfer_fee_config.get_epoch_fee(OLDER_EPOCH).epoch,
246            OLDER_EPOCH.into()
247        );
248        assert_eq!(
249            transfer_fee_config.get_epoch_fee(OLDER_EPOCH + 1).epoch,
250            OLDER_EPOCH.into()
251        );
252    }
253
254    #[test]
255    fn calculate_fee_max() {
256        let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
257        let transfer_fee = TransferFee {
258            epoch: PodU64::from(0),
259            maximum_fee: PodU64::from(5_000),
260            transfer_fee_basis_points: PodU16::from(1),
261        };
262        let maximum_fee = u64::from(transfer_fee.maximum_fee);
263        // hit maximum fee
264        assert_eq!(maximum_fee, transfer_fee.calculate_fee(u64::MAX).unwrap());
265        // at exactly the max
266        assert_eq!(
267            maximum_fee,
268            transfer_fee.calculate_fee(maximum_fee * one).unwrap()
269        );
270        // one token above, normally rounds up, but we're at the max
271        assert_eq!(
272            maximum_fee,
273            transfer_fee.calculate_fee(maximum_fee * one + 1).unwrap()
274        );
275        // one token below, rounds up to the max
276        assert_eq!(
277            maximum_fee,
278            transfer_fee.calculate_fee(maximum_fee * one - 1).unwrap()
279        );
280    }
281
282    #[test]
283    fn calculate_fee_min() {
284        let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
285        let transfer_fee = TransferFee {
286            epoch: PodU64::from(0),
287            maximum_fee: PodU64::from(5_000),
288            transfer_fee_basis_points: PodU16::from(1),
289        };
290        let minimum_fee = 1;
291        // hit minimum fee even with 1 token
292        assert_eq!(minimum_fee, transfer_fee.calculate_fee(1).unwrap());
293        // still minimum at 2 tokens
294        assert_eq!(minimum_fee, transfer_fee.calculate_fee(2).unwrap());
295        // still minimum at 10_000 tokens
296        assert_eq!(minimum_fee, transfer_fee.calculate_fee(one).unwrap());
297        // 2 token fee at 10_001
298        assert_eq!(
299            minimum_fee + 1,
300            transfer_fee.calculate_fee(one + 1).unwrap()
301        );
302        // zero is always zero
303        assert_eq!(0, transfer_fee.calculate_fee(0).unwrap());
304    }
305
306    #[test]
307    fn calculate_fee_zero() {
308        let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
309        let transfer_fee = TransferFee {
310            epoch: PodU64::from(0),
311            maximum_fee: PodU64::from(u64::MAX),
312            transfer_fee_basis_points: PodU16::from(0),
313        };
314        // always zero fee
315        assert_eq!(0, transfer_fee.calculate_fee(0).unwrap());
316        assert_eq!(0, transfer_fee.calculate_fee(u64::MAX).unwrap());
317        assert_eq!(0, transfer_fee.calculate_fee(1).unwrap());
318        assert_eq!(0, transfer_fee.calculate_fee(one).unwrap());
319
320        let transfer_fee = TransferFee {
321            epoch: PodU64::from(0),
322            maximum_fee: PodU64::from(0),
323            transfer_fee_basis_points: PodU16::from(MAX_FEE_BASIS_POINTS),
324        };
325        // always zero fee
326        assert_eq!(0, transfer_fee.calculate_fee(0).unwrap());
327        assert_eq!(0, transfer_fee.calculate_fee(u64::MAX).unwrap());
328        assert_eq!(0, transfer_fee.calculate_fee(1).unwrap());
329        assert_eq!(0, transfer_fee.calculate_fee(one).unwrap());
330    }
331
332    #[test]
333    fn calculate_fee_exact_out_max() {
334        let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
335        let transfer_fee = TransferFee {
336            epoch: PodU64::from(0),
337            maximum_fee: PodU64::from(5_000),
338            transfer_fee_basis_points: PodU16::from(1),
339        };
340        let maximum_fee = u64::from(transfer_fee.maximum_fee);
341        // hit maximum fee
342        assert_eq!(
343            maximum_fee,
344            transfer_fee
345                .calculate_inverse_fee(u64::MAX - maximum_fee)
346                .unwrap()
347        );
348        // at exactly the max
349        assert_eq!(
350            maximum_fee,
351            transfer_fee
352                .calculate_inverse_fee(maximum_fee * one - maximum_fee)
353                .unwrap()
354        );
355        // one token above, normally rounds up, but we're at the max
356        assert_eq!(
357            maximum_fee,
358            transfer_fee
359                .calculate_inverse_fee(maximum_fee * one - maximum_fee + 1)
360                .unwrap()
361        );
362        // one token below, rounds up to the max
363        assert_eq!(
364            maximum_fee,
365            transfer_fee
366                .calculate_inverse_fee(maximum_fee * one - maximum_fee - 1)
367                .unwrap()
368        );
369    }
370
371    #[test]
372    fn calculate_pre_fee_amount_edge_cases() {
373        let maximum_fee = 5_000;
374        let transfer_fee = TransferFee {
375            epoch: PodU64::from(0),
376            maximum_fee: PodU64::from(maximum_fee),
377            transfer_fee_basis_points: PodU16::from(u16::try_from(ONE_IN_BASIS_POINTS).unwrap()),
378        };
379
380        // 0 zero out, 0 in
381        assert_eq!(0, transfer_fee.calculate_pre_fee_amount(0).unwrap());
382
383        // cap at max fee
384        assert_eq!(
385            1 + maximum_fee,
386            transfer_fee.calculate_pre_fee_amount(1).unwrap()
387        );
388
389        // no fee same amount
390        let transfer_fee = TransferFee {
391            epoch: PodU64::from(0),
392            maximum_fee: PodU64::from(maximum_fee),
393            transfer_fee_basis_points: PodU16::from(0),
394        };
395        assert_eq!(1, transfer_fee.calculate_pre_fee_amount(1).unwrap());
396    }
397
398    #[test]
399    fn calculate_fee_exact_out_min() {
400        let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
401        let transfer_fee = TransferFee {
402            epoch: PodU64::from(0),
403            maximum_fee: PodU64::from(5_000),
404            transfer_fee_basis_points: PodU16::from(1),
405        };
406        let minimum_fee = 1;
407        // hit minimum fee even with 1 token
408        assert_eq!(minimum_fee, transfer_fee.calculate_inverse_fee(1).unwrap());
409        // still minimum at 2 tokens
410        assert_eq!(minimum_fee, transfer_fee.calculate_inverse_fee(2).unwrap());
411        // still minimum at 9_999 tokens
412        assert_eq!(
413            minimum_fee,
414            transfer_fee.calculate_inverse_fee(one - 1).unwrap()
415        );
416        // 2 token fee at 10_000
417        assert_eq!(
418            minimum_fee + 1,
419            transfer_fee.calculate_inverse_fee(one).unwrap()
420        );
421        // zero is zero token
422        assert_eq!(0, transfer_fee.calculate_inverse_fee(0).unwrap());
423    }
424
425    proptest! {
426        #[test]
427        fn round_trip_fee_calculation(
428            transfer_fee_basis_points in 0u16..MAX_FEE_BASIS_POINTS,
429            maximum_fee in u64::MIN..=u64::MAX,
430            amount_in in 0..=u64::MAX
431        ) {
432            let transfer_fee = TransferFee {
433                epoch: PodU64::from(0),
434                maximum_fee: PodU64::from(maximum_fee),
435                transfer_fee_basis_points: PodU16::from(transfer_fee_basis_points),
436            };
437            let fee = transfer_fee.calculate_fee(amount_in).unwrap();
438            let amount_out = amount_in.checked_sub(fee).unwrap();
439            let fee_exact_out = transfer_fee.calculate_inverse_fee(amount_out).unwrap();
440            let diff = if fee > fee_exact_out {
441                fee - fee_exact_out
442            } else {
443                fee_exact_out - fee
444            };
445            // We lose precision with every division by 10000, so for huge amounts,
446            // the difference can be in the hundreds. This comes out to less than
447            // 1 / 10^15
448            let one = MAX_FEE_BASIS_POINTS as u64;
449            let precision = amount_in / one / one / one;
450            assert!(diff < precision, "diff is {} for precision {}", diff, precision);
451        }
452    }
453
454    proptest! {
455        #[test]
456        fn inverse_fee_relationship(
457            transfer_fee_basis_points in 0u16..MAX_FEE_BASIS_POINTS,
458            maximum_fee in u64::MIN..=u64::MAX,
459            amount_in in 0..=u64::MAX
460        ) {
461            let transfer_fee = TransferFee {
462                epoch: PodU64::from(0),
463                maximum_fee: PodU64::from(maximum_fee),
464                transfer_fee_basis_points: PodU16::from(transfer_fee_basis_points),
465            };
466            let fee = transfer_fee.calculate_fee(amount_in).unwrap();
467            let amount_out = amount_in.checked_sub(fee).unwrap();
468            let fee_exact_out = transfer_fee.calculate_inverse_fee(amount_out).unwrap();
469            assert!(fee >= fee_exact_out);
470        }
471    }
472}