snafu_derive/
shared.rs

1use std::collections::BTreeSet;
2
3pub(crate) use self::context_module::ContextModule;
4pub(crate) use self::context_selector::ContextSelector;
5pub(crate) use self::display::{Display, DisplayMatchArm};
6pub(crate) use self::error::{Error, ErrorProvideMatchArm, ErrorSourceMatchArm};
7pub(crate) use self::error_compat::{ErrorCompat, ErrorCompatBacktraceMatchArm};
8
9pub(crate) struct StaticIdent(&'static str);
10
11impl quote::ToTokens for StaticIdent {
12    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
13        proc_macro2::Ident::new(self.0, proc_macro2::Span::call_site()).to_tokens(tokens)
14    }
15}
16
17struct AllFieldNames<'a>(&'a crate::FieldContainer);
18
19impl<'a> AllFieldNames<'a> {
20    fn field_names(&self) -> BTreeSet<&'a proc_macro2::Ident> {
21        let user_fields = self.0.selector_kind.user_fields();
22        let backtrace_field = self.0.backtrace_field.as_ref();
23        let implicit_fields = &self.0.implicit_fields;
24        let message_field = self.0.selector_kind.message_field();
25        let source_field = self.0.selector_kind.source_field();
26
27        user_fields
28            .iter()
29            .chain(backtrace_field)
30            .chain(implicit_fields)
31            .chain(message_field)
32            .map(crate::Field::name)
33            .chain(source_field.map(crate::SourceField::name))
34            .collect()
35    }
36}
37
38pub mod context_module {
39    use crate::ModuleName;
40    use heck::ToSnakeCase;
41    use proc_macro2::TokenStream;
42    use quote::{quote, ToTokens};
43    use syn::Ident;
44
45    #[derive(Copy, Clone)]
46    pub(crate) struct ContextModule<'a, T> {
47        pub container_name: &'a Ident,
48        pub module_name: &'a ModuleName,
49        pub visibility: Option<&'a dyn ToTokens>,
50        pub body: &'a T,
51    }
52
53    impl<'a, T> ToTokens for ContextModule<'a, T>
54    where
55        T: ToTokens,
56    {
57        fn to_tokens(&self, stream: &mut TokenStream) {
58            let module_name = match self.module_name {
59                ModuleName::Default => {
60                    let name_str = self.container_name.to_string().to_snake_case();
61                    syn::Ident::new(&name_str, self.container_name.span())
62                }
63                ModuleName::Custom(name) => name.clone(),
64            };
65
66            let visibility = self.visibility;
67            let body = self.body;
68
69            let module_tokens = quote! {
70                #visibility mod #module_name {
71                    use super::*;
72
73                    #body
74                }
75            };
76
77            stream.extend(module_tokens);
78        }
79    }
80}
81
82pub mod context_selector {
83    use crate::{ContextSelectorKind, Field, SuffixKind};
84    use proc_macro2::TokenStream;
85    use quote::{format_ident, quote, ToTokens};
86
87    #[derive(Copy, Clone)]
88    pub(crate) struct ContextSelector<'a> {
89        pub backtrace_field: Option<&'a Field>,
90        pub implicit_fields: &'a [Field],
91        pub crate_root: &'a dyn ToTokens,
92        pub error_constructor_name: &'a dyn ToTokens,
93        pub original_generics_without_defaults: &'a [TokenStream],
94        pub parameterized_error_name: &'a dyn ToTokens,
95        pub selector_doc_string: &'a str,
96        pub selector_kind: &'a ContextSelectorKind,
97        pub selector_base_name: &'a proc_macro2::Ident,
98        pub user_fields: &'a [Field],
99        pub visibility: Option<&'a dyn ToTokens>,
100        pub where_clauses: &'a [TokenStream],
101        pub default_suffix: &'a SuffixKind,
102    }
103
104    impl ToTokens for ContextSelector<'_> {
105        fn to_tokens(&self, stream: &mut TokenStream) {
106            use self::ContextSelectorKind::*;
107
108            let context_selector = match self.selector_kind {
109                Context { source_field, .. } => {
110                    let context_selector_type = self.generate_type();
111                    let context_selector_impl = match source_field {
112                        Some(_) => None,
113                        None => Some(self.generate_leaf()),
114                    };
115                    let context_selector_into_error_impl =
116                        self.generate_into_error(source_field.as_ref());
117
118                    quote! {
119                        #context_selector_type
120                        #context_selector_impl
121                        #context_selector_into_error_impl
122                    }
123                }
124                Whatever {
125                    source_field,
126                    message_field,
127                } => self.generate_whatever(source_field.as_ref(), message_field),
128                NoContext { source_field } => self.generate_from_source(source_field),
129            };
130
131            stream.extend(context_selector)
132        }
133    }
134
135    impl ContextSelector<'_> {
136        fn user_field_generics(&self) -> Vec<proc_macro2::Ident> {
137            (0..self.user_fields.len())
138                .map(|i| format_ident!("__T{}", i))
139                .collect()
140        }
141
142        fn user_field_names(&self) -> Vec<&syn::Ident> {
143            self.user_fields
144                .iter()
145                .map(|Field { name, .. }| name)
146                .collect()
147        }
148
149        fn parameterized_selector_name(&self) -> TokenStream {
150            let selector_name = self
151                .selector_kind
152                .resolve_name(self.default_suffix, self.selector_base_name);
153            let user_generics = self.user_field_generics();
154
155            quote! { #selector_name<#(#user_generics,)*> }
156        }
157
158        fn extended_where_clauses(&self) -> Vec<TokenStream> {
159            let user_fields = self.user_fields;
160            let user_field_generics = self.user_field_generics();
161            let where_clauses = self.where_clauses;
162
163            let target_types = user_fields
164                .iter()
165                .map(|Field { ty, .. }| quote! { ::core::convert::Into<#ty>});
166
167            user_field_generics
168                .into_iter()
169                .zip(target_types)
170                .map(|(gen, bound)| quote! { #gen: #bound })
171                .chain(where_clauses.iter().cloned())
172                .collect()
173        }
174
175        fn transfer_user_fields(&self) -> Vec<TokenStream> {
176            self.user_field_names()
177                .into_iter()
178                .map(|name| {
179                    quote! { #name: ::core::convert::Into::into(self.#name) }
180                })
181                .collect()
182        }
183
184        fn construct_implicit_fields(&self) -> TokenStream {
185            let crate_root = self.crate_root;
186            let expression = quote! {
187                #crate_root::GenerateImplicitData::generate()
188            };
189
190            self.construct_implicit_fields_with_expression(expression)
191        }
192
193        fn construct_implicit_fields_with_source(&self) -> TokenStream {
194            let crate_root = self.crate_root;
195            let expression = quote! { {
196                use #crate_root::AsErrorSource;
197                let error = error.as_error_source();
198                #crate_root::GenerateImplicitData::generate_with_source(error)
199            } };
200
201            self.construct_implicit_fields_with_expression(expression)
202        }
203
204        fn construct_implicit_fields_with_expression(
205            &self,
206            expression: TokenStream,
207        ) -> TokenStream {
208            self.implicit_fields
209                .iter()
210                .chain(self.backtrace_field)
211                .map(|field| {
212                    let name = &field.name;
213                    quote! { #name: #expression, }
214                })
215                .collect()
216        }
217
218        fn generate_type(self) -> TokenStream {
219            let visibility = self.visibility;
220            let parameterized_selector_name = self.parameterized_selector_name();
221            let user_field_generics = self.user_field_generics();
222            let user_field_names = self.user_field_names();
223            let selector_doc_string = self.selector_doc_string;
224
225            let body = if user_field_names.is_empty() {
226                quote! { ; }
227            } else {
228                quote! {
229                    {
230                        #(
231                            #[allow(missing_docs)]
232                            #visibility #user_field_names: #user_field_generics
233                        ),*
234                    }
235                }
236            };
237
238            quote! {
239                #[derive(Debug, Copy, Clone)]
240                #[doc = #selector_doc_string]
241                #visibility struct #parameterized_selector_name #body
242            }
243        }
244
245        fn generate_leaf(self) -> TokenStream {
246            let error_constructor_name = self.error_constructor_name;
247            let original_generics_without_defaults = self.original_generics_without_defaults;
248            let parameterized_error_name = self.parameterized_error_name;
249            let parameterized_selector_name = self.parameterized_selector_name();
250            let user_field_generics = self.user_field_generics();
251            let visibility = self.visibility;
252            let extended_where_clauses = self.extended_where_clauses();
253            let transfer_user_fields = self.transfer_user_fields();
254            let construct_implicit_fields = self.construct_implicit_fields();
255
256            quote! {
257                impl<#(#user_field_generics,)*> #parameterized_selector_name {
258                    #[doc = "Consume the selector and return the associated error"]
259                    #[must_use]
260                    #[track_caller]
261                    #visibility fn build<#(#original_generics_without_defaults,)*>(self) -> #parameterized_error_name
262                    where
263                        #(#extended_where_clauses),*
264                    {
265                        #error_constructor_name {
266                            #construct_implicit_fields
267                            #(#transfer_user_fields,)*
268                        }
269                    }
270
271                    #[doc = "Consume the selector and return a `Result` with the associated error"]
272                    #[allow(dead_code)]
273                    #[track_caller]
274                    #visibility fn fail<#(#original_generics_without_defaults,)* __T>(self) -> ::core::result::Result<__T, #parameterized_error_name>
275                    where
276                        #(#extended_where_clauses),*
277                    {
278                        ::core::result::Result::Err(self.build())
279                    }
280                }
281            }
282        }
283
284        fn generate_into_error(self, source_field: Option<&crate::SourceField>) -> TokenStream {
285            let crate_root = self.crate_root;
286            let error_constructor_name = self.error_constructor_name;
287            let original_generics_without_defaults = self.original_generics_without_defaults;
288            let parameterized_error_name = self.parameterized_error_name;
289            let parameterized_selector_name = self.parameterized_selector_name();
290            let user_field_generics = self.user_field_generics();
291            let extended_where_clauses = self.extended_where_clauses();
292            let transfer_user_fields = self.transfer_user_fields();
293            let construct_implicit_fields = if source_field.is_some() {
294                self.construct_implicit_fields_with_source()
295            } else {
296                self.construct_implicit_fields()
297            };
298
299            let (source_ty, transform_source, transfer_source_field) = match source_field {
300                Some(source_field) => {
301                    let SourceInfo {
302                        source_field_type,
303                        transform_source,
304                        transfer_source_field,
305                    } = build_source_info(source_field);
306                    (
307                        quote! { #source_field_type },
308                        Some(transform_source),
309                        Some(transfer_source_field),
310                    )
311                }
312                None => (quote! { #crate_root::NoneError }, None, None),
313            };
314
315            quote! {
316                impl<#(#original_generics_without_defaults,)* #(#user_field_generics,)*> #crate_root::IntoError<#parameterized_error_name> for #parameterized_selector_name
317                where
318                    #parameterized_error_name: #crate_root::Error + #crate_root::ErrorCompat,
319                    #(#extended_where_clauses),*
320                {
321                    type Source = #source_ty;
322
323                    #[track_caller]
324                    fn into_error(self, error: Self::Source) -> #parameterized_error_name {
325                        #transform_source;
326                        #error_constructor_name {
327                            #construct_implicit_fields
328                            #transfer_source_field
329                            #(#transfer_user_fields),*
330                        }
331                    }
332                }
333            }
334        }
335
336        fn generate_whatever(
337            self,
338            source_field: Option<&crate::SourceField>,
339            message_field: &crate::Field,
340        ) -> TokenStream {
341            let crate_root = self.crate_root;
342            let parameterized_error_name = self.parameterized_error_name;
343            let error_constructor_name = self.error_constructor_name;
344            let construct_implicit_fields = self.construct_implicit_fields();
345            let original_generics_without_defaults = self.original_generics_without_defaults;
346            let construct_implicit_fields_with_source =
347                self.construct_implicit_fields_with_source();
348            let extended_where_clauses = self.extended_where_clauses();
349
350            // testme: transform
351
352            let (source_ty, transfer_source_field, empty_source_field) = match source_field {
353                Some(f) => {
354                    let source_field_type = f.transformation.source_ty();
355                    let source_field_name = &f.name;
356                    let source_transformation = f.transformation.transformation();
357
358                    (
359                        quote! { #source_field_type },
360                        Some(quote! { #source_field_name: (#source_transformation)(error), }),
361                        Some(quote! { #source_field_name: core::option::Option::None, }),
362                    )
363                }
364                None => (quote! { #crate_root::NoneError }, None, None),
365            };
366
367            let message_field_name = &message_field.name;
368
369            quote! {
370                impl<#(#original_generics_without_defaults,)*> #crate_root::FromString for #parameterized_error_name
371                where
372                    #(#extended_where_clauses),*
373                {
374                    type Source = #source_ty;
375
376                    #[track_caller]
377                    fn without_source(message: String) -> Self {
378                        #error_constructor_name {
379                            #construct_implicit_fields
380                            #empty_source_field
381                            #message_field_name: message,
382                        }
383                    }
384
385                    #[track_caller]
386                    fn with_source(error: Self::Source, message: String) -> Self {
387                        #error_constructor_name {
388                            #construct_implicit_fields_with_source
389                            #transfer_source_field
390                            #message_field_name: message,
391                        }
392                    }
393                }
394            }
395        }
396
397        fn generate_from_source(self, source_field: &crate::SourceField) -> TokenStream {
398            let parameterized_error_name = self.parameterized_error_name;
399            let error_constructor_name = self.error_constructor_name;
400            let construct_implicit_fields_with_source =
401                self.construct_implicit_fields_with_source();
402            let original_generics_without_defaults = self.original_generics_without_defaults;
403            let user_field_generics = self.user_field_generics();
404            let where_clauses = self.where_clauses;
405
406            let SourceInfo {
407                source_field_type,
408                transform_source,
409                transfer_source_field,
410            } = build_source_info(source_field);
411
412            quote! {
413                impl<#(#original_generics_without_defaults,)* #(#user_field_generics,)*> ::core::convert::From<#source_field_type> for #parameterized_error_name
414                where
415                    #(#where_clauses),*
416                {
417                    #[track_caller]
418                    fn from(error: #source_field_type) -> Self {
419                        #transform_source;
420                        #error_constructor_name {
421                            #construct_implicit_fields_with_source
422                            #transfer_source_field
423                        }
424                    }
425                }
426            }
427        }
428    }
429
430    struct SourceInfo<'a> {
431        source_field_type: &'a syn::Type,
432        transform_source: TokenStream,
433        transfer_source_field: TokenStream,
434    }
435
436    // Assumes that the error is in a variable called "error"
437    fn build_source_info(source_field: &crate::SourceField) -> SourceInfo<'_> {
438        let source_field_name = source_field.name();
439        let source_field_type = source_field.transformation.source_ty();
440        let target_field_type = source_field.transformation.target_ty();
441        let source_transformation = source_field.transformation.transformation();
442
443        let transform_source =
444            quote! { let error: #target_field_type = (#source_transformation)(error) };
445        let transfer_source_field = quote! { #source_field_name: error, };
446
447        SourceInfo {
448            source_field_type,
449            transform_source,
450            transfer_source_field,
451        }
452    }
453}
454
455pub mod display {
456    use super::StaticIdent;
457    use proc_macro2::TokenStream;
458    use quote::{quote, ToTokens};
459    use std::collections::BTreeSet;
460
461    const FORMATTER_ARG: StaticIdent = StaticIdent("__snafu_display_formatter");
462
463    pub(crate) struct Display<'a> {
464        pub(crate) arms: &'a [TokenStream],
465        pub(crate) original_generics: &'a [TokenStream],
466        pub(crate) parameterized_error_name: &'a dyn ToTokens,
467        pub(crate) where_clauses: &'a [TokenStream],
468    }
469
470    impl ToTokens for Display<'_> {
471        fn to_tokens(&self, stream: &mut TokenStream) {
472            let Self {
473                arms,
474                original_generics,
475                parameterized_error_name,
476                where_clauses,
477            } = *self;
478
479            let display_impl = quote! {
480                #[allow(single_use_lifetimes)]
481                impl<#(#original_generics),*> ::core::fmt::Display for #parameterized_error_name
482                where
483                    #(#where_clauses),*
484                {
485                    fn fmt(&self, #FORMATTER_ARG: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
486                        #[allow(unused_variables)]
487                        match *self {
488                            #(#arms),*
489                        }
490                    }
491                }
492            };
493
494            stream.extend(display_impl);
495        }
496    }
497
498    pub(crate) struct DisplayMatchArm<'a> {
499        pub(crate) field_container: &'a crate::FieldContainer,
500        pub(crate) default_name: &'a dyn ToTokens,
501        pub(crate) display_format: Option<&'a crate::Display>,
502        pub(crate) doc_comment: Option<&'a crate::DocComment>,
503        pub(crate) pattern_ident: &'a dyn ToTokens,
504        pub(crate) selector_kind: &'a crate::ContextSelectorKind,
505    }
506
507    impl ToTokens for DisplayMatchArm<'_> {
508        fn to_tokens(&self, stream: &mut TokenStream) {
509            let Self {
510                field_container,
511                default_name,
512                display_format,
513                doc_comment,
514                pattern_ident,
515                selector_kind,
516            } = *self;
517
518            let source_field = selector_kind.source_field();
519
520            if field_container.is_transparent {
521                // transparent errors always have a source field
522                let source_field_name = source_field.unwrap().name();
523
524                let match_arm = quote! {
525                    #pattern_ident { ref #source_field_name, .. } => {
526                        ::core::fmt::Display::fmt(#source_field_name, #FORMATTER_ARG)
527                    }
528                };
529
530                stream.extend(match_arm);
531                return;
532            }
533
534            let mut shorthand_names = &BTreeSet::new();
535            let mut assigned_names = &BTreeSet::new();
536
537            let format = match (display_format, doc_comment) {
538                (Some(v), _) => {
539                    let exprs = &v.exprs;
540                    shorthand_names = &v.shorthand_names;
541                    assigned_names = &v.assigned_names;
542                    quote! { #(#exprs),* }
543                }
544                (_, Some(d)) => {
545                    let content = &d.content;
546                    shorthand_names = &d.shorthand_names;
547                    quote! { #content }
548                }
549                _ => quote! { stringify!(#default_name) },
550            };
551
552            let field_names = super::AllFieldNames(field_container).field_names();
553
554            let shorthand_names = shorthand_names.iter().collect::<BTreeSet<_>>();
555            let assigned_names = assigned_names.iter().collect::<BTreeSet<_>>();
556
557            let shorthand_fields = &shorthand_names & &field_names;
558            let shorthand_fields = &shorthand_fields - &assigned_names;
559
560            let shorthand_assignments = quote! { #( #shorthand_fields = #shorthand_fields ),* };
561
562            let match_arm = quote! {
563                #pattern_ident { #(ref #field_names),* } => {
564                    write!(#FORMATTER_ARG, #format, #shorthand_assignments)
565                }
566            };
567
568            stream.extend(match_arm);
569        }
570    }
571}
572
573pub mod error {
574    use super::StaticIdent;
575    use crate::{FieldContainer, Provide, SourceField};
576    use proc_macro2::TokenStream;
577    use quote::{format_ident, quote, ToTokens};
578
579    pub(crate) const PROVIDE_ARG: StaticIdent = StaticIdent("__snafu_provide_demand");
580
581    pub(crate) struct Error<'a> {
582        pub(crate) crate_root: &'a dyn ToTokens,
583        pub(crate) description_arms: &'a [TokenStream],
584        pub(crate) original_generics: &'a [TokenStream],
585        pub(crate) parameterized_error_name: &'a dyn ToTokens,
586        pub(crate) provide_arms: &'a [TokenStream],
587        pub(crate) source_arms: &'a [TokenStream],
588        pub(crate) where_clauses: &'a [TokenStream],
589    }
590
591    impl ToTokens for Error<'_> {
592        fn to_tokens(&self, stream: &mut TokenStream) {
593            let Self {
594                crate_root,
595                description_arms,
596                original_generics,
597                parameterized_error_name,
598                provide_arms,
599                source_arms,
600                where_clauses,
601            } = *self;
602
603            let description_fn = quote! {
604                fn description(&self) -> &str {
605                    match *self {
606                        #(#description_arms)*
607                    }
608                }
609            };
610
611            let source_body = quote! {
612                use #crate_root::AsErrorSource;
613                match *self {
614                    #(#source_arms)*
615                }
616            };
617
618            let cause_fn = quote! {
619                fn cause(&self) -> ::core::option::Option<&dyn #crate_root::Error> {
620                    #source_body
621                }
622            };
623
624            let source_fn = quote! {
625                fn source(&self) -> ::core::option::Option<&(dyn #crate_root::Error + 'static)> {
626                    #source_body
627                }
628            };
629
630            let provide_fn = if cfg!(feature = "unstable-provider-api") {
631                Some(quote! {
632                    fn provide<'a>(&'a self, #PROVIDE_ARG: &mut #crate_root::error::Request<'a>) {
633                        match *self {
634                            #(#provide_arms,)*
635                        };
636                    }
637                })
638            } else {
639                None
640            };
641
642            let error = quote! {
643                #[allow(single_use_lifetimes)]
644                impl<#(#original_generics),*> #crate_root::Error for #parameterized_error_name
645                where
646                    Self: ::core::fmt::Debug + ::core::fmt::Display,
647                    #(#where_clauses),*
648                {
649                    #description_fn
650                    #cause_fn
651                    #source_fn
652                    #provide_fn
653                }
654            };
655
656            stream.extend(error);
657        }
658    }
659
660    pub(crate) struct ErrorSourceMatchArm<'a> {
661        pub(crate) field_container: &'a FieldContainer,
662        pub(crate) pattern_ident: &'a dyn ToTokens,
663    }
664
665    impl ToTokens for ErrorSourceMatchArm<'_> {
666        fn to_tokens(&self, stream: &mut TokenStream) {
667            let Self {
668                field_container:
669                    FieldContainer {
670                        selector_kind,
671                        is_transparent,
672                        ..
673                    },
674                pattern_ident,
675            } = *self;
676
677            let source_field = selector_kind.source_field();
678
679            let arm = match source_field {
680                Some(source_field) => {
681                    let SourceField {
682                        name: field_name, ..
683                    } = source_field;
684
685                    let convert_to_error_source = if selector_kind.is_whatever() {
686                        quote! {
687                            #field_name.as_ref().map(|e| e.as_error_source())
688                        }
689                    } else if *is_transparent {
690                        quote! {
691                            #field_name.as_error_source().source()
692                        }
693                    } else {
694                        quote! {
695                            ::core::option::Option::Some(#field_name.as_error_source())
696                        }
697                    };
698
699                    quote! {
700                        #pattern_ident { ref #field_name, .. } => {
701                            #convert_to_error_source
702                        }
703                    }
704                }
705                None => {
706                    quote! {
707                        #pattern_ident { .. } => { ::core::option::Option::None }
708                    }
709                }
710            };
711
712            stream.extend(arm);
713        }
714    }
715
716    pub(crate) struct ProvidePlus<'a> {
717        provide: &'a Provide,
718        cached_name: proc_macro2::Ident,
719    }
720
721    pub(crate) struct ErrorProvideMatchArm<'a> {
722        pub(crate) crate_root: &'a dyn ToTokens,
723        pub(crate) field_container: &'a FieldContainer,
724        pub(crate) pattern_ident: &'a dyn ToTokens,
725    }
726
727    impl<'a> ToTokens for ErrorProvideMatchArm<'a> {
728        fn to_tokens(&self, stream: &mut TokenStream) {
729            let Self {
730                crate_root,
731                field_container,
732                pattern_ident,
733            } = *self;
734
735            let user_fields = field_container.user_fields();
736            let provides = enhance_provider_list(field_container.provides());
737            let field_names = super::AllFieldNames(field_container).field_names();
738
739            let (hi_explicit_calls, lo_explicit_calls) = build_explicit_provide_calls(&provides);
740
741            let cached_expressions = quote_cached_expressions(&provides);
742
743            let provide_refs = user_fields
744                .iter()
745                .chain(&field_container.implicit_fields)
746                .chain(field_container.selector_kind.message_field())
747                .flat_map(|f| {
748                    if f.provide {
749                        Some((&f.ty, f.name()))
750                    } else {
751                        None
752                    }
753                });
754
755            let provided_source = field_container
756                .selector_kind
757                .source_field()
758                .filter(|f| f.provide);
759
760            let source_provide_ref =
761                provided_source.map(|f| (f.transformation.source_ty(), f.name()));
762
763            let provide_refs = provide_refs.chain(source_provide_ref);
764
765            let source_chain = provided_source.map(|f| {
766                let name = f.name();
767                quote! {
768                    #name.provide(#PROVIDE_ARG);
769                }
770            });
771
772            let user_chained = quote_chained(crate_root, &provides);
773
774            let shorthand_calls = provide_refs.map(|(ty, name)| {
775                quote! { #PROVIDE_ARG.provide_ref::<#ty>(#name) }
776            });
777
778            let provided_backtrace = field_container
779                .backtrace_field
780                .as_ref()
781                .filter(|f| f.provide);
782
783            let provide_backtrace = provided_backtrace.map(|f| {
784                let name = f.name();
785                quote! {
786                    if #PROVIDE_ARG.would_be_satisfied_by_ref_of::<#crate_root::Backtrace>() {
787                        if let ::core::option::Option::Some(bt) = #crate_root::AsBacktrace::as_backtrace(#name) {
788                            #PROVIDE_ARG.provide_ref::<#crate_root::Backtrace>(bt);
789                        }
790                    }
791                }
792            });
793
794            let arm = quote! {
795                #pattern_ident { #(ref #field_names,)* .. } => {
796                    #(#cached_expressions;)*
797                    #(#hi_explicit_calls;)*
798                    #source_chain;
799                    #(#user_chained;)*
800                    #provide_backtrace;
801                    #(#shorthand_calls;)*
802                    #(#lo_explicit_calls;)*
803                }
804            };
805
806            stream.extend(arm);
807        }
808    }
809
810    pub(crate) fn enhance_provider_list(provides: &[Provide]) -> Vec<ProvidePlus<'_>> {
811        provides
812            .iter()
813            .enumerate()
814            .map(|(i, provide)| {
815                let cached_name = format_ident!("__snafu_cached_expr_{}", i);
816                ProvidePlus {
817                    provide,
818                    cached_name,
819                }
820            })
821            .collect()
822    }
823
824    pub(crate) fn quote_cached_expressions<'a>(
825        provides: &'a [ProvidePlus<'a>],
826    ) -> impl Iterator<Item = proc_macro2::TokenStream> + 'a {
827        provides.iter().filter(|pp| pp.provide.is_chain).map(|pp| {
828            let cached_name = &pp.cached_name;
829            let expr = &pp.provide.expr;
830
831            quote! {
832                let #cached_name = #expr;
833            }
834        })
835    }
836
837    pub(crate) fn quote_chained<'a>(
838        crate_root: &'a dyn ToTokens,
839        provides: &'a [ProvidePlus<'a>],
840    ) -> impl Iterator<Item = proc_macro2::TokenStream> + 'a {
841        provides
842            .iter()
843            .filter(|pp| pp.provide.is_chain)
844            .map(move |pp| {
845                let arm = if pp.provide.is_opt {
846                    quote! { ::core::option::Option::Some(chained_item) }
847                } else {
848                    quote! { chained_item }
849                };
850                let cached_name = &pp.cached_name;
851
852                quote! {
853                    if let #arm = #cached_name {
854                        #crate_root::Error::provide(chained_item, #PROVIDE_ARG);
855                    }
856                }
857            })
858    }
859
860    fn quote_provides<'a, I>(provides: I) -> impl Iterator<Item = proc_macro2::TokenStream> + 'a
861    where
862        I: IntoIterator<Item = &'a ProvidePlus<'a>>,
863        I::IntoIter: 'a,
864    {
865        provides.into_iter().map(|pp| {
866            let ProvidePlus {
867                provide:
868                    Provide {
869                        is_chain,
870                        is_opt,
871                        is_priority: _,
872                        is_ref,
873                        ty,
874                        expr,
875                    },
876                cached_name,
877            } = pp;
878
879            let effective_expr = if *is_chain {
880                quote! { #cached_name }
881            } else {
882                quote! { #expr }
883            };
884
885            match (is_opt, is_ref) {
886                (true, true) => {
887                    quote! {
888                        if #PROVIDE_ARG.would_be_satisfied_by_ref_of::<#ty>() {
889                            if let ::core::option::Option::Some(v) = #effective_expr {
890                                #PROVIDE_ARG.provide_ref::<#ty>(v);
891                            }
892                        }
893                    }
894                }
895                (true, false) => {
896                    quote! {
897                        if #PROVIDE_ARG.would_be_satisfied_by_value_of::<#ty>() {
898                            if let ::core::option::Option::Some(v) = #effective_expr {
899                                #PROVIDE_ARG.provide_value::<#ty>(v);
900                            }
901                        }
902                    }
903                }
904                (false, true) => {
905                    quote! { #PROVIDE_ARG.provide_ref_with::<#ty>(|| #effective_expr) }
906                }
907                (false, false) => {
908                    quote! { #PROVIDE_ARG.provide_value_with::<#ty>(|| #effective_expr) }
909                }
910            }
911        })
912    }
913
914    pub(crate) fn build_explicit_provide_calls<'a>(
915        provides: &'a [ProvidePlus<'a>],
916    ) -> (
917        impl Iterator<Item = TokenStream> + 'a,
918        impl Iterator<Item = TokenStream> + 'a,
919    ) {
920        let (high_priority, low_priority): (Vec<_>, Vec<_>) =
921            provides.iter().partition(|pp| pp.provide.is_priority);
922
923        let hi_explicit_calls = quote_provides(high_priority);
924        let lo_explicit_calls = quote_provides(low_priority);
925
926        (hi_explicit_calls, lo_explicit_calls)
927    }
928}
929
930pub mod error_compat {
931    use crate::{Field, FieldContainer, SourceField};
932    use proc_macro2::TokenStream;
933    use quote::{quote, ToTokens};
934
935    pub(crate) struct ErrorCompat<'a> {
936        pub(crate) crate_root: &'a dyn ToTokens,
937        pub(crate) parameterized_error_name: &'a dyn ToTokens,
938        pub(crate) backtrace_arms: &'a [TokenStream],
939        pub(crate) original_generics: &'a [TokenStream],
940        pub(crate) where_clauses: &'a [TokenStream],
941    }
942
943    impl ToTokens for ErrorCompat<'_> {
944        fn to_tokens(&self, stream: &mut TokenStream) {
945            let Self {
946                crate_root,
947                parameterized_error_name,
948                backtrace_arms,
949                original_generics,
950                where_clauses,
951            } = *self;
952
953            let backtrace_fn = quote! {
954                fn backtrace(&self) -> ::core::option::Option<&#crate_root::Backtrace> {
955                    match *self {
956                        #(#backtrace_arms),*
957                    }
958                }
959            };
960
961            let error_compat_impl = quote! {
962                #[allow(single_use_lifetimes)]
963                impl<#(#original_generics),*> #crate_root::ErrorCompat for #parameterized_error_name
964                where
965                    #(#where_clauses),*
966                {
967                    #backtrace_fn
968                }
969            };
970
971            stream.extend(error_compat_impl);
972        }
973    }
974
975    pub(crate) struct ErrorCompatBacktraceMatchArm<'a> {
976        pub(crate) crate_root: &'a dyn ToTokens,
977        pub(crate) field_container: &'a FieldContainer,
978        pub(crate) pattern_ident: &'a dyn ToTokens,
979    }
980
981    impl ToTokens for ErrorCompatBacktraceMatchArm<'_> {
982        fn to_tokens(&self, stream: &mut TokenStream) {
983            let Self {
984                crate_root,
985                field_container:
986                    FieldContainer {
987                        backtrace_field,
988                        selector_kind,
989                        ..
990                    },
991                pattern_ident,
992            } = *self;
993
994            let match_arm = match (selector_kind.source_field(), backtrace_field) {
995                (Some(source_field), _) if source_field.backtrace_delegate => {
996                    let SourceField {
997                        name: field_name, ..
998                    } = source_field;
999                    quote! {
1000                        #pattern_ident { ref #field_name, .. } => { #crate_root::ErrorCompat::backtrace(#field_name) }
1001                    }
1002                }
1003                (_, Some(backtrace_field)) => {
1004                    let Field {
1005                        name: field_name, ..
1006                    } = backtrace_field;
1007                    quote! {
1008                        #pattern_ident { ref #field_name, .. } => { #crate_root::AsBacktrace::as_backtrace(#field_name) }
1009                    }
1010                }
1011                _ => {
1012                    quote! {
1013                        #pattern_ident { .. } => { ::core::option::Option::None }
1014                    }
1015                }
1016            };
1017
1018            stream.extend(match_arm);
1019        }
1020    }
1021}