der_derive/
enumerated.rs

1//! Support for deriving the `Decode` and `Encode` traits on enums for
2//! the purposes of decoding/encoding ASN.1 `ENUMERATED` types as mapped to
3//! enum variants.
4
5use crate::attributes::AttrNameValue;
6use crate::{default_lifetime, ATTR_NAME};
7use proc_macro2::TokenStream;
8use quote::quote;
9use syn::{DeriveInput, Expr, ExprLit, Ident, Lit, LitInt, Variant};
10
11/// Valid options for the `#[repr]` attribute on `Enumerated` types.
12const REPR_TYPES: &[&str] = &["u8", "u16", "u32"];
13
14/// Derive the `Enumerated` trait for an enum.
15pub(crate) struct DeriveEnumerated {
16    /// Name of the enum type.
17    ident: Ident,
18
19    /// Value of the `repr` attribute.
20    repr: Ident,
21
22    /// Whether or not to tag the enum as an integer
23    integer: bool,
24
25    /// Variants of this enum.
26    variants: Vec<EnumeratedVariant>,
27}
28
29impl DeriveEnumerated {
30    /// Parse [`DeriveInput`].
31    pub fn new(input: DeriveInput) -> syn::Result<Self> {
32        let data = match input.data {
33            syn::Data::Enum(data) => data,
34            _ => abort!(
35                input.ident,
36                "can't derive `Enumerated` on this type: only `enum` types are allowed",
37            ),
38        };
39
40        // Reject `asn1` attributes, parse the `repr` attribute
41        let mut repr: Option<Ident> = None;
42        let mut integer = false;
43
44        for attr in &input.attrs {
45            if attr.path().is_ident(ATTR_NAME) {
46                let kvs = match AttrNameValue::parse_attribute(attr) {
47                    Ok(kvs) => kvs,
48                    Err(e) => abort!(attr, e),
49                };
50                for anv in kvs {
51                    if anv.name.is_ident("type") {
52                        match anv.value.value().as_str() {
53                            "ENUMERATED" => integer = false,
54                            "INTEGER" => integer = true,
55                            s => abort!(anv.value, format_args!("`type = \"{s}\"` is unsupported")),
56                        }
57                    }
58                }
59            } else if attr.path().is_ident("repr") {
60                if repr.is_some() {
61                    abort!(
62                        attr,
63                        "multiple `#[repr]` attributes encountered on `Enumerated`",
64                    );
65                }
66
67                let r = attr.parse_args::<Ident>().map_err(|_| {
68                    syn::Error::new_spanned(attr, "error parsing `#[repr]` attribute")
69                })?;
70
71                // Validate
72                if !REPR_TYPES.contains(&r.to_string().as_str()) {
73                    abort!(
74                        attr,
75                        format_args!("invalid `#[repr]` type: allowed types are {REPR_TYPES:?}"),
76                    );
77                }
78
79                repr = Some(r);
80            }
81        }
82
83        // Parse enum variants
84        let variants = data
85            .variants
86            .iter()
87            .map(EnumeratedVariant::new)
88            .collect::<syn::Result<_>>()?;
89
90        Ok(Self {
91            ident: input.ident.clone(),
92            repr: repr.ok_or_else(|| {
93                syn::Error::new_spanned(
94                    &input.ident,
95                    format_args!("no `#[repr]` attribute on enum: must be one of {REPR_TYPES:?}"),
96                )
97            })?,
98            variants,
99            integer,
100        })
101    }
102
103    /// Lower the derived output into a [`TokenStream`].
104    pub fn to_tokens(&self) -> TokenStream {
105        let default_lifetime = default_lifetime();
106        let ident = &self.ident;
107        let repr = &self.repr;
108        let tag = match self.integer {
109            false => quote! { ::der::Tag::Enumerated },
110            true => quote! { ::der::Tag::Integer },
111        };
112
113        let mut try_from_body = Vec::new();
114        for variant in &self.variants {
115            try_from_body.push(variant.to_try_from_tokens());
116        }
117
118        quote! {
119            impl<#default_lifetime> ::der::DecodeValue<#default_lifetime> for #ident {
120                fn decode_value<R: ::der::Reader<#default_lifetime>>(
121                    reader: &mut R,
122                    header: ::der::Header
123                ) -> ::der::Result<Self> {
124                    <#repr as ::der::DecodeValue>::decode_value(reader, header)?.try_into()
125                }
126            }
127
128            impl ::der::EncodeValue for #ident {
129                fn value_len(&self) -> ::der::Result<::der::Length> {
130                    ::der::EncodeValue::value_len(&(*self as #repr))
131                }
132
133                fn encode_value(&self, encoder: &mut impl ::der::Writer) -> ::der::Result<()> {
134                    ::der::EncodeValue::encode_value(&(*self as #repr), encoder)
135                }
136            }
137
138            impl ::der::FixedTag for #ident {
139                const TAG: ::der::Tag = #tag;
140            }
141
142            impl TryFrom<#repr> for #ident {
143                type Error = ::der::Error;
144
145                fn try_from(n: #repr) -> ::der::Result<Self> {
146                    match n {
147                        #(#try_from_body)*
148                        _ => Err(#tag.value_error())
149                    }
150                }
151            }
152        }
153    }
154}
155
156/// "IR" for a variant of a derived `Enumerated`.
157pub struct EnumeratedVariant {
158    /// Variant name.
159    ident: Ident,
160
161    /// Integer value that this variant corresponds to.
162    discriminant: LitInt,
163}
164
165impl EnumeratedVariant {
166    /// Create a new [`ChoiceVariant`] from the input [`Variant`].
167    fn new(input: &Variant) -> syn::Result<Self> {
168        for attr in &input.attrs {
169            if attr.path().is_ident(ATTR_NAME) {
170                abort!(
171                    attr,
172                    "`asn1` attribute is not allowed on fields of `Enumerated` types"
173                );
174            }
175        }
176
177        match &input.discriminant {
178            Some((
179                _,
180                Expr::Lit(ExprLit {
181                    lit: Lit::Int(discriminant),
182                    ..
183                }),
184            )) => Ok(Self {
185                ident: input.ident.clone(),
186                discriminant: discriminant.clone(),
187            }),
188            Some((_, other)) => abort!(other, "invalid discriminant for `Enumerated`"),
189            None => abort!(input, "`Enumerated` variant has no discriminant"),
190        }
191    }
192
193    /// Write the body for the derived [`TryFrom`] impl.
194    pub fn to_try_from_tokens(&self) -> TokenStream {
195        let ident = &self.ident;
196        let discriminant = &self.discriminant;
197        quote! {
198            #discriminant => Ok(Self::#ident),
199        }
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use super::DeriveEnumerated;
206    use syn::parse_quote;
207
208    /// X.509 `CRLReason`.
209    #[test]
210    fn crlreason_example() {
211        let input = parse_quote! {
212            #[repr(u32)]
213            pub enum CrlReason {
214                Unspecified = 0,
215                KeyCompromise = 1,
216                CaCompromise = 2,
217                AffiliationChanged = 3,
218                Superseded = 4,
219                CessationOfOperation = 5,
220                CertificateHold = 6,
221                RemoveFromCrl = 8,
222                PrivilegeWithdrawn = 9,
223                AaCompromised = 10,
224            }
225        };
226
227        let ir = DeriveEnumerated::new(input).unwrap();
228        assert_eq!(ir.ident, "CrlReason");
229        assert_eq!(ir.repr, "u32");
230        assert_eq!(ir.variants.len(), 10);
231
232        let unspecified = &ir.variants[0];
233        assert_eq!(unspecified.ident, "Unspecified");
234        assert_eq!(unspecified.discriminant.to_string(), "0");
235
236        let key_compromise = &ir.variants[1];
237        assert_eq!(key_compromise.ident, "KeyCompromise");
238        assert_eq!(key_compromise.discriminant.to_string(), "1");
239
240        let key_compromise = &ir.variants[2];
241        assert_eq!(key_compromise.ident, "CaCompromise");
242        assert_eq!(key_compromise.discriminant.to_string(), "2");
243    }
244}