tonic/transport/channel/service/
user_agent.rs1use http::{header::USER_AGENT, HeaderValue, Request};
2use std::task::{Context, Poll};
3use tower_service::Service;
4
5const TONIC_USER_AGENT: &str = concat!("tonic/", env!("CARGO_PKG_VERSION"));
6
7#[derive(Debug)]
8pub(crate) struct UserAgent<T> {
9 inner: T,
10 user_agent: HeaderValue,
11}
12
13impl<T> UserAgent<T> {
14 pub(crate) fn new(inner: T, user_agent: Option<HeaderValue>) -> Self {
15 let user_agent = user_agent
16 .map(|value| {
17 let mut buf = Vec::new();
18 buf.extend(value.as_bytes());
19 buf.push(b' ');
20 buf.extend(TONIC_USER_AGENT.as_bytes());
21 HeaderValue::from_bytes(&buf).expect("user-agent should be valid")
22 })
23 .unwrap_or_else(|| HeaderValue::from_static(TONIC_USER_AGENT));
24
25 Self { inner, user_agent }
26 }
27}
28
29impl<T, ReqBody> Service<Request<ReqBody>> for UserAgent<T>
30where
31 T: Service<Request<ReqBody>>,
32{
33 type Response = T::Response;
34 type Error = T::Error;
35 type Future = T::Future;
36
37 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
38 self.inner.poll_ready(cx)
39 }
40
41 fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
42 if let Ok(Some(user_agent)) = req
43 .headers_mut()
44 .try_insert(USER_AGENT, self.user_agent.clone())
45 {
46 let mut buf = Vec::new();
49 buf.extend(user_agent.as_bytes());
50 buf.push(b' ');
51 buf.extend(self.user_agent.as_bytes());
52 req.headers_mut().insert(
53 USER_AGENT,
54 HeaderValue::from_bytes(&buf).expect("user-agent should be valid"),
55 );
56 }
57
58 self.inner.call(req)
59 }
60}
61
62#[cfg(test)]
63mod tests {
64 use super::*;
65
66 struct Svc;
67
68 #[test]
69 fn sets_default_if_no_custom_user_agent() {
70 assert_eq!(
71 UserAgent::new(Svc, None).user_agent,
72 HeaderValue::from_static(TONIC_USER_AGENT)
73 )
74 }
75
76 #[test]
77 fn prepends_custom_user_agent_to_default() {
78 assert_eq!(
79 UserAgent::new(Svc, Some(HeaderValue::from_static("Greeter 1.1"))).user_agent,
80 HeaderValue::from_str(&format!("Greeter 1.1 {TONIC_USER_AGENT}")).unwrap()
81 )
82 }
83
84 struct TestSvc {
85 pub expected_user_agent: String,
86 }
87
88 impl Service<Request<()>> for TestSvc {
89 type Response = ();
90 type Error = ();
91 type Future = std::future::Ready<Result<(), ()>>;
92
93 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
94 Poll::Ready(Ok(()))
95 }
96
97 fn call(&mut self, req: Request<()>) -> Self::Future {
98 let user_agent = req.headers().get(USER_AGENT).unwrap().to_str().unwrap();
99 assert_eq!(user_agent, self.expected_user_agent);
100 std::future::ready(Ok(()))
101 }
102 }
103
104 #[tokio::test]
105 async fn sets_default_user_agent_if_none_present() {
106 let expected_user_agent = TONIC_USER_AGENT.to_string();
107 let mut ua = UserAgent::new(
108 TestSvc {
109 expected_user_agent,
110 },
111 None,
112 );
113 let _ = ua.call(Request::default()).await;
114 }
115
116 #[tokio::test]
117 async fn sets_custom_user_agent_if_none_present() {
118 let expected_user_agent = format!("Greeter 1.1 {TONIC_USER_AGENT}");
119 let mut ua = UserAgent::new(
120 TestSvc {
121 expected_user_agent,
122 },
123 Some(HeaderValue::from_static("Greeter 1.1")),
124 );
125 let _ = ua.call(Request::default()).await;
126 }
127
128 #[tokio::test]
129 async fn appends_default_user_agent_to_request_user_agent() {
130 let mut req = Request::default();
131 req.headers_mut()
132 .insert(USER_AGENT, HeaderValue::from_static("request-ua/x.y"));
133
134 let expected_user_agent = format!("request-ua/x.y {TONIC_USER_AGENT}");
135 let mut ua = UserAgent::new(
136 TestSvc {
137 expected_user_agent,
138 },
139 None,
140 );
141 let _ = ua.call(req).await;
142 }
143
144 #[tokio::test]
145 async fn appends_custom_user_agent_to_request_user_agent() {
146 let mut req = Request::default();
147 req.headers_mut()
148 .insert(USER_AGENT, HeaderValue::from_static("request-ua/x.y"));
149
150 let expected_user_agent = format!("request-ua/x.y Greeter 1.1 {TONIC_USER_AGENT}");
151 let mut ua = UserAgent::new(
152 TestSvc {
153 expected_user_agent,
154 },
155 Some(HeaderValue::from_static("Greeter 1.1")),
156 );
157 let _ = ua.call(req).await;
158 }
159}