use super::{Connected, Server};
use crate::transport::service::ServerIo;
use hyper::server::{
accept::Accept,
conn::{AddrIncoming, AddrStream},
};
use std::{
net::SocketAddr,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use tokio::{
io::{AsyncRead, AsyncWrite},
net::TcpListener,
};
use tokio_stream::{Stream, StreamExt};
#[cfg(not(feature = "tls"))]
pub(crate) fn tcp_incoming<IO, IE, L>(
incoming: impl Stream<Item = Result<IO, IE>>,
_server: Server<L>,
) -> impl Stream<Item = Result<ServerIo<IO>, crate::Error>>
where
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
IE: Into<crate::Error>,
{
async_stream::try_stream! {
tokio::pin!(incoming);
while let Some(item) = incoming.next().await {
yield item.map(ServerIo::new_io)?
}
}
}
#[cfg(feature = "tls")]
pub(crate) fn tcp_incoming<IO, IE, L>(
incoming: impl Stream<Item = Result<IO, IE>>,
server: Server<L>,
) -> impl Stream<Item = Result<ServerIo<IO>, crate::Error>>
where
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
IE: Into<crate::Error>,
{
async_stream::try_stream! {
tokio::pin!(incoming);
let mut tasks = tokio::task::JoinSet::new();
loop {
match select(&mut incoming, &mut tasks).await {
SelectOutput::Incoming(stream) => {
if let Some(tls) = &server.tls {
let tls = tls.clone();
tasks.spawn(async move {
let io = tls.accept(stream).await?;
Ok(ServerIo::new_tls_io(io))
});
} else {
yield ServerIo::new_io(stream);
}
}
SelectOutput::Io(io) => {
yield io;
}
SelectOutput::Err(e) => {
tracing::debug!(message = "Accept loop error.", error = %e);
}
SelectOutput::Done => {
break;
}
}
}
}
}
#[cfg(feature = "tls")]
async fn select<IO: 'static, IE>(
incoming: &mut (impl Stream<Item = Result<IO, IE>> + Unpin),
tasks: &mut tokio::task::JoinSet<Result<ServerIo<IO>, crate::Error>>,
) -> SelectOutput<IO>
where
IE: Into<crate::Error>,
{
if tasks.is_empty() {
return match incoming.try_next().await {
Ok(Some(stream)) => SelectOutput::Incoming(stream),
Ok(None) => SelectOutput::Done,
Err(e) => SelectOutput::Err(e.into()),
};
}
tokio::select! {
stream = incoming.try_next() => {
match stream {
Ok(Some(stream)) => SelectOutput::Incoming(stream),
Ok(None) => SelectOutput::Done,
Err(e) => SelectOutput::Err(e.into()),
}
}
accept = tasks.join_next() => {
match accept.expect("JoinSet should never end") {
Ok(Ok(io)) => SelectOutput::Io(io),
Ok(Err(e)) => SelectOutput::Err(e),
Err(e) => SelectOutput::Err(e.into()),
}
}
}
}
#[cfg(feature = "tls")]
enum SelectOutput<A> {
Incoming(A),
Io(ServerIo<A>),
Err(crate::Error),
Done,
}
#[derive(Debug)]
pub struct TcpIncoming {
inner: AddrIncoming,
}
impl TcpIncoming {
pub fn new(
addr: SocketAddr,
nodelay: bool,
keepalive: Option<Duration>,
) -> Result<Self, crate::Error> {
let mut inner = AddrIncoming::bind(&addr)?;
inner.set_nodelay(nodelay);
inner.set_keepalive(keepalive);
Ok(TcpIncoming { inner })
}
pub fn from_listener(
listener: TcpListener,
nodelay: bool,
keepalive: Option<Duration>,
) -> Result<Self, crate::Error> {
let mut inner = AddrIncoming::from_listener(listener)?;
inner.set_nodelay(nodelay);
inner.set_keepalive(keepalive);
Ok(TcpIncoming { inner })
}
}
impl Stream for TcpIncoming {
type Item = Result<AddrStream, std::io::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.inner).poll_accept(cx)
}
}
#[cfg(test)]
mod tests {
use crate::transport::server::TcpIncoming;
#[tokio::test]
async fn one_tcpincoming_at_a_time() {
let addr = "127.0.0.1:1322".parse().unwrap();
{
let _t1 = TcpIncoming::new(addr, true, None).unwrap();
let _t2 = TcpIncoming::new(addr, true, None).unwrap_err();
}
let _t3 = TcpIncoming::new(addr, true, None).unwrap();
}
}