1use super::{future::InfallibleRouteFuture, IntoMakeService};
4#[cfg(feature = "tokio")]
5use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
6use crate::{
7 body::{Body, Bytes, HttpBody},
8 boxed::BoxedIntoRoute,
9 error_handling::{HandleError, HandleErrorLayer},
10 handler::Handler,
11 http::{Method, StatusCode},
12 response::Response,
13 routing::{future::RouteFuture, Fallback, MethodFilter, Route},
14};
15use axum_core::{extract::Request, response::IntoResponse, BoxError};
16use bytes::BytesMut;
17use std::{
18 borrow::Cow,
19 convert::Infallible,
20 fmt,
21 task::{Context, Poll},
22};
23use tower::service_fn;
24use tower_layer::Layer;
25use tower_service::Service;
26
27macro_rules! top_level_service_fn {
28 (
29 $name:ident, GET
30 ) => {
31 top_level_service_fn!(
32 $name,
59 GET
60 );
61 };
62
63 (
64 $name:ident, CONNECT
65 ) => {
66 top_level_service_fn!(
67 $name,
72 CONNECT
73 );
74 };
75
76 (
77 $name:ident, $method:ident
78 ) => {
79 top_level_service_fn!(
80 #[doc = concat!("Route `", stringify!($method) ,"` requests to the given service.")]
81 $name,
84 $method
85 );
86 };
87
88 (
89 $(#[$m:meta])+
90 $name:ident, $method:ident
91 ) => {
92 $(#[$m])+
93 pub fn $name<T, S>(svc: T) -> MethodRouter<S, T::Error>
94 where
95 T: Service<Request> + Clone + Send + Sync + 'static,
96 T::Response: IntoResponse + 'static,
97 T::Future: Send + 'static,
98 S: Clone,
99 {
100 on_service(MethodFilter::$method, svc)
101 }
102 };
103}
104
105macro_rules! top_level_handler_fn {
106 (
107 $name:ident, GET
108 ) => {
109 top_level_handler_fn!(
110 $name,
131 GET
132 );
133 };
134
135 (
136 $name:ident, CONNECT
137 ) => {
138 top_level_handler_fn!(
139 $name,
144 CONNECT
145 );
146 };
147
148 (
149 $name:ident, $method:ident
150 ) => {
151 top_level_handler_fn!(
152 #[doc = concat!("Route `", stringify!($method) ,"` requests to the given handler.")]
153 $name,
156 $method
157 );
158 };
159
160 (
161 $(#[$m:meta])+
162 $name:ident, $method:ident
163 ) => {
164 $(#[$m])+
165 pub fn $name<H, T, S>(handler: H) -> MethodRouter<S, Infallible>
166 where
167 H: Handler<T, S>,
168 T: 'static,
169 S: Clone + Send + Sync + 'static,
170 {
171 on(MethodFilter::$method, handler)
172 }
173 };
174}
175
176macro_rules! chained_service_fn {
177 (
178 $name:ident, GET
179 ) => {
180 chained_service_fn!(
181 $name,
213 GET
214 );
215 };
216
217 (
218 $name:ident, CONNECT
219 ) => {
220 chained_service_fn!(
221 $name,
226 CONNECT
227 );
228 };
229
230 (
231 $name:ident, $method:ident
232 ) => {
233 chained_service_fn!(
234 #[doc = concat!("Chain an additional service that will only accept `", stringify!($method),"` requests.")]
235 $name,
238 $method
239 );
240 };
241
242 (
243 $(#[$m:meta])+
244 $name:ident, $method:ident
245 ) => {
246 $(#[$m])+
247 #[track_caller]
248 pub fn $name<T>(self, svc: T) -> Self
249 where
250 T: Service<Request, Error = E>
251 + Clone
252 + Send
253 + Sync
254 + 'static,
255 T::Response: IntoResponse + 'static,
256 T::Future: Send + 'static,
257 {
258 self.on_service(MethodFilter::$method, svc)
259 }
260 };
261}
262
263macro_rules! chained_handler_fn {
264 (
265 $name:ident, GET
266 ) => {
267 chained_handler_fn!(
268 $name,
289 GET
290 );
291 };
292
293 (
294 $name:ident, CONNECT
295 ) => {
296 chained_handler_fn!(
297 $name,
302 CONNECT
303 );
304 };
305
306 (
307 $name:ident, $method:ident
308 ) => {
309 chained_handler_fn!(
310 #[doc = concat!("Chain an additional handler that will only accept `", stringify!($method),"` requests.")]
311 $name,
314 $method
315 );
316 };
317
318 (
319 $(#[$m:meta])+
320 $name:ident, $method:ident
321 ) => {
322 $(#[$m])+
323 #[track_caller]
324 pub fn $name<H, T>(self, handler: H) -> Self
325 where
326 H: Handler<T, S>,
327 T: 'static,
328 S: Send + Sync + 'static,
329 {
330 self.on(MethodFilter::$method, handler)
331 }
332 };
333}
334
335top_level_service_fn!(connect_service, CONNECT);
336top_level_service_fn!(delete_service, DELETE);
337top_level_service_fn!(get_service, GET);
338top_level_service_fn!(head_service, HEAD);
339top_level_service_fn!(options_service, OPTIONS);
340top_level_service_fn!(patch_service, PATCH);
341top_level_service_fn!(post_service, POST);
342top_level_service_fn!(put_service, PUT);
343top_level_service_fn!(trace_service, TRACE);
344
345pub fn on_service<T, S>(filter: MethodFilter, svc: T) -> MethodRouter<S, T::Error>
369where
370 T: Service<Request> + Clone + Send + Sync + 'static,
371 T::Response: IntoResponse + 'static,
372 T::Future: Send + 'static,
373 S: Clone,
374{
375 MethodRouter::new().on_service(filter, svc)
376}
377
378pub fn any_service<T, S>(svc: T) -> MethodRouter<S, T::Error>
428where
429 T: Service<Request> + Clone + Send + Sync + 'static,
430 T::Response: IntoResponse + 'static,
431 T::Future: Send + 'static,
432 S: Clone,
433{
434 MethodRouter::new()
435 .fallback_service(svc)
436 .skip_allow_header()
437}
438
439top_level_handler_fn!(connect, CONNECT);
440top_level_handler_fn!(delete, DELETE);
441top_level_handler_fn!(get, GET);
442top_level_handler_fn!(head, HEAD);
443top_level_handler_fn!(options, OPTIONS);
444top_level_handler_fn!(patch, PATCH);
445top_level_handler_fn!(post, POST);
446top_level_handler_fn!(put, PUT);
447top_level_handler_fn!(trace, TRACE);
448
449pub fn on<H, T, S>(filter: MethodFilter, handler: H) -> MethodRouter<S, Infallible>
467where
468 H: Handler<T, S>,
469 T: 'static,
470 S: Clone + Send + Sync + 'static,
471{
472 MethodRouter::new().on(filter, handler)
473}
474
475pub fn any<H, T, S>(handler: H) -> MethodRouter<S, Infallible>
509where
510 H: Handler<T, S>,
511 T: 'static,
512 S: Clone + Send + Sync + 'static,
513{
514 MethodRouter::new().fallback(handler).skip_allow_header()
515}
516
517#[must_use]
547pub struct MethodRouter<S = (), E = Infallible> {
548 get: MethodEndpoint<S, E>,
549 head: MethodEndpoint<S, E>,
550 delete: MethodEndpoint<S, E>,
551 options: MethodEndpoint<S, E>,
552 patch: MethodEndpoint<S, E>,
553 post: MethodEndpoint<S, E>,
554 put: MethodEndpoint<S, E>,
555 trace: MethodEndpoint<S, E>,
556 connect: MethodEndpoint<S, E>,
557 fallback: Fallback<S, E>,
558 allow_header: AllowHeader,
559}
560
561#[derive(Clone, Debug)]
562enum AllowHeader {
563 None,
565 Skip,
567 Bytes(BytesMut),
569}
570
571impl AllowHeader {
572 fn merge(self, other: Self) -> Self {
573 match (self, other) {
574 (AllowHeader::Skip, _) | (_, AllowHeader::Skip) => AllowHeader::Skip,
575 (AllowHeader::None, AllowHeader::None) => AllowHeader::None,
576 (AllowHeader::None, AllowHeader::Bytes(pick)) => AllowHeader::Bytes(pick),
577 (AllowHeader::Bytes(pick), AllowHeader::None) => AllowHeader::Bytes(pick),
578 (AllowHeader::Bytes(mut a), AllowHeader::Bytes(b)) => {
579 a.extend_from_slice(b",");
580 a.extend_from_slice(&b);
581 AllowHeader::Bytes(a)
582 }
583 }
584 }
585}
586
587impl<S, E> fmt::Debug for MethodRouter<S, E> {
588 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
589 f.debug_struct("MethodRouter")
590 .field("get", &self.get)
591 .field("head", &self.head)
592 .field("delete", &self.delete)
593 .field("options", &self.options)
594 .field("patch", &self.patch)
595 .field("post", &self.post)
596 .field("put", &self.put)
597 .field("trace", &self.trace)
598 .field("connect", &self.connect)
599 .field("fallback", &self.fallback)
600 .field("allow_header", &self.allow_header)
601 .finish()
602 }
603}
604
605impl<S> MethodRouter<S, Infallible>
606where
607 S: Clone,
608{
609 #[track_caller]
631 pub fn on<H, T>(self, filter: MethodFilter, handler: H) -> Self
632 where
633 H: Handler<T, S>,
634 T: 'static,
635 S: Send + Sync + 'static,
636 {
637 self.on_endpoint(
638 filter,
639 MethodEndpoint::BoxedHandler(BoxedIntoRoute::from_handler(handler)),
640 )
641 }
642
643 chained_handler_fn!(connect, CONNECT);
644 chained_handler_fn!(delete, DELETE);
645 chained_handler_fn!(get, GET);
646 chained_handler_fn!(head, HEAD);
647 chained_handler_fn!(options, OPTIONS);
648 chained_handler_fn!(patch, PATCH);
649 chained_handler_fn!(post, POST);
650 chained_handler_fn!(put, PUT);
651 chained_handler_fn!(trace, TRACE);
652
653 pub fn fallback<H, T>(mut self, handler: H) -> Self
655 where
656 H: Handler<T, S>,
657 T: 'static,
658 S: Send + Sync + 'static,
659 {
660 self.fallback = Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler));
661 self
662 }
663
664 pub(crate) fn default_fallback<H, T>(self, handler: H) -> Self
666 where
667 H: Handler<T, S>,
668 T: 'static,
669 S: Send + Sync + 'static,
670 {
671 match self.fallback {
672 Fallback::Default(_) => self.fallback(handler),
673 _ => self,
674 }
675 }
676}
677
678impl MethodRouter<(), Infallible> {
679 #[must_use]
707 pub fn into_make_service(self) -> IntoMakeService<Self> {
708 IntoMakeService::new(self.with_state(()))
709 }
710
711 #[cfg(feature = "tokio")]
740 #[must_use]
741 pub fn into_make_service_with_connect_info<C>(self) -> IntoMakeServiceWithConnectInfo<Self, C> {
742 IntoMakeServiceWithConnectInfo::new(self.with_state(()))
743 }
744}
745
746impl<S, E> MethodRouter<S, E>
747where
748 S: Clone,
749{
750 pub fn new() -> Self {
753 let fallback = Route::new(service_fn(|_: Request| async {
754 Ok(StatusCode::METHOD_NOT_ALLOWED)
755 }));
756
757 Self {
758 get: MethodEndpoint::None,
759 head: MethodEndpoint::None,
760 delete: MethodEndpoint::None,
761 options: MethodEndpoint::None,
762 patch: MethodEndpoint::None,
763 post: MethodEndpoint::None,
764 put: MethodEndpoint::None,
765 trace: MethodEndpoint::None,
766 connect: MethodEndpoint::None,
767 allow_header: AllowHeader::None,
768 fallback: Fallback::Default(fallback),
769 }
770 }
771
772 pub fn with_state<S2>(self, state: S) -> MethodRouter<S2, E> {
774 MethodRouter {
775 get: self.get.with_state(&state),
776 head: self.head.with_state(&state),
777 delete: self.delete.with_state(&state),
778 options: self.options.with_state(&state),
779 patch: self.patch.with_state(&state),
780 post: self.post.with_state(&state),
781 put: self.put.with_state(&state),
782 trace: self.trace.with_state(&state),
783 connect: self.connect.with_state(&state),
784 allow_header: self.allow_header,
785 fallback: self.fallback.with_state(state),
786 }
787 }
788
789 #[track_caller]
813 pub fn on_service<T>(self, filter: MethodFilter, svc: T) -> Self
814 where
815 T: Service<Request, Error = E> + Clone + Send + Sync + 'static,
816 T::Response: IntoResponse + 'static,
817 T::Future: Send + 'static,
818 {
819 self.on_endpoint(filter, MethodEndpoint::Route(Route::new(svc)))
820 }
821
822 #[track_caller]
823 fn on_endpoint(mut self, filter: MethodFilter, endpoint: MethodEndpoint<S, E>) -> Self {
824 #[track_caller]
826 fn set_endpoint<S, E>(
827 method_name: &str,
828 out: &mut MethodEndpoint<S, E>,
829 endpoint: &MethodEndpoint<S, E>,
830 endpoint_filter: MethodFilter,
831 filter: MethodFilter,
832 allow_header: &mut AllowHeader,
833 methods: &[&'static str],
834 ) where
835 MethodEndpoint<S, E>: Clone,
836 S: Clone,
837 {
838 if endpoint_filter.contains(filter) {
839 if out.is_some() {
840 panic!(
841 "Overlapping method route. Cannot add two method routes that both handle \
842 `{method_name}`",
843 )
844 }
845 *out = endpoint.clone();
846 for method in methods {
847 append_allow_header(allow_header, method);
848 }
849 }
850 }
851
852 set_endpoint(
853 "GET",
854 &mut self.get,
855 &endpoint,
856 filter,
857 MethodFilter::GET,
858 &mut self.allow_header,
859 &["GET", "HEAD"],
860 );
861
862 set_endpoint(
863 "HEAD",
864 &mut self.head,
865 &endpoint,
866 filter,
867 MethodFilter::HEAD,
868 &mut self.allow_header,
869 &["HEAD"],
870 );
871
872 set_endpoint(
873 "TRACE",
874 &mut self.trace,
875 &endpoint,
876 filter,
877 MethodFilter::TRACE,
878 &mut self.allow_header,
879 &["TRACE"],
880 );
881
882 set_endpoint(
883 "PUT",
884 &mut self.put,
885 &endpoint,
886 filter,
887 MethodFilter::PUT,
888 &mut self.allow_header,
889 &["PUT"],
890 );
891
892 set_endpoint(
893 "POST",
894 &mut self.post,
895 &endpoint,
896 filter,
897 MethodFilter::POST,
898 &mut self.allow_header,
899 &["POST"],
900 );
901
902 set_endpoint(
903 "PATCH",
904 &mut self.patch,
905 &endpoint,
906 filter,
907 MethodFilter::PATCH,
908 &mut self.allow_header,
909 &["PATCH"],
910 );
911
912 set_endpoint(
913 "OPTIONS",
914 &mut self.options,
915 &endpoint,
916 filter,
917 MethodFilter::OPTIONS,
918 &mut self.allow_header,
919 &["OPTIONS"],
920 );
921
922 set_endpoint(
923 "DELETE",
924 &mut self.delete,
925 &endpoint,
926 filter,
927 MethodFilter::DELETE,
928 &mut self.allow_header,
929 &["DELETE"],
930 );
931
932 set_endpoint(
933 "CONNECT",
934 &mut self.options,
935 &endpoint,
936 filter,
937 MethodFilter::CONNECT,
938 &mut self.allow_header,
939 &["CONNECT"],
940 );
941
942 self
943 }
944
945 chained_service_fn!(connect_service, CONNECT);
946 chained_service_fn!(delete_service, DELETE);
947 chained_service_fn!(get_service, GET);
948 chained_service_fn!(head_service, HEAD);
949 chained_service_fn!(options_service, OPTIONS);
950 chained_service_fn!(patch_service, PATCH);
951 chained_service_fn!(post_service, POST);
952 chained_service_fn!(put_service, PUT);
953 chained_service_fn!(trace_service, TRACE);
954
955 #[doc = include_str!("../docs/method_routing/fallback.md")]
956 pub fn fallback_service<T>(mut self, svc: T) -> Self
957 where
958 T: Service<Request, Error = E> + Clone + Send + Sync + 'static,
959 T::Response: IntoResponse + 'static,
960 T::Future: Send + 'static,
961 {
962 self.fallback = Fallback::Service(Route::new(svc));
963 self
964 }
965
966 #[doc = include_str!("../docs/method_routing/layer.md")]
967 pub fn layer<L, NewError>(self, layer: L) -> MethodRouter<S, NewError>
968 where
969 L: Layer<Route<E>> + Clone + Send + Sync + 'static,
970 L::Service: Service<Request> + Clone + Send + Sync + 'static,
971 <L::Service as Service<Request>>::Response: IntoResponse + 'static,
972 <L::Service as Service<Request>>::Error: Into<NewError> + 'static,
973 <L::Service as Service<Request>>::Future: Send + 'static,
974 E: 'static,
975 S: 'static,
976 NewError: 'static,
977 {
978 let layer_fn = move |route: Route<E>| route.layer(layer.clone());
979
980 MethodRouter {
981 get: self.get.map(layer_fn.clone()),
982 head: self.head.map(layer_fn.clone()),
983 delete: self.delete.map(layer_fn.clone()),
984 options: self.options.map(layer_fn.clone()),
985 patch: self.patch.map(layer_fn.clone()),
986 post: self.post.map(layer_fn.clone()),
987 put: self.put.map(layer_fn.clone()),
988 trace: self.trace.map(layer_fn.clone()),
989 connect: self.connect.map(layer_fn.clone()),
990 fallback: self.fallback.map(layer_fn),
991 allow_header: self.allow_header,
992 }
993 }
994
995 #[doc = include_str!("../docs/method_routing/route_layer.md")]
996 #[track_caller]
997 pub fn route_layer<L>(mut self, layer: L) -> MethodRouter<S, E>
998 where
999 L: Layer<Route<E>> + Clone + Send + Sync + 'static,
1000 L::Service: Service<Request, Error = E> + Clone + Send + Sync + 'static,
1001 <L::Service as Service<Request>>::Response: IntoResponse + 'static,
1002 <L::Service as Service<Request>>::Future: Send + 'static,
1003 E: 'static,
1004 S: 'static,
1005 {
1006 if self.get.is_none()
1007 && self.head.is_none()
1008 && self.delete.is_none()
1009 && self.options.is_none()
1010 && self.patch.is_none()
1011 && self.post.is_none()
1012 && self.put.is_none()
1013 && self.trace.is_none()
1014 && self.connect.is_none()
1015 {
1016 panic!(
1017 "Adding a route_layer before any routes is a no-op. \
1018 Add the routes you want the layer to apply to first."
1019 );
1020 }
1021
1022 let layer_fn = move |svc| Route::new(layer.layer(svc));
1023
1024 self.get = self.get.map(layer_fn.clone());
1025 self.head = self.head.map(layer_fn.clone());
1026 self.delete = self.delete.map(layer_fn.clone());
1027 self.options = self.options.map(layer_fn.clone());
1028 self.patch = self.patch.map(layer_fn.clone());
1029 self.post = self.post.map(layer_fn.clone());
1030 self.put = self.put.map(layer_fn.clone());
1031 self.trace = self.trace.map(layer_fn.clone());
1032 self.connect = self.connect.map(layer_fn);
1033
1034 self
1035 }
1036
1037 pub(crate) fn merge_for_path(
1038 mut self,
1039 path: Option<&str>,
1040 other: MethodRouter<S, E>,
1041 ) -> Result<Self, Cow<'static, str>> {
1042 fn merge_inner<S, E>(
1044 path: Option<&str>,
1045 name: &str,
1046 first: MethodEndpoint<S, E>,
1047 second: MethodEndpoint<S, E>,
1048 ) -> Result<MethodEndpoint<S, E>, Cow<'static, str>> {
1049 match (first, second) {
1050 (MethodEndpoint::None, MethodEndpoint::None) => Ok(MethodEndpoint::None),
1051 (pick, MethodEndpoint::None) | (MethodEndpoint::None, pick) => Ok(pick),
1052 _ => {
1053 if let Some(path) = path {
1054 Err(format!(
1055 "Overlapping method route. Handler for `{name} {path}` already exists"
1056 )
1057 .into())
1058 } else {
1059 Err(format!(
1060 "Overlapping method route. Cannot merge two method routes that both \
1061 define `{name}`"
1062 )
1063 .into())
1064 }
1065 }
1066 }
1067 }
1068
1069 self.get = merge_inner(path, "GET", self.get, other.get)?;
1070 self.head = merge_inner(path, "HEAD", self.head, other.head)?;
1071 self.delete = merge_inner(path, "DELETE", self.delete, other.delete)?;
1072 self.options = merge_inner(path, "OPTIONS", self.options, other.options)?;
1073 self.patch = merge_inner(path, "PATCH", self.patch, other.patch)?;
1074 self.post = merge_inner(path, "POST", self.post, other.post)?;
1075 self.put = merge_inner(path, "PUT", self.put, other.put)?;
1076 self.trace = merge_inner(path, "TRACE", self.trace, other.trace)?;
1077 self.connect = merge_inner(path, "CONNECT", self.connect, other.connect)?;
1078
1079 self.fallback = self
1080 .fallback
1081 .merge(other.fallback)
1082 .ok_or("Cannot merge two `MethodRouter`s that both have a fallback")?;
1083
1084 self.allow_header = self.allow_header.merge(other.allow_header);
1085
1086 Ok(self)
1087 }
1088
1089 #[doc = include_str!("../docs/method_routing/merge.md")]
1090 #[track_caller]
1091 pub fn merge(self, other: MethodRouter<S, E>) -> Self {
1092 match self.merge_for_path(None, other) {
1093 Ok(t) => t,
1094 Err(e) => panic!("{e}"),
1096 }
1097 }
1098
1099 pub fn handle_error<F, T>(self, f: F) -> MethodRouter<S, Infallible>
1103 where
1104 F: Clone + Send + Sync + 'static,
1105 HandleError<Route<E>, F, T>: Service<Request, Error = Infallible>,
1106 <HandleError<Route<E>, F, T> as Service<Request>>::Future: Send,
1107 <HandleError<Route<E>, F, T> as Service<Request>>::Response: IntoResponse + Send,
1108 T: 'static,
1109 E: 'static,
1110 S: 'static,
1111 {
1112 self.layer(HandleErrorLayer::new(f))
1113 }
1114
1115 fn skip_allow_header(mut self) -> Self {
1116 self.allow_header = AllowHeader::Skip;
1117 self
1118 }
1119
1120 pub(crate) fn call_with_state(&self, req: Request, state: S) -> RouteFuture<E> {
1121 macro_rules! call {
1122 (
1123 $req:expr,
1124 $method_variant:ident,
1125 $svc:expr
1126 ) => {
1127 if *req.method() == Method::$method_variant {
1128 match $svc {
1129 MethodEndpoint::None => {}
1130 MethodEndpoint::Route(route) => {
1131 return route.clone().oneshot_inner_owned($req);
1132 }
1133 MethodEndpoint::BoxedHandler(handler) => {
1134 let route = handler.clone().into_route(state);
1135 return route.oneshot_inner_owned($req);
1136 }
1137 }
1138 }
1139 };
1140 }
1141
1142 let Self {
1144 get,
1145 head,
1146 delete,
1147 options,
1148 patch,
1149 post,
1150 put,
1151 trace,
1152 connect,
1153 fallback,
1154 allow_header,
1155 } = self;
1156
1157 call!(req, HEAD, head);
1158 call!(req, HEAD, get);
1159 call!(req, GET, get);
1160 call!(req, POST, post);
1161 call!(req, OPTIONS, options);
1162 call!(req, PATCH, patch);
1163 call!(req, PUT, put);
1164 call!(req, DELETE, delete);
1165 call!(req, TRACE, trace);
1166 call!(req, CONNECT, connect);
1167
1168 let future = fallback.clone().call_with_state(req, state);
1169
1170 match allow_header {
1171 AllowHeader::None => future.allow_header(Bytes::new()),
1172 AllowHeader::Skip => future,
1173 AllowHeader::Bytes(allow_header) => future.allow_header(allow_header.clone().freeze()),
1174 }
1175 }
1176}
1177
1178fn append_allow_header(allow_header: &mut AllowHeader, method: &'static str) {
1179 match allow_header {
1180 AllowHeader::None => {
1181 *allow_header = AllowHeader::Bytes(BytesMut::from(method));
1182 }
1183 AllowHeader::Skip => {}
1184 AllowHeader::Bytes(allow_header) => {
1185 if let Ok(s) = std::str::from_utf8(allow_header) {
1186 if !s.contains(method) {
1187 allow_header.extend_from_slice(b",");
1188 allow_header.extend_from_slice(method.as_bytes());
1189 }
1190 } else {
1191 #[cfg(debug_assertions)]
1192 panic!("`allow_header` contained invalid uft-8. This should never happen")
1193 }
1194 }
1195 }
1196}
1197
1198impl<S, E> Clone for MethodRouter<S, E> {
1199 fn clone(&self) -> Self {
1200 Self {
1201 get: self.get.clone(),
1202 head: self.head.clone(),
1203 delete: self.delete.clone(),
1204 options: self.options.clone(),
1205 patch: self.patch.clone(),
1206 post: self.post.clone(),
1207 put: self.put.clone(),
1208 trace: self.trace.clone(),
1209 connect: self.connect.clone(),
1210 fallback: self.fallback.clone(),
1211 allow_header: self.allow_header.clone(),
1212 }
1213 }
1214}
1215
1216impl<S, E> Default for MethodRouter<S, E>
1217where
1218 S: Clone,
1219{
1220 fn default() -> Self {
1221 Self::new()
1222 }
1223}
1224
1225enum MethodEndpoint<S, E> {
1226 None,
1227 Route(Route<E>),
1228 BoxedHandler(BoxedIntoRoute<S, E>),
1229}
1230
1231impl<S, E> MethodEndpoint<S, E>
1232where
1233 S: Clone,
1234{
1235 fn is_some(&self) -> bool {
1236 matches!(self, Self::Route(_) | Self::BoxedHandler(_))
1237 }
1238
1239 fn is_none(&self) -> bool {
1240 matches!(self, Self::None)
1241 }
1242
1243 fn map<F, E2>(self, f: F) -> MethodEndpoint<S, E2>
1244 where
1245 S: 'static,
1246 E: 'static,
1247 F: FnOnce(Route<E>) -> Route<E2> + Clone + Send + Sync + 'static,
1248 E2: 'static,
1249 {
1250 match self {
1251 Self::None => MethodEndpoint::None,
1252 Self::Route(route) => MethodEndpoint::Route(f(route)),
1253 Self::BoxedHandler(handler) => MethodEndpoint::BoxedHandler(handler.map(f)),
1254 }
1255 }
1256
1257 fn with_state<S2>(self, state: &S) -> MethodEndpoint<S2, E> {
1258 match self {
1259 MethodEndpoint::None => MethodEndpoint::None,
1260 MethodEndpoint::Route(route) => MethodEndpoint::Route(route),
1261 MethodEndpoint::BoxedHandler(handler) => {
1262 MethodEndpoint::Route(handler.into_route(state.clone()))
1263 }
1264 }
1265 }
1266}
1267
1268impl<S, E> Clone for MethodEndpoint<S, E> {
1269 fn clone(&self) -> Self {
1270 match self {
1271 Self::None => Self::None,
1272 Self::Route(inner) => Self::Route(inner.clone()),
1273 Self::BoxedHandler(inner) => Self::BoxedHandler(inner.clone()),
1274 }
1275 }
1276}
1277
1278impl<S, E> fmt::Debug for MethodEndpoint<S, E> {
1279 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1280 match self {
1281 Self::None => f.debug_tuple("None").finish(),
1282 Self::Route(inner) => inner.fmt(f),
1283 Self::BoxedHandler(_) => f.debug_tuple("BoxedHandler").finish(),
1284 }
1285 }
1286}
1287
1288impl<B, E> Service<Request<B>> for MethodRouter<(), E>
1289where
1290 B: HttpBody<Data = Bytes> + Send + 'static,
1291 B::Error: Into<BoxError>,
1292{
1293 type Response = Response;
1294 type Error = E;
1295 type Future = RouteFuture<E>;
1296
1297 #[inline]
1298 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1299 Poll::Ready(Ok(()))
1300 }
1301
1302 #[inline]
1303 fn call(&mut self, req: Request<B>) -> Self::Future {
1304 let req = req.map(Body::new);
1305 self.call_with_state(req, ())
1306 }
1307}
1308
1309impl<S> Handler<(), S> for MethodRouter<S>
1310where
1311 S: Clone + 'static,
1312{
1313 type Future = InfallibleRouteFuture;
1314
1315 fn call(self, req: Request, state: S) -> Self::Future {
1316 InfallibleRouteFuture::new(self.call_with_state(req, state))
1317 }
1318}
1319
1320#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
1322const _: () = {
1323 use crate::serve;
1324
1325 impl<L> Service<serve::IncomingStream<'_, L>> for MethodRouter<()>
1326 where
1327 L: serve::Listener,
1328 {
1329 type Response = Self;
1330 type Error = Infallible;
1331 type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
1332
1333 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1334 Poll::Ready(Ok(()))
1335 }
1336
1337 fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future {
1338 std::future::ready(Ok(self.clone().with_state(())))
1339 }
1340 }
1341};
1342
1343#[cfg(test)]
1344mod tests {
1345 use super::*;
1346 use crate::{extract::State, handler::HandlerWithoutStateExt};
1347 use http::{header::ALLOW, HeaderMap};
1348 use http_body_util::BodyExt;
1349 use std::time::Duration;
1350 use tower::ServiceExt;
1351 use tower_http::{
1352 services::fs::ServeDir, timeout::TimeoutLayer, validate_request::ValidateRequestHeaderLayer,
1353 };
1354
1355 #[crate::test]
1356 async fn method_not_allowed_by_default() {
1357 let mut svc = MethodRouter::new();
1358 let (status, _, body) = call(Method::GET, &mut svc).await;
1359 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1360 assert!(body.is_empty());
1361 }
1362
1363 #[crate::test]
1364 async fn get_service_fn() {
1365 async fn handle(_req: Request) -> Result<Response<Body>, Infallible> {
1366 Ok(Response::new(Body::from("ok")))
1367 }
1368
1369 let mut svc = get_service(service_fn(handle));
1370
1371 let (status, _, body) = call(Method::GET, &mut svc).await;
1372 assert_eq!(status, StatusCode::OK);
1373 assert_eq!(body, "ok");
1374 }
1375
1376 #[crate::test]
1377 async fn get_handler() {
1378 let mut svc = MethodRouter::new().get(ok);
1379 let (status, _, body) = call(Method::GET, &mut svc).await;
1380 assert_eq!(status, StatusCode::OK);
1381 assert_eq!(body, "ok");
1382 }
1383
1384 #[crate::test]
1385 async fn get_accepts_head() {
1386 let mut svc = MethodRouter::new().get(ok);
1387 let (status, _, body) = call(Method::HEAD, &mut svc).await;
1388 assert_eq!(status, StatusCode::OK);
1389 assert!(body.is_empty());
1390 }
1391
1392 #[crate::test]
1393 async fn head_takes_precedence_over_get() {
1394 let mut svc = MethodRouter::new().head(created).get(ok);
1395 let (status, _, body) = call(Method::HEAD, &mut svc).await;
1396 assert_eq!(status, StatusCode::CREATED);
1397 assert!(body.is_empty());
1398 }
1399
1400 #[crate::test]
1401 async fn merge() {
1402 let mut svc = get(ok).merge(post(ok));
1403
1404 let (status, _, _) = call(Method::GET, &mut svc).await;
1405 assert_eq!(status, StatusCode::OK);
1406
1407 let (status, _, _) = call(Method::POST, &mut svc).await;
1408 assert_eq!(status, StatusCode::OK);
1409 }
1410
1411 #[crate::test]
1412 async fn layer() {
1413 let mut svc = MethodRouter::new()
1414 .get(|| async { std::future::pending::<()>().await })
1415 .layer(ValidateRequestHeaderLayer::bearer("password"));
1416
1417 let (status, _, _) = call(Method::GET, &mut svc).await;
1419 assert_eq!(status, StatusCode::UNAUTHORIZED);
1420
1421 let (status, _, _) = call(Method::DELETE, &mut svc).await;
1423 assert_eq!(status, StatusCode::UNAUTHORIZED);
1424 }
1425
1426 #[crate::test]
1427 async fn route_layer() {
1428 let mut svc = MethodRouter::new()
1429 .get(|| async { std::future::pending::<()>().await })
1430 .route_layer(ValidateRequestHeaderLayer::bearer("password"));
1431
1432 let (status, _, _) = call(Method::GET, &mut svc).await;
1434 assert_eq!(status, StatusCode::UNAUTHORIZED);
1435
1436 let (status, _, _) = call(Method::DELETE, &mut svc).await;
1438 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1439 }
1440
1441 #[allow(dead_code)]
1442 async fn building_complex_router() {
1443 let app = crate::Router::new().route(
1444 "/",
1445 get(ok)
1447 .post(ok)
1448 .route_layer(ValidateRequestHeaderLayer::bearer("password"))
1449 .merge(delete_service(ServeDir::new(".")))
1450 .fallback(|| async { StatusCode::NOT_FOUND })
1451 .put(ok)
1452 .layer(TimeoutLayer::new(Duration::from_secs(10))),
1453 );
1454
1455 let listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
1456 crate::serve(listener, app).await.unwrap();
1457 }
1458
1459 #[crate::test]
1460 async fn sets_allow_header() {
1461 let mut svc = MethodRouter::new().put(ok).patch(ok);
1462 let (status, headers, _) = call(Method::GET, &mut svc).await;
1463 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1464 assert_eq!(headers[ALLOW], "PUT,PATCH");
1465 }
1466
1467 #[crate::test]
1468 async fn sets_allow_header_get_head() {
1469 let mut svc = MethodRouter::new().get(ok).head(ok);
1470 let (status, headers, _) = call(Method::PUT, &mut svc).await;
1471 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1472 assert_eq!(headers[ALLOW], "GET,HEAD");
1473 }
1474
1475 #[crate::test]
1476 async fn empty_allow_header_by_default() {
1477 let mut svc = MethodRouter::new();
1478 let (status, headers, _) = call(Method::PATCH, &mut svc).await;
1479 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1480 assert_eq!(headers[ALLOW], "");
1481 }
1482
1483 #[crate::test]
1484 async fn allow_header_when_merging() {
1485 let a = put(ok).patch(ok);
1486 let b = get(ok).head(ok);
1487 let mut svc = a.merge(b);
1488
1489 let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1490 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1491 assert_eq!(headers[ALLOW], "PUT,PATCH,GET,HEAD");
1492 }
1493
1494 #[crate::test]
1495 async fn allow_header_any() {
1496 let mut svc = any(ok);
1497
1498 let (status, headers, _) = call(Method::GET, &mut svc).await;
1499 assert_eq!(status, StatusCode::OK);
1500 assert!(!headers.contains_key(ALLOW));
1501 }
1502
1503 #[crate::test]
1504 async fn allow_header_with_fallback() {
1505 let mut svc = MethodRouter::new()
1506 .get(ok)
1507 .fallback(|| async { (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed") });
1508
1509 let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1510 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1511 assert_eq!(headers[ALLOW], "GET,HEAD");
1512 }
1513
1514 #[crate::test]
1515 async fn allow_header_with_fallback_that_sets_allow() {
1516 async fn fallback(method: Method) -> Response {
1517 if method == Method::POST {
1518 "OK".into_response()
1519 } else {
1520 (
1521 StatusCode::METHOD_NOT_ALLOWED,
1522 [(ALLOW, "GET,POST")],
1523 "Method not allowed",
1524 )
1525 .into_response()
1526 }
1527 }
1528
1529 let mut svc = MethodRouter::new().get(ok).fallback(fallback);
1530
1531 let (status, _, _) = call(Method::GET, &mut svc).await;
1532 assert_eq!(status, StatusCode::OK);
1533
1534 let (status, _, _) = call(Method::POST, &mut svc).await;
1535 assert_eq!(status, StatusCode::OK);
1536
1537 let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1538 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1539 assert_eq!(headers[ALLOW], "GET,POST");
1540 }
1541
1542 #[crate::test]
1543 async fn allow_header_noop_middleware() {
1544 let mut svc = MethodRouter::new()
1545 .get(ok)
1546 .layer(tower::layer::util::Identity::new());
1547
1548 let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1549 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1550 assert_eq!(headers[ALLOW], "GET,HEAD");
1551 }
1552
1553 #[crate::test]
1554 #[should_panic(
1555 expected = "Overlapping method route. Cannot add two method routes that both handle `GET`"
1556 )]
1557 async fn handler_overlaps() {
1558 let _: MethodRouter<()> = get(ok).get(ok);
1559 }
1560
1561 #[crate::test]
1562 #[should_panic(
1563 expected = "Overlapping method route. Cannot add two method routes that both handle `POST`"
1564 )]
1565 async fn service_overlaps() {
1566 let _: MethodRouter<()> = post_service(ok.into_service()).post_service(ok.into_service());
1567 }
1568
1569 #[crate::test]
1570 async fn get_head_does_not_overlap() {
1571 let _: MethodRouter<()> = get(ok).head(ok);
1572 }
1573
1574 #[crate::test]
1575 async fn head_get_does_not_overlap() {
1576 let _: MethodRouter<()> = head(ok).get(ok);
1577 }
1578
1579 #[crate::test]
1580 async fn accessing_state() {
1581 let mut svc = MethodRouter::new()
1582 .get(|State(state): State<&'static str>| async move { state })
1583 .with_state("state");
1584
1585 let (status, _, text) = call(Method::GET, &mut svc).await;
1586
1587 assert_eq!(status, StatusCode::OK);
1588 assert_eq!(text, "state");
1589 }
1590
1591 #[crate::test]
1592 async fn fallback_accessing_state() {
1593 let mut svc = MethodRouter::new()
1594 .fallback(|State(state): State<&'static str>| async move { state })
1595 .with_state("state");
1596
1597 let (status, _, text) = call(Method::GET, &mut svc).await;
1598
1599 assert_eq!(status, StatusCode::OK);
1600 assert_eq!(text, "state");
1601 }
1602
1603 #[crate::test]
1604 async fn merge_accessing_state() {
1605 let one = get(|State(state): State<&'static str>| async move { state });
1606 let two = post(|State(state): State<&'static str>| async move { state });
1607
1608 let mut svc = one.merge(two).with_state("state");
1609
1610 let (status, _, text) = call(Method::GET, &mut svc).await;
1611 assert_eq!(status, StatusCode::OK);
1612 assert_eq!(text, "state");
1613
1614 let (status, _, _) = call(Method::POST, &mut svc).await;
1615 assert_eq!(status, StatusCode::OK);
1616 assert_eq!(text, "state");
1617 }
1618
1619 async fn call<S>(method: Method, svc: &mut S) -> (StatusCode, HeaderMap, String)
1620 where
1621 S: Service<Request, Error = Infallible>,
1622 S::Response: IntoResponse,
1623 {
1624 let request = Request::builder()
1625 .uri("/")
1626 .method(method)
1627 .body(Body::empty())
1628 .unwrap();
1629 let response = svc
1630 .ready()
1631 .await
1632 .unwrap()
1633 .call(request)
1634 .await
1635 .unwrap()
1636 .into_response();
1637 let (parts, body) = response.into_parts();
1638 let body =
1639 String::from_utf8(BodyExt::collect(body).await.unwrap().to_bytes().to_vec()).unwrap();
1640 (parts.status, parts.headers, body)
1641 }
1642
1643 async fn ok() -> (StatusCode, &'static str) {
1644 (StatusCode::OK, "ok")
1645 }
1646
1647 async fn created() -> (StatusCode, &'static str) {
1648 (StatusCode::CREATED, "created")
1649 }
1650}