tonic/transport/server/
io_stream.rs1#[cfg(feature = "_tls-any")]
2use std::future::Future;
3use std::{
4 io,
5 ops::ControlFlow,
6 pin::{pin, Pin},
7 task::{ready, Context, Poll},
8};
9
10use pin_project::pin_project;
11use tokio::io::{AsyncRead, AsyncWrite};
12#[cfg(feature = "_tls-any")]
13use tokio::task::JoinSet;
14use tokio_stream::Stream;
15#[cfg(feature = "_tls-any")]
16use tokio_stream::StreamExt as _;
17
18use super::service::ServerIo;
19#[cfg(feature = "_tls-any")]
20use super::service::TlsAcceptor;
21
22#[cfg(feature = "_tls-any")]
23struct State<IO>(TlsAcceptor, JoinSet<Result<ServerIo<IO>, crate::BoxError>>);
24
25#[pin_project]
26pub(crate) struct ServerIoStream<S, IO, IE>
27where
28 S: Stream<Item = Result<IO, IE>>,
29{
30 #[pin]
31 inner: S,
32 #[cfg(feature = "_tls-any")]
33 state: Option<State<IO>>,
34}
35
36impl<S, IO, IE> ServerIoStream<S, IO, IE>
37where
38 S: Stream<Item = Result<IO, IE>>,
39{
40 pub(crate) fn new(incoming: S, #[cfg(feature = "_tls-any")] tls: Option<TlsAcceptor>) -> Self {
41 Self {
42 inner: incoming,
43 #[cfg(feature = "_tls-any")]
44 state: tls.map(|tls| State(tls, JoinSet::new())),
45 }
46 }
47
48 fn poll_next_without_tls(
49 mut self: Pin<&mut Self>,
50 cx: &mut Context<'_>,
51 ) -> Poll<Option<Result<ServerIo<IO>, crate::BoxError>>>
52 where
53 IE: Into<crate::BoxError>,
54 {
55 match ready!(self.as_mut().project().inner.poll_next(cx)) {
56 Some(Ok(io)) => Poll::Ready(Some(Ok(ServerIo::new_io(io)))),
57 Some(Err(e)) => match handle_tcp_accept_error(e) {
58 ControlFlow::Continue(()) => {
59 cx.waker().wake_by_ref();
60 Poll::Pending
61 }
62 ControlFlow::Break(e) => Poll::Ready(Some(Err(e))),
63 },
64 None => Poll::Ready(None),
65 }
66 }
67}
68
69impl<S, IO, IE> Stream for ServerIoStream<S, IO, IE>
70where
71 S: Stream<Item = Result<IO, IE>>,
72 IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
73 IE: Into<crate::BoxError>,
74{
75 type Item = Result<ServerIo<IO>, crate::BoxError>;
76
77 #[cfg(not(feature = "_tls-any"))]
78 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
79 self.poll_next_without_tls(cx)
80 }
81
82 #[cfg(feature = "_tls-any")]
83 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
84 let mut projected = self.as_mut().project();
85
86 let Some(State(tls, tasks)) = projected.state else {
87 return self.poll_next_without_tls(cx);
88 };
89
90 let select_output = ready!(pin!(select(&mut projected.inner, tasks)).poll(cx));
91
92 match select_output {
93 SelectOutput::Incoming(stream) => {
94 let tls = tls.clone();
95 tasks.spawn(async move {
96 let io = tls.accept(stream).await?;
97 Ok(ServerIo::new_tls_io(io))
98 });
99 cx.waker().wake_by_ref();
100 Poll::Pending
101 }
102
103 SelectOutput::Io(io) => Poll::Ready(Some(Ok(io))),
104
105 SelectOutput::TcpErr(e) => match handle_tcp_accept_error(e) {
106 ControlFlow::Continue(()) => {
107 cx.waker().wake_by_ref();
108 Poll::Pending
109 }
110 ControlFlow::Break(e) => Poll::Ready(Some(Err(e))),
111 },
112
113 SelectOutput::TlsErr(e) => {
114 tracing::debug!(error = %e, "tls accept error");
115 cx.waker().wake_by_ref();
116 Poll::Pending
117 }
118
119 SelectOutput::Done => Poll::Ready(None),
120 }
121 }
122}
123
124fn handle_tcp_accept_error(e: impl Into<crate::BoxError>) -> ControlFlow<crate::BoxError> {
125 let e = e.into();
126 tracing::debug!(error = %e, "accept loop error");
127 if let Some(e) = e.downcast_ref::<io::Error>() {
128 if matches!(
129 e.kind(),
130 io::ErrorKind::ConnectionAborted
131 | io::ErrorKind::ConnectionReset
132 | io::ErrorKind::BrokenPipe
133 | io::ErrorKind::Interrupted
134 | io::ErrorKind::WouldBlock
135 | io::ErrorKind::TimedOut
136 ) {
137 return ControlFlow::Continue(());
138 }
139 }
140
141 ControlFlow::Break(e)
142}
143
144#[cfg(feature = "_tls-any")]
145async fn select<IO: 'static, IE>(
146 incoming: &mut (impl Stream<Item = Result<IO, IE>> + Unpin),
147 tasks: &mut JoinSet<Result<ServerIo<IO>, crate::BoxError>>,
148) -> SelectOutput<IO>
149where
150 IE: Into<crate::BoxError>,
151{
152 let incoming_stream_future = async {
153 match incoming.try_next().await {
154 Ok(Some(stream)) => SelectOutput::Incoming(stream),
155 Ok(None) => SelectOutput::Done,
156 Err(e) => SelectOutput::TcpErr(e.into()),
157 }
158 };
159
160 if tasks.is_empty() {
161 return incoming_stream_future.await;
162 }
163
164 tokio::select! {
165 stream = incoming_stream_future => stream,
166 accept = tasks.join_next() => {
167 match accept.expect("JoinSet should never end") {
168 Ok(Ok(io)) => SelectOutput::Io(io),
169 Ok(Err(e)) => SelectOutput::TlsErr(e),
170 Err(e) => SelectOutput::TlsErr(e.into()),
171 }
172 }
173 }
174}
175
176#[cfg(feature = "_tls-any")]
177enum SelectOutput<A> {
178 Incoming(A),
179 Io(ServerIo<A>),
180 TcpErr(crate::BoxError),
181 TlsErr(crate::BoxError),
182 Done,
183}