1use oci_spec::distribution::Reference;
2use serde::Deserialize;
3use std::collections::BTreeMap;
4use std::fmt;
5use std::sync::Arc;
6use std::time::{SystemTime, UNIX_EPOCH};
7use tokio::sync::RwLock;
8use tracing::{debug, warn};
9
10#[derive(Deserialize, Clone)]
12#[serde(untagged)]
13#[serde(rename_all = "snake_case")]
14pub(crate) enum RegistryToken {
15 Token { token: String },
16 AccessToken { access_token: String },
17}
18
19impl fmt::Debug for RegistryToken {
20 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21 let redacted = String::from("<redacted>");
22 match self {
23 RegistryToken::Token { .. } => {
24 f.debug_struct("Token").field("token", &redacted).finish()
25 }
26 RegistryToken::AccessToken { .. } => f
27 .debug_struct("AccessToken")
28 .field("access_token", &redacted)
29 .finish(),
30 }
31 }
32}
33
34#[derive(Debug, Clone)]
35pub(crate) enum RegistryTokenType {
36 Bearer(RegistryToken),
37 Basic(String, String),
38}
39
40impl RegistryToken {
41 pub fn bearer_token(&self) -> String {
42 format!("Bearer {}", self.token())
43 }
44
45 pub fn token(&self) -> &str {
46 match self {
47 RegistryToken::Token { token } => token,
48 RegistryToken::AccessToken { access_token } => access_token,
49 }
50 }
51}
52
53#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
55pub enum RegistryOperation {
56 Push,
58 Pull,
60}
61
62#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
63struct TokenCacheKey {
64 registry: String,
65 repository: String,
66 operation: RegistryOperation,
67}
68
69struct TokenCacheValue {
70 token: RegistryTokenType,
71 expiration: u64,
72}
73
74#[derive(Clone)]
75pub(crate) struct TokenCache {
76 tokens: Arc<RwLock<BTreeMap<TokenCacheKey, TokenCacheValue>>>,
78 pub default_expiration_secs: usize,
80}
81
82impl TokenCache {
83 pub(crate) fn new(default_expiration_secs: usize) -> Self {
84 TokenCache {
85 tokens: Arc::new(RwLock::new(BTreeMap::new())),
86 default_expiration_secs,
87 }
88 }
89
90 pub(crate) async fn insert(
91 &self,
92 reference: &Reference,
93 op: RegistryOperation,
94 token: RegistryTokenType,
95 ) {
96 let expiration = match token {
97 RegistryTokenType::Basic(_, _) => u64::MAX,
98 RegistryTokenType::Bearer(ref t) => {
99 let token_str = t.token();
100 match jwt::Token::<
101 jwt::header::Header,
102 jwt::claims::Claims,
103 jwt::token::Unverified,
104 >::parse_unverified(token_str)
105 {
106 Ok(token) => token.claims().registered.expiration.unwrap_or(u64::MAX),
107 Err(jwt::Error::NoClaimsComponent) => {
108 let now = SystemTime::now();
117 let epoch = now
118 .duration_since(UNIX_EPOCH)
119 .expect("Time went backwards")
120 .as_secs();
121 let expiration = epoch + self.default_expiration_secs as u64;
122 debug!(?token, "Cannot extract expiration from token's claims, assuming a {} seconds validity", self.default_expiration_secs);
123 expiration
124 },
125 Err(error) => {
126 warn!(?error, "Invalid bearer token");
127 return;
128 }
129 }
130 }
131 };
132 let registry = reference.resolve_registry().to_string();
133 let repository = reference.repository().to_string();
134 debug!(%registry, %repository, ?op, %expiration, "Inserting token");
135 self.tokens.write().await.insert(
136 TokenCacheKey {
137 registry,
138 repository,
139 operation: op,
140 },
141 TokenCacheValue { token, expiration },
142 );
143 }
144
145 pub(crate) async fn get(
146 &self,
147 reference: &Reference,
148 op: RegistryOperation,
149 ) -> Option<RegistryTokenType> {
150 let registry = reference.resolve_registry().to_string();
151 let repository = reference.repository().to_string();
152 let key = TokenCacheKey {
153 registry,
154 repository,
155 operation: op,
156 };
157 match self.tokens.read().await.get(&key) {
158 Some(TokenCacheValue {
159 ref token,
160 expiration,
161 }) => {
162 let now = SystemTime::now();
163 let epoch = now
164 .duration_since(UNIX_EPOCH)
165 .expect("Time went backwards")
166 .as_secs();
167 if epoch > *expiration {
168 debug!(%key.registry, %key.repository, ?key.operation, %expiration, miss=false, expired=true, "Fetching token");
169 None
170 } else {
171 debug!(%key.registry, %key.repository, ?key.operation, %expiration, miss=false, expired=false, "Fetching token");
172 Some(token.clone())
173 }
174 }
175 None => {
176 debug!(%key.registry, %key.repository, ?key.operation, miss = true, "Fetching token");
177 None
178 }
179 }
180 }
181}