der_derive/
enumerated.rs
use crate::attributes::AttrNameValue;
use crate::{default_lifetime, ATTR_NAME};
use proc_macro2::TokenStream;
use quote::quote;
use syn::{DeriveInput, Expr, ExprLit, Ident, Lit, LitInt, Variant};
const REPR_TYPES: &[&str] = &["u8", "u16", "u32"];
pub(crate) struct DeriveEnumerated {
ident: Ident,
repr: Ident,
integer: bool,
variants: Vec<EnumeratedVariant>,
}
impl DeriveEnumerated {
pub fn new(input: DeriveInput) -> syn::Result<Self> {
let data = match input.data {
syn::Data::Enum(data) => data,
_ => abort!(
input.ident,
"can't derive `Enumerated` on this type: only `enum` types are allowed",
),
};
let mut repr: Option<Ident> = None;
let mut integer = false;
for attr in &input.attrs {
if attr.path().is_ident(ATTR_NAME) {
let kvs = match AttrNameValue::parse_attribute(attr) {
Ok(kvs) => kvs,
Err(e) => abort!(attr, e),
};
for anv in kvs {
if anv.name.is_ident("type") {
match anv.value.value().as_str() {
"ENUMERATED" => integer = false,
"INTEGER" => integer = true,
s => abort!(anv.value, format_args!("`type = \"{s}\"` is unsupported")),
}
}
}
} else if attr.path().is_ident("repr") {
if repr.is_some() {
abort!(
attr,
"multiple `#[repr]` attributes encountered on `Enumerated`",
);
}
let r = attr.parse_args::<Ident>().map_err(|_| {
syn::Error::new_spanned(attr, "error parsing `#[repr]` attribute")
})?;
if !REPR_TYPES.contains(&r.to_string().as_str()) {
abort!(
attr,
format_args!("invalid `#[repr]` type: allowed types are {REPR_TYPES:?}"),
);
}
repr = Some(r);
}
}
let variants = data
.variants
.iter()
.map(EnumeratedVariant::new)
.collect::<syn::Result<_>>()?;
Ok(Self {
ident: input.ident.clone(),
repr: repr.ok_or_else(|| {
syn::Error::new_spanned(
&input.ident,
format_args!("no `#[repr]` attribute on enum: must be one of {REPR_TYPES:?}"),
)
})?,
variants,
integer,
})
}
pub fn to_tokens(&self) -> TokenStream {
let default_lifetime = default_lifetime();
let ident = &self.ident;
let repr = &self.repr;
let tag = match self.integer {
false => quote! { ::der::Tag::Enumerated },
true => quote! { ::der::Tag::Integer },
};
let mut try_from_body = Vec::new();
for variant in &self.variants {
try_from_body.push(variant.to_try_from_tokens());
}
quote! {
impl<#default_lifetime> ::der::DecodeValue<#default_lifetime> for #ident {
fn decode_value<R: ::der::Reader<#default_lifetime>>(
reader: &mut R,
header: ::der::Header
) -> ::der::Result<Self> {
<#repr as ::der::DecodeValue>::decode_value(reader, header)?.try_into()
}
}
impl ::der::EncodeValue for #ident {
fn value_len(&self) -> ::der::Result<::der::Length> {
::der::EncodeValue::value_len(&(*self as #repr))
}
fn encode_value(&self, encoder: &mut impl ::der::Writer) -> ::der::Result<()> {
::der::EncodeValue::encode_value(&(*self as #repr), encoder)
}
}
impl ::der::FixedTag for #ident {
const TAG: ::der::Tag = #tag;
}
impl TryFrom<#repr> for #ident {
type Error = ::der::Error;
fn try_from(n: #repr) -> ::der::Result<Self> {
match n {
#(#try_from_body)*
_ => Err(#tag.value_error())
}
}
}
}
}
}
pub struct EnumeratedVariant {
ident: Ident,
discriminant: LitInt,
}
impl EnumeratedVariant {
fn new(input: &Variant) -> syn::Result<Self> {
for attr in &input.attrs {
if attr.path().is_ident(ATTR_NAME) {
abort!(
attr,
"`asn1` attribute is not allowed on fields of `Enumerated` types"
);
}
}
match &input.discriminant {
Some((
_,
Expr::Lit(ExprLit {
lit: Lit::Int(discriminant),
..
}),
)) => Ok(Self {
ident: input.ident.clone(),
discriminant: discriminant.clone(),
}),
Some((_, other)) => abort!(other, "invalid discriminant for `Enumerated`"),
None => abort!(input, "`Enumerated` variant has no discriminant"),
}
}
pub fn to_try_from_tokens(&self) -> TokenStream {
let ident = &self.ident;
let discriminant = &self.discriminant;
quote! {
#discriminant => Ok(Self::#ident),
}
}
}
#[cfg(test)]
mod tests {
use super::DeriveEnumerated;
use syn::parse_quote;
#[test]
fn crlreason_example() {
let input = parse_quote! {
#[repr(u32)]
pub enum CrlReason {
Unspecified = 0,
KeyCompromise = 1,
CaCompromise = 2,
AffiliationChanged = 3,
Superseded = 4,
CessationOfOperation = 5,
CertificateHold = 6,
RemoveFromCrl = 8,
PrivilegeWithdrawn = 9,
AaCompromised = 10,
}
};
let ir = DeriveEnumerated::new(input).unwrap();
assert_eq!(ir.ident, "CrlReason");
assert_eq!(ir.repr, "u32");
assert_eq!(ir.variants.len(), 10);
let unspecified = &ir.variants[0];
assert_eq!(unspecified.ident, "Unspecified");
assert_eq!(unspecified.discriminant.to_string(), "0");
let key_compromise = &ir.variants[1];
assert_eq!(key_compromise.ident, "KeyCompromise");
assert_eq!(key_compromise.discriminant.to_string(), "1");
let key_compromise = &ir.variants[2];
assert_eq!(key_compromise.ident, "CaCompromise");
assert_eq!(key_compromise.discriminant.to_string(), "2");
}
}