1use std::str::FromStr;
4
5use jsonwebtoken::jwk::Jwk;
6use jsonwebtoken::{Algorithm, DecodingKey, Validation};
7use serde::{de, Deserialize, Deserializer, Serialize};
8use thiserror::Error;
9use zeroize::Zeroize;
10
11use crate::bundle::jwt::JwtBundle;
12use crate::bundle::BundleRefSource;
13use crate::spiffe_id::{SpiffeId, SpiffeIdError, TrustDomain};
14use crate::svid::Svid;
15use std::error::Error;
16use std::fmt;
17use std::marker::PhantomData;
18use time::{Date, OffsetDateTime};
19
20const SUPPORTED_ALGORITHMS: &[Algorithm; 8] = &[
21 Algorithm::RS256,
22 Algorithm::RS384,
23 Algorithm::RS512,
24 Algorithm::ES256,
25 Algorithm::ES384,
26 Algorithm::PS256,
27 Algorithm::PS384,
28 Algorithm::PS512,
29];
30
31#[derive(Debug, Clone, PartialEq)]
35pub struct JwtSvid {
36 spiffe_id: SpiffeId,
37 expiry: Date,
38 claims: Claims,
40 kid: String,
41 alg: Algorithm,
42
43 token: Token,
44}
45
46impl Svid for JwtSvid {}
47
48#[derive(Debug, Error)]
51#[non_exhaustive]
52pub enum JwtSvidError {
53 #[error("invalid spiffe_id in token 'sub' claim")]
55 InvalidSubject(#[from] SpiffeIdError),
56
57 #[error("token header 'kid' not found")]
59 MissingKeyId,
60
61 #[error("token header 'typ' should be 'JWT' or 'JOSE'")]
63 InvalidTyp,
64
65 #[error("algorithm in 'alg' header is not supported")]
68 UnsupportedAlgorithm,
69
70 #[error("one of the required claims ({0}) is missing")]
72 RequiredClaimMissing(String),
73
74 #[error("cannot find JWT bundle for trust domain: {0}")]
76 BundleNotFound(TrustDomain),
77
78 #[error("cannot find JWT authority for key_id: {0}")]
80 AuthorityNotFound(String),
81
82 #[error("expected audience in {0:?} (audience={1:?})")]
84 InvalidAudience(Vec<String>, Vec<String>),
85
86 #[error("cannot decode token")]
88 InvalidToken(#[from] jsonwebtoken::errors::Error),
89
90 #[error("error parsing JWT-SVID")]
92 Other(#[from] Box<dyn Error + Send + Sync + 'static>),
93}
94
95#[derive(Debug, Clone, Eq, PartialEq, Zeroize)]
96#[zeroize(drop)]
97struct Token {
98 inner: String,
99}
100
101impl From<&str> for Token {
102 fn from(token: &str) -> Self {
103 Self {
104 inner: token.to_owned(),
105 }
106 }
107}
108
109impl AsRef<str> for Token {
110 fn as_ref(&self) -> &str {
111 self.inner.as_ref()
112 }
113}
114
115#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)]
116pub struct Claims {
119 sub: String,
120 #[serde(deserialize_with = "string_or_seq_string")]
121 aud: Vec<String>,
122 exp: u32,
123}
124
125impl Claims {
126 pub fn get_sub(&self) -> &str {
128 &self.sub
129 }
130
131 pub fn get_aud(&self) -> &Vec<String> {
133 &self.aud
134 }
135
136 pub fn get_exp(&self) -> u32 {
138 self.exp
139 }
140}
141
142impl JwtSvid {
143 pub fn parse_and_validate<T: AsRef<str> + ToString + std::fmt::Debug>(
159 token: &str,
160 bundle_source: &impl BundleRefSource<Item = JwtBundle>,
161 expected_audience: &[T],
162 ) -> Result<Self, JwtSvidError> {
163 let jwt_svid = JwtSvid::parse_insecure(token)?;
164
165 let jwt_authority = JwtSvid::find_jwt_authority(
166 bundle_source,
167 jwt_svid.spiffe_id.trust_domain(),
168 &jwt_svid.kid,
169 )?;
170
171 let mut validation = jsonwebtoken::Validation::new(jwt_svid.alg.to_owned());
172 validation.validate_exp = true;
173 validation.set_audience(expected_audience);
174 let dec_key = DecodingKey::from_jwk(jwt_authority)?;
175 jsonwebtoken::decode::<Claims>(token, &dec_key, &validation)?;
176 Ok(jwt_svid)
177 }
178
179 pub fn parse_insecure(token: &str) -> Result<Self, JwtSvidError> {
183 JwtSvid::from_str(token)
184 }
185
186 pub fn token(&self) -> &str {
188 self.token.as_ref()
189 }
190
191 pub fn spiffe_id(&self) -> &SpiffeId {
193 &self.spiffe_id
194 }
195
196 pub fn audience(&self) -> &Vec<String> {
198 &self.claims.aud
199 }
200
201 pub fn expiry(&self) -> &Date {
203 &self.expiry
204 }
205
206 pub fn key_id(&self) -> &str {
208 &self.kid
209 }
210
211 fn find_jwt_authority<'a>(
214 bundle_source: &'a impl BundleRefSource<Item = JwtBundle>,
215 trust_domain: &TrustDomain,
216 key_id: &str,
217 ) -> Result<&'a Jwk, JwtSvidError> {
218 let bundle = match bundle_source.get_bundle_for_trust_domain(trust_domain)? {
219 None => return Err(JwtSvidError::BundleNotFound(trust_domain.to_owned())),
220 Some(b) => b,
221 };
222
223 let jwt_authority = bundle
224 .find_jwt_authority(key_id)
225 .ok_or_else(|| JwtSvidError::AuthorityNotFound(key_id.to_owned()))?;
226
227 Ok(jwt_authority)
228 }
229}
230
231impl FromStr for JwtSvid {
232 type Err = JwtSvidError;
233
234 fn from_str(token: &str) -> Result<Self, Self::Err> {
239 let mut validation = Validation::default();
241 validation.validate_aud = false;
243 validation.insecure_disable_signature_validation();
244 let token_data =
245 jsonwebtoken::decode::<Claims>(token, &DecodingKey::from_secret(&[]), &validation)?;
246
247 let claims = token_data.claims;
248 let spiffe_id = SpiffeId::from_str(&claims.sub)?;
249
250 let expiry = OffsetDateTime::from_unix_timestamp(claims.exp as i64).unwrap();
251 let expiry = expiry.date();
252
253 let kid = match token_data.header.kid {
254 None => return Err(JwtSvidError::MissingKeyId),
255 Some(k) => k,
256 };
257
258 match token_data.header.typ {
259 None => return Err(JwtSvidError::InvalidTyp),
260 Some(t) => match t.as_str() {
261 "JWT" => {}
262 "JOSE" => {}
263 _ => return Err(JwtSvidError::InvalidTyp),
264 },
265 }
266
267 if !SUPPORTED_ALGORITHMS.contains(&token_data.header.alg) {
268 return Err(JwtSvidError::UnsupportedAlgorithm);
269 }
270
271 let alg = token_data.header.alg;
272
273 Ok(Self {
274 spiffe_id,
275 expiry,
276 claims,
277 kid,
278 alg,
279
280 token: Token::from(token),
281 })
282 }
283}
284
285fn string_or_seq_string<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
287where
288 D: Deserializer<'de>,
289{
290 struct StringOrVec(PhantomData<Vec<String>>);
291
292 impl<'de> de::Visitor<'de> for StringOrVec {
293 type Value = Vec<String>;
294
295 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
296 formatter.write_str("string or sequence of strings")
297 }
298
299 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
300 where
301 E: de::Error,
302 {
303 Ok(vec![value.to_owned()])
304 }
305
306 fn visit_seq<S>(self, visitor: S) -> Result<Self::Value, S::Error>
307 where
308 S: de::SeqAccess<'de>,
309 {
310 Deserialize::deserialize(de::value::SeqAccessDeserializer::new(visitor))
311 }
312 }
313
314 deserializer.deserialize_any(StringOrVec(PhantomData))
315}
316
317#[cfg(test)]
318mod test {
319 use super::*;
320 use crate::bundle::jwt::JwtBundleSet;
321 use jsonwebtoken::*;
322
323 #[test]
324 fn test_parse_and_validate_jwt_svid() {
325 let test_key_id = "test-key-id";
326
327 let test_key = jsonwebkey::Key::generate_p256();
328
329 let encoding_key = jsonwebtoken::EncodingKey::from_ec_der(&test_key.to_der());
330
331 let mut jwt_key = jsonwebkey::JsonWebKey::new(test_key);
332 jwt_key.set_algorithm(jsonwebkey::Algorithm::ES256).unwrap();
333 jwt_key.key_id = Some(test_key_id.to_string());
334
335 let res = serde_json::to_string(&jwt_key).expect("JWK should be serializable");
336 let jwk = serde_json::from_str(&res).expect("JWK should be deserializable");
337
338 let target_audience = vec!["audience".to_owned()];
339 let token = generate_token(
341 target_audience.clone(),
342 "spiffe://example.org/service".to_string(),
343 Some("JWT".to_string()),
344 Some(test_key_id.to_string()),
345 4294967295,
346 jsonwebtoken::Algorithm::ES256,
347 &encoding_key,
348 );
349
350 let mut bundle_source = JwtBundleSet::default();
352 let trust_domain = TrustDomain::new("example.org").unwrap();
353 let mut bundle = JwtBundle::new(trust_domain);
354 bundle.add_jwt_authority(jwk).unwrap();
355 bundle_source.add_bundle(bundle);
356
357 let jwt_svid = JwtSvid::parse_and_validate(&token, &bundle_source, &["audience"]).unwrap();
359
360 assert_eq!(
361 jwt_svid.spiffe_id,
362 SpiffeId::new("spiffe://example.org/service").unwrap()
363 );
364
365 assert_eq!(jwt_svid.audience(), &target_audience);
366 assert_eq!(jwt_svid.token(), token);
367 }
368
369 #[test]
370 fn test_parse_jwt_svid_with_unsupported_algorithm() {
371 let target_audience = vec!["audience".to_owned()];
372 let test_key_id = "test-key-id";
373 let mut jwt_key = jsonwebkey::JsonWebKey::new(jsonwebkey::Key::generate_p256());
374 jwt_key.set_algorithm(jsonwebkey::Algorithm::ES256).unwrap();
375 jwt_key.key_id = Some(test_key_id.to_string());
376
377 let token = generate_token(
379 target_audience,
380 "spiffe://example.org/service".to_string(),
381 Some("JWT".to_string()),
382 Some("some_key_id".to_string()),
383 4294967295,
384 jsonwebtoken::Algorithm::default(),
385 &EncodingKey::from_secret("secret".as_ref()),
386 );
387
388 let result = JwtSvid::parse_insecure(&token).unwrap_err();
389
390 assert!(matches!(result, JwtSvidError::UnsupportedAlgorithm));
391 }
392
393 #[test]
394 fn test_parse_invalid_jwt_svid_without_key_id() {
395 let test_key = jsonwebkey::Key::generate_p256();
396
397 let encoding_key = jsonwebtoken::EncodingKey::from_ec_der(&test_key.to_der());
398
399 let target_audience = vec!["audience".to_owned()];
400 let test_key_id = "test-key-id";
401 let mut jwt_key = jsonwebkey::JsonWebKey::new(test_key);
402 jwt_key.set_algorithm(jsonwebkey::Algorithm::ES256).unwrap();
403 jwt_key.key_id = Some(test_key_id.to_string());
404
405 let token = generate_token(
407 target_audience.clone(),
408 "spiffe://example.org/service".to_string(),
409 Some("JWT".to_string()),
410 None,
411 4294967295,
412 jsonwebtoken::Algorithm::ES256,
413 &encoding_key,
414 );
415
416 let result = JwtSvid::parse_insecure(&token).unwrap_err();
417
418 assert!(matches!(result, JwtSvidError::MissingKeyId))
419 }
420
421 #[test]
422 fn test_parse_invalid_jwt_svid_with_invalid_header_typ() {
423 let test_key = jsonwebkey::Key::generate_p256();
424
425 let encoding_key = jsonwebtoken::EncodingKey::from_ec_der(&test_key.to_der());
426
427 let target_audience = vec!["audience".to_owned()];
428 let test_key_id = "test-key-id";
429 let mut jwt_key = jsonwebkey::JsonWebKey::new(test_key);
430 jwt_key.set_algorithm(jsonwebkey::Algorithm::ES256).unwrap();
431 jwt_key.key_id = Some(test_key_id.to_string());
432
433 let token = generate_token(
435 target_audience.clone(),
436 "spiffe://example.org/service".to_string(),
437 Some("OTHER".to_string()),
438 Some("kid".to_string()),
439 4294967295,
440 jsonwebtoken::Algorithm::ES256,
441 &encoding_key,
442 );
443
444 let result = JwtSvid::parse_insecure(&token).unwrap_err();
446
447 assert!(matches!(result, JwtSvidError::InvalidTyp))
448 }
449
450 #[test]
451 fn test_parse_and_validate_jwt_svid_from_expired_token() {
452 let test_key = jsonwebkey::Key::generate_p256();
453
454 let encoding_key = jsonwebtoken::EncodingKey::from_ec_der(&test_key.to_der());
455
456 let target_audience = vec!["audience".to_owned()];
457 let test_key_id = "test-key-id";
458 let mut jwt_key = jsonwebkey::JsonWebKey::new(test_key);
459 jwt_key.set_algorithm(jsonwebkey::Algorithm::ES256).unwrap();
460 jwt_key.key_id = Some(test_key_id.to_string());
461
462 let res = serde_json::to_string(&jwt_key).expect("JWK should be serializable");
463 let jwk = serde_json::from_str(&res).expect("JWK should be deserializable");
464
465 let token = generate_token(
467 target_audience.clone(),
468 "spiffe://example.org/service".to_string(),
469 Some("JWT".to_string()),
470 Some(test_key_id.to_string()),
471 1,
472 jsonwebtoken::Algorithm::ES256,
473 &encoding_key,
474 );
475
476 let mut bundle_source = JwtBundleSet::default();
478 let trust_domain = TrustDomain::new("example.org").unwrap();
479 let mut bundle = JwtBundle::new(trust_domain);
480 bundle.add_jwt_authority(jwk).unwrap();
481 bundle_source.add_bundle(bundle);
482
483 let result =
485 JwtSvid::parse_and_validate(&token, &bundle_source, &["audience"]).unwrap_err();
486
487 assert!(matches!(result, JwtSvidError::InvalidToken(..)));
488 }
489
490 fn generate_token(
492 aud: Vec<String>,
493 sub: String,
494 typ: Option<String>,
495 kid: Option<String>,
496 exp: u32,
497 alg: jsonwebtoken::Algorithm,
498 encoding_key: &EncodingKey,
499 ) -> String {
500 let claims = Claims { sub, aud, exp };
501
502 let header = jsonwebtoken::Header {
503 typ,
504 alg,
505 kid,
506 cty: None,
507 jku: None,
508 x5u: None,
509 x5c: None,
510 x5t: None,
511 jwk: None,
512 x5t_s256: None,
513 };
514 encode(&header, &claims, encoding_key).unwrap()
515 }
516}