tonic/transport/service/
grpc_timeout.rs

1use crate::{metadata::GRPC_TIMEOUT_HEADER, TimeoutExpired};
2use http::{HeaderMap, HeaderValue, Request};
3use pin_project::pin_project;
4use std::{
5    future::Future,
6    pin::Pin,
7    task::{ready, Context, Poll},
8    time::Duration,
9};
10use tokio::time::Sleep;
11use tower_service::Service;
12
13#[derive(Debug, Clone)]
14pub(crate) struct GrpcTimeout<S> {
15    inner: S,
16    server_timeout: Option<Duration>,
17}
18
19impl<S> GrpcTimeout<S> {
20    pub(crate) fn new(inner: S, server_timeout: Option<Duration>) -> Self {
21        Self {
22            inner,
23            server_timeout,
24        }
25    }
26}
27
28impl<S, ReqBody> Service<Request<ReqBody>> for GrpcTimeout<S>
29where
30    S: Service<Request<ReqBody>>,
31    S::Error: Into<crate::BoxError>,
32{
33    type Response = S::Response;
34    type Error = crate::BoxError;
35    type Future = ResponseFuture<S::Future>;
36
37    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
38        self.inner.poll_ready(cx).map_err(Into::into)
39    }
40
41    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
42        let client_timeout = try_parse_grpc_timeout(req.headers()).unwrap_or_else(|e| {
43            tracing::trace!("Error parsing `grpc-timeout` header {:?}", e);
44            None
45        });
46
47        // Use the shorter of the two durations, if either are set
48        let timeout_duration = match (client_timeout, self.server_timeout) {
49            (None, None) => None,
50            (Some(dur), None) => Some(dur),
51            (None, Some(dur)) => Some(dur),
52            (Some(header), Some(server)) => {
53                let shorter_duration = std::cmp::min(header, server);
54                Some(shorter_duration)
55            }
56        };
57
58        ResponseFuture {
59            inner: self.inner.call(req),
60            sleep: timeout_duration.map(tokio::time::sleep),
61        }
62    }
63}
64
65#[pin_project]
66pub(crate) struct ResponseFuture<F> {
67    #[pin]
68    inner: F,
69    #[pin]
70    sleep: Option<Sleep>,
71}
72
73impl<F, Res, E> Future for ResponseFuture<F>
74where
75    F: Future<Output = Result<Res, E>>,
76    E: Into<crate::BoxError>,
77{
78    type Output = Result<Res, crate::BoxError>;
79
80    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
81        let this = self.project();
82
83        if let ready @ Poll::Ready(_) = this.inner.poll(cx) {
84            return ready.map_err(Into::into);
85        }
86
87        if let Some(sleep) = this.sleep.as_pin_mut() {
88            ready!(sleep.poll(cx));
89            return Poll::Ready(Err(TimeoutExpired(()).into()));
90        }
91
92        Poll::Pending
93    }
94}
95
96const SECONDS_IN_HOUR: u64 = 60 * 60;
97const SECONDS_IN_MINUTE: u64 = 60;
98
99/// Tries to parse the `grpc-timeout` header if it is present. If we fail to parse, returns
100/// the value we attempted to parse.
101///
102/// Follows the [gRPC over HTTP2 spec](https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md).
103fn try_parse_grpc_timeout(
104    headers: &HeaderMap<HeaderValue>,
105) -> Result<Option<Duration>, &HeaderValue> {
106    let Some(val) = headers.get(GRPC_TIMEOUT_HEADER) else {
107        return Ok(None);
108    };
109
110    let (timeout_value, timeout_unit) = val
111        .to_str()
112        .map_err(|_| val)
113        .and_then(|s| if s.is_empty() { Err(val) } else { Ok(s) })?
114        // `HeaderValue::to_str` only returns `Ok` if the header contains ASCII so this
115        // `split_at` will never panic from trying to split in the middle of a character.
116        // See https://docs.rs/http/1/http/header/struct.HeaderValue.html#method.to_str
117        //
118        // `len - 1` also wont panic since we just checked `s.is_empty`.
119        .split_at(val.len() - 1);
120
121    // gRPC spec specifies `TimeoutValue` will be at most 8 digits
122    // Caping this at 8 digits also prevents integer overflow from ever occurring
123    if timeout_value.len() > 8 {
124        return Err(val);
125    }
126
127    let timeout_value: u64 = timeout_value.parse().map_err(|_| val)?;
128
129    let duration = match timeout_unit {
130        // Hours
131        "H" => Duration::from_secs(timeout_value * SECONDS_IN_HOUR),
132        // Minutes
133        "M" => Duration::from_secs(timeout_value * SECONDS_IN_MINUTE),
134        // Seconds
135        "S" => Duration::from_secs(timeout_value),
136        // Milliseconds
137        "m" => Duration::from_millis(timeout_value),
138        // Microseconds
139        "u" => Duration::from_micros(timeout_value),
140        // Nanoseconds
141        "n" => Duration::from_nanos(timeout_value),
142        _ => return Err(val),
143    };
144
145    Ok(Some(duration))
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use quickcheck::{Arbitrary, Gen};
152    use quickcheck_macros::quickcheck;
153
154    // Helper function to reduce the boiler plate of our test cases
155    fn setup_map_try_parse(val: Option<&str>) -> Result<Option<Duration>, HeaderValue> {
156        let mut hm = HeaderMap::new();
157        if let Some(v) = val {
158            let hv = HeaderValue::from_str(v).unwrap();
159            hm.insert(GRPC_TIMEOUT_HEADER, hv);
160        };
161
162        try_parse_grpc_timeout(&hm).map_err(|e| e.clone())
163    }
164
165    #[test]
166    fn test_hours() {
167        let parsed_duration = setup_map_try_parse(Some("3H")).unwrap().unwrap();
168        assert_eq!(Duration::from_secs(3 * 60 * 60), parsed_duration);
169    }
170
171    #[test]
172    fn test_minutes() {
173        let parsed_duration = setup_map_try_parse(Some("1M")).unwrap().unwrap();
174        assert_eq!(Duration::from_secs(60), parsed_duration);
175    }
176
177    #[test]
178    fn test_seconds() {
179        let parsed_duration = setup_map_try_parse(Some("42S")).unwrap().unwrap();
180        assert_eq!(Duration::from_secs(42), parsed_duration);
181    }
182
183    #[test]
184    fn test_milliseconds() {
185        let parsed_duration = setup_map_try_parse(Some("13m")).unwrap().unwrap();
186        assert_eq!(Duration::from_millis(13), parsed_duration);
187    }
188
189    #[test]
190    fn test_microseconds() {
191        let parsed_duration = setup_map_try_parse(Some("2u")).unwrap().unwrap();
192        assert_eq!(Duration::from_micros(2), parsed_duration);
193    }
194
195    #[test]
196    fn test_nanoseconds() {
197        let parsed_duration = setup_map_try_parse(Some("82n")).unwrap().unwrap();
198        assert_eq!(Duration::from_nanos(82), parsed_duration);
199    }
200
201    #[test]
202    fn test_header_not_present() {
203        let parsed_duration = setup_map_try_parse(None).unwrap();
204        assert!(parsed_duration.is_none());
205    }
206
207    #[test]
208    #[should_panic(expected = "82f")]
209    fn test_invalid_unit() {
210        // "f" is not a valid TimeoutUnit
211        setup_map_try_parse(Some("82f")).unwrap().unwrap();
212    }
213
214    #[test]
215    #[should_panic(expected = "123456789H")]
216    fn test_too_many_digits() {
217        // gRPC spec states TimeoutValue will be at most 8 digits
218        setup_map_try_parse(Some("123456789H")).unwrap().unwrap();
219    }
220
221    #[test]
222    #[should_panic(expected = "oneH")]
223    fn test_invalid_digits() {
224        // gRPC spec states TimeoutValue will be at most 8 digits
225        setup_map_try_parse(Some("oneH")).unwrap().unwrap();
226    }
227
228    #[quickcheck]
229    fn fuzz(header_value: HeaderValueGen) -> bool {
230        let header_value = header_value.0;
231
232        // this just shouldn't panic
233        let _ = setup_map_try_parse(Some(&header_value));
234
235        true
236    }
237
238    /// Newtype to implement `Arbitrary` for generating `String`s that are valid `HeaderValue`s.
239    #[derive(Clone, Debug)]
240    struct HeaderValueGen(String);
241
242    impl Arbitrary for HeaderValueGen {
243        fn arbitrary(g: &mut Gen) -> Self {
244            let max = g.choose(&(1..70).collect::<Vec<_>>()).copied().unwrap();
245            Self(gen_string(g, 0, max))
246        }
247    }
248
249    // copied from https://github.com/hyperium/http/blob/master/tests/header_map_fuzz.rs
250    fn gen_string(g: &mut Gen, min: usize, max: usize) -> String {
251        let bytes: Vec<_> = (min..max)
252            .map(|_| {
253                // Chars to pick from
254                g.choose(b"ABCDEFGHIJKLMNOPQRSTUVabcdefghilpqrstuvwxyz----")
255                    .copied()
256                    .unwrap()
257            })
258            .collect();
259
260        String::from_utf8(bytes).unwrap()
261    }
262}