wasmtime_wasi/p2/
tcp.rs

1use crate::p2::{
2    DynInputStream, DynOutputStream, InputStream, OutputStream, Pollable, SocketError,
3    SocketResult, StreamError,
4};
5use crate::runtime::AbortOnDropJoinHandle;
6use crate::sockets::TcpSocket;
7use anyhow::Result;
8use io_lifetimes::AsSocketlike;
9use rustix::io::Errno;
10use std::io;
11use std::mem;
12use std::net::Shutdown;
13use std::sync::Arc;
14use tokio::sync::Mutex;
15
16impl TcpSocket {
17    pub(crate) fn p2_streams(&mut self) -> SocketResult<(DynInputStream, DynOutputStream)> {
18        let client = self.tcp_stream_arc()?;
19        let reader = Arc::new(Mutex::new(TcpReader::new(client.clone())));
20        let writer = Arc::new(Mutex::new(TcpWriter::new(client.clone())));
21        self.set_p2_streaming_state(P2TcpStreamingState {
22            stream: client.clone(),
23            reader: reader.clone(),
24            writer: writer.clone(),
25        })?;
26        let input: DynInputStream = Box::new(TcpReadStream(reader));
27        let output: DynOutputStream = Box::new(TcpWriteStream(writer));
28        Ok((input, output))
29    }
30}
31
32pub(crate) struct P2TcpStreamingState {
33    pub(crate) stream: Arc<tokio::net::TcpStream>,
34    reader: Arc<Mutex<TcpReader>>,
35    writer: Arc<Mutex<TcpWriter>>,
36}
37
38impl P2TcpStreamingState {
39    pub(crate) fn shutdown(&self, how: Shutdown) -> SocketResult<()> {
40        if let Shutdown::Both | Shutdown::Read = how {
41            try_lock_for_socket(&self.reader)?.shutdown();
42        }
43
44        if let Shutdown::Both | Shutdown::Write = how {
45            try_lock_for_socket(&self.writer)?.shutdown();
46        }
47
48        Ok(())
49    }
50}
51
52struct TcpReader {
53    stream: Arc<tokio::net::TcpStream>,
54    closed: bool,
55}
56
57impl TcpReader {
58    fn new(stream: Arc<tokio::net::TcpStream>) -> Self {
59        Self {
60            stream,
61            closed: false,
62        }
63    }
64    fn read(&mut self, size: usize) -> Result<bytes::Bytes, StreamError> {
65        if self.closed {
66            return Err(StreamError::Closed);
67        }
68        if size == 0 {
69            return Ok(bytes::Bytes::new());
70        }
71
72        let mut buf = bytes::BytesMut::with_capacity(size);
73        let n = match self.stream.try_read_buf(&mut buf) {
74            // A 0-byte read indicates that the stream has closed.
75            Ok(0) => {
76                self.closed = true;
77                return Err(StreamError::Closed);
78            }
79            Ok(n) => n,
80
81            // Failing with `EWOULDBLOCK` is how we differentiate between a closed channel and no
82            // data to read right now.
83            Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => 0,
84
85            Err(e) => {
86                self.closed = true;
87                return Err(StreamError::LastOperationFailed(e.into()));
88            }
89        };
90
91        buf.truncate(n);
92        Ok(buf.freeze())
93    }
94
95    fn shutdown(&mut self) {
96        native_shutdown(&self.stream, Shutdown::Read);
97        self.closed = true;
98    }
99
100    async fn ready(&mut self) {
101        if self.closed {
102            return;
103        }
104
105        self.stream.readable().await.unwrap();
106    }
107}
108
109struct TcpReadStream(Arc<Mutex<TcpReader>>);
110
111#[async_trait::async_trait]
112impl InputStream for TcpReadStream {
113    fn read(&mut self, size: usize) -> Result<bytes::Bytes, StreamError> {
114        try_lock_for_stream(&self.0)?.read(size)
115    }
116}
117
118#[async_trait::async_trait]
119impl Pollable for TcpReadStream {
120    async fn ready(&mut self) {
121        self.0.lock().await.ready().await
122    }
123}
124
125const SOCKET_READY_SIZE: usize = 1024 * 1024 * 1024;
126
127struct TcpWriter {
128    stream: Arc<tokio::net::TcpStream>,
129    state: WriteState,
130}
131
132enum WriteState {
133    Ready,
134    Writing(AbortOnDropJoinHandle<io::Result<()>>),
135    Closing(AbortOnDropJoinHandle<io::Result<()>>),
136    Closed,
137    Error(io::Error),
138}
139
140impl TcpWriter {
141    fn new(stream: Arc<tokio::net::TcpStream>) -> Self {
142        Self {
143            stream,
144            state: WriteState::Ready,
145        }
146    }
147
148    fn try_write_portable(stream: &tokio::net::TcpStream, buf: &[u8]) -> io::Result<usize> {
149        stream.try_write(buf).map_err(|error| {
150            match Errno::from_io_error(&error) {
151                // Windows returns `WSAESHUTDOWN` when writing to a shut down socket.
152                // We normalize this to EPIPE, because that is what the other platforms return.
153                // See: https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-send#:~:text=WSAESHUTDOWN
154                #[cfg(windows)]
155                Some(Errno::SHUTDOWN) => io::Error::new(io::ErrorKind::BrokenPipe, error),
156
157                _ => error,
158            }
159        })
160    }
161
162    /// Write `bytes` in a background task, remembering the task handle for use in a future call to
163    /// `write_ready`
164    fn background_write(&mut self, mut bytes: bytes::Bytes) {
165        assert!(matches!(self.state, WriteState::Ready));
166
167        let stream = self.stream.clone();
168        self.state = WriteState::Writing(crate::runtime::spawn(async move {
169            // Note: we are not using the AsyncWrite impl here, and instead using the TcpStream
170            // primitive try_write, which goes directly to attempt a write with mio. This has
171            // two advantages: 1. this operation takes a &TcpStream instead of a &mut TcpStream
172            // required to AsyncWrite, and 2. it eliminates any buffering in tokio we may need
173            // to flush.
174            while !bytes.is_empty() {
175                stream.writable().await?;
176                match Self::try_write_portable(&stream, &bytes) {
177                    Ok(n) => {
178                        let _ = bytes.split_to(n);
179                    }
180                    Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => continue,
181                    Err(e) => return Err(e),
182                }
183            }
184
185            Ok(())
186        }));
187    }
188
189    fn write(&mut self, mut bytes: bytes::Bytes) -> Result<(), StreamError> {
190        match self.state {
191            WriteState::Ready => {}
192            WriteState::Closed => return Err(StreamError::Closed),
193            WriteState::Writing(_) | WriteState::Closing(_) | WriteState::Error(_) => {
194                return Err(StreamError::Trap(anyhow::anyhow!(
195                    "unpermitted: must call check_write first"
196                )));
197            }
198        }
199        while !bytes.is_empty() {
200            match Self::try_write_portable(&self.stream, &bytes) {
201                Ok(n) => {
202                    let _ = bytes.split_to(n);
203                }
204
205                Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
206                    // As `try_write` indicated that it would have blocked, we'll perform the write
207                    // in the background to allow us to return immediately.
208                    self.background_write(bytes);
209
210                    return Ok(());
211                }
212
213                Err(e) if e.kind() == std::io::ErrorKind::BrokenPipe => {
214                    self.state = WriteState::Closed;
215                    return Err(StreamError::Closed);
216                }
217
218                Err(e) => return Err(StreamError::LastOperationFailed(e.into())),
219            }
220        }
221
222        Ok(())
223    }
224
225    fn flush(&mut self) -> Result<(), StreamError> {
226        // `flush` is a no-op here, as we're not managing any internal buffer. Additionally,
227        // `write_ready` will join the background write task if it's active, so following `flush`
228        // with `write_ready` will have the desired effect.
229        match self.state {
230            WriteState::Ready
231            | WriteState::Writing(_)
232            | WriteState::Closing(_)
233            | WriteState::Error(_) => Ok(()),
234            WriteState::Closed => Err(StreamError::Closed),
235        }
236    }
237
238    fn check_write(&mut self) -> Result<usize, StreamError> {
239        match mem::replace(&mut self.state, WriteState::Closed) {
240            WriteState::Writing(task) => {
241                self.state = WriteState::Writing(task);
242                return Ok(0);
243            }
244            WriteState::Closing(task) => {
245                self.state = WriteState::Closing(task);
246                return Ok(0);
247            }
248            WriteState::Ready => {
249                self.state = WriteState::Ready;
250            }
251            WriteState::Closed => return Err(StreamError::Closed),
252            WriteState::Error(e) => return Err(StreamError::LastOperationFailed(e.into())),
253        }
254
255        let writable = self.stream.writable();
256        futures::pin_mut!(writable);
257        if crate::runtime::poll_noop(writable).is_none() {
258            return Ok(0);
259        }
260        Ok(SOCKET_READY_SIZE)
261    }
262
263    fn shutdown(&mut self) {
264        self.state = match mem::replace(&mut self.state, WriteState::Closed) {
265            // No write in progress, immediately shut down:
266            WriteState::Ready => {
267                native_shutdown(&self.stream, Shutdown::Write);
268                WriteState::Closed
269            }
270
271            // Schedule the shutdown after the current write has finished:
272            WriteState::Writing(write) => {
273                let stream = self.stream.clone();
274                WriteState::Closing(crate::runtime::spawn(async move {
275                    let result = write.await;
276                    native_shutdown(&stream, Shutdown::Write);
277                    result
278                }))
279            }
280
281            s => s,
282        };
283    }
284
285    async fn cancel(&mut self) {
286        match mem::replace(&mut self.state, WriteState::Closed) {
287            WriteState::Writing(task) | WriteState::Closing(task) => _ = task.cancel().await,
288            _ => {}
289        }
290    }
291
292    async fn ready(&mut self) {
293        match &mut self.state {
294            WriteState::Writing(task) => {
295                self.state = match task.await {
296                    Ok(()) => WriteState::Ready,
297                    Err(e) => WriteState::Error(e),
298                }
299            }
300            WriteState::Closing(task) => {
301                self.state = match task.await {
302                    Ok(()) => WriteState::Closed,
303                    Err(e) => WriteState::Error(e),
304                }
305            }
306            _ => {}
307        }
308
309        if let WriteState::Ready = self.state {
310            self.stream.writable().await.unwrap();
311        }
312    }
313}
314
315struct TcpWriteStream(Arc<Mutex<TcpWriter>>);
316
317#[async_trait::async_trait]
318impl OutputStream for TcpWriteStream {
319    fn write(&mut self, bytes: bytes::Bytes) -> Result<(), StreamError> {
320        try_lock_for_stream(&self.0)?.write(bytes)
321    }
322
323    fn flush(&mut self) -> Result<(), StreamError> {
324        try_lock_for_stream(&self.0)?.flush()
325    }
326
327    fn check_write(&mut self) -> Result<usize, StreamError> {
328        try_lock_for_stream(&self.0)?.check_write()
329    }
330
331    async fn cancel(&mut self) {
332        self.0.lock().await.cancel().await
333    }
334}
335
336#[async_trait::async_trait]
337impl Pollable for TcpWriteStream {
338    async fn ready(&mut self) {
339        self.0.lock().await.ready().await
340    }
341}
342
343fn native_shutdown(stream: &tokio::net::TcpStream, how: Shutdown) {
344    _ = stream
345        .as_socketlike_view::<std::net::TcpStream>()
346        .shutdown(how);
347}
348
349fn try_lock_for_stream<T>(mutex: &Mutex<T>) -> Result<tokio::sync::MutexGuard<'_, T>, StreamError> {
350    mutex
351        .try_lock()
352        .map_err(|_| StreamError::trap("concurrent access to resource not supported"))
353}
354
355fn try_lock_for_socket<T>(mutex: &Mutex<T>) -> SocketResult<tokio::sync::MutexGuard<'_, T>> {
356    mutex.try_lock().map_err(|_| {
357        SocketError::trap(anyhow::anyhow!(
358            "concurrent access to resource not supported"
359        ))
360    })
361}