jito_searcher_client/
client_interceptor.rs

1use crate::SearcherClientResult;
2use jito_protos::auth::{
3  auth_service_client::AuthServiceClient, GenerateAuthChallengeRequest, GenerateAuthTokensRequest, RefreshAccessTokenRequest, Role,
4  Token,
5};
6use prost_types::Timestamp;
7use solana_metrics::datapoint_info;
8use solana_sdk::signature::{Keypair, Signer};
9use std::{
10  sync::{Arc, RwLock},
11  time::{Duration, SystemTime},
12};
13use tokio::time::sleep;
14use tonic::{service::Interceptor, transport::Channel, Request, Status};
15
16const AUTHORIZATION_HEADER: &str = "authorization";
17
18/// Adds the token to each requests' authorization header.
19/// Manages refreshing the token in a separate thread.
20#[derive(Clone)]
21pub struct ClientInterceptor {
22  /// The token added to each request header.
23  bearer_token: Arc<RwLock<String>>,
24}
25
26impl ClientInterceptor {
27  pub async fn new(
28    mut auth_service_client: AuthServiceClient<Channel>,
29    keypair: &Arc<Keypair>,
30  ) -> SearcherClientResult<Self> {
31    const ROLE: Role = Role::Searcher;
32
33    let (access_token, refresh_token) = Self::auth(&mut auth_service_client, keypair, ROLE).await?;
34
35    let bearer_token = Arc::new(RwLock::new(access_token.value.clone()));
36
37    tokio::spawn(Self::token_refresh_loop(
38      auth_service_client,
39      bearer_token.clone(),
40      refresh_token,
41      access_token.expires_at_utc.unwrap(),
42      keypair.clone(),
43      ROLE,
44    ));
45
46    Ok(Self { bearer_token })
47  }
48
49  async fn auth(
50    auth_service_client: &mut AuthServiceClient<Channel>,
51    keypair: &Keypair,
52    role: Role,
53  ) -> SearcherClientResult<(Token, Token)> {
54    let challenge_resp = auth_service_client
55      .generate_auth_challenge(GenerateAuthChallengeRequest {
56        role: role as i32,
57        pubkey: keypair.pubkey().as_ref().to_vec(),
58      })
59      .await?
60      .into_inner();
61
62    let challenge = format!("{}-{}", keypair.pubkey(), challenge_resp.challenge);
63    let signed_challenge = keypair.sign_message(challenge.as_bytes()).as_ref().to_vec();
64
65    let tokens = auth_service_client
66      .generate_auth_tokens(GenerateAuthTokensRequest {
67        challenge,
68        client_pubkey: keypair.pubkey().as_ref().to_vec(),
69        signed_challenge,
70      })
71      .await?
72      .into_inner();
73
74    Ok((tokens.access_token.unwrap(), tokens.refresh_token.unwrap()))
75  }
76
77  async fn token_refresh_loop(
78    mut auth_service_client: AuthServiceClient<Channel>,
79    bearer_token: Arc<RwLock<String>>,
80    refresh_token: Token,
81    access_token_expiration: Timestamp,
82    keypair: Arc<Keypair>,
83    role: Role,
84  ) {
85    let mut refresh_token = refresh_token;
86    let mut access_token_expiration = access_token_expiration;
87
88    loop {
89      let access_token_ttl = SystemTime::try_from(access_token_expiration.clone())
90        .unwrap()
91        .duration_since(SystemTime::now())
92        .unwrap_or_else(|_| Duration::from_secs(0));
93      let refresh_token_ttl =
94        SystemTime::try_from(refresh_token.expires_at_utc.as_ref().unwrap().clone())
95          .unwrap()
96          .duration_since(SystemTime::now())
97          .unwrap_or_else(|_| Duration::from_secs(0));
98
99      let does_access_token_expire_soon = access_token_ttl < Duration::from_secs(5 * 60);
100      let does_refresh_token_expire_soon = refresh_token_ttl < Duration::from_secs(5 * 60);
101
102      match (
103        does_refresh_token_expire_soon,
104        does_access_token_expire_soon,
105      ) {
106        // re-run entire auth workflow is refresh token expiring soon
107        (true, _) => {
108          let is_error = {
109            if let Ok((new_access_token, new_refresh_token)) =
110              Self::auth(&mut auth_service_client, &keypair, role).await
111            {
112              *bearer_token.write().unwrap() = new_access_token.value.clone();
113              access_token_expiration = new_access_token.expires_at_utc.unwrap();
114              refresh_token = new_refresh_token;
115              false
116            } else {
117              true
118            }
119          };
120          datapoint_info!("searcher-full-auth", ("is_error", is_error, bool));
121        }
122        // re-up the access token if it expires soon
123        (_, true) => {
124          let is_error = {
125            if let Ok(refresh_resp) = auth_service_client
126              .refresh_access_token(RefreshAccessTokenRequest {
127                refresh_token: refresh_token.value.clone(),
128              })
129              .await
130            {
131              let access_token = refresh_resp.into_inner().access_token.unwrap();
132              *bearer_token.write().unwrap() = access_token.value.clone();
133              access_token_expiration = access_token.expires_at_utc.unwrap();
134              false
135            } else {
136              true
137            }
138          };
139
140          datapoint_info!("searcher-refresh-auth", ("is_error", is_error, bool));
141        }
142        _ => {
143          sleep(Duration::from_secs(60)).await;
144        }
145      }
146    }
147  }
148}
149
150impl Interceptor for ClientInterceptor {
151  fn call(&mut self, mut request: Request<()>) -> Result<Request<()>, Status> {
152    let l_token = self.bearer_token.read().unwrap();
153
154    if !l_token.is_empty() {
155      request.metadata_mut().insert(
156        AUTHORIZATION_HEADER,
157        format!("Bearer {l_token}").parse().unwrap(),
158      );
159    }
160
161    Ok(request)
162  }
163}