rustify_derive/
lib.rs

1//! Provides a derive macro for easily implementing an `Endpoint` from the
2//! [rustify][1] crate. See the documentation for `rustify` for details on how
3//! to use this macro.
4//!
5//! [1]: https://docs.rs/rustify/
6
7#[macro_use]
8extern crate synstructure;
9extern crate proc_macro;
10
11mod error;
12mod params;
13mod parse;
14
15use std::{collections::HashMap, convert::TryFrom};
16
17use error::Error;
18use params::Parameters;
19use proc_macro2::Span;
20use quote::quote;
21use regex::Regex;
22use syn::{self, spanned::Spanned, Field, Generics, Ident, Meta};
23
24const MACRO_NAME: &str = "Endpoint";
25const ATTR_NAME: &str = "endpoint";
26
27#[derive(Debug, PartialEq, Eq, Hash)]
28pub(crate) enum EndpointAttribute {
29    Body,
30    Query,
31    Raw,
32    Skip,
33    Untagged,
34}
35
36impl TryFrom<&Meta> for EndpointAttribute {
37    type Error = Error;
38    fn try_from(m: &Meta) -> Result<Self, Self::Error> {
39        match m.path().get_ident() {
40            Some(i) => match i.to_string().to_lowercase().as_str() {
41                "body" => Ok(EndpointAttribute::Body),
42                "query" => Ok(EndpointAttribute::Query),
43                "raw" => Ok(EndpointAttribute::Raw),
44                "skip" => Ok(EndpointAttribute::Skip),
45                _ => Err(Error::new(
46                    m.span(),
47                    format!("Unknown attribute: {}", i).as_str(),
48                )),
49            },
50            None => Err(Error::new(m.span(), "Invalid attribute")),
51        }
52    }
53}
54
55/// Generates the path string for the endpoint.
56///
57/// The string supplied by the end-user supports basic interpolation using curly
58/// braces. For example,
59/// ```
60/// endpoint(path = "user/{self.name}")
61/// ```
62/// Should produce:
63/// ```
64/// format!("user/{}", self.name);
65/// ```
66/// This is currently accomplished using a basic regular expression which
67/// matches contents in the braces, extracts them out, leaving behind the empty
68/// braces and placing the contents into the proper position in `format!`.
69///
70/// If no interpolation is needed the user provided string is fed into
71/// `String::from` without modification.
72fn gen_path(path: &syn::LitStr) -> Result<proc_macro2::TokenStream, Error> {
73    let re = Regex::new(r"\{(.*?)\}").unwrap();
74    let mut fmt_args: Vec<syn::Expr> = Vec::new();
75    for cap in re.captures_iter(path.value().as_str()) {
76        let expr = syn::parse_str(&cap[1]);
77        match expr {
78            Ok(ex) => fmt_args.push(ex),
79            Err(_) => {
80                return Err(Error::new(
81                    path.span(),
82                    format!("Failed parsing format argument as expression: {}", &cap[1]).as_str(),
83                ));
84            }
85        }
86    }
87    let path = syn::LitStr::new(
88        re.replace_all(path.value().as_str(), "{}")
89            .to_string()
90            .as_str(),
91        Span::call_site(),
92    );
93
94    if !fmt_args.is_empty() {
95        Ok(quote! {
96            format!(#path, #(#fmt_args),*)
97        })
98    } else {
99        Ok(quote! {
100            String::from(#path)
101        })
102    }
103}
104
105/// Generates the query method for generating query parameters.
106///
107/// If any fields are found with the [EndpointAttribute::Query] attribute they
108/// are combined into a new struct and then serialized into a query string. If
109/// the attribute is not found on any of the fields the query method is not
110/// generated.
111fn gen_query(
112    fields: &HashMap<EndpointAttribute, Vec<Field>>,
113    serde_attrs: &[Meta],
114) -> proc_macro2::TokenStream {
115    let query_fields = fields.get(&EndpointAttribute::Query);
116    if let Some(v) = query_fields {
117        // Construct query function
118        let temp = parse::fields_to_struct(v, serde_attrs);
119        quote! {
120            fn query(&self) -> Result<Option<String>, ClientError> {
121                #temp
122
123                Ok(Some(build_query(&__temp)?))
124            }
125        }
126    } else {
127        quote! {}
128    }
129}
130
131/// Generates the body method for generating the request body.
132///
133/// The final result is determined by which attributes are present and/or
134/// missing on the struct fields. The following order is respected:
135///
136/// * If a field is found with the [EndpointAttribute::Raw] attribute that field
137///   is returned directly as the request body. The assumption is this field
138///   will always be a [Vec<u8>].
139/// * If any fields are found with the [EndpointAttribute::Body] attribute they
140///   are combined into a new struct and then serialized into the request body
141///   depending on the request type of the Endpoint.
142/// * If neither of the above two conditions are true, and there are fields
143///   found that don't have any attribute, those fields are combined into a new
144///   struct and then serialized into the request body depending on the request
145///   type of the Endpoint.
146/// * If none of the above is true, the body method is not generated.
147fn gen_body(
148    fields: &HashMap<EndpointAttribute, Vec<Field>>,
149    serde_attrs: &[Meta],
150) -> Result<proc_macro2::TokenStream, Error> {
151    // Check for a raw field first
152    if let Some(v) = fields.get(&EndpointAttribute::Raw) {
153        if v.len() > 1 {
154            return Err(Error::new(v[1].span(), "May only mark one field as raw"));
155        }
156
157        let id = v[0].ident.clone().unwrap();
158        Ok(quote! {
159            fn body(&self) -> Result<Option<Vec<u8>>, ClientError>{
160                Ok(Some(self.#id.clone()))
161            }
162        })
163    // Then for any body fields
164    } else if let Some(v) = fields.get(&EndpointAttribute::Body) {
165        let temp = parse::fields_to_struct(v, serde_attrs);
166        Ok(quote! {
167            fn body(&self) -> Result<Option<Vec<u8>>, ClientError> {
168                #temp
169
170                Ok(Some(build_body(&__temp, Self::REQUEST_BODY_TYPE)?))
171            }
172        })
173    // Then for any untagged fields
174    } else if let Some(v) = fields.get(&EndpointAttribute::Untagged) {
175        let temp = parse::fields_to_struct(v, serde_attrs);
176        Ok(quote! {
177            fn body(&self) -> Result<Option<Vec<u8>>, ClientError> {
178                #temp
179
180                Ok(Some(build_body(&__temp, Self::REQUEST_BODY_TYPE)?))
181            }
182        })
183    // Leave it undefined if no body fields found
184    } else {
185        Ok(quote! {})
186    }
187}
188
189/// Generates `builder()` and `exec_*` helper methods for use with
190/// `derive_builder`.
191///
192/// Adds an implementation to the base struct which provides a `builder` method
193/// for returning instances of the Builder variant of the struct. This removes
194/// the need to explicitly import it.
195fn gen_builder(id: &Ident, generics: &Generics) -> proc_macro2::TokenStream {
196    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
197    let builder_id: syn::Type = syn::parse_str(format!("{}Builder", id).as_str()).unwrap();
198    let builder_func: syn::Expr =
199        syn::parse_str(format!("{}Builder::default()", id).as_str()).unwrap();
200
201    quote! {
202        impl #impl_generics #id #ty_generics #where_clause {
203            pub fn builder() -> #builder_id #ty_generics {
204                #builder_func
205            }
206        }
207    }
208}
209
210/// Parses parameters passed into the `endpoint` attribute attached to the
211/// struct.
212fn parse_params(attr: &Meta) -> Result<Parameters, Error> {
213    // Parse the attribute as a key/value pair list
214    let kv = parse::attr_kv(attr)?;
215
216    // Create map from key/value pair list
217    let map = parse::to_map(&kv)?;
218
219    // Convert map to Parameters
220    params::Parameters::new(map)
221}
222
223/// Implements `Endpoint` on the provided struct.
224fn endpoint_derive(s: synstructure::Structure) -> proc_macro2::TokenStream {
225    // Parse `endpoint` attributes attached to input struct
226    let attrs = match parse::attributes(&s.ast().attrs, ATTR_NAME) {
227        Ok(v) => v,
228        Err(e) => return e.into_tokens(),
229    };
230
231    // Parse `endpoint` attributes attached to input struct fields
232    let field_attrs = match parse::field_attributes(&s.ast().data) {
233        Ok(v) => v,
234        Err(e) => return e.into_tokens(),
235    };
236
237    // Verify attribute is present
238    if attrs.is_empty() {
239        return Error::new(
240            Span::call_site(),
241            format!(
242                "Deriving `{}` requires attaching an `{}` attribute",
243                MACRO_NAME, ATTR_NAME
244            )
245            .as_str(),
246        )
247        .into_tokens();
248    }
249
250    // Verify there's only one instance of the attribute present
251    if attrs.len() > 1 {
252        return Error::new(
253            Span::call_site(),
254            format!("Cannot define the {} attribute more than once", ATTR_NAME).as_str(),
255        )
256        .into_tokens();
257    }
258
259    // Parse endpoint attribute parameters
260    let params = match parse_params(&attrs[0]) {
261        Ok(v) => v,
262        Err(e) => return e.into_tokens(),
263    };
264
265    let path = params.path;
266    let method = params.method;
267    let response = params.response;
268    let request_type = params.request_type;
269    let response_type = params.response_type;
270    let id = &s.ast().ident;
271
272    // Find serde attributes
273    let serde_attrs = parse::attributes(&s.ast().attrs, "serde");
274    let serde_attrs = serde_attrs.unwrap_or_default();
275
276    // Generate path string
277    let path = match gen_path(&path) {
278        Ok(a) => a,
279        Err(e) => return e.into_tokens(),
280    };
281
282    // Generate query function
283    let query = gen_query(&field_attrs, &serde_attrs);
284
285    // Generate body function
286    let body = match gen_body(&field_attrs, &serde_attrs) {
287        Ok(d) => d,
288        Err(e) => return e.into_tokens(),
289    };
290
291    // Generate helper functions when deriving Builder
292    let builder = match params.builder {
293        true => gen_builder(&s.ast().ident, &s.ast().generics),
294        false => quote! {},
295    };
296
297    // Capture generic information
298    let (impl_generics, ty_generics, where_clause) = s.ast().generics.split_for_impl();
299
300    // Generate Endpoint implementation
301    let const_name = format!("_DERIVE_Endpoint_FOR_{}", id);
302    let const_ident = Ident::new(const_name.as_str(), Span::call_site());
303    quote! {
304        #[allow(non_local_definitions)]
305        const #const_ident: () = {
306            use rustify::__private::serde::Serialize;
307            use rustify::http::{build_body, build_query};
308            use rustify::client::Client;
309            use rustify::endpoint::Endpoint;
310            use rustify::enums::{RequestMethod, RequestType, ResponseType};
311            use rustify::errors::ClientError;
312
313            impl #impl_generics Endpoint for #id #ty_generics #where_clause {
314                type Response = #response;
315                const REQUEST_BODY_TYPE: RequestType = RequestType::#request_type;
316                const RESPONSE_BODY_TYPE: ResponseType = ResponseType::#response_type;
317
318                fn path(&self) -> String {
319                    #path
320                }
321
322                fn method(&self) -> RequestMethod {
323                    RequestMethod::#method
324                }
325
326                #query
327
328
329                #body
330            }
331
332            #builder
333        };
334    }
335}
336
337synstructure::decl_derive!([Endpoint, attributes(endpoint)] => endpoint_derive);