tonic/transport/server/
io_stream.rs

1#[cfg(feature = "_tls-any")]
2use std::future::Future;
3use std::{
4    io,
5    ops::ControlFlow,
6    pin::{pin, Pin},
7    task::{ready, Context, Poll},
8};
9
10use pin_project::pin_project;
11use tokio::io::{AsyncRead, AsyncWrite};
12#[cfg(feature = "_tls-any")]
13use tokio::task::JoinSet;
14use tokio_stream::Stream;
15#[cfg(feature = "_tls-any")]
16use tokio_stream::StreamExt as _;
17
18use super::service::ServerIo;
19#[cfg(feature = "_tls-any")]
20use super::service::TlsAcceptor;
21
22#[cfg(feature = "_tls-any")]
23struct State<IO>(TlsAcceptor, JoinSet<Result<ServerIo<IO>, crate::BoxError>>);
24
25#[pin_project]
26pub(crate) struct ServerIoStream<S, IO, IE>
27where
28    S: Stream<Item = Result<IO, IE>>,
29{
30    #[pin]
31    inner: S,
32    #[cfg(feature = "_tls-any")]
33    state: Option<State<IO>>,
34}
35
36impl<S, IO, IE> ServerIoStream<S, IO, IE>
37where
38    S: Stream<Item = Result<IO, IE>>,
39{
40    pub(crate) fn new(incoming: S, #[cfg(feature = "_tls-any")] tls: Option<TlsAcceptor>) -> Self {
41        Self {
42            inner: incoming,
43            #[cfg(feature = "_tls-any")]
44            state: tls.map(|tls| State(tls, JoinSet::new())),
45        }
46    }
47
48    fn poll_next_without_tls(
49        mut self: Pin<&mut Self>,
50        cx: &mut Context<'_>,
51    ) -> Poll<Option<Result<ServerIo<IO>, crate::BoxError>>>
52    where
53        IE: Into<crate::BoxError>,
54    {
55        match ready!(self.as_mut().project().inner.poll_next(cx)) {
56            Some(Ok(io)) => Poll::Ready(Some(Ok(ServerIo::new_io(io)))),
57            Some(Err(e)) => match handle_tcp_accept_error(e) {
58                ControlFlow::Continue(()) => {
59                    cx.waker().wake_by_ref();
60                    Poll::Pending
61                }
62                ControlFlow::Break(e) => Poll::Ready(Some(Err(e))),
63            },
64            None => Poll::Ready(None),
65        }
66    }
67}
68
69impl<S, IO, IE> Stream for ServerIoStream<S, IO, IE>
70where
71    S: Stream<Item = Result<IO, IE>>,
72    IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
73    IE: Into<crate::BoxError>,
74{
75    type Item = Result<ServerIo<IO>, crate::BoxError>;
76
77    #[cfg(not(feature = "_tls-any"))]
78    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
79        self.poll_next_without_tls(cx)
80    }
81
82    #[cfg(feature = "_tls-any")]
83    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
84        let mut projected = self.as_mut().project();
85
86        let Some(State(tls, tasks)) = projected.state else {
87            return self.poll_next_without_tls(cx);
88        };
89
90        let select_output = ready!(pin!(select(&mut projected.inner, tasks)).poll(cx));
91
92        match select_output {
93            SelectOutput::Incoming(stream) => {
94                let tls = tls.clone();
95                tasks.spawn(async move {
96                    let io = tls.accept(stream).await?;
97                    Ok(ServerIo::new_tls_io(io))
98                });
99                cx.waker().wake_by_ref();
100                Poll::Pending
101            }
102
103            SelectOutput::Io(io) => Poll::Ready(Some(Ok(io))),
104
105            SelectOutput::TcpErr(e) => match handle_tcp_accept_error(e) {
106                ControlFlow::Continue(()) => {
107                    cx.waker().wake_by_ref();
108                    Poll::Pending
109                }
110                ControlFlow::Break(e) => Poll::Ready(Some(Err(e))),
111            },
112
113            SelectOutput::TlsErr(e) => {
114                tracing::debug!(error = %e, "tls accept error");
115                cx.waker().wake_by_ref();
116                Poll::Pending
117            }
118
119            SelectOutput::Done => Poll::Ready(None),
120        }
121    }
122}
123
124fn handle_tcp_accept_error(e: impl Into<crate::BoxError>) -> ControlFlow<crate::BoxError> {
125    let e = e.into();
126    tracing::debug!(error = %e, "accept loop error");
127    if let Some(e) = e.downcast_ref::<io::Error>() {
128        if matches!(
129            e.kind(),
130            io::ErrorKind::ConnectionAborted
131                | io::ErrorKind::ConnectionReset
132                | io::ErrorKind::BrokenPipe
133                | io::ErrorKind::Interrupted
134                | io::ErrorKind::WouldBlock
135                | io::ErrorKind::TimedOut
136        ) {
137            return ControlFlow::Continue(());
138        }
139    }
140
141    ControlFlow::Break(e)
142}
143
144#[cfg(feature = "_tls-any")]
145async fn select<IO: 'static, IE>(
146    incoming: &mut (impl Stream<Item = Result<IO, IE>> + Unpin),
147    tasks: &mut JoinSet<Result<ServerIo<IO>, crate::BoxError>>,
148) -> SelectOutput<IO>
149where
150    IE: Into<crate::BoxError>,
151{
152    let incoming_stream_future = async {
153        match incoming.try_next().await {
154            Ok(Some(stream)) => SelectOutput::Incoming(stream),
155            Ok(None) => SelectOutput::Done,
156            Err(e) => SelectOutput::TcpErr(e.into()),
157        }
158    };
159
160    if tasks.is_empty() {
161        return incoming_stream_future.await;
162    }
163
164    tokio::select! {
165        stream = incoming_stream_future => stream,
166        accept = tasks.join_next() => {
167            match accept.expect("JoinSet should never end") {
168                Ok(Ok(io)) => SelectOutput::Io(io),
169                Ok(Err(e)) => SelectOutput::TlsErr(e),
170                Err(e) => SelectOutput::TlsErr(e.into()),
171            }
172        }
173    }
174}
175
176#[cfg(feature = "_tls-any")]
177enum SelectOutput<A> {
178    Incoming(A),
179    Io(ServerIo<A>),
180    TcpErr(crate::BoxError),
181    TlsErr(crate::BoxError),
182    Done,
183}