use std::iter;
use proc_macro2::TokenStream;
use quote::quote;
use syn::{
parse::{Parse, ParseStream},
parse_macro_input, parse_quote,
punctuated::Punctuated,
token::Plus,
Error, FnArg, GenericParam, Ident, ItemTrait, Pat, PatType, Result, ReturnType, Signature,
Token, TraitBound, TraitItem, TraitItemConst, TraitItemFn, TraitItemType, Type, TypeGenerics,
TypeImplTrait, TypeParam, TypeParamBound,
};
struct Attrs {
variant: MakeVariant,
}
impl Parse for Attrs {
fn parse(input: ParseStream) -> Result<Self> {
Ok(Self {
variant: MakeVariant::parse(input)?,
})
}
}
enum MakeVariant {
Create {
name: Ident,
_colon: Token![:],
bounds: Punctuated<TraitBound, Plus>,
},
Rewrite {
bounds: Punctuated<TraitBound, Plus>,
},
}
impl Parse for MakeVariant {
fn parse(input: ParseStream) -> Result<Self> {
let variant = if input.peek(Ident) && input.peek2(Token![:]) {
MakeVariant::Create {
name: input.parse()?,
_colon: input.parse()?,
bounds: input.parse_terminated(TraitBound::parse, Token![+])?,
}
} else {
MakeVariant::Rewrite {
bounds: input.parse_terminated(TraitBound::parse, Token![+])?,
}
};
Ok(variant)
}
}
pub fn make(
attr: proc_macro::TokenStream,
item: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let attrs = parse_macro_input!(attr as Attrs);
let item = parse_macro_input!(item as ItemTrait);
match attrs.variant {
MakeVariant::Create { name, bounds, .. } => {
let maybe_allow_async_lint = if bounds
.iter()
.any(|b| b.path.segments.last().unwrap().ident == "Send")
{
quote! { #[allow(async_fn_in_trait)] }
} else {
quote! {}
};
let variant = mk_variant(&name, bounds, &item);
let blanket_impl = mk_blanket_impl(&name, &item);
quote! {
#maybe_allow_async_lint
#item
#variant
#blanket_impl
}
.into()
}
MakeVariant::Rewrite { bounds, .. } => {
let variant = mk_variant(&item.ident, bounds, &item);
quote! {
#variant
}
.into()
}
}
}
fn mk_variant(
variant: &Ident,
with_bounds: Punctuated<TraitBound, Plus>,
tr: &ItemTrait,
) -> TokenStream {
let bounds: Vec<_> = with_bounds
.into_iter()
.map(|b| TypeParamBound::Trait(b.clone()))
.collect();
let variant = ItemTrait {
ident: variant.clone(),
supertraits: tr.supertraits.iter().chain(&bounds).cloned().collect(),
items: tr
.items
.iter()
.map(|item| transform_item(item, &bounds))
.collect(),
..tr.clone()
};
quote! { #variant }
}
fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
let TraitItem::Fn(fn_item @ TraitItemFn { sig, .. }) = item else {
return item.clone();
};
let (arrow, output) = if sig.asyncness.is_some() {
let orig = match &sig.output {
ReturnType::Default => quote! { () },
ReturnType::Type(_, ty) => quote! { #ty },
};
let future = syn::parse2(quote! { ::core::future::Future<Output = #orig> }).unwrap();
let ty = Type::ImplTrait(TypeImplTrait {
impl_token: syn::parse2(quote! { impl }).unwrap(),
bounds: iter::once(TypeParamBound::Trait(future))
.chain(bounds.iter().cloned())
.collect(),
});
(syn::parse2(quote! { -> }).unwrap(), ty)
} else {
match &sig.output {
ReturnType::Type(arrow, ty) => match &**ty {
Type::ImplTrait(it) => {
let ty = Type::ImplTrait(TypeImplTrait {
impl_token: it.impl_token,
bounds: it.bounds.iter().chain(bounds).cloned().collect(),
});
(*arrow, ty)
}
_ => return item.clone(),
},
ReturnType::Default => return item.clone(),
}
};
TraitItem::Fn(TraitItemFn {
sig: Signature {
asyncness: None,
output: ReturnType::Type(arrow, Box::new(output)),
..sig.clone()
},
..fn_item.clone()
})
}
fn mk_blanket_impl(variant: &Ident, tr: &ItemTrait) -> TokenStream {
let orig = &tr.ident;
let (_impl, orig_ty_generics, _where) = &tr.generics.split_for_impl();
let items = tr
.items
.iter()
.map(|item| blanket_impl_item(item, variant, orig_ty_generics));
let blanket_bound: TypeParam =
parse_quote!(TraitVariantBlanketType: #variant #orig_ty_generics);
let blanket = &blanket_bound.ident.clone();
let mut blanket_generics = tr.generics.clone();
blanket_generics
.params
.push(GenericParam::Type(blanket_bound));
let (blanket_impl_generics, _ty, blanket_where_clause) = &blanket_generics.split_for_impl();
quote! {
impl #blanket_impl_generics #orig #orig_ty_generics for #blanket #blanket_where_clause
{
#(#items)*
}
}
}
fn blanket_impl_item(
item: &TraitItem,
variant: &Ident,
trait_ty_generics: &TypeGenerics<'_>,
) -> TokenStream {
match item {
TraitItem::Const(TraitItemConst {
ident,
generics,
ty,
..
}) => {
quote! {
const #ident #generics: #ty = <Self as #variant #trait_ty_generics>::#ident;
}
}
TraitItem::Fn(TraitItemFn { sig, .. }) => {
let ident = &sig.ident;
let args = sig.inputs.iter().map(|arg| match arg {
FnArg::Receiver(_) => quote! { self },
FnArg::Typed(PatType { pat, .. }) => match &**pat {
Pat::Ident(arg) => quote! { #arg },
_ => Error::new_spanned(pat, "patterns are not supported in arguments")
.to_compile_error(),
},
});
let maybe_await = if sig.asyncness.is_some() {
quote! { .await }
} else {
quote! {}
};
quote! {
#sig {
<Self as #variant #trait_ty_generics>::#ident(#(#args),*)#maybe_await
}
}
}
TraitItem::Type(TraitItemType {
ident, generics, ..
}) => {
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
quote! {
type #ident #impl_generics = <Self as #variant #trait_ty_generics>::#ident #ty_generics #where_clause;
}
}
_ => Error::new_spanned(item, "unsupported item type").into_compile_error(),
}
}