tonic/transport/channel/service/
user_agent.rs

1use http::{header::USER_AGENT, HeaderValue, Request};
2use std::task::{Context, Poll};
3use tower_service::Service;
4
5const TONIC_USER_AGENT: &str = concat!("tonic/", env!("CARGO_PKG_VERSION"));
6
7#[derive(Debug)]
8pub(crate) struct UserAgent<T> {
9    inner: T,
10    user_agent: HeaderValue,
11}
12
13impl<T> UserAgent<T> {
14    pub(crate) fn new(inner: T, user_agent: Option<HeaderValue>) -> Self {
15        let user_agent = user_agent
16            .map(|value| {
17                let mut buf = Vec::new();
18                buf.extend(value.as_bytes());
19                buf.push(b' ');
20                buf.extend(TONIC_USER_AGENT.as_bytes());
21                HeaderValue::from_bytes(&buf).expect("user-agent should be valid")
22            })
23            .unwrap_or_else(|| HeaderValue::from_static(TONIC_USER_AGENT));
24
25        Self { inner, user_agent }
26    }
27}
28
29impl<T, ReqBody> Service<Request<ReqBody>> for UserAgent<T>
30where
31    T: Service<Request<ReqBody>>,
32{
33    type Response = T::Response;
34    type Error = T::Error;
35    type Future = T::Future;
36
37    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
38        self.inner.poll_ready(cx)
39    }
40
41    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
42        if let Ok(Some(user_agent)) = req
43            .headers_mut()
44            .try_insert(USER_AGENT, self.user_agent.clone())
45        {
46            // The User-Agent header has already been set on the request. Let's
47            // append our user agent to the end.
48            let mut buf = Vec::new();
49            buf.extend(user_agent.as_bytes());
50            buf.push(b' ');
51            buf.extend(self.user_agent.as_bytes());
52            req.headers_mut().insert(
53                USER_AGENT,
54                HeaderValue::from_bytes(&buf).expect("user-agent should be valid"),
55            );
56        }
57
58        self.inner.call(req)
59    }
60}
61
62#[cfg(test)]
63mod tests {
64    use super::*;
65
66    struct Svc;
67
68    #[test]
69    fn sets_default_if_no_custom_user_agent() {
70        assert_eq!(
71            UserAgent::new(Svc, None).user_agent,
72            HeaderValue::from_static(TONIC_USER_AGENT)
73        )
74    }
75
76    #[test]
77    fn prepends_custom_user_agent_to_default() {
78        assert_eq!(
79            UserAgent::new(Svc, Some(HeaderValue::from_static("Greeter 1.1"))).user_agent,
80            HeaderValue::from_str(&format!("Greeter 1.1 {TONIC_USER_AGENT}")).unwrap()
81        )
82    }
83
84    struct TestSvc {
85        pub expected_user_agent: String,
86    }
87
88    impl Service<Request<()>> for TestSvc {
89        type Response = ();
90        type Error = ();
91        type Future = std::future::Ready<Result<(), ()>>;
92
93        fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
94            Poll::Ready(Ok(()))
95        }
96
97        fn call(&mut self, req: Request<()>) -> Self::Future {
98            let user_agent = req.headers().get(USER_AGENT).unwrap().to_str().unwrap();
99            assert_eq!(user_agent, self.expected_user_agent);
100            std::future::ready(Ok(()))
101        }
102    }
103
104    #[tokio::test]
105    async fn sets_default_user_agent_if_none_present() {
106        let expected_user_agent = TONIC_USER_AGENT.to_string();
107        let mut ua = UserAgent::new(
108            TestSvc {
109                expected_user_agent,
110            },
111            None,
112        );
113        let _ = ua.call(Request::default()).await;
114    }
115
116    #[tokio::test]
117    async fn sets_custom_user_agent_if_none_present() {
118        let expected_user_agent = format!("Greeter 1.1 {TONIC_USER_AGENT}");
119        let mut ua = UserAgent::new(
120            TestSvc {
121                expected_user_agent,
122            },
123            Some(HeaderValue::from_static("Greeter 1.1")),
124        );
125        let _ = ua.call(Request::default()).await;
126    }
127
128    #[tokio::test]
129    async fn appends_default_user_agent_to_request_user_agent() {
130        let mut req = Request::default();
131        req.headers_mut()
132            .insert(USER_AGENT, HeaderValue::from_static("request-ua/x.y"));
133
134        let expected_user_agent = format!("request-ua/x.y {TONIC_USER_AGENT}");
135        let mut ua = UserAgent::new(
136            TestSvc {
137                expected_user_agent,
138            },
139            None,
140        );
141        let _ = ua.call(req).await;
142    }
143
144    #[tokio::test]
145    async fn appends_custom_user_agent_to_request_user_agent() {
146        let mut req = Request::default();
147        req.headers_mut()
148            .insert(USER_AGENT, HeaderValue::from_static("request-ua/x.y"));
149
150        let expected_user_agent = format!("request-ua/x.y Greeter 1.1 {TONIC_USER_AGENT}");
151        let mut ua = UserAgent::new(
152            TestSvc {
153                expected_user_agent,
154            },
155            Some(HeaderValue::from_static("Greeter 1.1")),
156        );
157        let _ = ua.call(req).await;
158    }
159}