axum/routing/
method_routing.rs

1//! Route to services and handlers based on HTTP methods.
2
3use super::{future::InfallibleRouteFuture, IntoMakeService};
4#[cfg(feature = "tokio")]
5use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
6use crate::{
7    body::{Body, Bytes, HttpBody},
8    boxed::BoxedIntoRoute,
9    error_handling::{HandleError, HandleErrorLayer},
10    handler::Handler,
11    http::{Method, StatusCode},
12    response::Response,
13    routing::{future::RouteFuture, Fallback, MethodFilter, Route},
14};
15use axum_core::{extract::Request, response::IntoResponse, BoxError};
16use bytes::BytesMut;
17use std::{
18    borrow::Cow,
19    convert::Infallible,
20    fmt,
21    task::{Context, Poll},
22};
23use tower::service_fn;
24use tower_layer::Layer;
25use tower_service::Service;
26
27macro_rules! top_level_service_fn {
28    (
29        $name:ident, GET
30    ) => {
31        top_level_service_fn!(
32            /// Route `GET` requests to the given service.
33            ///
34            /// # Example
35            ///
36            /// ```rust
37            /// use axum::{
38            ///     extract::Request,
39            ///     Router,
40            ///     routing::get_service,
41            ///     body::Body,
42            /// };
43            /// use http::Response;
44            /// use std::convert::Infallible;
45            ///
46            /// let service = tower::service_fn(|request: Request| async {
47            ///     Ok::<_, Infallible>(Response::new(Body::empty()))
48            /// });
49            ///
50            /// // Requests to `GET /` will go to `service`.
51            /// let app = Router::new().route("/", get_service(service));
52            /// # let _: Router = app;
53            /// ```
54            ///
55            /// Note that `get` routes will also be called for `HEAD` requests but will have
56            /// the response body removed. Make sure to add explicit `HEAD` routes
57            /// afterwards.
58            $name,
59            GET
60        );
61    };
62
63    (
64        $name:ident, CONNECT
65    ) => {
66        top_level_service_fn!(
67            /// Route `CONNECT` requests to the given service.
68            ///
69            /// See [`MethodFilter::CONNECT`] for when you'd want to use this,
70            /// and [`get_service`] for an example.
71            $name,
72            CONNECT
73        );
74    };
75
76    (
77        $name:ident, $method:ident
78    ) => {
79        top_level_service_fn!(
80            #[doc = concat!("Route `", stringify!($method) ,"` requests to the given service.")]
81            ///
82            /// See [`get_service`] for an example.
83            $name,
84            $method
85        );
86    };
87
88    (
89        $(#[$m:meta])+
90        $name:ident, $method:ident
91    ) => {
92        $(#[$m])+
93        pub fn $name<T, S>(svc: T) -> MethodRouter<S, T::Error>
94        where
95            T: Service<Request> + Clone + Send + Sync + 'static,
96            T::Response: IntoResponse + 'static,
97            T::Future: Send + 'static,
98            S: Clone,
99        {
100            on_service(MethodFilter::$method, svc)
101        }
102    };
103}
104
105macro_rules! top_level_handler_fn {
106    (
107        $name:ident, GET
108    ) => {
109        top_level_handler_fn!(
110            /// Route `GET` requests to the given handler.
111            ///
112            /// # Example
113            ///
114            /// ```rust
115            /// use axum::{
116            ///     routing::get,
117            ///     Router,
118            /// };
119            ///
120            /// async fn handler() {}
121            ///
122            /// // Requests to `GET /` will go to `handler`.
123            /// let app = Router::new().route("/", get(handler));
124            /// # let _: Router = app;
125            /// ```
126            ///
127            /// Note that `get` routes will also be called for `HEAD` requests but will have
128            /// the response body removed. Make sure to add explicit `HEAD` routes
129            /// afterwards.
130            $name,
131            GET
132        );
133    };
134
135    (
136        $name:ident, CONNECT
137    ) => {
138        top_level_handler_fn!(
139            /// Route `CONNECT` requests to the given handler.
140            ///
141            /// See [`MethodFilter::CONNECT`] for when you'd want to use this,
142            /// and [`get`] for an example.
143            $name,
144            CONNECT
145        );
146    };
147
148    (
149        $name:ident, $method:ident
150    ) => {
151        top_level_handler_fn!(
152            #[doc = concat!("Route `", stringify!($method) ,"` requests to the given handler.")]
153            ///
154            /// See [`get`] for an example.
155            $name,
156            $method
157        );
158    };
159
160    (
161        $(#[$m:meta])+
162        $name:ident, $method:ident
163    ) => {
164        $(#[$m])+
165        pub fn $name<H, T, S>(handler: H) -> MethodRouter<S, Infallible>
166        where
167            H: Handler<T, S>,
168            T: 'static,
169            S: Clone + Send + Sync + 'static,
170        {
171            on(MethodFilter::$method, handler)
172        }
173    };
174}
175
176macro_rules! chained_service_fn {
177    (
178        $name:ident, GET
179    ) => {
180        chained_service_fn!(
181            /// Chain an additional service that will only accept `GET` requests.
182            ///
183            /// # Example
184            ///
185            /// ```rust
186            /// use axum::{
187            ///     extract::Request,
188            ///     Router,
189            ///     routing::post_service,
190            ///     body::Body,
191            /// };
192            /// use http::Response;
193            /// use std::convert::Infallible;
194            ///
195            /// let service = tower::service_fn(|request: Request| async {
196            ///     Ok::<_, Infallible>(Response::new(Body::empty()))
197            /// });
198            ///
199            /// let other_service = tower::service_fn(|request: Request| async {
200            ///     Ok::<_, Infallible>(Response::new(Body::empty()))
201            /// });
202            ///
203            /// // Requests to `POST /` will go to `service` and `GET /` will go to
204            /// // `other_service`.
205            /// let app = Router::new().route("/", post_service(service).get_service(other_service));
206            /// # let _: Router = app;
207            /// ```
208            ///
209            /// Note that `get` routes will also be called for `HEAD` requests but will have
210            /// the response body removed. Make sure to add explicit `HEAD` routes
211            /// afterwards.
212            $name,
213            GET
214        );
215    };
216
217    (
218        $name:ident, CONNECT
219    ) => {
220        chained_service_fn!(
221            /// Chain an additional service that will only accept `CONNECT` requests.
222            ///
223            /// See [`MethodFilter::CONNECT`] for when you'd want to use this,
224            /// and [`MethodRouter::get_service`] for an example.
225            $name,
226            CONNECT
227        );
228    };
229
230    (
231        $name:ident, $method:ident
232    ) => {
233        chained_service_fn!(
234            #[doc = concat!("Chain an additional service that will only accept `", stringify!($method),"` requests.")]
235            ///
236            /// See [`MethodRouter::get_service`] for an example.
237            $name,
238            $method
239        );
240    };
241
242    (
243        $(#[$m:meta])+
244        $name:ident, $method:ident
245    ) => {
246        $(#[$m])+
247        #[track_caller]
248        pub fn $name<T>(self, svc: T) -> Self
249        where
250            T: Service<Request, Error = E>
251                + Clone
252                + Send
253                + Sync
254                + 'static,
255            T::Response: IntoResponse + 'static,
256            T::Future: Send + 'static,
257        {
258            self.on_service(MethodFilter::$method, svc)
259        }
260    };
261}
262
263macro_rules! chained_handler_fn {
264    (
265        $name:ident, GET
266    ) => {
267        chained_handler_fn!(
268            /// Chain an additional handler that will only accept `GET` requests.
269            ///
270            /// # Example
271            ///
272            /// ```rust
273            /// use axum::{routing::post, Router};
274            ///
275            /// async fn handler() {}
276            ///
277            /// async fn other_handler() {}
278            ///
279            /// // Requests to `POST /` will go to `handler` and `GET /` will go to
280            /// // `other_handler`.
281            /// let app = Router::new().route("/", post(handler).get(other_handler));
282            /// # let _: Router = app;
283            /// ```
284            ///
285            /// Note that `get` routes will also be called for `HEAD` requests but will have
286            /// the response body removed. Make sure to add explicit `HEAD` routes
287            /// afterwards.
288            $name,
289            GET
290        );
291    };
292
293    (
294        $name:ident, CONNECT
295    ) => {
296        chained_handler_fn!(
297            /// Chain an additional handler that will only accept `CONNECT` requests.
298            ///
299            /// See [`MethodFilter::CONNECT`] for when you'd want to use this,
300            /// and [`MethodRouter::get`] for an example.
301            $name,
302            CONNECT
303        );
304    };
305
306    (
307        $name:ident, $method:ident
308    ) => {
309        chained_handler_fn!(
310            #[doc = concat!("Chain an additional handler that will only accept `", stringify!($method),"` requests.")]
311            ///
312            /// See [`MethodRouter::get`] for an example.
313            $name,
314            $method
315        );
316    };
317
318    (
319        $(#[$m:meta])+
320        $name:ident, $method:ident
321    ) => {
322        $(#[$m])+
323        #[track_caller]
324        pub fn $name<H, T>(self, handler: H) -> Self
325        where
326            H: Handler<T, S>,
327            T: 'static,
328            S: Send + Sync + 'static,
329        {
330            self.on(MethodFilter::$method, handler)
331        }
332    };
333}
334
335top_level_service_fn!(connect_service, CONNECT);
336top_level_service_fn!(delete_service, DELETE);
337top_level_service_fn!(get_service, GET);
338top_level_service_fn!(head_service, HEAD);
339top_level_service_fn!(options_service, OPTIONS);
340top_level_service_fn!(patch_service, PATCH);
341top_level_service_fn!(post_service, POST);
342top_level_service_fn!(put_service, PUT);
343top_level_service_fn!(trace_service, TRACE);
344
345/// Route requests with the given method to the service.
346///
347/// # Example
348///
349/// ```rust
350/// use axum::{
351///     extract::Request,
352///     routing::on,
353///     Router,
354///     body::Body,
355///     routing::{MethodFilter, on_service},
356/// };
357/// use http::Response;
358/// use std::convert::Infallible;
359///
360/// let service = tower::service_fn(|request: Request| async {
361///     Ok::<_, Infallible>(Response::new(Body::empty()))
362/// });
363///
364/// // Requests to `POST /` will go to `service`.
365/// let app = Router::new().route("/", on_service(MethodFilter::POST, service));
366/// # let _: Router = app;
367/// ```
368pub fn on_service<T, S>(filter: MethodFilter, svc: T) -> MethodRouter<S, T::Error>
369where
370    T: Service<Request> + Clone + Send + Sync + 'static,
371    T::Response: IntoResponse + 'static,
372    T::Future: Send + 'static,
373    S: Clone,
374{
375    MethodRouter::new().on_service(filter, svc)
376}
377
378/// Route requests to the given service regardless of its method.
379///
380/// # Example
381///
382/// ```rust
383/// use axum::{
384///     extract::Request,
385///     Router,
386///     routing::any_service,
387///     body::Body,
388/// };
389/// use http::Response;
390/// use std::convert::Infallible;
391///
392/// let service = tower::service_fn(|request: Request| async {
393///     Ok::<_, Infallible>(Response::new(Body::empty()))
394/// });
395///
396/// // All requests to `/` will go to `service`.
397/// let app = Router::new().route("/", any_service(service));
398/// # let _: Router = app;
399/// ```
400///
401/// Additional methods can still be chained:
402///
403/// ```rust
404/// use axum::{
405///     extract::Request,
406///     Router,
407///     routing::any_service,
408///     body::Body,
409/// };
410/// use http::Response;
411/// use std::convert::Infallible;
412///
413/// let service = tower::service_fn(|request: Request| async {
414///     # Ok::<_, Infallible>(Response::new(Body::empty()))
415///     // ...
416/// });
417///
418/// let other_service = tower::service_fn(|request: Request| async {
419///     # Ok::<_, Infallible>(Response::new(Body::empty()))
420///     // ...
421/// });
422///
423/// // `POST /` goes to `other_service`. All other requests go to `service`
424/// let app = Router::new().route("/", any_service(service).post_service(other_service));
425/// # let _: Router = app;
426/// ```
427pub fn any_service<T, S>(svc: T) -> MethodRouter<S, T::Error>
428where
429    T: Service<Request> + Clone + Send + Sync + 'static,
430    T::Response: IntoResponse + 'static,
431    T::Future: Send + 'static,
432    S: Clone,
433{
434    MethodRouter::new()
435        .fallback_service(svc)
436        .skip_allow_header()
437}
438
439top_level_handler_fn!(connect, CONNECT);
440top_level_handler_fn!(delete, DELETE);
441top_level_handler_fn!(get, GET);
442top_level_handler_fn!(head, HEAD);
443top_level_handler_fn!(options, OPTIONS);
444top_level_handler_fn!(patch, PATCH);
445top_level_handler_fn!(post, POST);
446top_level_handler_fn!(put, PUT);
447top_level_handler_fn!(trace, TRACE);
448
449/// Route requests with the given method to the handler.
450///
451/// # Example
452///
453/// ```rust
454/// use axum::{
455///     routing::on,
456///     Router,
457///     routing::MethodFilter,
458/// };
459///
460/// async fn handler() {}
461///
462/// // Requests to `POST /` will go to `handler`.
463/// let app = Router::new().route("/", on(MethodFilter::POST, handler));
464/// # let _: Router = app;
465/// ```
466pub fn on<H, T, S>(filter: MethodFilter, handler: H) -> MethodRouter<S, Infallible>
467where
468    H: Handler<T, S>,
469    T: 'static,
470    S: Clone + Send + Sync + 'static,
471{
472    MethodRouter::new().on(filter, handler)
473}
474
475/// Route requests with the given handler regardless of the method.
476///
477/// # Example
478///
479/// ```rust
480/// use axum::{
481///     routing::any,
482///     Router,
483/// };
484///
485/// async fn handler() {}
486///
487/// // All requests to `/` will go to `handler`.
488/// let app = Router::new().route("/", any(handler));
489/// # let _: Router = app;
490/// ```
491///
492/// Additional methods can still be chained:
493///
494/// ```rust
495/// use axum::{
496///     routing::any,
497///     Router,
498/// };
499///
500/// async fn handler() {}
501///
502/// async fn other_handler() {}
503///
504/// // `POST /` goes to `other_handler`. All other requests go to `handler`
505/// let app = Router::new().route("/", any(handler).post(other_handler));
506/// # let _: Router = app;
507/// ```
508pub fn any<H, T, S>(handler: H) -> MethodRouter<S, Infallible>
509where
510    H: Handler<T, S>,
511    T: 'static,
512    S: Clone + Send + Sync + 'static,
513{
514    MethodRouter::new().fallback(handler).skip_allow_header()
515}
516
517/// A [`Service`] that accepts requests based on a [`MethodFilter`] and
518/// allows chaining additional handlers and services.
519///
520/// # When does `MethodRouter` implement [`Service`]?
521///
522/// Whether or not `MethodRouter` implements [`Service`] depends on the state type it requires.
523///
524/// ```
525/// use tower::Service;
526/// use axum::{routing::get, extract::{State, Request}, body::Body};
527///
528/// // this `MethodRouter` doesn't require any state, i.e. the state is `()`,
529/// let method_router = get(|| async {});
530/// // and thus it implements `Service`
531/// assert_service(method_router);
532///
533/// // this requires a `String` and doesn't implement `Service`
534/// let method_router = get(|_: State<String>| async {});
535/// // until you provide the `String` with `.with_state(...)`
536/// let method_router_with_state = method_router.with_state(String::new());
537/// // and then it implements `Service`
538/// assert_service(method_router_with_state);
539///
540/// // helper to check that a value implements `Service`
541/// fn assert_service<S>(service: S)
542/// where
543///     S: Service<Request>,
544/// {}
545/// ```
546#[must_use]
547pub struct MethodRouter<S = (), E = Infallible> {
548    get: MethodEndpoint<S, E>,
549    head: MethodEndpoint<S, E>,
550    delete: MethodEndpoint<S, E>,
551    options: MethodEndpoint<S, E>,
552    patch: MethodEndpoint<S, E>,
553    post: MethodEndpoint<S, E>,
554    put: MethodEndpoint<S, E>,
555    trace: MethodEndpoint<S, E>,
556    connect: MethodEndpoint<S, E>,
557    fallback: Fallback<S, E>,
558    allow_header: AllowHeader,
559}
560
561#[derive(Clone, Debug)]
562enum AllowHeader {
563    /// No `Allow` header value has been built-up yet. This is the default state
564    None,
565    /// Don't set an `Allow` header. This is used when `any` or `any_service` are called.
566    Skip,
567    /// The current value of the `Allow` header.
568    Bytes(BytesMut),
569}
570
571impl AllowHeader {
572    fn merge(self, other: Self) -> Self {
573        match (self, other) {
574            (AllowHeader::Skip, _) | (_, AllowHeader::Skip) => AllowHeader::Skip,
575            (AllowHeader::None, AllowHeader::None) => AllowHeader::None,
576            (AllowHeader::None, AllowHeader::Bytes(pick)) => AllowHeader::Bytes(pick),
577            (AllowHeader::Bytes(pick), AllowHeader::None) => AllowHeader::Bytes(pick),
578            (AllowHeader::Bytes(mut a), AllowHeader::Bytes(b)) => {
579                a.extend_from_slice(b",");
580                a.extend_from_slice(&b);
581                AllowHeader::Bytes(a)
582            }
583        }
584    }
585}
586
587impl<S, E> fmt::Debug for MethodRouter<S, E> {
588    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
589        f.debug_struct("MethodRouter")
590            .field("get", &self.get)
591            .field("head", &self.head)
592            .field("delete", &self.delete)
593            .field("options", &self.options)
594            .field("patch", &self.patch)
595            .field("post", &self.post)
596            .field("put", &self.put)
597            .field("trace", &self.trace)
598            .field("connect", &self.connect)
599            .field("fallback", &self.fallback)
600            .field("allow_header", &self.allow_header)
601            .finish()
602    }
603}
604
605impl<S> MethodRouter<S, Infallible>
606where
607    S: Clone,
608{
609    /// Chain an additional handler that will accept requests matching the given
610    /// `MethodFilter`.
611    ///
612    /// # Example
613    ///
614    /// ```rust
615    /// use axum::{
616    ///     routing::get,
617    ///     Router,
618    ///     routing::MethodFilter
619    /// };
620    ///
621    /// async fn handler() {}
622    ///
623    /// async fn other_handler() {}
624    ///
625    /// // Requests to `GET /` will go to `handler` and `DELETE /` will go to
626    /// // `other_handler`
627    /// let app = Router::new().route("/", get(handler).on(MethodFilter::DELETE, other_handler));
628    /// # let _: Router = app;
629    /// ```
630    #[track_caller]
631    pub fn on<H, T>(self, filter: MethodFilter, handler: H) -> Self
632    where
633        H: Handler<T, S>,
634        T: 'static,
635        S: Send + Sync + 'static,
636    {
637        self.on_endpoint(
638            filter,
639            MethodEndpoint::BoxedHandler(BoxedIntoRoute::from_handler(handler)),
640        )
641    }
642
643    chained_handler_fn!(connect, CONNECT);
644    chained_handler_fn!(delete, DELETE);
645    chained_handler_fn!(get, GET);
646    chained_handler_fn!(head, HEAD);
647    chained_handler_fn!(options, OPTIONS);
648    chained_handler_fn!(patch, PATCH);
649    chained_handler_fn!(post, POST);
650    chained_handler_fn!(put, PUT);
651    chained_handler_fn!(trace, TRACE);
652
653    /// Add a fallback [`Handler`] to the router.
654    pub fn fallback<H, T>(mut self, handler: H) -> Self
655    where
656        H: Handler<T, S>,
657        T: 'static,
658        S: Send + Sync + 'static,
659    {
660        self.fallback = Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler));
661        self
662    }
663
664    /// Add a fallback [`Handler`] if no custom one has been provided.
665    pub(crate) fn default_fallback<H, T>(self, handler: H) -> Self
666    where
667        H: Handler<T, S>,
668        T: 'static,
669        S: Send + Sync + 'static,
670    {
671        match self.fallback {
672            Fallback::Default(_) => self.fallback(handler),
673            _ => self,
674        }
675    }
676}
677
678impl MethodRouter<(), Infallible> {
679    /// Convert the router into a [`MakeService`].
680    ///
681    /// This allows you to serve a single `MethodRouter` if you don't need any
682    /// routing based on the path:
683    ///
684    /// ```rust
685    /// use axum::{
686    ///     handler::Handler,
687    ///     http::{Uri, Method},
688    ///     response::IntoResponse,
689    ///     routing::get,
690    /// };
691    /// use std::net::SocketAddr;
692    ///
693    /// async fn handler(method: Method, uri: Uri, body: String) -> String {
694    ///     format!("received `{method} {uri}` with body `{body:?}`")
695    /// }
696    ///
697    /// let router = get(handler).post(handler);
698    ///
699    /// # async {
700    /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
701    /// axum::serve(listener, router.into_make_service()).await.unwrap();
702    /// # };
703    /// ```
704    ///
705    /// [`MakeService`]: tower::make::MakeService
706    #[must_use]
707    pub fn into_make_service(self) -> IntoMakeService<Self> {
708        IntoMakeService::new(self.with_state(()))
709    }
710
711    /// Convert the router into a [`MakeService`] which stores information
712    /// about the incoming connection.
713    ///
714    /// See [`Router::into_make_service_with_connect_info`] for more details.
715    ///
716    /// ```rust
717    /// use axum::{
718    ///     handler::Handler,
719    ///     response::IntoResponse,
720    ///     extract::ConnectInfo,
721    ///     routing::get,
722    /// };
723    /// use std::net::SocketAddr;
724    ///
725    /// async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
726    ///     format!("Hello {addr}")
727    /// }
728    ///
729    /// let router = get(handler).post(handler);
730    ///
731    /// # async {
732    /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
733    /// axum::serve(listener, router.into_make_service()).await.unwrap();
734    /// # };
735    /// ```
736    ///
737    /// [`MakeService`]: tower::make::MakeService
738    /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
739    #[cfg(feature = "tokio")]
740    #[must_use]
741    pub fn into_make_service_with_connect_info<C>(self) -> IntoMakeServiceWithConnectInfo<Self, C> {
742        IntoMakeServiceWithConnectInfo::new(self.with_state(()))
743    }
744}
745
746impl<S, E> MethodRouter<S, E>
747where
748    S: Clone,
749{
750    /// Create a default `MethodRouter` that will respond with `405 Method Not Allowed` to all
751    /// requests.
752    pub fn new() -> Self {
753        let fallback = Route::new(service_fn(|_: Request| async {
754            Ok(StatusCode::METHOD_NOT_ALLOWED)
755        }));
756
757        Self {
758            get: MethodEndpoint::None,
759            head: MethodEndpoint::None,
760            delete: MethodEndpoint::None,
761            options: MethodEndpoint::None,
762            patch: MethodEndpoint::None,
763            post: MethodEndpoint::None,
764            put: MethodEndpoint::None,
765            trace: MethodEndpoint::None,
766            connect: MethodEndpoint::None,
767            allow_header: AllowHeader::None,
768            fallback: Fallback::Default(fallback),
769        }
770    }
771
772    /// Provide the state for the router.
773    pub fn with_state<S2>(self, state: S) -> MethodRouter<S2, E> {
774        MethodRouter {
775            get: self.get.with_state(&state),
776            head: self.head.with_state(&state),
777            delete: self.delete.with_state(&state),
778            options: self.options.with_state(&state),
779            patch: self.patch.with_state(&state),
780            post: self.post.with_state(&state),
781            put: self.put.with_state(&state),
782            trace: self.trace.with_state(&state),
783            connect: self.connect.with_state(&state),
784            allow_header: self.allow_header,
785            fallback: self.fallback.with_state(state),
786        }
787    }
788
789    /// Chain an additional service that will accept requests matching the given
790    /// `MethodFilter`.
791    ///
792    /// # Example
793    ///
794    /// ```rust
795    /// use axum::{
796    ///     extract::Request,
797    ///     Router,
798    ///     routing::{MethodFilter, on_service},
799    ///     body::Body,
800    /// };
801    /// use http::Response;
802    /// use std::convert::Infallible;
803    ///
804    /// let service = tower::service_fn(|request: Request| async {
805    ///     Ok::<_, Infallible>(Response::new(Body::empty()))
806    /// });
807    ///
808    /// // Requests to `DELETE /` will go to `service`
809    /// let app = Router::new().route("/", on_service(MethodFilter::DELETE, service));
810    /// # let _: Router = app;
811    /// ```
812    #[track_caller]
813    pub fn on_service<T>(self, filter: MethodFilter, svc: T) -> Self
814    where
815        T: Service<Request, Error = E> + Clone + Send + Sync + 'static,
816        T::Response: IntoResponse + 'static,
817        T::Future: Send + 'static,
818    {
819        self.on_endpoint(filter, MethodEndpoint::Route(Route::new(svc)))
820    }
821
822    #[track_caller]
823    fn on_endpoint(mut self, filter: MethodFilter, endpoint: MethodEndpoint<S, E>) -> Self {
824        // written as a separate function to generate less IR
825        #[track_caller]
826        fn set_endpoint<S, E>(
827            method_name: &str,
828            out: &mut MethodEndpoint<S, E>,
829            endpoint: &MethodEndpoint<S, E>,
830            endpoint_filter: MethodFilter,
831            filter: MethodFilter,
832            allow_header: &mut AllowHeader,
833            methods: &[&'static str],
834        ) where
835            MethodEndpoint<S, E>: Clone,
836            S: Clone,
837        {
838            if endpoint_filter.contains(filter) {
839                if out.is_some() {
840                    panic!(
841                        "Overlapping method route. Cannot add two method routes that both handle \
842                         `{method_name}`",
843                    )
844                }
845                *out = endpoint.clone();
846                for method in methods {
847                    append_allow_header(allow_header, method);
848                }
849            }
850        }
851
852        set_endpoint(
853            "GET",
854            &mut self.get,
855            &endpoint,
856            filter,
857            MethodFilter::GET,
858            &mut self.allow_header,
859            &["GET", "HEAD"],
860        );
861
862        set_endpoint(
863            "HEAD",
864            &mut self.head,
865            &endpoint,
866            filter,
867            MethodFilter::HEAD,
868            &mut self.allow_header,
869            &["HEAD"],
870        );
871
872        set_endpoint(
873            "TRACE",
874            &mut self.trace,
875            &endpoint,
876            filter,
877            MethodFilter::TRACE,
878            &mut self.allow_header,
879            &["TRACE"],
880        );
881
882        set_endpoint(
883            "PUT",
884            &mut self.put,
885            &endpoint,
886            filter,
887            MethodFilter::PUT,
888            &mut self.allow_header,
889            &["PUT"],
890        );
891
892        set_endpoint(
893            "POST",
894            &mut self.post,
895            &endpoint,
896            filter,
897            MethodFilter::POST,
898            &mut self.allow_header,
899            &["POST"],
900        );
901
902        set_endpoint(
903            "PATCH",
904            &mut self.patch,
905            &endpoint,
906            filter,
907            MethodFilter::PATCH,
908            &mut self.allow_header,
909            &["PATCH"],
910        );
911
912        set_endpoint(
913            "OPTIONS",
914            &mut self.options,
915            &endpoint,
916            filter,
917            MethodFilter::OPTIONS,
918            &mut self.allow_header,
919            &["OPTIONS"],
920        );
921
922        set_endpoint(
923            "DELETE",
924            &mut self.delete,
925            &endpoint,
926            filter,
927            MethodFilter::DELETE,
928            &mut self.allow_header,
929            &["DELETE"],
930        );
931
932        set_endpoint(
933            "CONNECT",
934            &mut self.options,
935            &endpoint,
936            filter,
937            MethodFilter::CONNECT,
938            &mut self.allow_header,
939            &["CONNECT"],
940        );
941
942        self
943    }
944
945    chained_service_fn!(connect_service, CONNECT);
946    chained_service_fn!(delete_service, DELETE);
947    chained_service_fn!(get_service, GET);
948    chained_service_fn!(head_service, HEAD);
949    chained_service_fn!(options_service, OPTIONS);
950    chained_service_fn!(patch_service, PATCH);
951    chained_service_fn!(post_service, POST);
952    chained_service_fn!(put_service, PUT);
953    chained_service_fn!(trace_service, TRACE);
954
955    #[doc = include_str!("../docs/method_routing/fallback.md")]
956    pub fn fallback_service<T>(mut self, svc: T) -> Self
957    where
958        T: Service<Request, Error = E> + Clone + Send + Sync + 'static,
959        T::Response: IntoResponse + 'static,
960        T::Future: Send + 'static,
961    {
962        self.fallback = Fallback::Service(Route::new(svc));
963        self
964    }
965
966    #[doc = include_str!("../docs/method_routing/layer.md")]
967    pub fn layer<L, NewError>(self, layer: L) -> MethodRouter<S, NewError>
968    where
969        L: Layer<Route<E>> + Clone + Send + Sync + 'static,
970        L::Service: Service<Request> + Clone + Send + Sync + 'static,
971        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
972        <L::Service as Service<Request>>::Error: Into<NewError> + 'static,
973        <L::Service as Service<Request>>::Future: Send + 'static,
974        E: 'static,
975        S: 'static,
976        NewError: 'static,
977    {
978        let layer_fn = move |route: Route<E>| route.layer(layer.clone());
979
980        MethodRouter {
981            get: self.get.map(layer_fn.clone()),
982            head: self.head.map(layer_fn.clone()),
983            delete: self.delete.map(layer_fn.clone()),
984            options: self.options.map(layer_fn.clone()),
985            patch: self.patch.map(layer_fn.clone()),
986            post: self.post.map(layer_fn.clone()),
987            put: self.put.map(layer_fn.clone()),
988            trace: self.trace.map(layer_fn.clone()),
989            connect: self.connect.map(layer_fn.clone()),
990            fallback: self.fallback.map(layer_fn),
991            allow_header: self.allow_header,
992        }
993    }
994
995    #[doc = include_str!("../docs/method_routing/route_layer.md")]
996    #[track_caller]
997    pub fn route_layer<L>(mut self, layer: L) -> MethodRouter<S, E>
998    where
999        L: Layer<Route<E>> + Clone + Send + Sync + 'static,
1000        L::Service: Service<Request, Error = E> + Clone + Send + Sync + 'static,
1001        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
1002        <L::Service as Service<Request>>::Future: Send + 'static,
1003        E: 'static,
1004        S: 'static,
1005    {
1006        if self.get.is_none()
1007            && self.head.is_none()
1008            && self.delete.is_none()
1009            && self.options.is_none()
1010            && self.patch.is_none()
1011            && self.post.is_none()
1012            && self.put.is_none()
1013            && self.trace.is_none()
1014            && self.connect.is_none()
1015        {
1016            panic!(
1017                "Adding a route_layer before any routes is a no-op. \
1018                 Add the routes you want the layer to apply to first."
1019            );
1020        }
1021
1022        let layer_fn = move |svc| Route::new(layer.layer(svc));
1023
1024        self.get = self.get.map(layer_fn.clone());
1025        self.head = self.head.map(layer_fn.clone());
1026        self.delete = self.delete.map(layer_fn.clone());
1027        self.options = self.options.map(layer_fn.clone());
1028        self.patch = self.patch.map(layer_fn.clone());
1029        self.post = self.post.map(layer_fn.clone());
1030        self.put = self.put.map(layer_fn.clone());
1031        self.trace = self.trace.map(layer_fn.clone());
1032        self.connect = self.connect.map(layer_fn);
1033
1034        self
1035    }
1036
1037    pub(crate) fn merge_for_path(
1038        mut self,
1039        path: Option<&str>,
1040        other: MethodRouter<S, E>,
1041    ) -> Result<Self, Cow<'static, str>> {
1042        // written using inner functions to generate less IR
1043        fn merge_inner<S, E>(
1044            path: Option<&str>,
1045            name: &str,
1046            first: MethodEndpoint<S, E>,
1047            second: MethodEndpoint<S, E>,
1048        ) -> Result<MethodEndpoint<S, E>, Cow<'static, str>> {
1049            match (first, second) {
1050                (MethodEndpoint::None, MethodEndpoint::None) => Ok(MethodEndpoint::None),
1051                (pick, MethodEndpoint::None) | (MethodEndpoint::None, pick) => Ok(pick),
1052                _ => {
1053                    if let Some(path) = path {
1054                        Err(format!(
1055                            "Overlapping method route. Handler for `{name} {path}` already exists"
1056                        )
1057                        .into())
1058                    } else {
1059                        Err(format!(
1060                            "Overlapping method route. Cannot merge two method routes that both \
1061                             define `{name}`"
1062                        )
1063                        .into())
1064                    }
1065                }
1066            }
1067        }
1068
1069        self.get = merge_inner(path, "GET", self.get, other.get)?;
1070        self.head = merge_inner(path, "HEAD", self.head, other.head)?;
1071        self.delete = merge_inner(path, "DELETE", self.delete, other.delete)?;
1072        self.options = merge_inner(path, "OPTIONS", self.options, other.options)?;
1073        self.patch = merge_inner(path, "PATCH", self.patch, other.patch)?;
1074        self.post = merge_inner(path, "POST", self.post, other.post)?;
1075        self.put = merge_inner(path, "PUT", self.put, other.put)?;
1076        self.trace = merge_inner(path, "TRACE", self.trace, other.trace)?;
1077        self.connect = merge_inner(path, "CONNECT", self.connect, other.connect)?;
1078
1079        self.fallback = self
1080            .fallback
1081            .merge(other.fallback)
1082            .ok_or("Cannot merge two `MethodRouter`s that both have a fallback")?;
1083
1084        self.allow_header = self.allow_header.merge(other.allow_header);
1085
1086        Ok(self)
1087    }
1088
1089    #[doc = include_str!("../docs/method_routing/merge.md")]
1090    #[track_caller]
1091    pub fn merge(self, other: MethodRouter<S, E>) -> Self {
1092        match self.merge_for_path(None, other) {
1093            Ok(t) => t,
1094            // not using unwrap or unwrap_or_else to get a clean panic message + the right location
1095            Err(e) => panic!("{e}"),
1096        }
1097    }
1098
1099    /// Apply a [`HandleErrorLayer`].
1100    ///
1101    /// This is a convenience method for doing `self.layer(HandleErrorLayer::new(f))`.
1102    pub fn handle_error<F, T>(self, f: F) -> MethodRouter<S, Infallible>
1103    where
1104        F: Clone + Send + Sync + 'static,
1105        HandleError<Route<E>, F, T>: Service<Request, Error = Infallible>,
1106        <HandleError<Route<E>, F, T> as Service<Request>>::Future: Send,
1107        <HandleError<Route<E>, F, T> as Service<Request>>::Response: IntoResponse + Send,
1108        T: 'static,
1109        E: 'static,
1110        S: 'static,
1111    {
1112        self.layer(HandleErrorLayer::new(f))
1113    }
1114
1115    fn skip_allow_header(mut self) -> Self {
1116        self.allow_header = AllowHeader::Skip;
1117        self
1118    }
1119
1120    pub(crate) fn call_with_state(&self, req: Request, state: S) -> RouteFuture<E> {
1121        macro_rules! call {
1122            (
1123                $req:expr,
1124                $method_variant:ident,
1125                $svc:expr
1126            ) => {
1127                if *req.method() == Method::$method_variant {
1128                    match $svc {
1129                        MethodEndpoint::None => {}
1130                        MethodEndpoint::Route(route) => {
1131                            return route.clone().oneshot_inner_owned($req);
1132                        }
1133                        MethodEndpoint::BoxedHandler(handler) => {
1134                            let route = handler.clone().into_route(state);
1135                            return route.oneshot_inner_owned($req);
1136                        }
1137                    }
1138                }
1139            };
1140        }
1141
1142        // written with a pattern match like this to ensure we call all routes
1143        let Self {
1144            get,
1145            head,
1146            delete,
1147            options,
1148            patch,
1149            post,
1150            put,
1151            trace,
1152            connect,
1153            fallback,
1154            allow_header,
1155        } = self;
1156
1157        call!(req, HEAD, head);
1158        call!(req, HEAD, get);
1159        call!(req, GET, get);
1160        call!(req, POST, post);
1161        call!(req, OPTIONS, options);
1162        call!(req, PATCH, patch);
1163        call!(req, PUT, put);
1164        call!(req, DELETE, delete);
1165        call!(req, TRACE, trace);
1166        call!(req, CONNECT, connect);
1167
1168        let future = fallback.clone().call_with_state(req, state);
1169
1170        match allow_header {
1171            AllowHeader::None => future.allow_header(Bytes::new()),
1172            AllowHeader::Skip => future,
1173            AllowHeader::Bytes(allow_header) => future.allow_header(allow_header.clone().freeze()),
1174        }
1175    }
1176}
1177
1178fn append_allow_header(allow_header: &mut AllowHeader, method: &'static str) {
1179    match allow_header {
1180        AllowHeader::None => {
1181            *allow_header = AllowHeader::Bytes(BytesMut::from(method));
1182        }
1183        AllowHeader::Skip => {}
1184        AllowHeader::Bytes(allow_header) => {
1185            if let Ok(s) = std::str::from_utf8(allow_header) {
1186                if !s.contains(method) {
1187                    allow_header.extend_from_slice(b",");
1188                    allow_header.extend_from_slice(method.as_bytes());
1189                }
1190            } else {
1191                #[cfg(debug_assertions)]
1192                panic!("`allow_header` contained invalid uft-8. This should never happen")
1193            }
1194        }
1195    }
1196}
1197
1198impl<S, E> Clone for MethodRouter<S, E> {
1199    fn clone(&self) -> Self {
1200        Self {
1201            get: self.get.clone(),
1202            head: self.head.clone(),
1203            delete: self.delete.clone(),
1204            options: self.options.clone(),
1205            patch: self.patch.clone(),
1206            post: self.post.clone(),
1207            put: self.put.clone(),
1208            trace: self.trace.clone(),
1209            connect: self.connect.clone(),
1210            fallback: self.fallback.clone(),
1211            allow_header: self.allow_header.clone(),
1212        }
1213    }
1214}
1215
1216impl<S, E> Default for MethodRouter<S, E>
1217where
1218    S: Clone,
1219{
1220    fn default() -> Self {
1221        Self::new()
1222    }
1223}
1224
1225enum MethodEndpoint<S, E> {
1226    None,
1227    Route(Route<E>),
1228    BoxedHandler(BoxedIntoRoute<S, E>),
1229}
1230
1231impl<S, E> MethodEndpoint<S, E>
1232where
1233    S: Clone,
1234{
1235    fn is_some(&self) -> bool {
1236        matches!(self, Self::Route(_) | Self::BoxedHandler(_))
1237    }
1238
1239    fn is_none(&self) -> bool {
1240        matches!(self, Self::None)
1241    }
1242
1243    fn map<F, E2>(self, f: F) -> MethodEndpoint<S, E2>
1244    where
1245        S: 'static,
1246        E: 'static,
1247        F: FnOnce(Route<E>) -> Route<E2> + Clone + Send + Sync + 'static,
1248        E2: 'static,
1249    {
1250        match self {
1251            Self::None => MethodEndpoint::None,
1252            Self::Route(route) => MethodEndpoint::Route(f(route)),
1253            Self::BoxedHandler(handler) => MethodEndpoint::BoxedHandler(handler.map(f)),
1254        }
1255    }
1256
1257    fn with_state<S2>(self, state: &S) -> MethodEndpoint<S2, E> {
1258        match self {
1259            MethodEndpoint::None => MethodEndpoint::None,
1260            MethodEndpoint::Route(route) => MethodEndpoint::Route(route),
1261            MethodEndpoint::BoxedHandler(handler) => {
1262                MethodEndpoint::Route(handler.into_route(state.clone()))
1263            }
1264        }
1265    }
1266}
1267
1268impl<S, E> Clone for MethodEndpoint<S, E> {
1269    fn clone(&self) -> Self {
1270        match self {
1271            Self::None => Self::None,
1272            Self::Route(inner) => Self::Route(inner.clone()),
1273            Self::BoxedHandler(inner) => Self::BoxedHandler(inner.clone()),
1274        }
1275    }
1276}
1277
1278impl<S, E> fmt::Debug for MethodEndpoint<S, E> {
1279    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1280        match self {
1281            Self::None => f.debug_tuple("None").finish(),
1282            Self::Route(inner) => inner.fmt(f),
1283            Self::BoxedHandler(_) => f.debug_tuple("BoxedHandler").finish(),
1284        }
1285    }
1286}
1287
1288impl<B, E> Service<Request<B>> for MethodRouter<(), E>
1289where
1290    B: HttpBody<Data = Bytes> + Send + 'static,
1291    B::Error: Into<BoxError>,
1292{
1293    type Response = Response;
1294    type Error = E;
1295    type Future = RouteFuture<E>;
1296
1297    #[inline]
1298    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1299        Poll::Ready(Ok(()))
1300    }
1301
1302    #[inline]
1303    fn call(&mut self, req: Request<B>) -> Self::Future {
1304        let req = req.map(Body::new);
1305        self.call_with_state(req, ())
1306    }
1307}
1308
1309impl<S> Handler<(), S> for MethodRouter<S>
1310where
1311    S: Clone + 'static,
1312{
1313    type Future = InfallibleRouteFuture;
1314
1315    fn call(self, req: Request, state: S) -> Self::Future {
1316        InfallibleRouteFuture::new(self.call_with_state(req, state))
1317    }
1318}
1319
1320// for `axum::serve(listener, router)`
1321#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
1322const _: () = {
1323    use crate::serve;
1324
1325    impl<L> Service<serve::IncomingStream<'_, L>> for MethodRouter<()>
1326    where
1327        L: serve::Listener,
1328    {
1329        type Response = Self;
1330        type Error = Infallible;
1331        type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
1332
1333        fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1334            Poll::Ready(Ok(()))
1335        }
1336
1337        fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future {
1338            std::future::ready(Ok(self.clone().with_state(())))
1339        }
1340    }
1341};
1342
1343#[cfg(test)]
1344mod tests {
1345    use super::*;
1346    use crate::{extract::State, handler::HandlerWithoutStateExt};
1347    use http::{header::ALLOW, HeaderMap};
1348    use http_body_util::BodyExt;
1349    use std::time::Duration;
1350    use tower::ServiceExt;
1351    use tower_http::{
1352        services::fs::ServeDir, timeout::TimeoutLayer, validate_request::ValidateRequestHeaderLayer,
1353    };
1354
1355    #[crate::test]
1356    async fn method_not_allowed_by_default() {
1357        let mut svc = MethodRouter::new();
1358        let (status, _, body) = call(Method::GET, &mut svc).await;
1359        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1360        assert!(body.is_empty());
1361    }
1362
1363    #[crate::test]
1364    async fn get_service_fn() {
1365        async fn handle(_req: Request) -> Result<Response<Body>, Infallible> {
1366            Ok(Response::new(Body::from("ok")))
1367        }
1368
1369        let mut svc = get_service(service_fn(handle));
1370
1371        let (status, _, body) = call(Method::GET, &mut svc).await;
1372        assert_eq!(status, StatusCode::OK);
1373        assert_eq!(body, "ok");
1374    }
1375
1376    #[crate::test]
1377    async fn get_handler() {
1378        let mut svc = MethodRouter::new().get(ok);
1379        let (status, _, body) = call(Method::GET, &mut svc).await;
1380        assert_eq!(status, StatusCode::OK);
1381        assert_eq!(body, "ok");
1382    }
1383
1384    #[crate::test]
1385    async fn get_accepts_head() {
1386        let mut svc = MethodRouter::new().get(ok);
1387        let (status, _, body) = call(Method::HEAD, &mut svc).await;
1388        assert_eq!(status, StatusCode::OK);
1389        assert!(body.is_empty());
1390    }
1391
1392    #[crate::test]
1393    async fn head_takes_precedence_over_get() {
1394        let mut svc = MethodRouter::new().head(created).get(ok);
1395        let (status, _, body) = call(Method::HEAD, &mut svc).await;
1396        assert_eq!(status, StatusCode::CREATED);
1397        assert!(body.is_empty());
1398    }
1399
1400    #[crate::test]
1401    async fn merge() {
1402        let mut svc = get(ok).merge(post(ok));
1403
1404        let (status, _, _) = call(Method::GET, &mut svc).await;
1405        assert_eq!(status, StatusCode::OK);
1406
1407        let (status, _, _) = call(Method::POST, &mut svc).await;
1408        assert_eq!(status, StatusCode::OK);
1409    }
1410
1411    #[crate::test]
1412    async fn layer() {
1413        let mut svc = MethodRouter::new()
1414            .get(|| async { std::future::pending::<()>().await })
1415            .layer(ValidateRequestHeaderLayer::bearer("password"));
1416
1417        // method with route
1418        let (status, _, _) = call(Method::GET, &mut svc).await;
1419        assert_eq!(status, StatusCode::UNAUTHORIZED);
1420
1421        // method without route
1422        let (status, _, _) = call(Method::DELETE, &mut svc).await;
1423        assert_eq!(status, StatusCode::UNAUTHORIZED);
1424    }
1425
1426    #[crate::test]
1427    async fn route_layer() {
1428        let mut svc = MethodRouter::new()
1429            .get(|| async { std::future::pending::<()>().await })
1430            .route_layer(ValidateRequestHeaderLayer::bearer("password"));
1431
1432        // method with route
1433        let (status, _, _) = call(Method::GET, &mut svc).await;
1434        assert_eq!(status, StatusCode::UNAUTHORIZED);
1435
1436        // method without route
1437        let (status, _, _) = call(Method::DELETE, &mut svc).await;
1438        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1439    }
1440
1441    #[allow(dead_code)]
1442    async fn building_complex_router() {
1443        let app = crate::Router::new().route(
1444            "/",
1445            // use the all the things 💣️
1446            get(ok)
1447                .post(ok)
1448                .route_layer(ValidateRequestHeaderLayer::bearer("password"))
1449                .merge(delete_service(ServeDir::new(".")))
1450                .fallback(|| async { StatusCode::NOT_FOUND })
1451                .put(ok)
1452                .layer(TimeoutLayer::new(Duration::from_secs(10))),
1453        );
1454
1455        let listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
1456        crate::serve(listener, app).await.unwrap();
1457    }
1458
1459    #[crate::test]
1460    async fn sets_allow_header() {
1461        let mut svc = MethodRouter::new().put(ok).patch(ok);
1462        let (status, headers, _) = call(Method::GET, &mut svc).await;
1463        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1464        assert_eq!(headers[ALLOW], "PUT,PATCH");
1465    }
1466
1467    #[crate::test]
1468    async fn sets_allow_header_get_head() {
1469        let mut svc = MethodRouter::new().get(ok).head(ok);
1470        let (status, headers, _) = call(Method::PUT, &mut svc).await;
1471        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1472        assert_eq!(headers[ALLOW], "GET,HEAD");
1473    }
1474
1475    #[crate::test]
1476    async fn empty_allow_header_by_default() {
1477        let mut svc = MethodRouter::new();
1478        let (status, headers, _) = call(Method::PATCH, &mut svc).await;
1479        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1480        assert_eq!(headers[ALLOW], "");
1481    }
1482
1483    #[crate::test]
1484    async fn allow_header_when_merging() {
1485        let a = put(ok).patch(ok);
1486        let b = get(ok).head(ok);
1487        let mut svc = a.merge(b);
1488
1489        let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1490        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1491        assert_eq!(headers[ALLOW], "PUT,PATCH,GET,HEAD");
1492    }
1493
1494    #[crate::test]
1495    async fn allow_header_any() {
1496        let mut svc = any(ok);
1497
1498        let (status, headers, _) = call(Method::GET, &mut svc).await;
1499        assert_eq!(status, StatusCode::OK);
1500        assert!(!headers.contains_key(ALLOW));
1501    }
1502
1503    #[crate::test]
1504    async fn allow_header_with_fallback() {
1505        let mut svc = MethodRouter::new()
1506            .get(ok)
1507            .fallback(|| async { (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed") });
1508
1509        let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1510        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1511        assert_eq!(headers[ALLOW], "GET,HEAD");
1512    }
1513
1514    #[crate::test]
1515    async fn allow_header_with_fallback_that_sets_allow() {
1516        async fn fallback(method: Method) -> Response {
1517            if method == Method::POST {
1518                "OK".into_response()
1519            } else {
1520                (
1521                    StatusCode::METHOD_NOT_ALLOWED,
1522                    [(ALLOW, "GET,POST")],
1523                    "Method not allowed",
1524                )
1525                    .into_response()
1526            }
1527        }
1528
1529        let mut svc = MethodRouter::new().get(ok).fallback(fallback);
1530
1531        let (status, _, _) = call(Method::GET, &mut svc).await;
1532        assert_eq!(status, StatusCode::OK);
1533
1534        let (status, _, _) = call(Method::POST, &mut svc).await;
1535        assert_eq!(status, StatusCode::OK);
1536
1537        let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1538        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1539        assert_eq!(headers[ALLOW], "GET,POST");
1540    }
1541
1542    #[crate::test]
1543    async fn allow_header_noop_middleware() {
1544        let mut svc = MethodRouter::new()
1545            .get(ok)
1546            .layer(tower::layer::util::Identity::new());
1547
1548        let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1549        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1550        assert_eq!(headers[ALLOW], "GET,HEAD");
1551    }
1552
1553    #[crate::test]
1554    #[should_panic(
1555        expected = "Overlapping method route. Cannot add two method routes that both handle `GET`"
1556    )]
1557    async fn handler_overlaps() {
1558        let _: MethodRouter<()> = get(ok).get(ok);
1559    }
1560
1561    #[crate::test]
1562    #[should_panic(
1563        expected = "Overlapping method route. Cannot add two method routes that both handle `POST`"
1564    )]
1565    async fn service_overlaps() {
1566        let _: MethodRouter<()> = post_service(ok.into_service()).post_service(ok.into_service());
1567    }
1568
1569    #[crate::test]
1570    async fn get_head_does_not_overlap() {
1571        let _: MethodRouter<()> = get(ok).head(ok);
1572    }
1573
1574    #[crate::test]
1575    async fn head_get_does_not_overlap() {
1576        let _: MethodRouter<()> = head(ok).get(ok);
1577    }
1578
1579    #[crate::test]
1580    async fn accessing_state() {
1581        let mut svc = MethodRouter::new()
1582            .get(|State(state): State<&'static str>| async move { state })
1583            .with_state("state");
1584
1585        let (status, _, text) = call(Method::GET, &mut svc).await;
1586
1587        assert_eq!(status, StatusCode::OK);
1588        assert_eq!(text, "state");
1589    }
1590
1591    #[crate::test]
1592    async fn fallback_accessing_state() {
1593        let mut svc = MethodRouter::new()
1594            .fallback(|State(state): State<&'static str>| async move { state })
1595            .with_state("state");
1596
1597        let (status, _, text) = call(Method::GET, &mut svc).await;
1598
1599        assert_eq!(status, StatusCode::OK);
1600        assert_eq!(text, "state");
1601    }
1602
1603    #[crate::test]
1604    async fn merge_accessing_state() {
1605        let one = get(|State(state): State<&'static str>| async move { state });
1606        let two = post(|State(state): State<&'static str>| async move { state });
1607
1608        let mut svc = one.merge(two).with_state("state");
1609
1610        let (status, _, text) = call(Method::GET, &mut svc).await;
1611        assert_eq!(status, StatusCode::OK);
1612        assert_eq!(text, "state");
1613
1614        let (status, _, _) = call(Method::POST, &mut svc).await;
1615        assert_eq!(status, StatusCode::OK);
1616        assert_eq!(text, "state");
1617    }
1618
1619    async fn call<S>(method: Method, svc: &mut S) -> (StatusCode, HeaderMap, String)
1620    where
1621        S: Service<Request, Error = Infallible>,
1622        S::Response: IntoResponse,
1623    {
1624        let request = Request::builder()
1625            .uri("/")
1626            .method(method)
1627            .body(Body::empty())
1628            .unwrap();
1629        let response = svc
1630            .ready()
1631            .await
1632            .unwrap()
1633            .call(request)
1634            .await
1635            .unwrap()
1636            .into_response();
1637        let (parts, body) = response.into_parts();
1638        let body =
1639            String::from_utf8(BodyExt::collect(body).await.unwrap().to_bytes().to_vec()).unwrap();
1640        (parts.status, parts.headers, body)
1641    }
1642
1643    async fn ok() -> (StatusCode, &'static str) {
1644        (StatusCode::OK, "ok")
1645    }
1646
1647    async fn created() -> (StatusCode, &'static str) {
1648        (StatusCode::CREATED, "created")
1649    }
1650}