der_derive/
sequence.rs

1//! Support for deriving the `Sequence` trait on structs for the purposes of
2//! decoding/encoding ASN.1 `SEQUENCE` types as mapped to struct fields.
3
4mod field;
5
6use crate::{default_lifetime, TypeAttrs};
7use field::SequenceField;
8use proc_macro2::TokenStream;
9use quote::quote;
10use syn::{DeriveInput, GenericParam, Generics, Ident, LifetimeParam};
11
12/// Derive the `Sequence` trait for a struct
13pub(crate) struct DeriveSequence {
14    /// Name of the sequence struct.
15    ident: Ident,
16
17    /// Generics of the struct.
18    generics: Generics,
19
20    /// Fields of the struct.
21    fields: Vec<SequenceField>,
22}
23
24impl DeriveSequence {
25    /// Parse [`DeriveInput`].
26    pub fn new(input: DeriveInput) -> syn::Result<Self> {
27        let data = match input.data {
28            syn::Data::Struct(data) => data,
29            _ => abort!(
30                input.ident,
31                "can't derive `Sequence` on this type: only `struct` types are allowed",
32            ),
33        };
34
35        let type_attrs = TypeAttrs::parse(&input.attrs)?;
36
37        let fields = data
38            .fields
39            .iter()
40            .map(|field| SequenceField::new(field, &type_attrs))
41            .collect::<syn::Result<_>>()?;
42
43        Ok(Self {
44            ident: input.ident,
45            generics: input.generics.clone(),
46            fields,
47        })
48    }
49
50    /// Lower the derived output into a [`TokenStream`].
51    pub fn to_tokens(&self) -> TokenStream {
52        let ident = &self.ident;
53        let mut generics = self.generics.clone();
54
55        // Use the first lifetime parameter as lifetime for Decode/Encode lifetime
56        // if none found, add one.
57        let lifetime = generics
58            .lifetimes()
59            .next()
60            .map(|lt| lt.lifetime.clone())
61            .unwrap_or_else(|| {
62                let lt = default_lifetime();
63                generics
64                    .params
65                    .insert(0, GenericParam::Lifetime(LifetimeParam::new(lt.clone())));
66                lt
67            });
68
69        // We may or may not have inserted a lifetime.
70        let (_, ty_generics, where_clause) = self.generics.split_for_impl();
71        let (impl_generics, _, _) = generics.split_for_impl();
72
73        let mut decode_body = Vec::new();
74        let mut decode_result = Vec::new();
75        let mut encoded_lengths = Vec::new();
76        let mut encode_fields = Vec::new();
77
78        for field in &self.fields {
79            decode_body.push(field.to_decode_tokens());
80            decode_result.push(&field.ident);
81
82            let field = field.to_encode_tokens();
83            encoded_lengths.push(quote!(#field.encoded_len()?));
84            encode_fields.push(quote!(#field.encode(writer)?;));
85        }
86
87        quote! {
88            impl #impl_generics ::der::DecodeValue<#lifetime> for #ident #ty_generics #where_clause {
89                fn decode_value<R: ::der::Reader<#lifetime>>(
90                    reader: &mut R,
91                    header: ::der::Header,
92                ) -> ::der::Result<Self> {
93                    use ::der::{Decode as _, DecodeValue as _, Reader as _};
94
95                    reader.read_nested(header.length, |reader| {
96                        #(#decode_body)*
97
98                        Ok(Self {
99                            #(#decode_result),*
100                        })
101                    })
102                }
103            }
104
105            impl #impl_generics ::der::EncodeValue for #ident #ty_generics #where_clause {
106                fn value_len(&self) -> ::der::Result<::der::Length> {
107                    use ::der::Encode as _;
108
109                    [
110                        #(#encoded_lengths),*
111                    ]
112                        .into_iter()
113                        .try_fold(::der::Length::ZERO, |acc, len| acc + len)
114                }
115
116                fn encode_value(&self, writer: &mut impl ::der::Writer) -> ::der::Result<()> {
117                    use ::der::Encode as _;
118                    #(#encode_fields)*
119                    Ok(())
120                }
121            }
122
123            impl #impl_generics ::der::Sequence<#lifetime> for #ident #ty_generics #where_clause {}
124        }
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::DeriveSequence;
131    use crate::{Asn1Type, TagMode};
132    use syn::parse_quote;
133
134    /// X.509 SPKI `AlgorithmIdentifier`.
135    #[test]
136    fn algorithm_identifier_example() {
137        let input = parse_quote! {
138            #[derive(Sequence)]
139            pub struct AlgorithmIdentifier<'a> {
140                pub algorithm: ObjectIdentifier,
141                pub parameters: Option<Any<'a>>,
142            }
143        };
144
145        let ir = DeriveSequence::new(input).unwrap();
146        assert_eq!(ir.ident, "AlgorithmIdentifier");
147        assert_eq!(
148            ir.generics.lifetimes().next().unwrap().lifetime.to_string(),
149            "'a"
150        );
151        assert_eq!(ir.fields.len(), 2);
152
153        let algorithm_field = &ir.fields[0];
154        assert_eq!(algorithm_field.ident, "algorithm");
155        assert_eq!(algorithm_field.attrs.asn1_type, None);
156        assert_eq!(algorithm_field.attrs.context_specific, None);
157        assert_eq!(algorithm_field.attrs.tag_mode, TagMode::Explicit);
158
159        let parameters_field = &ir.fields[1];
160        assert_eq!(parameters_field.ident, "parameters");
161        assert_eq!(parameters_field.attrs.asn1_type, None);
162        assert_eq!(parameters_field.attrs.context_specific, None);
163        assert_eq!(parameters_field.attrs.tag_mode, TagMode::Explicit);
164    }
165
166    /// X.509 `SubjectPublicKeyInfo`.
167    #[test]
168    fn spki_example() {
169        let input = parse_quote! {
170            #[derive(Sequence)]
171            pub struct SubjectPublicKeyInfo<'a> {
172                pub algorithm: AlgorithmIdentifier<'a>,
173
174                #[asn1(type = "BIT STRING")]
175                pub subject_public_key: &'a [u8],
176            }
177        };
178
179        let ir = DeriveSequence::new(input).unwrap();
180        assert_eq!(ir.ident, "SubjectPublicKeyInfo");
181        assert_eq!(
182            ir.generics.lifetimes().next().unwrap().lifetime.to_string(),
183            "'a"
184        );
185        assert_eq!(ir.fields.len(), 2);
186
187        let algorithm_field = &ir.fields[0];
188        assert_eq!(algorithm_field.ident, "algorithm");
189        assert_eq!(algorithm_field.attrs.asn1_type, None);
190        assert_eq!(algorithm_field.attrs.context_specific, None);
191        assert_eq!(algorithm_field.attrs.tag_mode, TagMode::Explicit);
192
193        let subject_public_key_field = &ir.fields[1];
194        assert_eq!(subject_public_key_field.ident, "subject_public_key");
195        assert_eq!(
196            subject_public_key_field.attrs.asn1_type,
197            Some(Asn1Type::BitString)
198        );
199        assert_eq!(subject_public_key_field.attrs.context_specific, None);
200        assert_eq!(subject_public_key_field.attrs.tag_mode, TagMode::Explicit);
201    }
202
203    /// PKCS#8v2 `OneAsymmetricKey`.
204    ///
205    /// ```text
206    /// OneAsymmetricKey ::= SEQUENCE {
207    ///     version                   Version,
208    ///     privateKeyAlgorithm       PrivateKeyAlgorithmIdentifier,
209    ///     privateKey                PrivateKey,
210    ///     attributes            [0] Attributes OPTIONAL,
211    ///     ...,
212    ///     [[2: publicKey        [1] PublicKey OPTIONAL ]],
213    ///     ...
214    ///   }
215    ///
216    /// Version ::= INTEGER { v1(0), v2(1) } (v1, ..., v2)
217    ///
218    /// PrivateKeyAlgorithmIdentifier ::= AlgorithmIdentifier
219    ///
220    /// PrivateKey ::= OCTET STRING
221    ///
222    /// Attributes ::= SET OF Attribute
223    ///
224    /// PublicKey ::= BIT STRING
225    /// ```
226    #[test]
227    fn pkcs8_example() {
228        let input = parse_quote! {
229            #[derive(Sequence)]
230            pub struct OneAsymmetricKey<'a> {
231                pub version: u8,
232                pub private_key_algorithm: AlgorithmIdentifier<'a>,
233                #[asn1(type = "OCTET STRING")]
234                pub private_key: &'a [u8],
235                #[asn1(context_specific = "0", extensible = "true", optional = "true")]
236                pub attributes: Option<SetOf<Any<'a>, 1>>,
237                #[asn1(
238                    context_specific = "1",
239                    extensible = "true",
240                    optional = "true",
241                    type = "BIT STRING"
242                )]
243                pub public_key: Option<&'a [u8]>,
244            }
245        };
246
247        let ir = DeriveSequence::new(input).unwrap();
248        assert_eq!(ir.ident, "OneAsymmetricKey");
249        assert_eq!(
250            ir.generics.lifetimes().next().unwrap().lifetime.to_string(),
251            "'a"
252        );
253        assert_eq!(ir.fields.len(), 5);
254
255        let version_field = &ir.fields[0];
256        assert_eq!(version_field.ident, "version");
257        assert_eq!(version_field.attrs.asn1_type, None);
258        assert_eq!(version_field.attrs.context_specific, None);
259        assert_eq!(version_field.attrs.extensible, false);
260        assert_eq!(version_field.attrs.optional, false);
261        assert_eq!(version_field.attrs.tag_mode, TagMode::Explicit);
262
263        let algorithm_field = &ir.fields[1];
264        assert_eq!(algorithm_field.ident, "private_key_algorithm");
265        assert_eq!(algorithm_field.attrs.asn1_type, None);
266        assert_eq!(algorithm_field.attrs.context_specific, None);
267        assert_eq!(algorithm_field.attrs.extensible, false);
268        assert_eq!(algorithm_field.attrs.optional, false);
269        assert_eq!(algorithm_field.attrs.tag_mode, TagMode::Explicit);
270
271        let private_key_field = &ir.fields[2];
272        assert_eq!(private_key_field.ident, "private_key");
273        assert_eq!(
274            private_key_field.attrs.asn1_type,
275            Some(Asn1Type::OctetString)
276        );
277        assert_eq!(private_key_field.attrs.context_specific, None);
278        assert_eq!(private_key_field.attrs.extensible, false);
279        assert_eq!(private_key_field.attrs.optional, false);
280        assert_eq!(private_key_field.attrs.tag_mode, TagMode::Explicit);
281
282        let attributes_field = &ir.fields[3];
283        assert_eq!(attributes_field.ident, "attributes");
284        assert_eq!(attributes_field.attrs.asn1_type, None);
285        assert_eq!(
286            attributes_field.attrs.context_specific,
287            Some("0".parse().unwrap())
288        );
289        assert_eq!(attributes_field.attrs.extensible, true);
290        assert_eq!(attributes_field.attrs.optional, true);
291        assert_eq!(attributes_field.attrs.tag_mode, TagMode::Explicit);
292
293        let public_key_field = &ir.fields[4];
294        assert_eq!(public_key_field.ident, "public_key");
295        assert_eq!(public_key_field.attrs.asn1_type, Some(Asn1Type::BitString));
296        assert_eq!(
297            public_key_field.attrs.context_specific,
298            Some("1".parse().unwrap())
299        );
300        assert_eq!(public_key_field.attrs.extensible, true);
301        assert_eq!(public_key_field.attrs.optional, true);
302        assert_eq!(public_key_field.attrs.tag_mode, TagMode::Explicit);
303    }
304
305    /// `IMPLICIT` tagged example
306    #[test]
307    fn implicit_example() {
308        let input = parse_quote! {
309            #[asn1(tag_mode = "IMPLICIT")]
310            pub struct ImplicitSequence<'a> {
311                #[asn1(context_specific = "0", type = "BIT STRING")]
312                bit_string: BitString<'a>,
313
314                #[asn1(context_specific = "1", type = "GeneralizedTime")]
315                time: GeneralizedTime,
316
317                #[asn1(context_specific = "2", type = "UTF8String")]
318                utf8_string: String,
319            }
320        };
321
322        let ir = DeriveSequence::new(input).unwrap();
323        assert_eq!(ir.ident, "ImplicitSequence");
324        assert_eq!(
325            ir.generics.lifetimes().next().unwrap().lifetime.to_string(),
326            "'a"
327        );
328        assert_eq!(ir.fields.len(), 3);
329
330        let bit_string = &ir.fields[0];
331        assert_eq!(bit_string.ident, "bit_string");
332        assert_eq!(bit_string.attrs.asn1_type, Some(Asn1Type::BitString));
333        assert_eq!(
334            bit_string.attrs.context_specific,
335            Some("0".parse().unwrap())
336        );
337        assert_eq!(bit_string.attrs.tag_mode, TagMode::Implicit);
338
339        let time = &ir.fields[1];
340        assert_eq!(time.ident, "time");
341        assert_eq!(time.attrs.asn1_type, Some(Asn1Type::GeneralizedTime));
342        assert_eq!(time.attrs.context_specific, Some("1".parse().unwrap()));
343        assert_eq!(time.attrs.tag_mode, TagMode::Implicit);
344
345        let utf8_string = &ir.fields[2];
346        assert_eq!(utf8_string.ident, "utf8_string");
347        assert_eq!(utf8_string.attrs.asn1_type, Some(Asn1Type::Utf8String));
348        assert_eq!(
349            utf8_string.attrs.context_specific,
350            Some("2".parse().unwrap())
351        );
352        assert_eq!(utf8_string.attrs.tag_mode, TagMode::Implicit);
353    }
354}