1use crate::attributes::AttrNameValue;
6use crate::{default_lifetime, ATTR_NAME};
7use proc_macro2::TokenStream;
8use quote::quote;
9use syn::{DeriveInput, Expr, ExprLit, Ident, Lit, LitInt, Variant};
10
11const REPR_TYPES: &[&str] = &["u8", "u16", "u32"];
13
14pub(crate) struct DeriveEnumerated {
16 ident: Ident,
18
19 repr: Ident,
21
22 integer: bool,
24
25 variants: Vec<EnumeratedVariant>,
27}
28
29impl DeriveEnumerated {
30 pub fn new(input: DeriveInput) -> syn::Result<Self> {
32 let data = match input.data {
33 syn::Data::Enum(data) => data,
34 _ => abort!(
35 input.ident,
36 "can't derive `Enumerated` on this type: only `enum` types are allowed",
37 ),
38 };
39
40 let mut repr: Option<Ident> = None;
42 let mut integer = false;
43
44 for attr in &input.attrs {
45 if attr.path().is_ident(ATTR_NAME) {
46 let kvs = match AttrNameValue::parse_attribute(attr) {
47 Ok(kvs) => kvs,
48 Err(e) => abort!(attr, e),
49 };
50 for anv in kvs {
51 if anv.name.is_ident("type") {
52 match anv.value.value().as_str() {
53 "ENUMERATED" => integer = false,
54 "INTEGER" => integer = true,
55 s => abort!(anv.value, format_args!("`type = \"{s}\"` is unsupported")),
56 }
57 }
58 }
59 } else if attr.path().is_ident("repr") {
60 if repr.is_some() {
61 abort!(
62 attr,
63 "multiple `#[repr]` attributes encountered on `Enumerated`",
64 );
65 }
66
67 let r = attr.parse_args::<Ident>().map_err(|_| {
68 syn::Error::new_spanned(attr, "error parsing `#[repr]` attribute")
69 })?;
70
71 if !REPR_TYPES.contains(&r.to_string().as_str()) {
73 abort!(
74 attr,
75 format_args!("invalid `#[repr]` type: allowed types are {REPR_TYPES:?}"),
76 );
77 }
78
79 repr = Some(r);
80 }
81 }
82
83 let variants = data
85 .variants
86 .iter()
87 .map(EnumeratedVariant::new)
88 .collect::<syn::Result<_>>()?;
89
90 Ok(Self {
91 ident: input.ident.clone(),
92 repr: repr.ok_or_else(|| {
93 syn::Error::new_spanned(
94 &input.ident,
95 format_args!("no `#[repr]` attribute on enum: must be one of {REPR_TYPES:?}"),
96 )
97 })?,
98 variants,
99 integer,
100 })
101 }
102
103 pub fn to_tokens(&self) -> TokenStream {
105 let default_lifetime = default_lifetime();
106 let ident = &self.ident;
107 let repr = &self.repr;
108 let tag = match self.integer {
109 false => quote! { ::der::Tag::Enumerated },
110 true => quote! { ::der::Tag::Integer },
111 };
112
113 let mut try_from_body = Vec::new();
114 for variant in &self.variants {
115 try_from_body.push(variant.to_try_from_tokens());
116 }
117
118 quote! {
119 impl<#default_lifetime> ::der::DecodeValue<#default_lifetime> for #ident {
120 fn decode_value<R: ::der::Reader<#default_lifetime>>(
121 reader: &mut R,
122 header: ::der::Header
123 ) -> ::der::Result<Self> {
124 <#repr as ::der::DecodeValue>::decode_value(reader, header)?.try_into()
125 }
126 }
127
128 impl ::der::EncodeValue for #ident {
129 fn value_len(&self) -> ::der::Result<::der::Length> {
130 ::der::EncodeValue::value_len(&(*self as #repr))
131 }
132
133 fn encode_value(&self, encoder: &mut impl ::der::Writer) -> ::der::Result<()> {
134 ::der::EncodeValue::encode_value(&(*self as #repr), encoder)
135 }
136 }
137
138 impl ::der::FixedTag for #ident {
139 const TAG: ::der::Tag = #tag;
140 }
141
142 impl TryFrom<#repr> for #ident {
143 type Error = ::der::Error;
144
145 fn try_from(n: #repr) -> ::der::Result<Self> {
146 match n {
147 #(#try_from_body)*
148 _ => Err(#tag.value_error())
149 }
150 }
151 }
152 }
153 }
154}
155
156pub struct EnumeratedVariant {
158 ident: Ident,
160
161 discriminant: LitInt,
163}
164
165impl EnumeratedVariant {
166 fn new(input: &Variant) -> syn::Result<Self> {
168 for attr in &input.attrs {
169 if attr.path().is_ident(ATTR_NAME) {
170 abort!(
171 attr,
172 "`asn1` attribute is not allowed on fields of `Enumerated` types"
173 );
174 }
175 }
176
177 match &input.discriminant {
178 Some((
179 _,
180 Expr::Lit(ExprLit {
181 lit: Lit::Int(discriminant),
182 ..
183 }),
184 )) => Ok(Self {
185 ident: input.ident.clone(),
186 discriminant: discriminant.clone(),
187 }),
188 Some((_, other)) => abort!(other, "invalid discriminant for `Enumerated`"),
189 None => abort!(input, "`Enumerated` variant has no discriminant"),
190 }
191 }
192
193 pub fn to_try_from_tokens(&self) -> TokenStream {
195 let ident = &self.ident;
196 let discriminant = &self.discriminant;
197 quote! {
198 #discriminant => Ok(Self::#ident),
199 }
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use super::DeriveEnumerated;
206 use syn::parse_quote;
207
208 #[test]
210 fn crlreason_example() {
211 let input = parse_quote! {
212 #[repr(u32)]
213 pub enum CrlReason {
214 Unspecified = 0,
215 KeyCompromise = 1,
216 CaCompromise = 2,
217 AffiliationChanged = 3,
218 Superseded = 4,
219 CessationOfOperation = 5,
220 CertificateHold = 6,
221 RemoveFromCrl = 8,
222 PrivilegeWithdrawn = 9,
223 AaCompromised = 10,
224 }
225 };
226
227 let ir = DeriveEnumerated::new(input).unwrap();
228 assert_eq!(ir.ident, "CrlReason");
229 assert_eq!(ir.repr, "u32");
230 assert_eq!(ir.variants.len(), 10);
231
232 let unspecified = &ir.variants[0];
233 assert_eq!(unspecified.ident, "Unspecified");
234 assert_eq!(unspecified.discriminant.to_string(), "0");
235
236 let key_compromise = &ir.variants[1];
237 assert_eq!(key_compromise.ident, "KeyCompromise");
238 assert_eq!(key_compromise.discriminant.to_string(), "1");
239
240 let key_compromise = &ir.variants[2];
241 assert_eq!(key_compromise.ident, "CaCompromise");
242 assert_eq!(key_compromise.discriminant.to_string(), "2");
243 }
244}