tonic/transport/service/
grpc_timeout.rs1use crate::{metadata::GRPC_TIMEOUT_HEADER, TimeoutExpired};
2use http::{HeaderMap, HeaderValue, Request};
3use pin_project::pin_project;
4use std::{
5 future::Future,
6 pin::Pin,
7 task::{ready, Context, Poll},
8 time::Duration,
9};
10use tokio::time::Sleep;
11use tower_service::Service;
12
13#[derive(Debug, Clone)]
14pub(crate) struct GrpcTimeout<S> {
15 inner: S,
16 server_timeout: Option<Duration>,
17}
18
19impl<S> GrpcTimeout<S> {
20 pub(crate) fn new(inner: S, server_timeout: Option<Duration>) -> Self {
21 Self {
22 inner,
23 server_timeout,
24 }
25 }
26}
27
28impl<S, ReqBody> Service<Request<ReqBody>> for GrpcTimeout<S>
29where
30 S: Service<Request<ReqBody>>,
31 S::Error: Into<crate::BoxError>,
32{
33 type Response = S::Response;
34 type Error = crate::BoxError;
35 type Future = ResponseFuture<S::Future>;
36
37 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
38 self.inner.poll_ready(cx).map_err(Into::into)
39 }
40
41 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
42 let client_timeout = try_parse_grpc_timeout(req.headers()).unwrap_or_else(|e| {
43 tracing::trace!("Error parsing `grpc-timeout` header {:?}", e);
44 None
45 });
46
47 let timeout_duration = match (client_timeout, self.server_timeout) {
49 (None, None) => None,
50 (Some(dur), None) => Some(dur),
51 (None, Some(dur)) => Some(dur),
52 (Some(header), Some(server)) => {
53 let shorter_duration = std::cmp::min(header, server);
54 Some(shorter_duration)
55 }
56 };
57
58 ResponseFuture {
59 inner: self.inner.call(req),
60 sleep: timeout_duration.map(tokio::time::sleep),
61 }
62 }
63}
64
65#[pin_project]
66pub(crate) struct ResponseFuture<F> {
67 #[pin]
68 inner: F,
69 #[pin]
70 sleep: Option<Sleep>,
71}
72
73impl<F, Res, E> Future for ResponseFuture<F>
74where
75 F: Future<Output = Result<Res, E>>,
76 E: Into<crate::BoxError>,
77{
78 type Output = Result<Res, crate::BoxError>;
79
80 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
81 let this = self.project();
82
83 if let ready @ Poll::Ready(_) = this.inner.poll(cx) {
84 return ready.map_err(Into::into);
85 }
86
87 if let Some(sleep) = this.sleep.as_pin_mut() {
88 ready!(sleep.poll(cx));
89 return Poll::Ready(Err(TimeoutExpired(()).into()));
90 }
91
92 Poll::Pending
93 }
94}
95
96const SECONDS_IN_HOUR: u64 = 60 * 60;
97const SECONDS_IN_MINUTE: u64 = 60;
98
99fn try_parse_grpc_timeout(
104 headers: &HeaderMap<HeaderValue>,
105) -> Result<Option<Duration>, &HeaderValue> {
106 let Some(val) = headers.get(GRPC_TIMEOUT_HEADER) else {
107 return Ok(None);
108 };
109
110 let (timeout_value, timeout_unit) = val
111 .to_str()
112 .map_err(|_| val)
113 .and_then(|s| if s.is_empty() { Err(val) } else { Ok(s) })?
114 .split_at(val.len() - 1);
120
121 if timeout_value.len() > 8 {
124 return Err(val);
125 }
126
127 let timeout_value: u64 = timeout_value.parse().map_err(|_| val)?;
128
129 let duration = match timeout_unit {
130 "H" => Duration::from_secs(timeout_value * SECONDS_IN_HOUR),
132 "M" => Duration::from_secs(timeout_value * SECONDS_IN_MINUTE),
134 "S" => Duration::from_secs(timeout_value),
136 "m" => Duration::from_millis(timeout_value),
138 "u" => Duration::from_micros(timeout_value),
140 "n" => Duration::from_nanos(timeout_value),
142 _ => return Err(val),
143 };
144
145 Ok(Some(duration))
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use quickcheck::{Arbitrary, Gen};
152 use quickcheck_macros::quickcheck;
153
154 fn setup_map_try_parse(val: Option<&str>) -> Result<Option<Duration>, HeaderValue> {
156 let mut hm = HeaderMap::new();
157 if let Some(v) = val {
158 let hv = HeaderValue::from_str(v).unwrap();
159 hm.insert(GRPC_TIMEOUT_HEADER, hv);
160 };
161
162 try_parse_grpc_timeout(&hm).map_err(|e| e.clone())
163 }
164
165 #[test]
166 fn test_hours() {
167 let parsed_duration = setup_map_try_parse(Some("3H")).unwrap().unwrap();
168 assert_eq!(Duration::from_secs(3 * 60 * 60), parsed_duration);
169 }
170
171 #[test]
172 fn test_minutes() {
173 let parsed_duration = setup_map_try_parse(Some("1M")).unwrap().unwrap();
174 assert_eq!(Duration::from_secs(60), parsed_duration);
175 }
176
177 #[test]
178 fn test_seconds() {
179 let parsed_duration = setup_map_try_parse(Some("42S")).unwrap().unwrap();
180 assert_eq!(Duration::from_secs(42), parsed_duration);
181 }
182
183 #[test]
184 fn test_milliseconds() {
185 let parsed_duration = setup_map_try_parse(Some("13m")).unwrap().unwrap();
186 assert_eq!(Duration::from_millis(13), parsed_duration);
187 }
188
189 #[test]
190 fn test_microseconds() {
191 let parsed_duration = setup_map_try_parse(Some("2u")).unwrap().unwrap();
192 assert_eq!(Duration::from_micros(2), parsed_duration);
193 }
194
195 #[test]
196 fn test_nanoseconds() {
197 let parsed_duration = setup_map_try_parse(Some("82n")).unwrap().unwrap();
198 assert_eq!(Duration::from_nanos(82), parsed_duration);
199 }
200
201 #[test]
202 fn test_header_not_present() {
203 let parsed_duration = setup_map_try_parse(None).unwrap();
204 assert!(parsed_duration.is_none());
205 }
206
207 #[test]
208 #[should_panic(expected = "82f")]
209 fn test_invalid_unit() {
210 setup_map_try_parse(Some("82f")).unwrap().unwrap();
212 }
213
214 #[test]
215 #[should_panic(expected = "123456789H")]
216 fn test_too_many_digits() {
217 setup_map_try_parse(Some("123456789H")).unwrap().unwrap();
219 }
220
221 #[test]
222 #[should_panic(expected = "oneH")]
223 fn test_invalid_digits() {
224 setup_map_try_parse(Some("oneH")).unwrap().unwrap();
226 }
227
228 #[quickcheck]
229 fn fuzz(header_value: HeaderValueGen) -> bool {
230 let header_value = header_value.0;
231
232 let _ = setup_map_try_parse(Some(&header_value));
234
235 true
236 }
237
238 #[derive(Clone, Debug)]
240 struct HeaderValueGen(String);
241
242 impl Arbitrary for HeaderValueGen {
243 fn arbitrary(g: &mut Gen) -> Self {
244 let max = g.choose(&(1..70).collect::<Vec<_>>()).copied().unwrap();
245 Self(gen_string(g, 0, max))
246 }
247 }
248
249 fn gen_string(g: &mut Gen, min: usize, max: usize) -> String {
251 let bytes: Vec<_> = (min..max)
252 .map(|_| {
253 g.choose(b"ABCDEFGHIJKLMNOPQRSTUVabcdefghilpqrstuvwxyz----")
255 .copied()
256 .unwrap()
257 })
258 .collect();
259
260 String::from_utf8(bytes).unwrap()
261 }
262}