axum/
extension.rs

1use crate::{extract::rejection::*, response::IntoResponseParts};
2use axum_core::extract::OptionalFromRequestParts;
3use axum_core::{
4    extract::FromRequestParts,
5    response::{IntoResponse, Response, ResponseParts},
6};
7use http::{request::Parts, Extensions, Request};
8use std::{
9    convert::Infallible,
10    task::{Context, Poll},
11};
12use tower_service::Service;
13
14/// Extractor and response for extensions.
15///
16/// # As extractor
17///
18/// This is commonly used to share state across handlers.
19///
20/// ```rust,no_run
21/// use axum::{
22///     Router,
23///     Extension,
24///     routing::get,
25/// };
26/// use std::sync::Arc;
27///
28/// // Some shared state used throughout our application
29/// struct State {
30///     // ...
31/// }
32///
33/// async fn handler(state: Extension<Arc<State>>) {
34///     // ...
35/// }
36///
37/// let state = Arc::new(State { /* ... */ });
38///
39/// let app = Router::new().route("/", get(handler))
40///     // Add middleware that inserts the state into all incoming request's
41///     // extensions.
42///     .layer(Extension(state));
43/// # let _: Router = app;
44/// ```
45///
46/// If the extension is missing it will reject the request with a `500 Internal
47/// Server Error` response. Alternatively, you can use `Option<Extension<T>>` to
48/// make the extension extractor optional.
49///
50/// # As response
51///
52/// Response extensions can be used to share state with middleware.
53///
54/// ```rust
55/// use axum::{
56///     Extension,
57///     response::IntoResponse,
58/// };
59///
60/// async fn handler() -> (Extension<Foo>, &'static str) {
61///     (
62///         Extension(Foo("foo")),
63///         "Hello, World!"
64///     )
65/// }
66///
67/// #[derive(Clone)]
68/// struct Foo(&'static str);
69/// ```
70#[derive(Debug, Clone, Copy, Default)]
71#[must_use]
72pub struct Extension<T>(pub T);
73
74impl<T> Extension<T>
75where
76    T: Clone + Send + Sync + 'static,
77{
78    fn from_extensions(extensions: &Extensions) -> Option<Self> {
79        extensions.get::<T>().cloned().map(Extension)
80    }
81}
82
83impl<T, S> FromRequestParts<S> for Extension<T>
84where
85    T: Clone + Send + Sync + 'static,
86    S: Send + Sync,
87{
88    type Rejection = ExtensionRejection;
89
90    async fn from_request_parts(req: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
91        Ok(Self::from_extensions(&req.extensions).ok_or_else(|| {
92            MissingExtension::from_err(format!(
93                "Extension of type `{}` was not found. Perhaps you forgot to add it? See `axum::Extension`.",
94                std::any::type_name::<T>()
95            ))
96        })?)
97    }
98}
99
100impl<T, S> OptionalFromRequestParts<S> for Extension<T>
101where
102    T: Clone + Send + Sync + 'static,
103    S: Send + Sync,
104{
105    type Rejection = Infallible;
106
107    async fn from_request_parts(
108        req: &mut Parts,
109        _state: &S,
110    ) -> Result<Option<Self>, Self::Rejection> {
111        Ok(Self::from_extensions(&req.extensions))
112    }
113}
114
115axum_core::__impl_deref!(Extension);
116
117impl<T> IntoResponseParts for Extension<T>
118where
119    T: Clone + Send + Sync + 'static,
120{
121    type Error = Infallible;
122
123    fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
124        res.extensions_mut().insert(self.0);
125        Ok(res)
126    }
127}
128
129impl<T> IntoResponse for Extension<T>
130where
131    T: Clone + Send + Sync + 'static,
132{
133    fn into_response(self) -> Response {
134        let mut res = ().into_response();
135        res.extensions_mut().insert(self.0);
136        res
137    }
138}
139
140impl<S, T> tower_layer::Layer<S> for Extension<T>
141where
142    T: Clone + Send + Sync + 'static,
143{
144    type Service = AddExtension<S, T>;
145
146    fn layer(&self, inner: S) -> Self::Service {
147        AddExtension {
148            inner,
149            value: self.0.clone(),
150        }
151    }
152}
153
154/// Middleware for adding some shareable value to [request extensions].
155///
156/// See [Passing state from middleware to handlers](index.html#passing-state-from-middleware-to-handlers)
157/// for more details.
158///
159/// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html
160///
161/// If you need a layer to add an extension to every request,
162/// use the [Layer](tower::Layer) implementation of [Extension].
163#[derive(Clone, Copy, Debug)]
164pub struct AddExtension<S, T> {
165    pub(crate) inner: S,
166    pub(crate) value: T,
167}
168
169impl<ResBody, S, T> Service<Request<ResBody>> for AddExtension<S, T>
170where
171    S: Service<Request<ResBody>>,
172    T: Clone + Send + Sync + 'static,
173{
174    type Response = S::Response;
175    type Error = S::Error;
176    type Future = S::Future;
177
178    #[inline]
179    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
180        self.inner.poll_ready(cx)
181    }
182
183    fn call(&mut self, mut req: Request<ResBody>) -> Self::Future {
184        req.extensions_mut().insert(self.value.clone());
185        self.inner.call(req)
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192    use crate::routing::get;
193    use crate::test_helpers::TestClient;
194    use crate::Router;
195    use http::StatusCode;
196
197    #[derive(Clone)]
198    struct Foo(String);
199
200    #[derive(Clone)]
201    struct Bar(String);
202
203    #[crate::test]
204    async fn extension_extractor() {
205        async fn requires_foo(Extension(foo): Extension<Foo>) -> String {
206            foo.0
207        }
208
209        async fn optional_foo(extension: Option<Extension<Foo>>) -> String {
210            extension.map(|foo| foo.0 .0).unwrap_or("none".to_owned())
211        }
212
213        async fn requires_bar(Extension(bar): Extension<Bar>) -> String {
214            bar.0
215        }
216
217        async fn optional_bar(extension: Option<Extension<Bar>>) -> String {
218            extension.map(|bar| bar.0 .0).unwrap_or("none".to_owned())
219        }
220
221        let app = Router::new()
222            .route("/requires_foo", get(requires_foo))
223            .route("/optional_foo", get(optional_foo))
224            .route("/requires_bar", get(requires_bar))
225            .route("/optional_bar", get(optional_bar))
226            .layer(Extension(Foo("foo".to_owned())));
227
228        let client = TestClient::new(app);
229
230        let response = client.get("/requires_foo").await;
231        assert_eq!(response.status(), StatusCode::OK);
232        assert_eq!(response.text().await, "foo");
233
234        let response = client.get("/optional_foo").await;
235        assert_eq!(response.status(), StatusCode::OK);
236        assert_eq!(response.text().await, "foo");
237
238        let response = client.get("/requires_bar").await;
239        assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
240        assert_eq!(response.text().await, "Missing request extension: Extension of type `axum::extension::tests::Bar` was not found. Perhaps you forgot to add it? See `axum::Extension`.");
241
242        let response = client.get("/optional_bar").await;
243        assert_eq!(response.status(), StatusCode::OK);
244        assert_eq!(response.text().await, "none");
245    }
246}