1mod field;
5
6use crate::{default_lifetime, TypeAttrs};
7use field::SequenceField;
8use proc_macro2::TokenStream;
9use quote::quote;
10use syn::{DeriveInput, GenericParam, Generics, Ident, LifetimeParam};
11
12pub(crate) struct DeriveSequence {
14 ident: Ident,
16
17 generics: Generics,
19
20 fields: Vec<SequenceField>,
22}
23
24impl DeriveSequence {
25 pub fn new(input: DeriveInput) -> syn::Result<Self> {
27 let data = match input.data {
28 syn::Data::Struct(data) => data,
29 _ => abort!(
30 input.ident,
31 "can't derive `Sequence` on this type: only `struct` types are allowed",
32 ),
33 };
34
35 let type_attrs = TypeAttrs::parse(&input.attrs)?;
36
37 let fields = data
38 .fields
39 .iter()
40 .map(|field| SequenceField::new(field, &type_attrs))
41 .collect::<syn::Result<_>>()?;
42
43 Ok(Self {
44 ident: input.ident,
45 generics: input.generics.clone(),
46 fields,
47 })
48 }
49
50 pub fn to_tokens(&self) -> TokenStream {
52 let ident = &self.ident;
53 let mut generics = self.generics.clone();
54
55 let lifetime = generics
58 .lifetimes()
59 .next()
60 .map(|lt| lt.lifetime.clone())
61 .unwrap_or_else(|| {
62 let lt = default_lifetime();
63 generics
64 .params
65 .insert(0, GenericParam::Lifetime(LifetimeParam::new(lt.clone())));
66 lt
67 });
68
69 let (_, ty_generics, where_clause) = self.generics.split_for_impl();
71 let (impl_generics, _, _) = generics.split_for_impl();
72
73 let mut decode_body = Vec::new();
74 let mut decode_result = Vec::new();
75 let mut encoded_lengths = Vec::new();
76 let mut encode_fields = Vec::new();
77
78 for field in &self.fields {
79 decode_body.push(field.to_decode_tokens());
80 decode_result.push(&field.ident);
81
82 let field = field.to_encode_tokens();
83 encoded_lengths.push(quote!(#field.encoded_len()?));
84 encode_fields.push(quote!(#field.encode(writer)?;));
85 }
86
87 quote! {
88 impl #impl_generics ::der::DecodeValue<#lifetime> for #ident #ty_generics #where_clause {
89 fn decode_value<R: ::der::Reader<#lifetime>>(
90 reader: &mut R,
91 header: ::der::Header,
92 ) -> ::der::Result<Self> {
93 use ::der::{Decode as _, DecodeValue as _, Reader as _};
94
95 reader.read_nested(header.length, |reader| {
96 #(#decode_body)*
97
98 Ok(Self {
99 #(#decode_result),*
100 })
101 })
102 }
103 }
104
105 impl #impl_generics ::der::EncodeValue for #ident #ty_generics #where_clause {
106 fn value_len(&self) -> ::der::Result<::der::Length> {
107 use ::der::Encode as _;
108
109 [
110 #(#encoded_lengths),*
111 ]
112 .into_iter()
113 .try_fold(::der::Length::ZERO, |acc, len| acc + len)
114 }
115
116 fn encode_value(&self, writer: &mut impl ::der::Writer) -> ::der::Result<()> {
117 use ::der::Encode as _;
118 #(#encode_fields)*
119 Ok(())
120 }
121 }
122
123 impl #impl_generics ::der::Sequence<#lifetime> for #ident #ty_generics #where_clause {}
124 }
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::DeriveSequence;
131 use crate::{Asn1Type, TagMode};
132 use syn::parse_quote;
133
134 #[test]
136 fn algorithm_identifier_example() {
137 let input = parse_quote! {
138 #[derive(Sequence)]
139 pub struct AlgorithmIdentifier<'a> {
140 pub algorithm: ObjectIdentifier,
141 pub parameters: Option<Any<'a>>,
142 }
143 };
144
145 let ir = DeriveSequence::new(input).unwrap();
146 assert_eq!(ir.ident, "AlgorithmIdentifier");
147 assert_eq!(
148 ir.generics.lifetimes().next().unwrap().lifetime.to_string(),
149 "'a"
150 );
151 assert_eq!(ir.fields.len(), 2);
152
153 let algorithm_field = &ir.fields[0];
154 assert_eq!(algorithm_field.ident, "algorithm");
155 assert_eq!(algorithm_field.attrs.asn1_type, None);
156 assert_eq!(algorithm_field.attrs.context_specific, None);
157 assert_eq!(algorithm_field.attrs.tag_mode, TagMode::Explicit);
158
159 let parameters_field = &ir.fields[1];
160 assert_eq!(parameters_field.ident, "parameters");
161 assert_eq!(parameters_field.attrs.asn1_type, None);
162 assert_eq!(parameters_field.attrs.context_specific, None);
163 assert_eq!(parameters_field.attrs.tag_mode, TagMode::Explicit);
164 }
165
166 #[test]
168 fn spki_example() {
169 let input = parse_quote! {
170 #[derive(Sequence)]
171 pub struct SubjectPublicKeyInfo<'a> {
172 pub algorithm: AlgorithmIdentifier<'a>,
173
174 #[asn1(type = "BIT STRING")]
175 pub subject_public_key: &'a [u8],
176 }
177 };
178
179 let ir = DeriveSequence::new(input).unwrap();
180 assert_eq!(ir.ident, "SubjectPublicKeyInfo");
181 assert_eq!(
182 ir.generics.lifetimes().next().unwrap().lifetime.to_string(),
183 "'a"
184 );
185 assert_eq!(ir.fields.len(), 2);
186
187 let algorithm_field = &ir.fields[0];
188 assert_eq!(algorithm_field.ident, "algorithm");
189 assert_eq!(algorithm_field.attrs.asn1_type, None);
190 assert_eq!(algorithm_field.attrs.context_specific, None);
191 assert_eq!(algorithm_field.attrs.tag_mode, TagMode::Explicit);
192
193 let subject_public_key_field = &ir.fields[1];
194 assert_eq!(subject_public_key_field.ident, "subject_public_key");
195 assert_eq!(
196 subject_public_key_field.attrs.asn1_type,
197 Some(Asn1Type::BitString)
198 );
199 assert_eq!(subject_public_key_field.attrs.context_specific, None);
200 assert_eq!(subject_public_key_field.attrs.tag_mode, TagMode::Explicit);
201 }
202
203 #[test]
227 fn pkcs8_example() {
228 let input = parse_quote! {
229 #[derive(Sequence)]
230 pub struct OneAsymmetricKey<'a> {
231 pub version: u8,
232 pub private_key_algorithm: AlgorithmIdentifier<'a>,
233 #[asn1(type = "OCTET STRING")]
234 pub private_key: &'a [u8],
235 #[asn1(context_specific = "0", extensible = "true", optional = "true")]
236 pub attributes: Option<SetOf<Any<'a>, 1>>,
237 #[asn1(
238 context_specific = "1",
239 extensible = "true",
240 optional = "true",
241 type = "BIT STRING"
242 )]
243 pub public_key: Option<&'a [u8]>,
244 }
245 };
246
247 let ir = DeriveSequence::new(input).unwrap();
248 assert_eq!(ir.ident, "OneAsymmetricKey");
249 assert_eq!(
250 ir.generics.lifetimes().next().unwrap().lifetime.to_string(),
251 "'a"
252 );
253 assert_eq!(ir.fields.len(), 5);
254
255 let version_field = &ir.fields[0];
256 assert_eq!(version_field.ident, "version");
257 assert_eq!(version_field.attrs.asn1_type, None);
258 assert_eq!(version_field.attrs.context_specific, None);
259 assert_eq!(version_field.attrs.extensible, false);
260 assert_eq!(version_field.attrs.optional, false);
261 assert_eq!(version_field.attrs.tag_mode, TagMode::Explicit);
262
263 let algorithm_field = &ir.fields[1];
264 assert_eq!(algorithm_field.ident, "private_key_algorithm");
265 assert_eq!(algorithm_field.attrs.asn1_type, None);
266 assert_eq!(algorithm_field.attrs.context_specific, None);
267 assert_eq!(algorithm_field.attrs.extensible, false);
268 assert_eq!(algorithm_field.attrs.optional, false);
269 assert_eq!(algorithm_field.attrs.tag_mode, TagMode::Explicit);
270
271 let private_key_field = &ir.fields[2];
272 assert_eq!(private_key_field.ident, "private_key");
273 assert_eq!(
274 private_key_field.attrs.asn1_type,
275 Some(Asn1Type::OctetString)
276 );
277 assert_eq!(private_key_field.attrs.context_specific, None);
278 assert_eq!(private_key_field.attrs.extensible, false);
279 assert_eq!(private_key_field.attrs.optional, false);
280 assert_eq!(private_key_field.attrs.tag_mode, TagMode::Explicit);
281
282 let attributes_field = &ir.fields[3];
283 assert_eq!(attributes_field.ident, "attributes");
284 assert_eq!(attributes_field.attrs.asn1_type, None);
285 assert_eq!(
286 attributes_field.attrs.context_specific,
287 Some("0".parse().unwrap())
288 );
289 assert_eq!(attributes_field.attrs.extensible, true);
290 assert_eq!(attributes_field.attrs.optional, true);
291 assert_eq!(attributes_field.attrs.tag_mode, TagMode::Explicit);
292
293 let public_key_field = &ir.fields[4];
294 assert_eq!(public_key_field.ident, "public_key");
295 assert_eq!(public_key_field.attrs.asn1_type, Some(Asn1Type::BitString));
296 assert_eq!(
297 public_key_field.attrs.context_specific,
298 Some("1".parse().unwrap())
299 );
300 assert_eq!(public_key_field.attrs.extensible, true);
301 assert_eq!(public_key_field.attrs.optional, true);
302 assert_eq!(public_key_field.attrs.tag_mode, TagMode::Explicit);
303 }
304
305 #[test]
307 fn implicit_example() {
308 let input = parse_quote! {
309 #[asn1(tag_mode = "IMPLICIT")]
310 pub struct ImplicitSequence<'a> {
311 #[asn1(context_specific = "0", type = "BIT STRING")]
312 bit_string: BitString<'a>,
313
314 #[asn1(context_specific = "1", type = "GeneralizedTime")]
315 time: GeneralizedTime,
316
317 #[asn1(context_specific = "2", type = "UTF8String")]
318 utf8_string: String,
319 }
320 };
321
322 let ir = DeriveSequence::new(input).unwrap();
323 assert_eq!(ir.ident, "ImplicitSequence");
324 assert_eq!(
325 ir.generics.lifetimes().next().unwrap().lifetime.to_string(),
326 "'a"
327 );
328 assert_eq!(ir.fields.len(), 3);
329
330 let bit_string = &ir.fields[0];
331 assert_eq!(bit_string.ident, "bit_string");
332 assert_eq!(bit_string.attrs.asn1_type, Some(Asn1Type::BitString));
333 assert_eq!(
334 bit_string.attrs.context_specific,
335 Some("0".parse().unwrap())
336 );
337 assert_eq!(bit_string.attrs.tag_mode, TagMode::Implicit);
338
339 let time = &ir.fields[1];
340 assert_eq!(time.ident, "time");
341 assert_eq!(time.attrs.asn1_type, Some(Asn1Type::GeneralizedTime));
342 assert_eq!(time.attrs.context_specific, Some("1".parse().unwrap()));
343 assert_eq!(time.attrs.tag_mode, TagMode::Implicit);
344
345 let utf8_string = &ir.fields[2];
346 assert_eq!(utf8_string.ident, "utf8_string");
347 assert_eq!(utf8_string.attrs.asn1_type, Some(Asn1Type::Utf8String));
348 assert_eq!(
349 utf8_string.attrs.context_specific,
350 Some("2".parse().unwrap())
351 );
352 assert_eq!(utf8_string.attrs.tag_mode, TagMode::Implicit);
353 }
354}