oci_client/
token_cache.rs

1use oci_spec::distribution::Reference;
2use serde::Deserialize;
3use std::collections::BTreeMap;
4use std::fmt;
5use std::sync::Arc;
6use std::time::{SystemTime, UNIX_EPOCH};
7use tokio::sync::RwLock;
8use tracing::{debug, warn};
9
10/// A token granted during the OAuth2-like workflow for OCI registries.
11#[derive(Deserialize, Clone)]
12#[serde(untagged)]
13#[serde(rename_all = "snake_case")]
14pub(crate) enum RegistryToken {
15    Token { token: String },
16    AccessToken { access_token: String },
17}
18
19impl fmt::Debug for RegistryToken {
20    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21        let redacted = String::from("<redacted>");
22        match self {
23            RegistryToken::Token { .. } => {
24                f.debug_struct("Token").field("token", &redacted).finish()
25            }
26            RegistryToken::AccessToken { .. } => f
27                .debug_struct("AccessToken")
28                .field("access_token", &redacted)
29                .finish(),
30        }
31    }
32}
33
34#[derive(Debug, Clone)]
35pub(crate) enum RegistryTokenType {
36    Bearer(RegistryToken),
37    Basic(String, String),
38}
39
40impl RegistryToken {
41    pub fn bearer_token(&self) -> String {
42        format!("Bearer {}", self.token())
43    }
44
45    pub fn token(&self) -> &str {
46        match self {
47            RegistryToken::Token { token } => token,
48            RegistryToken::AccessToken { access_token } => access_token,
49        }
50    }
51}
52
53/// Desired operation for registry authentication
54#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
55pub enum RegistryOperation {
56    /// Authenticate for push operations
57    Push,
58    /// Authenticate for pull operations
59    Pull,
60}
61
62#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
63struct TokenCacheKey {
64    registry: String,
65    repository: String,
66    operation: RegistryOperation,
67}
68
69struct TokenCacheValue {
70    token: RegistryTokenType,
71    expiration: u64,
72}
73
74#[derive(Clone)]
75pub(crate) struct TokenCache {
76    // (registry, repository, scope) -> (token, expiration)
77    tokens: Arc<RwLock<BTreeMap<TokenCacheKey, TokenCacheValue>>>,
78    /// Default token expiration in seconds, to use when claim doesn't specify a value
79    pub default_expiration_secs: usize,
80}
81
82impl TokenCache {
83    pub(crate) fn new(default_expiration_secs: usize) -> Self {
84        TokenCache {
85            tokens: Arc::new(RwLock::new(BTreeMap::new())),
86            default_expiration_secs,
87        }
88    }
89
90    pub(crate) async fn insert(
91        &self,
92        reference: &Reference,
93        op: RegistryOperation,
94        token: RegistryTokenType,
95    ) {
96        let expiration = match token {
97            RegistryTokenType::Basic(_, _) => u64::MAX,
98            RegistryTokenType::Bearer(ref t) => {
99                let token_str = t.token();
100                match jwt::Token::<
101                        jwt::header::Header,
102                        jwt::claims::Claims,
103                        jwt::token::Unverified,
104                    >::parse_unverified(token_str)
105                    {
106                        Ok(token) => token.claims().registered.expiration.unwrap_or(u64::MAX),
107                        Err(jwt::Error::NoClaimsComponent) => {
108                            // the token doesn't have a claim that states a
109                            // value for the expiration. We assume it has a 60
110                            // seconds validity as indicated here:
111                            // https://docs.docker.com/registry/spec/auth/token/#requesting-a-token
112                            // > (Optional) The duration in seconds since the token was issued
113                            // > that it will remain valid. When omitted, this defaults to 60 seconds.
114                            // > For compatibility with older clients, a token should never be returned
115                            // > with less than 60 seconds to live.
116                            let now = SystemTime::now();
117                            let epoch = now
118                                .duration_since(UNIX_EPOCH)
119                                .expect("Time went backwards")
120                                .as_secs();
121                            let expiration = epoch + self.default_expiration_secs as u64;
122                            debug!(?token, "Cannot extract expiration from token's claims, assuming a {} seconds validity", self.default_expiration_secs);
123                            expiration
124                        },
125                        Err(error) => {
126                            warn!(?error, "Invalid bearer token");
127                            return;
128                        }
129                    }
130            }
131        };
132        let registry = reference.resolve_registry().to_string();
133        let repository = reference.repository().to_string();
134        debug!(%registry, %repository, ?op, %expiration, "Inserting token");
135        self.tokens.write().await.insert(
136            TokenCacheKey {
137                registry,
138                repository,
139                operation: op,
140            },
141            TokenCacheValue { token, expiration },
142        );
143    }
144
145    pub(crate) async fn get(
146        &self,
147        reference: &Reference,
148        op: RegistryOperation,
149    ) -> Option<RegistryTokenType> {
150        let registry = reference.resolve_registry().to_string();
151        let repository = reference.repository().to_string();
152        let key = TokenCacheKey {
153            registry,
154            repository,
155            operation: op,
156        };
157        match self.tokens.read().await.get(&key) {
158            Some(TokenCacheValue {
159                ref token,
160                expiration,
161            }) => {
162                let now = SystemTime::now();
163                let epoch = now
164                    .duration_since(UNIX_EPOCH)
165                    .expect("Time went backwards")
166                    .as_secs();
167                if epoch > *expiration {
168                    debug!(%key.registry, %key.repository, ?key.operation, %expiration, miss=false, expired=true, "Fetching token");
169                    None
170                } else {
171                    debug!(%key.registry, %key.repository, ?key.operation, %expiration, miss=false, expired=false, "Fetching token");
172                    Some(token.clone())
173                }
174            }
175            None => {
176                debug!(%key.registry, %key.repository, ?key.operation, miss = true, "Fetching token");
177                None
178            }
179        }
180    }
181}