1mod variant;
6
7use self::variant::ChoiceVariant;
8use crate::{default_lifetime, TypeAttrs};
9use proc_macro2::TokenStream;
10use quote::quote;
11use syn::{DeriveInput, Ident, Lifetime};
12
13pub(crate) struct DeriveChoice {
15 ident: Ident,
17
18 lifetime: Option<Lifetime>,
20
21 variants: Vec<ChoiceVariant>,
23}
24
25impl DeriveChoice {
26 pub fn new(input: DeriveInput) -> syn::Result<Self> {
28 let data = match input.data {
29 syn::Data::Enum(data) => data,
30 _ => abort!(
31 input.ident,
32 "can't derive `Choice` on this type: only `enum` types are allowed",
33 ),
34 };
35
36 let lifetime = input
38 .generics
39 .lifetimes()
40 .next()
41 .map(|lt| lt.lifetime.clone());
42
43 let type_attrs = TypeAttrs::parse(&input.attrs)?;
44 let variants = data
45 .variants
46 .iter()
47 .map(|variant| ChoiceVariant::new(variant, &type_attrs))
48 .collect::<syn::Result<_>>()?;
49
50 Ok(Self {
51 ident: input.ident,
52 lifetime,
53 variants,
54 })
55 }
56
57 pub fn to_tokens(&self) -> TokenStream {
59 let ident = &self.ident;
60
61 let lifetime = match self.lifetime {
62 Some(ref lifetime) => quote!(#lifetime),
63 None => {
64 let lifetime = default_lifetime();
65 quote!(#lifetime)
66 }
67 };
68
69 let lt_params = self
72 .lifetime
73 .as_ref()
74 .map(|_| lifetime.clone())
75 .unwrap_or_default();
76
77 let mut can_decode_body = Vec::new();
78 let mut decode_body = Vec::new();
79 let mut encode_body = Vec::new();
80 let mut value_len_body = Vec::new();
81 let mut tagged_body = Vec::new();
82
83 for variant in &self.variants {
84 can_decode_body.push(variant.tag.to_tokens());
85 decode_body.push(variant.to_decode_tokens());
86 encode_body.push(variant.to_encode_value_tokens());
87 value_len_body.push(variant.to_value_len_tokens());
88 tagged_body.push(variant.to_tagged_tokens());
89 }
90
91 quote! {
92 impl<#lifetime> ::der::Choice<#lifetime> for #ident<#lt_params> {
93 fn can_decode(tag: ::der::Tag) -> bool {
94 matches!(tag, #(#can_decode_body)|*)
95 }
96 }
97
98 impl<#lifetime> ::der::Decode<#lifetime> for #ident<#lt_params> {
99 fn decode<R: ::der::Reader<#lifetime>>(reader: &mut R) -> ::der::Result<Self> {
100 use der::Reader as _;
101 match reader.peek_tag()? {
102 #(#decode_body)*
103 actual => Err(der::ErrorKind::TagUnexpected {
104 expected: None,
105 actual
106 }
107 .into()),
108 }
109 }
110 }
111
112 impl<#lt_params> ::der::EncodeValue for #ident<#lt_params> {
113 fn encode_value(&self, encoder: &mut impl ::der::Writer) -> ::der::Result<()> {
114 match self {
115 #(#encode_body)*
116 }
117 }
118
119 fn value_len(&self) -> ::der::Result<::der::Length> {
120 match self {
121 #(#value_len_body)*
122 }
123 }
124 }
125
126 impl<#lt_params> ::der::Tagged for #ident<#lt_params> {
127 fn tag(&self) -> ::der::Tag {
128 match self {
129 #(#tagged_body)*
130 }
131 }
132 }
133 }
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::DeriveChoice;
140 use crate::{Asn1Type, Tag, TagMode};
141 use syn::parse_quote;
142
143 #[test]
152 fn time_example() {
153 let input = parse_quote! {
154 pub enum Time {
155 #[asn1(type = "UTCTime")]
156 UtcTime(UtcTime),
157
158 #[asn1(type = "GeneralizedTime")]
159 GeneralTime(GeneralizedTime),
160 }
161 };
162
163 let ir = DeriveChoice::new(input).unwrap();
164 assert_eq!(ir.ident, "Time");
165 assert_eq!(ir.lifetime, None);
166 assert_eq!(ir.variants.len(), 2);
167
168 let utc_time = &ir.variants[0];
169 assert_eq!(utc_time.ident, "UtcTime");
170 assert_eq!(utc_time.attrs.asn1_type, Some(Asn1Type::UtcTime));
171 assert_eq!(utc_time.attrs.context_specific, None);
172 assert_eq!(utc_time.attrs.tag_mode, TagMode::Explicit);
173 assert_eq!(utc_time.tag, Tag::Universal(Asn1Type::UtcTime));
174
175 let general_time = &ir.variants[1];
176 assert_eq!(general_time.ident, "GeneralTime");
177 assert_eq!(
178 general_time.attrs.asn1_type,
179 Some(Asn1Type::GeneralizedTime)
180 );
181 assert_eq!(general_time.attrs.context_specific, None);
182 assert_eq!(general_time.attrs.tag_mode, TagMode::Explicit);
183 assert_eq!(general_time.tag, Tag::Universal(Asn1Type::GeneralizedTime));
184 }
185
186 #[test]
188 fn implicit_example() {
189 let input = parse_quote! {
190 #[asn1(tag_mode = "IMPLICIT")]
191 pub enum ImplicitChoice<'a> {
192 #[asn1(context_specific = "0", type = "BIT STRING")]
193 BitString(BitString<'a>),
194
195 #[asn1(context_specific = "1", type = "GeneralizedTime")]
196 Time(GeneralizedTime),
197
198 #[asn1(context_specific = "2", type = "UTF8String")]
199 Utf8String(String),
200 }
201 };
202
203 let ir = DeriveChoice::new(input).unwrap();
204 assert_eq!(ir.ident, "ImplicitChoice");
205 assert_eq!(ir.lifetime.unwrap().to_string(), "'a");
206 assert_eq!(ir.variants.len(), 3);
207
208 let bit_string = &ir.variants[0];
209 assert_eq!(bit_string.ident, "BitString");
210 assert_eq!(bit_string.attrs.asn1_type, Some(Asn1Type::BitString));
211 assert_eq!(
212 bit_string.attrs.context_specific,
213 Some("0".parse().unwrap())
214 );
215 assert_eq!(bit_string.attrs.tag_mode, TagMode::Implicit);
216 assert_eq!(
217 bit_string.tag,
218 Tag::ContextSpecific {
219 constructed: false,
220 number: "0".parse().unwrap()
221 }
222 );
223
224 let time = &ir.variants[1];
225 assert_eq!(time.ident, "Time");
226 assert_eq!(time.attrs.asn1_type, Some(Asn1Type::GeneralizedTime));
227 assert_eq!(time.attrs.context_specific, Some("1".parse().unwrap()));
228 assert_eq!(time.attrs.tag_mode, TagMode::Implicit);
229 assert_eq!(
230 time.tag,
231 Tag::ContextSpecific {
232 constructed: false,
233 number: "1".parse().unwrap()
234 }
235 );
236
237 let utf8_string = &ir.variants[2];
238 assert_eq!(utf8_string.ident, "Utf8String");
239 assert_eq!(utf8_string.attrs.asn1_type, Some(Asn1Type::Utf8String));
240 assert_eq!(
241 utf8_string.attrs.context_specific,
242 Some("2".parse().unwrap())
243 );
244 assert_eq!(utf8_string.attrs.tag_mode, TagMode::Implicit);
245 assert_eq!(
246 utf8_string.tag,
247 Tag::ContextSpecific {
248 constructed: false,
249 number: "2".parse().unwrap()
250 }
251 );
252 }
253}