axum/extract/
state.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
use async_trait::async_trait;
use axum_core::extract::{FromRef, FromRequestParts};
use http::request::Parts;
use std::{
    convert::Infallible,
    ops::{Deref, DerefMut},
};

/// Extractor for state.
///
/// See ["Accessing state in middleware"][state-from-middleware] for how to
/// access state in middleware.
///
/// State is global and used in every request a router with state receives.
/// For accessing data derived from requests, such as authorization data, see [`Extension`].
///
/// [state-from-middleware]: crate::middleware#accessing-state-in-middleware
/// [`Extension`]: crate::Extension
///
/// # With `Router`
///
/// ```
/// use axum::{Router, routing::get, extract::State};
///
/// // the application state
/// //
/// // here you can put configuration, database connection pools, or whatever
/// // state you need
/// //
/// // see "When states need to implement `Clone`" for more details on why we need
/// // `#[derive(Clone)]` here.
/// #[derive(Clone)]
/// struct AppState {}
///
/// let state = AppState {};
///
/// // create a `Router` that holds our state
/// let app = Router::new()
///     .route("/", get(handler))
///     // provide the state so the router can access it
///     .with_state(state);
///
/// async fn handler(
///     // access the state via the `State` extractor
///     // extracting a state of the wrong type results in a compile error
///     State(state): State<AppState>,
/// ) {
///     // use `state`...
/// }
/// # let _: axum::Router = app;
/// ```
///
/// Note that `State` is an extractor, so be sure to put it before any body
/// extractors, see ["the order of extractors"][order-of-extractors].
///
/// [order-of-extractors]: crate::extract#the-order-of-extractors
///
/// ## Combining stateful routers
///
/// Multiple [`Router`]s can be combined with [`Router::nest`] or [`Router::merge`]
/// When combining [`Router`]s with one of these methods, the [`Router`]s must have
/// the same state type. Generally, this can be inferred automatically:
///
/// ```
/// use axum::{Router, routing::get, extract::State};
///
/// #[derive(Clone)]
/// struct AppState {}
///
/// let state = AppState {};
///
/// // create a `Router` that will be nested within another
/// let api = Router::new()
///     .route("/posts", get(posts_handler));
///
/// let app = Router::new()
///     .nest("/api", api)
///     .with_state(state);
///
/// async fn posts_handler(State(state): State<AppState>) {
///     // use `state`...
/// }
/// # let _: axum::Router = app;
/// ```
///
/// However, if you are composing [`Router`]s that are defined in separate scopes,
/// you may need to annotate the [`State`] type explicitly:
///
/// ```
/// use axum::{Router, routing::get, extract::State};
///
/// #[derive(Clone)]
/// struct AppState {}
///
/// fn make_app() -> Router {
///     let state = AppState {};
///
///     Router::new()
///         .nest("/api", make_api())
///         .with_state(state) // the outer Router's state is inferred
/// }
///
/// // the inner Router must specify its state type to compose with the
/// // outer router
/// fn make_api() -> Router<AppState> {
///     Router::new()
///         .route("/posts", get(posts_handler))
/// }
///
/// async fn posts_handler(State(state): State<AppState>) {
///     // use `state`...
/// }
/// # let _: axum::Router = make_app();
/// ```
///
/// In short, a [`Router`]'s generic state type defaults to `()`
/// (no state) unless [`Router::with_state`] is called or the value
/// of the generic type is given explicitly.
///
/// [`Router`]: crate::Router
/// [`Router::merge`]: crate::Router::merge
/// [`Router::nest`]: crate::Router::nest
/// [`Router::with_state`]: crate::Router::with_state
///
/// # With `MethodRouter`
///
/// ```
/// use axum::{routing::get, extract::State};
///
/// #[derive(Clone)]
/// struct AppState {}
///
/// let state = AppState {};
///
/// let method_router_with_state = get(handler)
///     // provide the state so the handler can access it
///     .with_state(state);
/// # let _: axum::routing::MethodRouter = method_router_with_state;
///
/// async fn handler(State(state): State<AppState>) {
///     // use `state`...
/// }
/// ```
///
/// # With `Handler`
///
/// ```
/// use axum::{routing::get, handler::Handler, extract::State};
///
/// #[derive(Clone)]
/// struct AppState {}
///
/// let state = AppState {};
///
/// async fn handler(State(state): State<AppState>) {
///     // use `state`...
/// }
///
/// // provide the state so the handler can access it
/// let handler_with_state = handler.with_state(state);
///
/// # async {
/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
/// axum::serve(listener, handler_with_state.into_make_service()).await.unwrap();
/// # };
/// ```
///
/// # Substates
///
/// [`State`] only allows a single state type but you can use [`FromRef`] to extract "substates":
///
/// ```
/// use axum::{Router, routing::get, extract::{State, FromRef}};
///
/// // the application state
/// #[derive(Clone)]
/// struct AppState {
///     // that holds some api specific state
///     api_state: ApiState,
/// }
///
/// // the api specific state
/// #[derive(Clone)]
/// struct ApiState {}
///
/// // support converting an `AppState` in an `ApiState`
/// impl FromRef<AppState> for ApiState {
///     fn from_ref(app_state: &AppState) -> ApiState {
///         app_state.api_state.clone()
///     }
/// }
///
/// let state = AppState {
///     api_state: ApiState {},
/// };
///
/// let app = Router::new()
///     .route("/", get(handler))
///     .route("/api/users", get(api_users))
///     .with_state(state);
///
/// async fn api_users(
///     // access the api specific state
///     State(api_state): State<ApiState>,
/// ) {
/// }
///
/// async fn handler(
///     // we can still access to top level state
///     State(state): State<AppState>,
/// ) {
/// }
/// # let _: axum::Router = app;
/// ```
///
/// For convenience `FromRef` can also be derived using `#[derive(FromRef)]`.
///
/// # For library authors
///
/// If you're writing a library that has an extractor that needs state, this is the recommended way
/// to do it:
///
/// ```rust
/// use axum_core::extract::{FromRequestParts, FromRef};
/// use http::request::Parts;
/// use async_trait::async_trait;
/// use std::convert::Infallible;
///
/// // the extractor your library provides
/// struct MyLibraryExtractor;
///
/// #[async_trait]
/// impl<S> FromRequestParts<S> for MyLibraryExtractor
/// where
///     // keep `S` generic but require that it can produce a `MyLibraryState`
///     // this means users will have to implement `FromRef<UserState> for MyLibraryState`
///     MyLibraryState: FromRef<S>,
///     S: Send + Sync,
/// {
///     type Rejection = Infallible;
///
///     async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
///         // get a `MyLibraryState` from a reference to the state
///         let state = MyLibraryState::from_ref(state);
///
///         // ...
///         # todo!()
///     }
/// }
///
/// // the state your library needs
/// struct MyLibraryState {
///     // ...
/// }
/// ```
///
/// # When states need to implement `Clone`
///
/// Your top level state type must implement `Clone` to be extractable with `State`:
///
/// ```
/// use axum::extract::State;
///
/// // no substates, so to extract to `State<AppState>` we must implement `Clone` for `AppState`
/// #[derive(Clone)]
/// struct AppState {}
///
/// async fn handler(State(state): State<AppState>) {
///     // ...
/// }
/// ```
///
/// This works because of [`impl<S> FromRef<S> for S where S: Clone`][`FromRef`].
///
/// This is also true if you're extracting substates, unless you _never_ extract the top level
/// state itself:
///
/// ```
/// use axum::extract::{State, FromRef};
///
/// // we never extract `State<AppState>`, just `State<InnerState>`. So `AppState` doesn't need to
/// // implement `Clone`
/// struct AppState {
///     inner: InnerState,
/// }
///
/// #[derive(Clone)]
/// struct InnerState {}
///
/// impl FromRef<AppState> for InnerState {
///     fn from_ref(app_state: &AppState) -> InnerState {
///         app_state.inner.clone()
///     }
/// }
///
/// async fn api_users(State(inner): State<InnerState>) {
///     // ...
/// }
/// ```
///
/// In general however we recommend you implement `Clone` for all your state types to avoid
/// potential type errors.
///
/// # Shared mutable state
///
/// [As state is global within a `Router`][global] you can't directly get a mutable reference to
/// the state.
///
/// The most basic solution is to use an `Arc<Mutex<_>>`. Which kind of mutex you need depends on
/// your use case. See [the tokio docs] for more details.
///
/// Note that holding a locked `std::sync::Mutex` across `.await` points will result in `!Send`
/// futures which are incompatible with axum. If you need to hold a mutex across `.await` points,
/// consider using a `tokio::sync::Mutex` instead.
///
/// ## Example
///
/// ```
/// use axum::{Router, routing::get, extract::State};
/// use std::sync::{Arc, Mutex};
///
/// #[derive(Clone)]
/// struct AppState {
///     data: Arc<Mutex<String>>,
/// }
///
/// async fn handler(State(state): State<AppState>) {
///     {
///         let mut data = state.data.lock().expect("mutex was poisoned");
///         *data = "updated foo".to_owned();
///     }
///
///     // ...
/// }
///
/// let state = AppState {
///     data: Arc::new(Mutex::new("foo".to_owned())),
/// };
///
/// let app = Router::new()
///     .route("/", get(handler))
///     .with_state(state);
/// # let _: Router = app;
/// ```
///
/// [global]: crate::Router::with_state
/// [the tokio docs]: https://docs.rs/tokio/1.25.0/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use
#[derive(Debug, Default, Clone, Copy)]
pub struct State<S>(pub S);

#[async_trait]
impl<OuterState, InnerState> FromRequestParts<OuterState> for State<InnerState>
where
    InnerState: FromRef<OuterState>,
    OuterState: Send + Sync,
{
    type Rejection = Infallible;

    async fn from_request_parts(
        _parts: &mut Parts,
        state: &OuterState,
    ) -> Result<Self, Self::Rejection> {
        let inner_state = InnerState::from_ref(state);
        Ok(Self(inner_state))
    }
}

impl<S> Deref for State<S> {
    type Target = S;

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl<S> DerefMut for State<S> {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.0
    }
}