spl_token_2022_interface/extension/transfer_fee/
mod.rs1#[cfg(feature = "serde")]
2use serde::{Deserialize, Serialize};
3use {
4 crate::{
5 error::TokenError,
6 extension::{Extension, ExtensionType},
7 },
8 bytemuck::{Pod, Zeroable},
9 solana_program_error::ProgramResult,
10 spl_pod::{
11 optional_keys::OptionalNonZeroPubkey,
12 primitives::{PodU16, PodU64},
13 },
14 std::{
15 cmp,
16 convert::{TryFrom, TryInto},
17 },
18};
19
20pub mod instruction;
22
23pub const MAX_FEE_BASIS_POINTS: u16 = 10_000;
25const ONE_IN_BASIS_POINTS: u128 = MAX_FEE_BASIS_POINTS as u128;
26
27#[repr(C)]
29#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
30#[cfg_attr(feature = "serde", serde(rename_all = "camelCase"))]
31#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
32pub struct TransferFee {
33 pub epoch: PodU64, pub maximum_fee: PodU64,
37 pub transfer_fee_basis_points: PodU16,
40}
41impl TransferFee {
42 fn ceil_div(numerator: u128, denominator: u128) -> Option<u128> {
49 numerator
50 .checked_add(denominator)?
51 .checked_sub(1)?
52 .checked_div(denominator)
53 }
54
55 pub fn calculate_fee(&self, pre_fee_amount: u64) -> Option<u64> {
57 let transfer_fee_basis_points = u16::from(self.transfer_fee_basis_points) as u128;
58 if transfer_fee_basis_points == 0 || pre_fee_amount == 0 {
59 Some(0)
60 } else {
61 let numerator = (pre_fee_amount as u128).checked_mul(transfer_fee_basis_points)?;
62 let raw_fee = Self::ceil_div(numerator, ONE_IN_BASIS_POINTS)?
63 .try_into() .ok()?;
65
66 Some(cmp::min(raw_fee, u64::from(self.maximum_fee)))
67 }
68 }
69
70 pub fn calculate_post_fee_amount(&self, pre_fee_amount: u64) -> Option<u64> {
72 pre_fee_amount.checked_sub(self.calculate_fee(pre_fee_amount)?)
73 }
74
75 pub fn calculate_pre_fee_amount(&self, post_fee_amount: u64) -> Option<u64> {
88 let maximum_fee = u64::from(self.maximum_fee);
89 let transfer_fee_basis_points = u16::from(self.transfer_fee_basis_points) as u128;
90 match (transfer_fee_basis_points, post_fee_amount) {
91 (0, _) => Some(post_fee_amount),
93 (_, 0) => Some(0),
95 (ONE_IN_BASIS_POINTS, _) => maximum_fee.checked_add(post_fee_amount),
97 _ => {
98 let numerator = (post_fee_amount as u128).checked_mul(ONE_IN_BASIS_POINTS)?;
99 let denominator = ONE_IN_BASIS_POINTS.checked_sub(transfer_fee_basis_points)?;
100 let raw_pre_fee_amount = Self::ceil_div(numerator, denominator)?;
101
102 if raw_pre_fee_amount.checked_sub(post_fee_amount as u128)? >= maximum_fee as u128 {
103 post_fee_amount.checked_add(maximum_fee)
104 } else {
105 u64::try_from(raw_pre_fee_amount).ok()
107 }
108 }
109 }
110 }
111
112 pub fn calculate_inverse_fee(&self, post_fee_amount: u64) -> Option<u64> {
123 let pre_fee_amount = self.calculate_pre_fee_amount(post_fee_amount)?;
124 self.calculate_fee(pre_fee_amount)
125 }
126}
127
128#[repr(C)]
130#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
131#[cfg_attr(feature = "serde", serde(rename_all = "camelCase"))]
132#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
133pub struct TransferFeeConfig {
134 pub transfer_fee_config_authority: OptionalNonZeroPubkey,
136 pub withdraw_withheld_authority: OptionalNonZeroPubkey,
138 pub withheld_amount: PodU64,
141 pub older_transfer_fee: TransferFee,
143 pub newer_transfer_fee: TransferFee,
145}
146impl TransferFeeConfig {
147 pub fn get_epoch_fee(&self, epoch: u64) -> &TransferFee {
149 if epoch >= self.newer_transfer_fee.epoch.into() {
150 &self.newer_transfer_fee
151 } else {
152 &self.older_transfer_fee
153 }
154 }
155 pub fn calculate_epoch_fee(&self, epoch: u64, pre_fee_amount: u64) -> Option<u64> {
157 self.get_epoch_fee(epoch).calculate_fee(pre_fee_amount)
158 }
159 pub fn calculate_inverse_epoch_fee(&self, epoch: u64, post_fee_amount: u64) -> Option<u64> {
161 self.get_epoch_fee(epoch)
162 .calculate_inverse_fee(post_fee_amount)
163 }
164}
165impl Extension for TransferFeeConfig {
166 const TYPE: ExtensionType = ExtensionType::TransferFeeConfig;
167}
168
169#[repr(C)]
171#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
172#[cfg_attr(feature = "serde", serde(rename_all = "camelCase"))]
173#[derive(Clone, Copy, Debug, Default, PartialEq, Pod, Zeroable)]
174pub struct TransferFeeAmount {
175 pub withheld_amount: PodU64,
177}
178impl TransferFeeAmount {
179 pub fn closable(&self) -> ProgramResult {
181 if self.withheld_amount == 0.into() {
182 Ok(())
183 } else {
184 Err(TokenError::AccountHasWithheldTransferFees.into())
185 }
186 }
187}
188impl Extension for TransferFeeAmount {
189 const TYPE: ExtensionType = ExtensionType::TransferFeeAmount;
190}
191
192#[cfg(test)]
193pub(crate) mod test {
194 use {super::*, proptest::prelude::*, solana_pubkey::Pubkey, std::convert::TryFrom};
195
196 const NEWER_EPOCH: u64 = 100;
197 const OLDER_EPOCH: u64 = 1;
198
199 pub(crate) fn test_transfer_fee_config() -> TransferFeeConfig {
200 TransferFeeConfig {
201 transfer_fee_config_authority: OptionalNonZeroPubkey::try_from(Some(
202 Pubkey::new_from_array([10; 32]),
203 ))
204 .unwrap(),
205 withdraw_withheld_authority: OptionalNonZeroPubkey::try_from(Some(
206 Pubkey::new_from_array([11; 32]),
207 ))
208 .unwrap(),
209 withheld_amount: PodU64::from(u64::MAX),
210 older_transfer_fee: TransferFee {
211 epoch: PodU64::from(OLDER_EPOCH),
212 maximum_fee: PodU64::from(10),
213 transfer_fee_basis_points: PodU16::from(100),
214 },
215 newer_transfer_fee: TransferFee {
216 epoch: PodU64::from(NEWER_EPOCH),
217 maximum_fee: PodU64::from(5_000),
218 transfer_fee_basis_points: PodU16::from(1),
219 },
220 }
221 }
222
223 #[test]
224 fn epoch_fee() {
225 let transfer_fee_config = test_transfer_fee_config();
226 assert_eq!(
228 transfer_fee_config.get_epoch_fee(NEWER_EPOCH).epoch,
229 NEWER_EPOCH.into()
230 );
231 assert_eq!(
232 transfer_fee_config.get_epoch_fee(NEWER_EPOCH + 1).epoch,
233 NEWER_EPOCH.into()
234 );
235 assert_eq!(
236 transfer_fee_config.get_epoch_fee(u64::MAX).epoch,
237 NEWER_EPOCH.into()
238 );
239 assert_eq!(
241 transfer_fee_config.get_epoch_fee(NEWER_EPOCH - 1).epoch,
242 OLDER_EPOCH.into()
243 );
244 assert_eq!(
245 transfer_fee_config.get_epoch_fee(OLDER_EPOCH).epoch,
246 OLDER_EPOCH.into()
247 );
248 assert_eq!(
249 transfer_fee_config.get_epoch_fee(OLDER_EPOCH + 1).epoch,
250 OLDER_EPOCH.into()
251 );
252 }
253
254 #[test]
255 fn calculate_fee_max() {
256 let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
257 let transfer_fee = TransferFee {
258 epoch: PodU64::from(0),
259 maximum_fee: PodU64::from(5_000),
260 transfer_fee_basis_points: PodU16::from(1),
261 };
262 let maximum_fee = u64::from(transfer_fee.maximum_fee);
263 assert_eq!(maximum_fee, transfer_fee.calculate_fee(u64::MAX).unwrap());
265 assert_eq!(
267 maximum_fee,
268 transfer_fee.calculate_fee(maximum_fee * one).unwrap()
269 );
270 assert_eq!(
272 maximum_fee,
273 transfer_fee.calculate_fee(maximum_fee * one + 1).unwrap()
274 );
275 assert_eq!(
277 maximum_fee,
278 transfer_fee.calculate_fee(maximum_fee * one - 1).unwrap()
279 );
280 }
281
282 #[test]
283 fn calculate_fee_min() {
284 let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
285 let transfer_fee = TransferFee {
286 epoch: PodU64::from(0),
287 maximum_fee: PodU64::from(5_000),
288 transfer_fee_basis_points: PodU16::from(1),
289 };
290 let minimum_fee = 1;
291 assert_eq!(minimum_fee, transfer_fee.calculate_fee(1).unwrap());
293 assert_eq!(minimum_fee, transfer_fee.calculate_fee(2).unwrap());
295 assert_eq!(minimum_fee, transfer_fee.calculate_fee(one).unwrap());
297 assert_eq!(
299 minimum_fee + 1,
300 transfer_fee.calculate_fee(one + 1).unwrap()
301 );
302 assert_eq!(0, transfer_fee.calculate_fee(0).unwrap());
304 }
305
306 #[test]
307 fn calculate_fee_zero() {
308 let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
309 let transfer_fee = TransferFee {
310 epoch: PodU64::from(0),
311 maximum_fee: PodU64::from(u64::MAX),
312 transfer_fee_basis_points: PodU16::from(0),
313 };
314 assert_eq!(0, transfer_fee.calculate_fee(0).unwrap());
316 assert_eq!(0, transfer_fee.calculate_fee(u64::MAX).unwrap());
317 assert_eq!(0, transfer_fee.calculate_fee(1).unwrap());
318 assert_eq!(0, transfer_fee.calculate_fee(one).unwrap());
319
320 let transfer_fee = TransferFee {
321 epoch: PodU64::from(0),
322 maximum_fee: PodU64::from(0),
323 transfer_fee_basis_points: PodU16::from(MAX_FEE_BASIS_POINTS),
324 };
325 assert_eq!(0, transfer_fee.calculate_fee(0).unwrap());
327 assert_eq!(0, transfer_fee.calculate_fee(u64::MAX).unwrap());
328 assert_eq!(0, transfer_fee.calculate_fee(1).unwrap());
329 assert_eq!(0, transfer_fee.calculate_fee(one).unwrap());
330 }
331
332 #[test]
333 fn calculate_fee_exact_out_max() {
334 let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
335 let transfer_fee = TransferFee {
336 epoch: PodU64::from(0),
337 maximum_fee: PodU64::from(5_000),
338 transfer_fee_basis_points: PodU16::from(1),
339 };
340 let maximum_fee = u64::from(transfer_fee.maximum_fee);
341 assert_eq!(
343 maximum_fee,
344 transfer_fee
345 .calculate_inverse_fee(u64::MAX - maximum_fee)
346 .unwrap()
347 );
348 assert_eq!(
350 maximum_fee,
351 transfer_fee
352 .calculate_inverse_fee(maximum_fee * one - maximum_fee)
353 .unwrap()
354 );
355 assert_eq!(
357 maximum_fee,
358 transfer_fee
359 .calculate_inverse_fee(maximum_fee * one - maximum_fee + 1)
360 .unwrap()
361 );
362 assert_eq!(
364 maximum_fee,
365 transfer_fee
366 .calculate_inverse_fee(maximum_fee * one - maximum_fee - 1)
367 .unwrap()
368 );
369 }
370
371 #[test]
372 fn calculate_pre_fee_amount_edge_cases() {
373 let maximum_fee = 5_000;
374 let transfer_fee = TransferFee {
375 epoch: PodU64::from(0),
376 maximum_fee: PodU64::from(maximum_fee),
377 transfer_fee_basis_points: PodU16::from(u16::try_from(ONE_IN_BASIS_POINTS).unwrap()),
378 };
379
380 assert_eq!(0, transfer_fee.calculate_pre_fee_amount(0).unwrap());
382
383 assert_eq!(
385 1 + maximum_fee,
386 transfer_fee.calculate_pre_fee_amount(1).unwrap()
387 );
388
389 let transfer_fee = TransferFee {
391 epoch: PodU64::from(0),
392 maximum_fee: PodU64::from(maximum_fee),
393 transfer_fee_basis_points: PodU16::from(0),
394 };
395 assert_eq!(1, transfer_fee.calculate_pre_fee_amount(1).unwrap());
396 }
397
398 #[test]
399 fn calculate_fee_exact_out_min() {
400 let one = u64::try_from(ONE_IN_BASIS_POINTS).unwrap();
401 let transfer_fee = TransferFee {
402 epoch: PodU64::from(0),
403 maximum_fee: PodU64::from(5_000),
404 transfer_fee_basis_points: PodU16::from(1),
405 };
406 let minimum_fee = 1;
407 assert_eq!(minimum_fee, transfer_fee.calculate_inverse_fee(1).unwrap());
409 assert_eq!(minimum_fee, transfer_fee.calculate_inverse_fee(2).unwrap());
411 assert_eq!(
413 minimum_fee,
414 transfer_fee.calculate_inverse_fee(one - 1).unwrap()
415 );
416 assert_eq!(
418 minimum_fee + 1,
419 transfer_fee.calculate_inverse_fee(one).unwrap()
420 );
421 assert_eq!(0, transfer_fee.calculate_inverse_fee(0).unwrap());
423 }
424
425 proptest! {
426 #[test]
427 fn round_trip_fee_calculation(
428 transfer_fee_basis_points in 0u16..MAX_FEE_BASIS_POINTS,
429 maximum_fee in u64::MIN..=u64::MAX,
430 amount_in in 0..=u64::MAX
431 ) {
432 let transfer_fee = TransferFee {
433 epoch: PodU64::from(0),
434 maximum_fee: PodU64::from(maximum_fee),
435 transfer_fee_basis_points: PodU16::from(transfer_fee_basis_points),
436 };
437 let fee = transfer_fee.calculate_fee(amount_in).unwrap();
438 let amount_out = amount_in.checked_sub(fee).unwrap();
439 let fee_exact_out = transfer_fee.calculate_inverse_fee(amount_out).unwrap();
440 let diff = if fee > fee_exact_out {
441 fee - fee_exact_out
442 } else {
443 fee_exact_out - fee
444 };
445 let one = MAX_FEE_BASIS_POINTS as u64;
449 let precision = amount_in / one / one / one;
450 assert!(diff < precision, "diff is {} for precision {}", diff, precision);
451 }
452 }
453
454 proptest! {
455 #[test]
456 fn inverse_fee_relationship(
457 transfer_fee_basis_points in 0u16..MAX_FEE_BASIS_POINTS,
458 maximum_fee in u64::MIN..=u64::MAX,
459 amount_in in 0..=u64::MAX
460 ) {
461 let transfer_fee = TransferFee {
462 epoch: PodU64::from(0),
463 maximum_fee: PodU64::from(maximum_fee),
464 transfer_fee_basis_points: PodU16::from(transfer_fee_basis_points),
465 };
466 let fee = transfer_fee.calculate_fee(amount_in).unwrap();
467 let amount_out = amount_in.checked_sub(fee).unwrap();
468 let fee_exact_out = transfer_fee.calculate_inverse_fee(amount_out).unwrap();
469 assert!(fee >= fee_exact_out);
470 }
471 }
472}