der_derive/
choice.rs

1//! Support for deriving the `Decode` and `Encode` traits on enums for
2//! the purposes of decoding/encoding ASN.1 `CHOICE` types as mapped to
3//! enum variants.
4
5mod variant;
6
7use self::variant::ChoiceVariant;
8use crate::{default_lifetime, TypeAttrs};
9use proc_macro2::TokenStream;
10use quote::quote;
11use syn::{DeriveInput, Ident, Lifetime};
12
13/// Derive the `Choice` trait for an enum.
14pub(crate) struct DeriveChoice {
15    /// Name of the enum type.
16    ident: Ident,
17
18    /// Lifetime of the type.
19    lifetime: Option<Lifetime>,
20
21    /// Variants of this `Choice`.
22    variants: Vec<ChoiceVariant>,
23}
24
25impl DeriveChoice {
26    /// Parse [`DeriveInput`].
27    pub fn new(input: DeriveInput) -> syn::Result<Self> {
28        let data = match input.data {
29            syn::Data::Enum(data) => data,
30            _ => abort!(
31                input.ident,
32                "can't derive `Choice` on this type: only `enum` types are allowed",
33            ),
34        };
35
36        // TODO(tarcieri): properly handle multiple lifetimes
37        let lifetime = input
38            .generics
39            .lifetimes()
40            .next()
41            .map(|lt| lt.lifetime.clone());
42
43        let type_attrs = TypeAttrs::parse(&input.attrs)?;
44        let variants = data
45            .variants
46            .iter()
47            .map(|variant| ChoiceVariant::new(variant, &type_attrs))
48            .collect::<syn::Result<_>>()?;
49
50        Ok(Self {
51            ident: input.ident,
52            lifetime,
53            variants,
54        })
55    }
56
57    /// Lower the derived output into a [`TokenStream`].
58    pub fn to_tokens(&self) -> TokenStream {
59        let ident = &self.ident;
60
61        let lifetime = match self.lifetime {
62            Some(ref lifetime) => quote!(#lifetime),
63            None => {
64                let lifetime = default_lifetime();
65                quote!(#lifetime)
66            }
67        };
68
69        // Lifetime parameters
70        // TODO(tarcieri): support multiple lifetimes
71        let lt_params = self
72            .lifetime
73            .as_ref()
74            .map(|_| lifetime.clone())
75            .unwrap_or_default();
76
77        let mut can_decode_body = Vec::new();
78        let mut decode_body = Vec::new();
79        let mut encode_body = Vec::new();
80        let mut value_len_body = Vec::new();
81        let mut tagged_body = Vec::new();
82
83        for variant in &self.variants {
84            can_decode_body.push(variant.tag.to_tokens());
85            decode_body.push(variant.to_decode_tokens());
86            encode_body.push(variant.to_encode_value_tokens());
87            value_len_body.push(variant.to_value_len_tokens());
88            tagged_body.push(variant.to_tagged_tokens());
89        }
90
91        quote! {
92            impl<#lifetime> ::der::Choice<#lifetime> for #ident<#lt_params> {
93                fn can_decode(tag: ::der::Tag) -> bool {
94                    matches!(tag, #(#can_decode_body)|*)
95                }
96            }
97
98            impl<#lifetime> ::der::Decode<#lifetime> for #ident<#lt_params> {
99                fn decode<R: ::der::Reader<#lifetime>>(reader: &mut R) -> ::der::Result<Self> {
100                    use der::Reader as _;
101                    match reader.peek_tag()? {
102                        #(#decode_body)*
103                        actual => Err(der::ErrorKind::TagUnexpected {
104                            expected: None,
105                            actual
106                        }
107                        .into()),
108                    }
109                }
110            }
111
112            impl<#lt_params> ::der::EncodeValue for #ident<#lt_params> {
113                fn encode_value(&self, encoder: &mut impl ::der::Writer) -> ::der::Result<()> {
114                    match self {
115                        #(#encode_body)*
116                    }
117                }
118
119                fn value_len(&self) -> ::der::Result<::der::Length> {
120                    match self {
121                        #(#value_len_body)*
122                    }
123                }
124            }
125
126            impl<#lt_params> ::der::Tagged for #ident<#lt_params> {
127                fn tag(&self) -> ::der::Tag {
128                    match self {
129                        #(#tagged_body)*
130                    }
131                }
132            }
133        }
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::DeriveChoice;
140    use crate::{Asn1Type, Tag, TagMode};
141    use syn::parse_quote;
142
143    /// Based on `Time` as defined in RFC 5280:
144    /// <https://tools.ietf.org/html/rfc5280#page-117>
145    ///
146    /// ```text
147    /// Time ::= CHOICE {
148    ///      utcTime        UTCTime,
149    ///      generalTime    GeneralizedTime }
150    /// ```
151    #[test]
152    fn time_example() {
153        let input = parse_quote! {
154            pub enum Time {
155                #[asn1(type = "UTCTime")]
156                UtcTime(UtcTime),
157
158                #[asn1(type = "GeneralizedTime")]
159                GeneralTime(GeneralizedTime),
160            }
161        };
162
163        let ir = DeriveChoice::new(input).unwrap();
164        assert_eq!(ir.ident, "Time");
165        assert_eq!(ir.lifetime, None);
166        assert_eq!(ir.variants.len(), 2);
167
168        let utc_time = &ir.variants[0];
169        assert_eq!(utc_time.ident, "UtcTime");
170        assert_eq!(utc_time.attrs.asn1_type, Some(Asn1Type::UtcTime));
171        assert_eq!(utc_time.attrs.context_specific, None);
172        assert_eq!(utc_time.attrs.tag_mode, TagMode::Explicit);
173        assert_eq!(utc_time.tag, Tag::Universal(Asn1Type::UtcTime));
174
175        let general_time = &ir.variants[1];
176        assert_eq!(general_time.ident, "GeneralTime");
177        assert_eq!(
178            general_time.attrs.asn1_type,
179            Some(Asn1Type::GeneralizedTime)
180        );
181        assert_eq!(general_time.attrs.context_specific, None);
182        assert_eq!(general_time.attrs.tag_mode, TagMode::Explicit);
183        assert_eq!(general_time.tag, Tag::Universal(Asn1Type::GeneralizedTime));
184    }
185
186    /// `IMPLICIT` tagged example
187    #[test]
188    fn implicit_example() {
189        let input = parse_quote! {
190            #[asn1(tag_mode = "IMPLICIT")]
191            pub enum ImplicitChoice<'a> {
192                #[asn1(context_specific = "0", type = "BIT STRING")]
193                BitString(BitString<'a>),
194
195                #[asn1(context_specific = "1", type = "GeneralizedTime")]
196                Time(GeneralizedTime),
197
198                #[asn1(context_specific = "2", type = "UTF8String")]
199                Utf8String(String),
200            }
201        };
202
203        let ir = DeriveChoice::new(input).unwrap();
204        assert_eq!(ir.ident, "ImplicitChoice");
205        assert_eq!(ir.lifetime.unwrap().to_string(), "'a");
206        assert_eq!(ir.variants.len(), 3);
207
208        let bit_string = &ir.variants[0];
209        assert_eq!(bit_string.ident, "BitString");
210        assert_eq!(bit_string.attrs.asn1_type, Some(Asn1Type::BitString));
211        assert_eq!(
212            bit_string.attrs.context_specific,
213            Some("0".parse().unwrap())
214        );
215        assert_eq!(bit_string.attrs.tag_mode, TagMode::Implicit);
216        assert_eq!(
217            bit_string.tag,
218            Tag::ContextSpecific {
219                constructed: false,
220                number: "0".parse().unwrap()
221            }
222        );
223
224        let time = &ir.variants[1];
225        assert_eq!(time.ident, "Time");
226        assert_eq!(time.attrs.asn1_type, Some(Asn1Type::GeneralizedTime));
227        assert_eq!(time.attrs.context_specific, Some("1".parse().unwrap()));
228        assert_eq!(time.attrs.tag_mode, TagMode::Implicit);
229        assert_eq!(
230            time.tag,
231            Tag::ContextSpecific {
232                constructed: false,
233                number: "1".parse().unwrap()
234            }
235        );
236
237        let utf8_string = &ir.variants[2];
238        assert_eq!(utf8_string.ident, "Utf8String");
239        assert_eq!(utf8_string.attrs.asn1_type, Some(Asn1Type::Utf8String));
240        assert_eq!(
241            utf8_string.attrs.context_specific,
242            Some("2".parse().unwrap())
243        );
244        assert_eq!(utf8_string.attrs.tag_mode, TagMode::Implicit);
245        assert_eq!(
246            utf8_string.tag,
247            Tag::ContextSpecific {
248                constructed: false,
249                number: "2".parse().unwrap()
250            }
251        );
252    }
253}