1use crate::p2::{
2 DynInputStream, DynOutputStream, InputStream, OutputStream, Pollable, SocketError,
3 SocketResult, StreamError,
4};
5use crate::runtime::AbortOnDropJoinHandle;
6use crate::sockets::TcpSocket;
7use anyhow::Result;
8use io_lifetimes::AsSocketlike;
9use rustix::io::Errno;
10use std::io;
11use std::mem;
12use std::net::Shutdown;
13use std::sync::Arc;
14use tokio::sync::Mutex;
15
16impl TcpSocket {
17 pub(crate) fn p2_streams(&mut self) -> SocketResult<(DynInputStream, DynOutputStream)> {
18 let client = self.tcp_stream_arc()?;
19 let reader = Arc::new(Mutex::new(TcpReader::new(client.clone())));
20 let writer = Arc::new(Mutex::new(TcpWriter::new(client.clone())));
21 self.set_p2_streaming_state(P2TcpStreamingState {
22 stream: client.clone(),
23 reader: reader.clone(),
24 writer: writer.clone(),
25 })?;
26 let input: DynInputStream = Box::new(TcpReadStream(reader));
27 let output: DynOutputStream = Box::new(TcpWriteStream(writer));
28 Ok((input, output))
29 }
30}
31
32pub(crate) struct P2TcpStreamingState {
33 pub(crate) stream: Arc<tokio::net::TcpStream>,
34 reader: Arc<Mutex<TcpReader>>,
35 writer: Arc<Mutex<TcpWriter>>,
36}
37
38impl P2TcpStreamingState {
39 pub(crate) fn shutdown(&self, how: Shutdown) -> SocketResult<()> {
40 if let Shutdown::Both | Shutdown::Read = how {
41 try_lock_for_socket(&self.reader)?.shutdown();
42 }
43
44 if let Shutdown::Both | Shutdown::Write = how {
45 try_lock_for_socket(&self.writer)?.shutdown();
46 }
47
48 Ok(())
49 }
50}
51
52struct TcpReader {
53 stream: Arc<tokio::net::TcpStream>,
54 closed: bool,
55}
56
57impl TcpReader {
58 fn new(stream: Arc<tokio::net::TcpStream>) -> Self {
59 Self {
60 stream,
61 closed: false,
62 }
63 }
64 fn read(&mut self, size: usize) -> Result<bytes::Bytes, StreamError> {
65 if self.closed {
66 return Err(StreamError::Closed);
67 }
68 if size == 0 {
69 return Ok(bytes::Bytes::new());
70 }
71
72 let mut buf = bytes::BytesMut::with_capacity(size);
73 let n = match self.stream.try_read_buf(&mut buf) {
74 Ok(0) => {
76 self.closed = true;
77 return Err(StreamError::Closed);
78 }
79 Ok(n) => n,
80
81 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => 0,
84
85 Err(e) => {
86 self.closed = true;
87 return Err(StreamError::LastOperationFailed(e.into()));
88 }
89 };
90
91 buf.truncate(n);
92 Ok(buf.freeze())
93 }
94
95 fn shutdown(&mut self) {
96 native_shutdown(&self.stream, Shutdown::Read);
97 self.closed = true;
98 }
99
100 async fn ready(&mut self) {
101 if self.closed {
102 return;
103 }
104
105 self.stream.readable().await.unwrap();
106 }
107}
108
109struct TcpReadStream(Arc<Mutex<TcpReader>>);
110
111#[async_trait::async_trait]
112impl InputStream for TcpReadStream {
113 fn read(&mut self, size: usize) -> Result<bytes::Bytes, StreamError> {
114 try_lock_for_stream(&self.0)?.read(size)
115 }
116}
117
118#[async_trait::async_trait]
119impl Pollable for TcpReadStream {
120 async fn ready(&mut self) {
121 self.0.lock().await.ready().await
122 }
123}
124
125const SOCKET_READY_SIZE: usize = 1024 * 1024 * 1024;
126
127struct TcpWriter {
128 stream: Arc<tokio::net::TcpStream>,
129 state: WriteState,
130}
131
132enum WriteState {
133 Ready,
134 Writing(AbortOnDropJoinHandle<io::Result<()>>),
135 Closing(AbortOnDropJoinHandle<io::Result<()>>),
136 Closed,
137 Error(io::Error),
138}
139
140impl TcpWriter {
141 fn new(stream: Arc<tokio::net::TcpStream>) -> Self {
142 Self {
143 stream,
144 state: WriteState::Ready,
145 }
146 }
147
148 fn try_write_portable(stream: &tokio::net::TcpStream, buf: &[u8]) -> io::Result<usize> {
149 stream.try_write(buf).map_err(|error| {
150 match Errno::from_io_error(&error) {
151 #[cfg(windows)]
155 Some(Errno::SHUTDOWN) => io::Error::new(io::ErrorKind::BrokenPipe, error),
156
157 _ => error,
158 }
159 })
160 }
161
162 fn background_write(&mut self, mut bytes: bytes::Bytes) {
165 assert!(matches!(self.state, WriteState::Ready));
166
167 let stream = self.stream.clone();
168 self.state = WriteState::Writing(crate::runtime::spawn(async move {
169 while !bytes.is_empty() {
175 stream.writable().await?;
176 match Self::try_write_portable(&stream, &bytes) {
177 Ok(n) => {
178 let _ = bytes.split_to(n);
179 }
180 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => continue,
181 Err(e) => return Err(e),
182 }
183 }
184
185 Ok(())
186 }));
187 }
188
189 fn write(&mut self, mut bytes: bytes::Bytes) -> Result<(), StreamError> {
190 match self.state {
191 WriteState::Ready => {}
192 WriteState::Closed => return Err(StreamError::Closed),
193 WriteState::Writing(_) | WriteState::Closing(_) | WriteState::Error(_) => {
194 return Err(StreamError::Trap(anyhow::anyhow!(
195 "unpermitted: must call check_write first"
196 )));
197 }
198 }
199 while !bytes.is_empty() {
200 match Self::try_write_portable(&self.stream, &bytes) {
201 Ok(n) => {
202 let _ = bytes.split_to(n);
203 }
204
205 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
206 self.background_write(bytes);
209
210 return Ok(());
211 }
212
213 Err(e) if e.kind() == std::io::ErrorKind::BrokenPipe => {
214 self.state = WriteState::Closed;
215 return Err(StreamError::Closed);
216 }
217
218 Err(e) => return Err(StreamError::LastOperationFailed(e.into())),
219 }
220 }
221
222 Ok(())
223 }
224
225 fn flush(&mut self) -> Result<(), StreamError> {
226 match self.state {
230 WriteState::Ready
231 | WriteState::Writing(_)
232 | WriteState::Closing(_)
233 | WriteState::Error(_) => Ok(()),
234 WriteState::Closed => Err(StreamError::Closed),
235 }
236 }
237
238 fn check_write(&mut self) -> Result<usize, StreamError> {
239 match mem::replace(&mut self.state, WriteState::Closed) {
240 WriteState::Writing(task) => {
241 self.state = WriteState::Writing(task);
242 return Ok(0);
243 }
244 WriteState::Closing(task) => {
245 self.state = WriteState::Closing(task);
246 return Ok(0);
247 }
248 WriteState::Ready => {
249 self.state = WriteState::Ready;
250 }
251 WriteState::Closed => return Err(StreamError::Closed),
252 WriteState::Error(e) => return Err(StreamError::LastOperationFailed(e.into())),
253 }
254
255 let writable = self.stream.writable();
256 futures::pin_mut!(writable);
257 if crate::runtime::poll_noop(writable).is_none() {
258 return Ok(0);
259 }
260 Ok(SOCKET_READY_SIZE)
261 }
262
263 fn shutdown(&mut self) {
264 self.state = match mem::replace(&mut self.state, WriteState::Closed) {
265 WriteState::Ready => {
267 native_shutdown(&self.stream, Shutdown::Write);
268 WriteState::Closed
269 }
270
271 WriteState::Writing(write) => {
273 let stream = self.stream.clone();
274 WriteState::Closing(crate::runtime::spawn(async move {
275 let result = write.await;
276 native_shutdown(&stream, Shutdown::Write);
277 result
278 }))
279 }
280
281 s => s,
282 };
283 }
284
285 async fn cancel(&mut self) {
286 match mem::replace(&mut self.state, WriteState::Closed) {
287 WriteState::Writing(task) | WriteState::Closing(task) => _ = task.cancel().await,
288 _ => {}
289 }
290 }
291
292 async fn ready(&mut self) {
293 match &mut self.state {
294 WriteState::Writing(task) => {
295 self.state = match task.await {
296 Ok(()) => WriteState::Ready,
297 Err(e) => WriteState::Error(e),
298 }
299 }
300 WriteState::Closing(task) => {
301 self.state = match task.await {
302 Ok(()) => WriteState::Closed,
303 Err(e) => WriteState::Error(e),
304 }
305 }
306 _ => {}
307 }
308
309 if let WriteState::Ready = self.state {
310 self.stream.writable().await.unwrap();
311 }
312 }
313}
314
315struct TcpWriteStream(Arc<Mutex<TcpWriter>>);
316
317#[async_trait::async_trait]
318impl OutputStream for TcpWriteStream {
319 fn write(&mut self, bytes: bytes::Bytes) -> Result<(), StreamError> {
320 try_lock_for_stream(&self.0)?.write(bytes)
321 }
322
323 fn flush(&mut self) -> Result<(), StreamError> {
324 try_lock_for_stream(&self.0)?.flush()
325 }
326
327 fn check_write(&mut self) -> Result<usize, StreamError> {
328 try_lock_for_stream(&self.0)?.check_write()
329 }
330
331 async fn cancel(&mut self) {
332 self.0.lock().await.cancel().await
333 }
334}
335
336#[async_trait::async_trait]
337impl Pollable for TcpWriteStream {
338 async fn ready(&mut self) {
339 self.0.lock().await.ready().await
340 }
341}
342
343fn native_shutdown(stream: &tokio::net::TcpStream, how: Shutdown) {
344 _ = stream
345 .as_socketlike_view::<std::net::TcpStream>()
346 .shutdown(how);
347}
348
349fn try_lock_for_stream<T>(mutex: &Mutex<T>) -> Result<tokio::sync::MutexGuard<'_, T>, StreamError> {
350 mutex
351 .try_lock()
352 .map_err(|_| StreamError::trap("concurrent access to resource not supported"))
353}
354
355fn try_lock_for_socket<T>(mutex: &Mutex<T>) -> SocketResult<tokio::sync::MutexGuard<'_, T>> {
356 mutex.try_lock().map_err(|_| {
357 SocketError::trap(anyhow::anyhow!(
358 "concurrent access to resource not supported"
359 ))
360 })
361}