use crate::bindings::sockets::tcp::ErrorCode;
use crate::host::network;
use crate::network::SocketAddressFamily;
use crate::runtime::{with_ambient_tokio_runtime, AbortOnDropJoinHandle};
use crate::{
HostInputStream, HostOutputStream, InputStream, OutputStream, SocketResult, StreamError,
Subscribe,
};
use anyhow::{Error, Result};
use cap_net_ext::AddressFamily;
use futures::Future;
use io_lifetimes::views::SocketlikeView;
use io_lifetimes::AsSocketlike;
use rustix::io::Errno;
use rustix::net::sockopt;
use std::io;
use std::mem;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Poll;
const DEFAULT_BACKLOG: u32 = 128;
enum TcpState {
Default(tokio::net::TcpSocket),
BindStarted(tokio::net::TcpSocket),
Bound(tokio::net::TcpSocket),
ListenStarted(tokio::net::TcpSocket),
Listening {
listener: tokio::net::TcpListener,
pending_accept: Option<io::Result<tokio::net::TcpStream>>,
},
Connecting(Pin<Box<dyn Future<Output = io::Result<tokio::net::TcpStream>> + Send>>),
ConnectReady(io::Result<tokio::net::TcpStream>),
Connected(Arc<tokio::net::TcpStream>),
Closed,
}
impl std::fmt::Debug for TcpState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Default(_) => f.debug_tuple("Default").finish(),
Self::BindStarted(_) => f.debug_tuple("BindStarted").finish(),
Self::Bound(_) => f.debug_tuple("Bound").finish(),
Self::ListenStarted(_) => f.debug_tuple("ListenStarted").finish(),
Self::Listening { pending_accept, .. } => f
.debug_struct("Listening")
.field("pending_accept", pending_accept)
.finish(),
Self::Connecting(_) => f.debug_tuple("Connecting").finish(),
Self::ConnectReady(_) => f.debug_tuple("ConnectReady").finish(),
Self::Connected(_) => f.debug_tuple("Connected").finish(),
Self::Closed => write!(f, "Closed"),
}
}
}
pub struct TcpSocket {
tcp_state: TcpState,
listen_backlog_size: u32,
family: SocketAddressFamily,
#[cfg(target_os = "macos")]
receive_buffer_size: Option<usize>,
#[cfg(target_os = "macos")]
send_buffer_size: Option<usize>,
#[cfg(target_os = "macos")]
hop_limit: Option<u8>,
#[cfg(target_os = "macos")]
keep_alive_idle_time: Option<std::time::Duration>,
}
impl TcpSocket {
pub fn new(family: AddressFamily) -> io::Result<Self> {
with_ambient_tokio_runtime(|| {
let (socket, family) = match family {
AddressFamily::Ipv4 => {
let socket = tokio::net::TcpSocket::new_v4()?;
(socket, SocketAddressFamily::Ipv4)
}
AddressFamily::Ipv6 => {
let socket = tokio::net::TcpSocket::new_v6()?;
sockopt::set_ipv6_v6only(&socket, true)?;
(socket, SocketAddressFamily::Ipv6)
}
};
Self::from_state(TcpState::Default(socket), family)
})
}
fn from_state(state: TcpState, family: SocketAddressFamily) -> io::Result<Self> {
Ok(Self {
tcp_state: state,
listen_backlog_size: DEFAULT_BACKLOG,
family,
#[cfg(target_os = "macos")]
receive_buffer_size: None,
#[cfg(target_os = "macos")]
send_buffer_size: None,
#[cfg(target_os = "macos")]
hop_limit: None,
#[cfg(target_os = "macos")]
keep_alive_idle_time: None,
})
}
fn as_std_view(&self) -> SocketResult<SocketlikeView<'_, std::net::TcpStream>> {
use crate::bindings::sockets::network::ErrorCode;
match &self.tcp_state {
TcpState::Default(socket) | TcpState::Bound(socket) => {
Ok(socket.as_socketlike_view::<std::net::TcpStream>())
}
TcpState::Connected(stream) => Ok(stream.as_socketlike_view::<std::net::TcpStream>()),
TcpState::Listening { listener, .. } => {
Ok(listener.as_socketlike_view::<std::net::TcpStream>())
}
TcpState::BindStarted(..)
| TcpState::ListenStarted(..)
| TcpState::Connecting(..)
| TcpState::ConnectReady(..)
| TcpState::Closed => Err(ErrorCode::InvalidState.into()),
}
}
}
impl TcpSocket {
pub fn start_bind(&mut self, local_address: SocketAddr) -> io::Result<()> {
let tokio_socket = match &self.tcp_state {
TcpState::Default(socket) => socket,
TcpState::BindStarted(..) => return Err(Errno::ALREADY.into()),
_ => return Err(Errno::ISCONN.into()),
};
network::util::validate_unicast(&local_address)?;
network::util::validate_address_family(&local_address, &self.family)?;
{
let reuse_addr = local_address.port() > 0;
network::util::set_tcp_reuseaddr(&tokio_socket, reuse_addr)?;
tokio_socket.bind(local_address).map_err(|error| {
match Errno::from_io_error(&error) {
Some(Errno::AFNOSUPPORT) => io::Error::new(
io::ErrorKind::InvalidInput,
"The specified address is not a valid address for the address family of the specified socket",
),
#[cfg(windows)]
Some(Errno::NOBUFS) => io::Error::new(io::ErrorKind::AddrInUse, "no more free local ports"),
_ => error,
}
})?;
self.tcp_state = match std::mem::replace(&mut self.tcp_state, TcpState::Closed) {
TcpState::Default(socket) => TcpState::BindStarted(socket),
_ => unreachable!(),
};
Ok(())
}
}
pub fn finish_bind(&mut self) -> SocketResult<()> {
match std::mem::replace(&mut self.tcp_state, TcpState::Closed) {
TcpState::BindStarted(socket) => {
self.tcp_state = TcpState::Bound(socket);
Ok(())
}
current_state => {
self.tcp_state = current_state;
Err(ErrorCode::NotInProgress.into())
}
}
}
pub fn start_connect(&mut self, remote_address: SocketAddr) -> SocketResult<()> {
match self.tcp_state {
TcpState::Default(..) => {}
TcpState::Connecting(..) | TcpState::ConnectReady(..) => {
return Err(ErrorCode::ConcurrencyConflict.into())
}
_ => return Err(ErrorCode::InvalidState.into()),
};
network::util::validate_unicast(&remote_address)?;
network::util::validate_remote_address(&remote_address)?;
network::util::validate_address_family(&remote_address, &self.family)?;
let TcpState::Default(tokio_socket) =
std::mem::replace(&mut self.tcp_state, TcpState::Closed)
else {
unreachable!();
};
let future = tokio_socket.connect(remote_address);
self.tcp_state = TcpState::Connecting(Box::pin(future));
Ok(())
}
pub fn finish_connect(&mut self) -> SocketResult<(InputStream, OutputStream)> {
let previous_state = std::mem::replace(&mut self.tcp_state, TcpState::Closed);
let result = match previous_state {
TcpState::ConnectReady(result) => result,
TcpState::Connecting(mut future) => {
let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref());
match with_ambient_tokio_runtime(|| future.as_mut().poll(&mut cx)) {
Poll::Ready(result) => result,
Poll::Pending => {
self.tcp_state = TcpState::Connecting(future);
return Err(ErrorCode::WouldBlock.into());
}
}
}
previous_state => {
self.tcp_state = previous_state;
return Err(ErrorCode::NotInProgress.into());
}
};
match result {
Ok(stream) => {
let stream = Arc::new(stream);
self.tcp_state = TcpState::Connected(stream.clone());
let input: InputStream = Box::new(TcpReadStream::new(stream.clone()));
let output: OutputStream = Box::new(TcpWriteStream::new(stream));
Ok((input, output))
}
Err(err) => {
self.tcp_state = TcpState::Closed;
Err(err.into())
}
}
}
pub fn start_listen(&mut self) -> SocketResult<()> {
match std::mem::replace(&mut self.tcp_state, TcpState::Closed) {
TcpState::Bound(tokio_socket) => {
self.tcp_state = TcpState::ListenStarted(tokio_socket);
Ok(())
}
TcpState::ListenStarted(tokio_socket) => {
self.tcp_state = TcpState::ListenStarted(tokio_socket);
Err(ErrorCode::ConcurrencyConflict.into())
}
previous_state => {
self.tcp_state = previous_state;
Err(ErrorCode::InvalidState.into())
}
}
}
pub fn finish_listen(&mut self) -> SocketResult<()> {
let tokio_socket = match std::mem::replace(&mut self.tcp_state, TcpState::Closed) {
TcpState::ListenStarted(tokio_socket) => tokio_socket,
previous_state => {
self.tcp_state = previous_state;
return Err(ErrorCode::NotInProgress.into());
}
};
match with_ambient_tokio_runtime(|| tokio_socket.listen(self.listen_backlog_size)) {
Ok(listener) => {
self.tcp_state = TcpState::Listening {
listener,
pending_accept: None,
};
Ok(())
}
Err(err) => {
self.tcp_state = TcpState::Closed;
Err(match Errno::from_io_error(&err) {
#[cfg(windows)]
Some(Errno::MFILE) => Errno::NOBUFS.into(),
_ => err.into(),
})
}
}
}
pub fn accept(&mut self) -> SocketResult<(Self, InputStream, OutputStream)> {
let TcpState::Listening {
listener,
pending_accept,
} = &mut self.tcp_state
else {
return Err(ErrorCode::InvalidState.into());
};
let result = match pending_accept.take() {
Some(result) => result,
None => {
let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref());
match with_ambient_tokio_runtime(|| listener.poll_accept(&mut cx))
.map_ok(|(stream, _)| stream)
{
Poll::Ready(result) => result,
Poll::Pending => Err(Errno::WOULDBLOCK.into()),
}
}
};
let client = result.map_err(|err| match Errno::from_io_error(&err) {
#[cfg(windows)]
Some(Errno::INPROGRESS) => Errno::INTR.into(),
#[cfg(target_os = "linux")]
Some(
Errno::CONNRESET
| Errno::NETRESET
| Errno::HOSTUNREACH
| Errno::HOSTDOWN
| Errno::NETDOWN
| Errno::NETUNREACH
| Errno::PROTO
| Errno::NOPROTOOPT
| Errno::NONET
| Errno::OPNOTSUPP,
) => Errno::CONNABORTED.into(),
_ => err,
})?;
#[cfg(target_os = "macos")]
{
if let Some(size) = self.receive_buffer_size {
_ = network::util::set_socket_recv_buffer_size(&client, size); }
if let Some(size) = self.send_buffer_size {
_ = network::util::set_socket_send_buffer_size(&client, size); }
if let (SocketAddressFamily::Ipv6, Some(ttl)) = (self.family, self.hop_limit) {
_ = network::util::set_ipv6_unicast_hops(&client, ttl); }
if let Some(value) = self.keep_alive_idle_time {
_ = network::util::set_tcp_keepidle(&client, value); }
}
let client = Arc::new(client);
let input: InputStream = Box::new(TcpReadStream::new(client.clone()));
let output: OutputStream = Box::new(TcpWriteStream::new(client.clone()));
let tcp_socket = TcpSocket::from_state(TcpState::Connected(client), self.family)?;
Ok((tcp_socket, input, output))
}
pub fn local_address(&self) -> SocketResult<SocketAddr> {
let view = match self.tcp_state {
TcpState::Default(..) => return Err(ErrorCode::InvalidState.into()),
TcpState::BindStarted(..) => return Err(ErrorCode::ConcurrencyConflict.into()),
_ => self.as_std_view()?,
};
Ok(view.local_addr()?)
}
pub fn remote_address(&self) -> SocketResult<SocketAddr> {
let view = match self.tcp_state {
TcpState::Connected(..) => self.as_std_view()?,
TcpState::Connecting(..) | TcpState::ConnectReady(..) => {
return Err(ErrorCode::ConcurrencyConflict.into())
}
_ => return Err(ErrorCode::InvalidState.into()),
};
Ok(view.peer_addr()?)
}
pub fn is_listening(&self) -> bool {
matches!(self.tcp_state, TcpState::Listening { .. })
}
pub fn address_family(&self) -> SocketAddressFamily {
self.family
}
pub fn set_listen_backlog_size(&mut self, value: u32) -> SocketResult<()> {
const MIN_BACKLOG: u32 = 1;
const MAX_BACKLOG: u32 = i32::MAX as u32; if value == 0 {
return Err(ErrorCode::InvalidArgument.into());
}
let value = value.clamp(MIN_BACKLOG, MAX_BACKLOG);
match &self.tcp_state {
TcpState::Default(..) | TcpState::Bound(..) => {
}
TcpState::Listening { listener, .. } => {
rustix::net::listen(&listener, value.try_into().unwrap())
.map_err(|_| ErrorCode::NotSupported)?;
}
_ => return Err(ErrorCode::InvalidState.into()),
}
self.listen_backlog_size = value;
Ok(())
}
pub fn keep_alive_enabled(&self) -> SocketResult<bool> {
let view = &*self.as_std_view()?;
Ok(sockopt::get_socket_keepalive(view)?)
}
pub fn set_keep_alive_enabled(&self, value: bool) -> SocketResult<()> {
let view = &*self.as_std_view()?;
Ok(sockopt::set_socket_keepalive(view, value)?)
}
pub fn keep_alive_idle_time(&self) -> SocketResult<std::time::Duration> {
let view = &*self.as_std_view()?;
Ok(sockopt::get_tcp_keepidle(view)?)
}
pub fn set_keep_alive_idle_time(&mut self, duration: std::time::Duration) -> SocketResult<()> {
{
let view = &*self.as_std_view()?;
network::util::set_tcp_keepidle(view, duration)?;
}
#[cfg(target_os = "macos")]
{
self.keep_alive_idle_time = Some(duration);
}
Ok(())
}
pub fn keep_alive_interval(&self) -> SocketResult<std::time::Duration> {
let view = &*self.as_std_view()?;
Ok(sockopt::get_tcp_keepintvl(view)?)
}
pub fn set_keep_alive_interval(&self, duration: std::time::Duration) -> SocketResult<()> {
let view = &*self.as_std_view()?;
Ok(network::util::set_tcp_keepintvl(view, duration)?)
}
pub fn keep_alive_count(&self) -> SocketResult<u32> {
let view = &*self.as_std_view()?;
Ok(sockopt::get_tcp_keepcnt(view)?)
}
pub fn set_keep_alive_count(&self, value: u32) -> SocketResult<()> {
let view = &*self.as_std_view()?;
Ok(network::util::set_tcp_keepcnt(view, value)?)
}
pub fn hop_limit(&self) -> SocketResult<u8> {
let view = &*self.as_std_view()?;
let ttl = match self.family {
SocketAddressFamily::Ipv4 => network::util::get_ip_ttl(view)?,
SocketAddressFamily::Ipv6 => network::util::get_ipv6_unicast_hops(view)?,
};
Ok(ttl)
}
pub fn set_hop_limit(&mut self, value: u8) -> SocketResult<()> {
{
let view = &*self.as_std_view()?;
match self.family {
SocketAddressFamily::Ipv4 => network::util::set_ip_ttl(view, value)?,
SocketAddressFamily::Ipv6 => network::util::set_ipv6_unicast_hops(view, value)?,
}
}
#[cfg(target_os = "macos")]
{
self.hop_limit = Some(value);
}
Ok(())
}
pub fn receive_buffer_size(&self) -> SocketResult<usize> {
let view = &*self.as_std_view()?;
Ok(network::util::get_socket_recv_buffer_size(view)?)
}
pub fn set_receive_buffer_size(&mut self, value: usize) -> SocketResult<()> {
{
let view = &*self.as_std_view()?;
network::util::set_socket_recv_buffer_size(view, value)?;
}
#[cfg(target_os = "macos")]
{
self.receive_buffer_size = Some(value);
}
Ok(())
}
pub fn send_buffer_size(&self) -> SocketResult<usize> {
let view = &*self.as_std_view()?;
Ok(network::util::get_socket_send_buffer_size(view)?)
}
pub fn set_send_buffer_size(&mut self, value: usize) -> SocketResult<()> {
{
let view = &*self.as_std_view()?;
network::util::set_socket_send_buffer_size(view, value)?;
}
#[cfg(target_os = "macos")]
{
self.send_buffer_size = Some(value);
}
Ok(())
}
pub fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()> {
let stream = match &self.tcp_state {
TcpState::Connected(stream) => stream,
_ => {
return Err(io::Error::new(
io::ErrorKind::NotConnected,
"socket not connected",
))
}
};
stream
.as_socketlike_view::<std::net::TcpStream>()
.shutdown(how)?;
Ok(())
}
}
#[async_trait::async_trait]
impl Subscribe for TcpSocket {
async fn ready(&mut self) {
match &mut self.tcp_state {
TcpState::Default(..)
| TcpState::BindStarted(..)
| TcpState::Bound(..)
| TcpState::ListenStarted(..)
| TcpState::ConnectReady(..)
| TcpState::Closed
| TcpState::Connected(..) => {
}
TcpState::Connecting(future) => {
self.tcp_state = TcpState::ConnectReady(future.as_mut().await);
}
TcpState::Listening {
listener,
pending_accept,
} => match pending_accept {
Some(_) => {}
None => {
let result = futures::future::poll_fn(|cx| {
listener.poll_accept(cx).map_ok(|(stream, _)| stream)
})
.await;
*pending_accept = Some(result);
}
},
}
}
}
struct TcpReadStream {
stream: Arc<tokio::net::TcpStream>,
closed: bool,
}
impl TcpReadStream {
fn new(stream: Arc<tokio::net::TcpStream>) -> Self {
Self {
stream,
closed: false,
}
}
}
#[async_trait::async_trait]
impl HostInputStream for TcpReadStream {
fn read(&mut self, size: usize) -> Result<bytes::Bytes, StreamError> {
if self.closed {
return Err(StreamError::Closed);
}
if size == 0 {
return Ok(bytes::Bytes::new());
}
let mut buf = bytes::BytesMut::with_capacity(size);
let n = match self.stream.try_read_buf(&mut buf) {
Ok(0) => {
self.closed = true;
return Err(StreamError::Closed);
}
Ok(n) => n,
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => 0,
Err(e) => {
self.closed = true;
return Err(StreamError::LastOperationFailed(e.into()));
}
};
buf.truncate(n);
Ok(buf.freeze())
}
}
#[async_trait::async_trait]
impl Subscribe for TcpReadStream {
async fn ready(&mut self) {
if self.closed {
return;
}
self.stream.readable().await.unwrap();
}
}
const SOCKET_READY_SIZE: usize = 1024 * 1024 * 1024;
struct TcpWriteStream {
stream: Arc<tokio::net::TcpStream>,
last_write: LastWrite,
}
enum LastWrite {
Waiting(AbortOnDropJoinHandle<Result<()>>),
Error(Error),
Done,
Closed,
}
impl TcpWriteStream {
fn new(stream: Arc<tokio::net::TcpStream>) -> Self {
Self {
stream,
last_write: LastWrite::Done,
}
}
fn try_write_portable(stream: &tokio::net::TcpStream, buf: &[u8]) -> io::Result<usize> {
stream.try_write(buf).map_err(|error| {
match Errno::from_io_error(&error) {
#[cfg(windows)]
Some(Errno::SHUTDOWN) => io::Error::new(io::ErrorKind::BrokenPipe, error),
_ => error,
}
})
}
fn background_write(&mut self, mut bytes: bytes::Bytes) {
assert!(matches!(self.last_write, LastWrite::Done));
let stream = self.stream.clone();
self.last_write = LastWrite::Waiting(crate::runtime::spawn(async move {
while !bytes.is_empty() {
stream.writable().await?;
match Self::try_write_portable(&stream, &bytes) {
Ok(n) => {
let _ = bytes.split_to(n);
}
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => continue,
Err(e) => return Err(e.into()),
}
}
Ok(())
}));
}
}
#[async_trait::async_trait]
impl HostOutputStream for TcpWriteStream {
fn write(&mut self, mut bytes: bytes::Bytes) -> Result<(), StreamError> {
match self.last_write {
LastWrite::Done => {}
LastWrite::Waiting(_) | LastWrite::Error(_) | LastWrite::Closed => {
return Err(StreamError::Trap(anyhow::anyhow!(
"unpermitted: must call check_write first"
)));
}
}
while !bytes.is_empty() {
match Self::try_write_portable(&self.stream, &bytes) {
Ok(n) => {
let _ = bytes.split_to(n);
}
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
self.background_write(bytes);
return Ok(());
}
Err(e) if e.kind() == std::io::ErrorKind::BrokenPipe => {
self.last_write = LastWrite::Closed;
return Err(StreamError::Closed);
}
Err(e) => return Err(StreamError::LastOperationFailed(e.into())),
}
}
Ok(())
}
fn flush(&mut self) -> Result<(), StreamError> {
match self.last_write {
LastWrite::Done | LastWrite::Waiting(_) | LastWrite::Error(_) => Ok(()),
LastWrite::Closed => Err(StreamError::Closed),
}
}
fn check_write(&mut self) -> Result<usize, StreamError> {
match mem::replace(&mut self.last_write, LastWrite::Closed) {
LastWrite::Waiting(task) => {
self.last_write = LastWrite::Waiting(task);
return Ok(0);
}
LastWrite::Done => {
self.last_write = LastWrite::Done;
}
LastWrite::Closed => return Err(StreamError::Closed),
LastWrite::Error(e) => return Err(StreamError::LastOperationFailed(e.into())),
}
let writable = self.stream.writable();
futures::pin_mut!(writable);
if crate::runtime::poll_noop(writable).is_none() {
return Ok(0);
}
Ok(SOCKET_READY_SIZE)
}
async fn cancel(&mut self) {
match mem::replace(&mut self.last_write, LastWrite::Closed) {
LastWrite::Waiting(task) => _ = task.abort_wait().await,
_ => {}
}
}
}
#[async_trait::async_trait]
impl Subscribe for TcpWriteStream {
async fn ready(&mut self) {
if let LastWrite::Waiting(task) = &mut self.last_write {
self.last_write = match task.await {
Ok(()) => LastWrite::Done,
Err(e) => LastWrite::Error(e),
};
}
if let LastWrite::Done = self.last_write {
self.stream.writable().await.unwrap();
}
}
}