spiffe/svid/jwt/
mod.rs

1//! JWT SVID types.
2
3use std::str::FromStr;
4
5use jsonwebtoken::jwk::Jwk;
6use jsonwebtoken::{Algorithm, DecodingKey, Validation};
7use serde::{de, Deserialize, Deserializer, Serialize};
8use thiserror::Error;
9use zeroize::Zeroize;
10
11use crate::bundle::jwt::JwtBundle;
12use crate::bundle::BundleRefSource;
13use crate::spiffe_id::{SpiffeId, SpiffeIdError, TrustDomain};
14use crate::svid::Svid;
15use std::error::Error;
16use std::fmt;
17use std::marker::PhantomData;
18use time::{Date, OffsetDateTime};
19
20const SUPPORTED_ALGORITHMS: &[Algorithm; 8] = &[
21    Algorithm::RS256,
22    Algorithm::RS384,
23    Algorithm::RS512,
24    Algorithm::ES256,
25    Algorithm::ES384,
26    Algorithm::PS256,
27    Algorithm::PS384,
28    Algorithm::PS512,
29];
30
31/// This type represents a [SPIFFE JWT-SVID](https://github.com/spiffe/spiffe/blob/main/standards/JWT-SVID.md).
32///
33/// The token field is zeroized on drop.
34#[derive(Debug, Clone, PartialEq)]
35pub struct JwtSvid {
36    spiffe_id: SpiffeId,
37    expiry: Date,
38    // expiry: DateTime<Utc>,
39    claims: Claims,
40    kid: String,
41    alg: Algorithm,
42
43    token: Token,
44}
45
46impl Svid for JwtSvid {}
47
48/// An error that can arise trying to parse a [`JwtSvid`] from a JWT token. It also represents
49/// errors that can happen validating the token signature or the token audience.
50#[derive(Debug, Error)]
51#[non_exhaustive]
52pub enum JwtSvidError {
53    /// The 'sub' claim is not a valid SPIFFE ID.
54    #[error("invalid spiffe_id in token 'sub' claim")]
55    InvalidSubject(#[from] SpiffeIdError),
56
57    /// The header 'kid' is not present.
58    #[error("token header 'kid' not found")]
59    MissingKeyId,
60
61    /// The header 'typ' contains a value other than 'JWT' or 'JOSE'.
62    #[error("token header 'typ' should be 'JWT' or 'JOSE'")]
63    InvalidTyp,
64
65    /// The header 'alg' contains an algorithm that is not supported.
66    /// Supported algorithms are ['RS256', 'RS384', 'RS512', 'ES256', 'ES384', 'PS256', 'PS384', 'PS512'].
67    #[error("algorithm in 'alg' header is not supported")]
68    UnsupportedAlgorithm,
69
70    /// One of the required claims is missing. "aud", "sub" and "exp" must be present.
71    #[error("one of the required claims ({0}) is missing")]
72    RequiredClaimMissing(String),
73
74    /// Cannot find a JWT bundle for the trust domain, to validate the token signature.
75    #[error("cannot find JWT bundle for trust domain: {0}")]
76    BundleNotFound(TrustDomain),
77
78    /// Cannot find the JWT authority with key_id, to validate the token signature.
79    #[error("cannot find JWT authority for key_id: {0}")]
80    AuthorityNotFound(String),
81
82    /// The token doesn't have the expected audience.
83    #[error("expected audience in {0:?} (audience={1:?})")]
84    InvalidAudience(Vec<String>, Vec<String>),
85
86    /// Error returned by the JWT decoding library.
87    #[error("cannot decode token")]
88    InvalidToken(#[from] jsonwebtoken::errors::Error),
89
90    /// Other errors that can arise.
91    #[error("error parsing JWT-SVID")]
92    Other(#[from] Box<dyn Error + Send + Sync + 'static>),
93}
94
95#[derive(Debug, Clone, Eq, PartialEq, Zeroize)]
96#[zeroize(drop)]
97struct Token {
98    inner: String,
99}
100
101impl From<&str> for Token {
102    fn from(token: &str) -> Self {
103        Self {
104            inner: token.to_owned(),
105        }
106    }
107}
108
109impl AsRef<str> for Token {
110    fn as_ref(&self) -> &str {
111        self.inner.as_ref()
112    }
113}
114
115#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)]
116/// Representation of the required
117/// [claims](https://github.com/spiffe/spiffe/blob/main/standards/JWT-SVID.md#3-jwt-claims) in a SPIFFE JWT-SVID.
118pub struct Claims {
119    sub: String,
120    #[serde(deserialize_with = "string_or_seq_string")]
121    aud: Vec<String>,
122    exp: u32,
123}
124
125impl Claims {
126    /// Get the sub claim.
127    pub fn get_sub(&self) -> &str {
128        &self.sub
129    }
130
131    /// Get the aud claim.
132    pub fn get_aud(&self) -> &Vec<String> {
133        &self.aud
134    }
135
136    /// Get the exp claim.
137    pub fn get_exp(&self) -> u32 {
138        self.exp
139    }
140}
141
142impl JwtSvid {
143    /// Parses the given token verifying the token signature using the provided [`BundleSource`] as
144    /// a source of [`JwtBundle`], validating the audience in the token with the expected audience,
145    /// and validating the expiration datetime.
146    ///
147    /// Returns a validated instance of `JwtSvid`.
148    ///
149    /// # Arguments
150    ///
151    /// * `token`: JWT token to parse.
152    /// * `bundle_source`: Struct that implements a [`BundleSource`] for the type [`JwtBundle`].
153    /// * `expected_audience`: List of audience strings that should be present in the token 'aud' claim.
154    ///
155    /// # Errors
156    ///
157    /// If the function cannot parse or verify the signature of the token, a [`JwtSvidError`] variant will be returned.
158    pub fn parse_and_validate<T: AsRef<str> + ToString + std::fmt::Debug>(
159        token: &str,
160        bundle_source: &impl BundleRefSource<Item = JwtBundle>,
161        expected_audience: &[T],
162    ) -> Result<Self, JwtSvidError> {
163        let jwt_svid = JwtSvid::parse_insecure(token)?;
164
165        let jwt_authority = JwtSvid::find_jwt_authority(
166            bundle_source,
167            jwt_svid.spiffe_id.trust_domain(),
168            &jwt_svid.kid,
169        )?;
170
171        let mut validation = jsonwebtoken::Validation::new(jwt_svid.alg.to_owned());
172        validation.validate_exp = true;
173        validation.set_audience(expected_audience);
174        let dec_key = DecodingKey::from_jwk(jwt_authority)?;
175        jsonwebtoken::decode::<Claims>(token, &dec_key, &validation)?;
176        Ok(jwt_svid)
177    }
178
179    /// Creates a new [`JwtSvid`] with the given token without signature verification.
180    ///
181    /// IMPORTANT: For parsing and validating the signature of untrusted tokens, use `parse_and_validate` method.
182    pub fn parse_insecure(token: &str) -> Result<Self, JwtSvidError> {
183        JwtSvid::from_str(token)
184    }
185
186    /// Returns the serialized JWT token.
187    pub fn token(&self) -> &str {
188        self.token.as_ref()
189    }
190
191    /// Returns the SPIFFE ID ('aud' claim) of the token.
192    pub fn spiffe_id(&self) -> &SpiffeId {
193        &self.spiffe_id
194    }
195
196    /// Returns the audience as present in the 'aud' claim.
197    pub fn audience(&self) -> &Vec<String> {
198        &self.claims.aud
199    }
200
201    /// Returns the expiration date of the JWT token.
202    pub fn expiry(&self) -> &Date {
203        &self.expiry
204    }
205
206    /// Returns the key id header of the JWT token.
207    pub fn key_id(&self) -> &str {
208        &self.kid
209    }
210
211    // Get the bundle associated to the trust_domain in the bundle_source, then from the bundle
212    // return the jwt_authority with the key_id
213    fn find_jwt_authority<'a>(
214        bundle_source: &'a impl BundleRefSource<Item = JwtBundle>,
215        trust_domain: &TrustDomain,
216        key_id: &str,
217    ) -> Result<&'a Jwk, JwtSvidError> {
218        let bundle = match bundle_source.get_bundle_for_trust_domain(trust_domain)? {
219            None => return Err(JwtSvidError::BundleNotFound(trust_domain.to_owned())),
220            Some(b) => b,
221        };
222
223        let jwt_authority = bundle
224            .find_jwt_authority(key_id)
225            .ok_or_else(|| JwtSvidError::AuthorityNotFound(key_id.to_owned()))?;
226
227        Ok(jwt_authority)
228    }
229}
230
231impl FromStr for JwtSvid {
232    type Err = JwtSvidError;
233
234    /// Creates a new [`JwtSvid`] with the given token without signature verification.
235    /// Any result from this function is untrusted.
236    ///
237    /// IMPORTANT: For parsing and validating the signature of untrusted tokens, use `parse_and_validate` method.
238    fn from_str(token: &str) -> Result<Self, Self::Err> {
239        // decode token without signature or expiration validation
240        let mut validation = Validation::default();
241        // We later on validate audience separately with `parse_and_validate`
242        validation.validate_aud = false;
243        validation.insecure_disable_signature_validation();
244        let token_data =
245            jsonwebtoken::decode::<Claims>(token, &DecodingKey::from_secret(&[]), &validation)?;
246
247        let claims = token_data.claims;
248        let spiffe_id = SpiffeId::from_str(&claims.sub)?;
249
250        let expiry = OffsetDateTime::from_unix_timestamp(claims.exp as i64).unwrap();
251        let expiry = expiry.date();
252
253        let kid = match token_data.header.kid {
254            None => return Err(JwtSvidError::MissingKeyId),
255            Some(k) => k,
256        };
257
258        match token_data.header.typ {
259            None => return Err(JwtSvidError::InvalidTyp),
260            Some(t) => match t.as_str() {
261                "JWT" => {}
262                "JOSE" => {}
263                _ => return Err(JwtSvidError::InvalidTyp),
264            },
265        }
266
267        if !SUPPORTED_ALGORITHMS.contains(&token_data.header.alg) {
268            return Err(JwtSvidError::UnsupportedAlgorithm);
269        }
270
271        let alg = token_data.header.alg;
272
273        Ok(Self {
274            spiffe_id,
275            expiry,
276            claims,
277            kid,
278            alg,
279
280            token: Token::from(token),
281        })
282    }
283}
284
285// Used to deserialize 'aud' claim being either a String or a sequence of strings.
286fn string_or_seq_string<'de, D>(deserializer: D) -> Result<Vec<String>, D::Error>
287where
288    D: Deserializer<'de>,
289{
290    struct StringOrVec(PhantomData<Vec<String>>);
291
292    impl<'de> de::Visitor<'de> for StringOrVec {
293        type Value = Vec<String>;
294
295        fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
296            formatter.write_str("string or sequence of strings")
297        }
298
299        fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
300        where
301            E: de::Error,
302        {
303            Ok(vec![value.to_owned()])
304        }
305
306        fn visit_seq<S>(self, visitor: S) -> Result<Self::Value, S::Error>
307        where
308            S: de::SeqAccess<'de>,
309        {
310            Deserialize::deserialize(de::value::SeqAccessDeserializer::new(visitor))
311        }
312    }
313
314    deserializer.deserialize_any(StringOrVec(PhantomData))
315}
316
317#[cfg(test)]
318mod test {
319    use super::*;
320    use crate::bundle::jwt::JwtBundleSet;
321    use jsonwebtoken::*;
322
323    #[test]
324    fn test_parse_and_validate_jwt_svid() {
325        let test_key_id = "test-key-id";
326
327        let test_key = jsonwebkey::Key::generate_p256();
328
329        let encoding_key = jsonwebtoken::EncodingKey::from_ec_der(&test_key.to_der());
330
331        let mut jwt_key = jsonwebkey::JsonWebKey::new(test_key);
332        jwt_key.set_algorithm(jsonwebkey::Algorithm::ES256).unwrap();
333        jwt_key.key_id = Some(test_key_id.to_string());
334
335        let res = serde_json::to_string(&jwt_key).expect("JWK should be serializable");
336        let jwk = serde_json::from_str(&res).expect("JWK should be deserializable");
337
338        let target_audience = vec!["audience".to_owned()];
339        // generate signed token
340        let token = generate_token(
341            target_audience.clone(),
342            "spiffe://example.org/service".to_string(),
343            Some("JWT".to_string()),
344            Some(test_key_id.to_string()),
345            4294967295,
346            jsonwebtoken::Algorithm::ES256,
347            &encoding_key,
348        );
349
350        // create a new source of JWT bundles
351        let mut bundle_source = JwtBundleSet::default();
352        let trust_domain = TrustDomain::new("example.org").unwrap();
353        let mut bundle = JwtBundle::new(trust_domain);
354        bundle.add_jwt_authority(jwk).unwrap();
355        bundle_source.add_bundle(bundle);
356
357        // parse and validate JWT-SVID from signed token using the bundle source to validate the signature
358        let jwt_svid = JwtSvid::parse_and_validate(&token, &bundle_source, &["audience"]).unwrap();
359
360        assert_eq!(
361            jwt_svid.spiffe_id,
362            SpiffeId::new("spiffe://example.org/service").unwrap()
363        );
364
365        assert_eq!(jwt_svid.audience(), &target_audience);
366        assert_eq!(jwt_svid.token(), token);
367    }
368
369    #[test]
370    fn test_parse_jwt_svid_with_unsupported_algorithm() {
371        let target_audience = vec!["audience".to_owned()];
372        let test_key_id = "test-key-id";
373        let mut jwt_key = jsonwebkey::JsonWebKey::new(jsonwebkey::Key::generate_p256());
374        jwt_key.set_algorithm(jsonwebkey::Algorithm::ES256).unwrap();
375        jwt_key.key_id = Some(test_key_id.to_string());
376
377        // generate signed token
378        let token = generate_token(
379            target_audience,
380            "spiffe://example.org/service".to_string(),
381            Some("JWT".to_string()),
382            Some("some_key_id".to_string()),
383            4294967295,
384            jsonwebtoken::Algorithm::default(),
385            &EncodingKey::from_secret("secret".as_ref()),
386        );
387
388        let result = JwtSvid::parse_insecure(&token).unwrap_err();
389
390        assert!(matches!(result, JwtSvidError::UnsupportedAlgorithm));
391    }
392
393    #[test]
394    fn test_parse_invalid_jwt_svid_without_key_id() {
395        let test_key = jsonwebkey::Key::generate_p256();
396
397        let encoding_key = jsonwebtoken::EncodingKey::from_ec_der(&test_key.to_der());
398
399        let target_audience = vec!["audience".to_owned()];
400        let test_key_id = "test-key-id";
401        let mut jwt_key = jsonwebkey::JsonWebKey::new(test_key);
402        jwt_key.set_algorithm(jsonwebkey::Algorithm::ES256).unwrap();
403        jwt_key.key_id = Some(test_key_id.to_string());
404
405        // generate signed token
406        let token = generate_token(
407            target_audience.clone(),
408            "spiffe://example.org/service".to_string(),
409            Some("JWT".to_string()),
410            None,
411            4294967295,
412            jsonwebtoken::Algorithm::ES256,
413            &encoding_key,
414        );
415
416        let result = JwtSvid::parse_insecure(&token).unwrap_err();
417
418        assert!(matches!(result, JwtSvidError::MissingKeyId))
419    }
420
421    #[test]
422    fn test_parse_invalid_jwt_svid_with_invalid_header_typ() {
423        let test_key = jsonwebkey::Key::generate_p256();
424
425        let encoding_key = jsonwebtoken::EncodingKey::from_ec_der(&test_key.to_der());
426
427        let target_audience = vec!["audience".to_owned()];
428        let test_key_id = "test-key-id";
429        let mut jwt_key = jsonwebkey::JsonWebKey::new(test_key);
430        jwt_key.set_algorithm(jsonwebkey::Algorithm::ES256).unwrap();
431        jwt_key.key_id = Some(test_key_id.to_string());
432
433        // generate signed token
434        let token = generate_token(
435            target_audience.clone(),
436            "spiffe://example.org/service".to_string(),
437            Some("OTHER".to_string()),
438            Some("kid".to_string()),
439            4294967295,
440            jsonwebtoken::Algorithm::ES256,
441            &encoding_key,
442        );
443
444        // parse JWT-SVID from token without validating
445        let result = JwtSvid::parse_insecure(&token).unwrap_err();
446
447        assert!(matches!(result, JwtSvidError::InvalidTyp))
448    }
449
450    #[test]
451    fn test_parse_and_validate_jwt_svid_from_expired_token() {
452        let test_key = jsonwebkey::Key::generate_p256();
453
454        let encoding_key = jsonwebtoken::EncodingKey::from_ec_der(&test_key.to_der());
455
456        let target_audience = vec!["audience".to_owned()];
457        let test_key_id = "test-key-id";
458        let mut jwt_key = jsonwebkey::JsonWebKey::new(test_key);
459        jwt_key.set_algorithm(jsonwebkey::Algorithm::ES256).unwrap();
460        jwt_key.key_id = Some(test_key_id.to_string());
461
462        let res = serde_json::to_string(&jwt_key).expect("JWK should be serializable");
463        let jwk = serde_json::from_str(&res).expect("JWK should be deserializable");
464
465        // generate signed token
466        let token = generate_token(
467            target_audience.clone(),
468            "spiffe://example.org/service".to_string(),
469            Some("JWT".to_string()),
470            Some(test_key_id.to_string()),
471            1,
472            jsonwebtoken::Algorithm::ES256,
473            &encoding_key,
474        );
475
476        // create a new source of JWT bundles
477        let mut bundle_source = JwtBundleSet::default();
478        let trust_domain = TrustDomain::new("example.org").unwrap();
479        let mut bundle = JwtBundle::new(trust_domain);
480        bundle.add_jwt_authority(jwk).unwrap();
481        bundle_source.add_bundle(bundle);
482
483        // parse and validate JWT-SVID from signed token using the bundle source to validate the signature
484        let result =
485            JwtSvid::parse_and_validate(&token, &bundle_source, &["audience"]).unwrap_err();
486
487        assert!(matches!(result, JwtSvidError::InvalidToken(..)));
488    }
489
490    // used to generate jwt token for testing
491    fn generate_token(
492        aud: Vec<String>,
493        sub: String,
494        typ: Option<String>,
495        kid: Option<String>,
496        exp: u32,
497        alg: jsonwebtoken::Algorithm,
498        encoding_key: &EncodingKey,
499    ) -> String {
500        let claims = Claims { sub, aud, exp };
501
502        let header = jsonwebtoken::Header {
503            typ,
504            alg,
505            kid,
506            cty: None,
507            jku: None,
508            x5u: None,
509            x5c: None,
510            x5t: None,
511            jwk: None,
512            x5t_s256: None,
513        };
514        encode(&header, &claims, encoding_key).unwrap()
515    }
516}