1use std::str::FromStr;
4
5use jsonwebtoken::jwk::Jwk;
6use jsonwebtoken::{Algorithm, DecodingKey, Validation};
7use serde::{de, Deserialize, Deserializer, Serialize};
8use thiserror::Error;
9use zeroize::Zeroize;
10
11use crate::bundle::jwt::JwtBundle;
12use crate::bundle::BundleRefSource;
13use crate::spiffe_id::{SpiffeId, SpiffeIdError, TrustDomain};
14use crate::svid::Svid;
15use std::error::Error;
16use std::fmt;
17use std::marker::PhantomData;
18use time::{Date, OffsetDateTime};
19
20const SUPPORTED_ALGORITHMS: &[Algorithm; 8] = &[
21 Algorithm::RS256,
22 Algorithm::RS384,
23 Algorithm::RS512,
24 Algorithm::ES256,
25 Algorithm::ES384,
26 Algorithm::PS256,
27 Algorithm::PS384,
28 Algorithm::PS512,
29];
30
31#[derive(Debug, Clone, PartialEq)]
35pub struct JwtSvid {
36 spiffe_id: SpiffeId,
37 expiry: Date,
38 claims: Claims,
40 kid: String,
41 alg: Algorithm,
42
43 token: Token,
44}
45
46impl Svid for JwtSvid {}
47
48#[derive(Debug, Error)]
51#[non_exhaustive]
52pub enum JwtSvidError {
53 #[error("invalid spiffe_id in token 'sub' claim")]
55 InvalidSubject(#[from] SpiffeIdError),
56
57 #[error("token header 'kid' not found")]
59 MissingKeyId,
60
61 #[error("token header 'typ' should be 'JWT' or 'JOSE'")]
63 InvalidTyp,
64
65 #[error("algorithm in 'alg' header is not supported")]
68 UnsupportedAlgorithm,
69
70 #[error("one of the required claims ({0}) is missing")]
72 RequiredClaimMissing(String),
73
74 #[error("cannot find JWT bundle for trust domain: {0}")]
76 BundleNotFound(TrustDomain),
77
78 #[error("cannot find JWT authority for key_id: {0}")]
80 AuthorityNotFound(String),
81
82 #[error("expected audience in {0:?} (audience={1:?})")]
84 InvalidAudience(Vec<String>, Vec<String>),
85
86 #[error("cannot decode token")]
88 InvalidToken(#[from] jsonwebtoken::errors::Error),
89
90 #[error("error parsing JWT-SVID")]
92 Other(#[from] Box<dyn Error + Send + Sync + 'static>),
93}
94
95#[derive(Debug, Clone, Eq, PartialEq, Zeroize)]
96#[zeroize(drop)]
97struct Token {
98 inner: String,
99}
100
101impl From<&str> for Token {
102 fn from(token: &str) -> Self {
103 Self {
104 inner: token.to_owned(),
105 }
106 }
107}
108
109impl AsRef<str> for Token {
110 fn as_ref(&self) -> &str {
111 self.inner.as_ref()
112 }
113}
114
115#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)]
116pub struct Claims {
119 sub: String,
120 #[serde(deserialize_with = "string_or_seq_string")]
121 aud: Vec<String>,
122 exp: u32,
123}
124
125impl Claims {
126 pub fn sub(&self) -> &str {
128 &self.sub
129 }
130
131 pub fn aud(&self) -> &Vec<String> {
133 &self.aud
134 }
135
136 pub fn exp(&self) -> u32 {
138 self.exp
139 }
140}
141
142impl JwtSvid {
143 pub fn parse_and_validate<T: AsRef<str> + ToString + std::fmt::Debug>(
159 token: &str,
160 bundle_source: &impl BundleRefSource<Item = JwtBundle>,
161 expected_audience: &[T],
162 ) -> Result<Self, JwtSvidError> {
163 let jwt_svid = JwtSvid::parse_insecure(token)?;
164
165 let jwt_authority = JwtSvid::find_jwt_authority(
166 bundle_source,
167 jwt_svid.spiffe_id.trust_domain(),
168 &jwt_svid.kid,
169 )?;
170
171 let mut validation = jsonwebtoken::Validation::new(jwt_svid.alg.to_owned());
172 validation.validate_exp = true;
173 validation.set_audience(expected_audience);
174 let dec_key = DecodingKey::from_jwk(jwt_authority)?;
175 jsonwebtoken::decode::<Claims>(token, &dec_key, &validation)?;
176 Ok(jwt_svid)
177 }
178
179 pub fn parse_insecure(token: &str) -> Result<Self, JwtSvidError> {
183 JwtSvid::from_str(token)
184 }
185
186 pub fn token(&self) -> &str {
188 self.token.as_ref()
189 }
190
191 pub fn spiffe_id(&self) -> &SpiffeId {
193 &self.spiffe_id
194 }
195
196 pub fn audience(&self) -> &Vec<String> {
198 &self.claims.aud
199 }
200
201 pub fn expiry(&self) -> &Date {
203 &self.expiry
204 }
205
206 pub fn key_id(&self) -> &str {
208 &self.kid
209 }
210
211 pub fn claims(&self) -> &Claims {
213 &self.claims
214 }
215
216 fn find_jwt_authority<'a>(
219 bundle_source: &'a impl BundleRefSource<Item = JwtBundle>,
220 trust_domain: &TrustDomain,
221 key_id: &str,
222 ) -> Result<&'a Jwk, JwtSvidError> {
223 let bundle = match bundle_source.get_bundle_for_trust_domain(trust_domain)? {
224 None => return Err(JwtSvidError::BundleNotFound(trust_domain.to_owned())),
225 Some(b) => b,
226 };
227
228 let jwt_authority = bundle
229 .find_jwt_authority(key_id)
230 .ok_or_else(|| JwtSvidError::AuthorityNotFound(key_id.to_owned()))?;
231
232 Ok(jwt_authority)
233 }
234}
235
236impl FromStr for JwtSvid {
237 type Err = JwtSvidError;
238
239 fn from_str(token: &str) -> Result<Self, Self::Err> {
244 let mut validation = Validation::default();
246 validation.validate_aud = false;
248 validation.insecure_disable_signature_validation();
249 let token_data =
250 jsonwebtoken::decode::<Claims>(token, &DecodingKey::from_secret(&[]), &validation)?;
251
252 let claims = token_data.claims;
253 let spiffe_id = SpiffeId::from_str(&claims.sub)?;
254
255 let expiry = OffsetDateTime::from_unix_timestamp(claims.exp as i64).unwrap();
256 let expiry = expiry.date();
257
258 let kid = match token_data.header.kid {
259 None => return Err(JwtSvidError::MissingKeyId),
260 Some(k) => k,
261 };
262
263 match token_data.header.typ {
264 None => return Err(JwtSvidError::InvalidTyp),
265 Some(t) => match t.as_str() {
266 "JWT" => {}
267 "JOSE" => {}
268 _ => return Err(JwtSvidError::InvalidTyp),
269 },
270 }
271
272 if !SUPPORTED_ALGORITHMS.contains(&token_data.header.alg) {
273 return Err(JwtSvidError::UnsupportedAlgorithm);
274 }
275
276 let alg = token_data.header.alg;
277
278 Ok(Self {
279 spiffe_id,
280 expiry,
281 claims,
282 kid,
283 alg,
284
285 token: Token::from(token),
286 })
287 }
288}
289
290fn string_or_seq_string<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
292where
293 D: Deserializer<'de>,
294{
295 struct StringOrVec(PhantomData<Vec<String>>);
296
297 impl<'de> de::Visitor<'de> for StringOrVec {
298 type Value = Vec<String>;
299
300 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
301 formatter.write_str("string or sequence of strings")
302 }
303
304 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
305 where
306 E: de::Error,
307 {
308 Ok(vec![value.to_owned()])
309 }
310
311 fn visit_seq<S>(self, visitor: S) -> Result<Self::Value, S::Error>
312 where
313 S: de::SeqAccess<'de>,
314 {
315 Deserialize::deserialize(de::value::SeqAccessDeserializer::new(visitor))
316 }
317 }
318
319 deserializer.deserialize_any(StringOrVec(PhantomData))
320}
321
322#[cfg(test)]
323mod test {
324 use super::*;
325 use crate::bundle::jwt::JwtBundleSet;
326 use jsonwebtoken::*;
327
328 #[test]
329 fn test_parse_and_validate_jwt_svid() {
330 let test_key_id = "test-key-id";
331
332 let test_key = jsonwebkey::Key::generate_p256();
333
334 let encoding_key = jsonwebtoken::EncodingKey::from_ec_der(&test_key.to_der());
335
336 let mut jwt_key = jsonwebkey::JsonWebKey::new(test_key);
337 jwt_key.set_algorithm(jsonwebkey::Algorithm::ES256).unwrap();
338 jwt_key.key_id = Some(test_key_id.to_string());
339
340 let res = serde_json::to_string(&jwt_key).expect("JWK should be serializable");
341 let jwk = serde_json::from_str(&res).expect("JWK should be deserializable");
342
343 let target_audience = vec!["audience".to_owned()];
344 let token = generate_token(
346 target_audience.clone(),
347 "spiffe://example.org/service".to_string(),
348 Some("JWT".to_string()),
349 Some(test_key_id.to_string()),
350 4294967295,
351 jsonwebtoken::Algorithm::ES256,
352 &encoding_key,
353 );
354
355 let mut bundle_source = JwtBundleSet::default();
357 let trust_domain = TrustDomain::new("example.org").unwrap();
358 let mut bundle = JwtBundle::new(trust_domain);
359 bundle.add_jwt_authority(jwk).unwrap();
360 bundle_source.add_bundle(bundle);
361
362 let jwt_svid = JwtSvid::parse_and_validate(&token, &bundle_source, &["audience"]).unwrap();
364
365 assert_eq!(
366 jwt_svid.spiffe_id,
367 SpiffeId::new("spiffe://example.org/service").unwrap()
368 );
369
370 assert_eq!(jwt_svid.audience(), &target_audience);
371 assert_eq!(jwt_svid.token(), token);
372 }
373
374 #[test]
375 fn test_parse_jwt_svid_with_unsupported_algorithm() {
376 let target_audience = vec!["audience".to_owned()];
377 let test_key_id = "test-key-id";
378 let mut jwt_key = jsonwebkey::JsonWebKey::new(jsonwebkey::Key::generate_p256());
379 jwt_key.set_algorithm(jsonwebkey::Algorithm::ES256).unwrap();
380 jwt_key.key_id = Some(test_key_id.to_string());
381
382 let token = generate_token(
384 target_audience,
385 "spiffe://example.org/service".to_string(),
386 Some("JWT".to_string()),
387 Some("some_key_id".to_string()),
388 4294967295,
389 jsonwebtoken::Algorithm::default(),
390 &EncodingKey::from_secret("secret".as_ref()),
391 );
392
393 let result = JwtSvid::parse_insecure(&token).unwrap_err();
394
395 assert!(matches!(result, JwtSvidError::UnsupportedAlgorithm));
396 }
397
398 #[test]
399 fn test_parse_invalid_jwt_svid_without_key_id() {
400 let test_key = jsonwebkey::Key::generate_p256();
401
402 let encoding_key = jsonwebtoken::EncodingKey::from_ec_der(&test_key.to_der());
403
404 let target_audience = vec!["audience".to_owned()];
405 let test_key_id = "test-key-id";
406 let mut jwt_key = jsonwebkey::JsonWebKey::new(test_key);
407 jwt_key.set_algorithm(jsonwebkey::Algorithm::ES256).unwrap();
408 jwt_key.key_id = Some(test_key_id.to_string());
409
410 let token = generate_token(
412 target_audience.clone(),
413 "spiffe://example.org/service".to_string(),
414 Some("JWT".to_string()),
415 None,
416 4294967295,
417 jsonwebtoken::Algorithm::ES256,
418 &encoding_key,
419 );
420
421 let result = JwtSvid::parse_insecure(&token).unwrap_err();
422
423 assert!(matches!(result, JwtSvidError::MissingKeyId))
424 }
425
426 #[test]
427 fn test_parse_invalid_jwt_svid_with_invalid_header_typ() {
428 let test_key = jsonwebkey::Key::generate_p256();
429
430 let encoding_key = jsonwebtoken::EncodingKey::from_ec_der(&test_key.to_der());
431
432 let target_audience = vec!["audience".to_owned()];
433 let test_key_id = "test-key-id";
434 let mut jwt_key = jsonwebkey::JsonWebKey::new(test_key);
435 jwt_key.set_algorithm(jsonwebkey::Algorithm::ES256).unwrap();
436 jwt_key.key_id = Some(test_key_id.to_string());
437
438 let token = generate_token(
440 target_audience.clone(),
441 "spiffe://example.org/service".to_string(),
442 Some("OTHER".to_string()),
443 Some("kid".to_string()),
444 4294967295,
445 jsonwebtoken::Algorithm::ES256,
446 &encoding_key,
447 );
448
449 let result = JwtSvid::parse_insecure(&token).unwrap_err();
451
452 assert!(matches!(result, JwtSvidError::InvalidTyp))
453 }
454
455 #[test]
456 fn test_parse_and_validate_jwt_svid_from_expired_token() {
457 let test_key = jsonwebkey::Key::generate_p256();
458
459 let encoding_key = jsonwebtoken::EncodingKey::from_ec_der(&test_key.to_der());
460
461 let target_audience = vec!["audience".to_owned()];
462 let test_key_id = "test-key-id";
463 let mut jwt_key = jsonwebkey::JsonWebKey::new(test_key);
464 jwt_key.set_algorithm(jsonwebkey::Algorithm::ES256).unwrap();
465 jwt_key.key_id = Some(test_key_id.to_string());
466
467 let res = serde_json::to_string(&jwt_key).expect("JWK should be serializable");
468 let jwk = serde_json::from_str(&res).expect("JWK should be deserializable");
469
470 let token = generate_token(
472 target_audience.clone(),
473 "spiffe://example.org/service".to_string(),
474 Some("JWT".to_string()),
475 Some(test_key_id.to_string()),
476 1,
477 jsonwebtoken::Algorithm::ES256,
478 &encoding_key,
479 );
480
481 let mut bundle_source = JwtBundleSet::default();
483 let trust_domain = TrustDomain::new("example.org").unwrap();
484 let mut bundle = JwtBundle::new(trust_domain);
485 bundle.add_jwt_authority(jwk).unwrap();
486 bundle_source.add_bundle(bundle);
487
488 let result =
490 JwtSvid::parse_and_validate(&token, &bundle_source, &["audience"]).unwrap_err();
491
492 assert!(matches!(result, JwtSvidError::InvalidToken(..)));
493 }
494
495 fn generate_token(
497 aud: Vec<String>,
498 sub: String,
499 typ: Option<String>,
500 kid: Option<String>,
501 exp: u32,
502 alg: jsonwebtoken::Algorithm,
503 encoding_key: &EncodingKey,
504 ) -> String {
505 let claims = Claims { sub, aud, exp };
506
507 let header = jsonwebtoken::Header {
508 typ,
509 alg,
510 kid,
511 cty: None,
512 jku: None,
513 x5u: None,
514 x5c: None,
515 x5t: None,
516 jwk: None,
517 x5t_s256: None,
518 };
519 encode(&header, &claims, encoding_key).unwrap()
520 }
521}