1use core::fmt;
2use core::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
3use core::str::FromStr as _;
4use core::time::Duration;
5
6use cap_net_ext::{AddressFamily, Blocking, UdpSocketExt};
7use rustix::fd::AsFd;
8use rustix::io::Errno;
9use rustix::net::{bind, connect_unspec, sockopt};
10use tracing::debug;
11
12use crate::sockets::SocketAddressFamily;
13
14#[derive(Debug)]
15pub enum ErrorCode {
16 Unknown,
17 AccessDenied,
18 NotSupported,
19 InvalidArgument,
20 OutOfMemory,
21 Timeout,
22 InvalidState,
23 AddressNotBindable,
24 AddressInUse,
25 RemoteUnreachable,
26 ConnectionRefused,
27 ConnectionReset,
28 ConnectionAborted,
29 DatagramTooLarge,
30 NotInProgress,
31 ConcurrencyConflict,
32}
33
34impl fmt::Display for ErrorCode {
35 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36 fmt::Debug::fmt(self, f)
37 }
38}
39
40impl std::error::Error for ErrorCode {}
41
42fn is_deprecated_ipv4_compatible(addr: Ipv6Addr) -> bool {
43 matches!(addr.segments(), [0, 0, 0, 0, 0, 0, _, _])
44 && addr != Ipv6Addr::UNSPECIFIED
45 && addr != Ipv6Addr::LOCALHOST
46}
47
48pub fn is_valid_address_family(addr: IpAddr, socket_family: SocketAddressFamily) -> bool {
49 match (socket_family, addr) {
50 (SocketAddressFamily::Ipv4, IpAddr::V4(..)) => true,
51 (SocketAddressFamily::Ipv6, IpAddr::V6(ipv6)) => {
52 !is_deprecated_ipv4_compatible(ipv6) && ipv6.to_ipv4_mapped().is_none()
57 }
58 _ => false,
59 }
60}
61
62pub fn is_valid_remote_address(addr: SocketAddr) -> bool {
63 !addr.ip().to_canonical().is_unspecified() && addr.port() != 0
64}
65
66pub fn is_valid_unicast_address(addr: IpAddr) -> bool {
67 match addr.to_canonical() {
68 IpAddr::V4(ipv4) => !ipv4.is_multicast() && !ipv4.is_broadcast(),
69 IpAddr::V6(ipv6) => !ipv6.is_multicast(),
70 }
71}
72
73pub fn to_ipv4_addr(addr: (u8, u8, u8, u8)) -> Ipv4Addr {
74 let (x0, x1, x2, x3) = addr;
75 Ipv4Addr::new(x0, x1, x2, x3)
76}
77
78pub fn from_ipv4_addr(addr: Ipv4Addr) -> (u8, u8, u8, u8) {
79 let [x0, x1, x2, x3] = addr.octets();
80 (x0, x1, x2, x3)
81}
82
83pub fn to_ipv6_addr(addr: (u16, u16, u16, u16, u16, u16, u16, u16)) -> Ipv6Addr {
84 let (x0, x1, x2, x3, x4, x5, x6, x7) = addr;
85 Ipv6Addr::new(x0, x1, x2, x3, x4, x5, x6, x7)
86}
87
88pub fn from_ipv6_addr(addr: Ipv6Addr) -> (u16, u16, u16, u16, u16, u16, u16, u16) {
89 let [x0, x1, x2, x3, x4, x5, x6, x7] = addr.segments();
90 (x0, x1, x2, x3, x4, x5, x6, x7)
91}
92
93pub fn normalize_get_buffer_size(value: usize) -> usize {
98 if cfg!(target_os = "linux") {
99 value / 2
105 } else {
106 value
107 }
108}
109
110pub fn normalize_set_buffer_size(value: usize) -> usize {
111 value.clamp(1, i32::MAX as usize)
112}
113
114impl From<std::io::Error> for ErrorCode {
115 fn from(value: std::io::Error) -> Self {
116 (&value).into()
117 }
118}
119
120impl From<&std::io::Error> for ErrorCode {
121 fn from(value: &std::io::Error) -> Self {
122 if let Some(errno) = Errno::from_io_error(value) {
124 return errno.into();
125 }
126
127 match value.kind() {
128 std::io::ErrorKind::AddrInUse => Self::AddressInUse,
129 std::io::ErrorKind::AddrNotAvailable => Self::AddressNotBindable,
130 std::io::ErrorKind::ConnectionAborted => Self::ConnectionAborted,
131 std::io::ErrorKind::ConnectionRefused => Self::ConnectionRefused,
132 std::io::ErrorKind::ConnectionReset => Self::ConnectionReset,
133 std::io::ErrorKind::InvalidInput => Self::InvalidArgument,
134 std::io::ErrorKind::NotConnected => Self::InvalidState,
135 std::io::ErrorKind::OutOfMemory => Self::OutOfMemory,
136 std::io::ErrorKind::PermissionDenied => Self::AccessDenied,
137 std::io::ErrorKind::TimedOut => Self::Timeout,
138 std::io::ErrorKind::Unsupported => Self::NotSupported,
139 _ => {
140 debug!("unknown I/O error: {value}");
141 Self::Unknown
142 }
143 }
144 }
145}
146
147impl From<Errno> for ErrorCode {
148 fn from(value: Errno) -> Self {
149 (&value).into()
150 }
151}
152
153impl From<&Errno> for ErrorCode {
154 fn from(value: &Errno) -> Self {
155 match *value {
156 #[cfg(not(windows))]
157 Errno::PERM => Self::AccessDenied,
158 Errno::ACCESS => Self::AccessDenied,
159 Errno::ADDRINUSE => Self::AddressInUse,
160 Errno::ADDRNOTAVAIL => Self::AddressNotBindable,
161 Errno::TIMEDOUT => Self::Timeout,
162 Errno::CONNREFUSED => Self::ConnectionRefused,
163 Errno::CONNRESET => Self::ConnectionReset,
164 Errno::CONNABORTED => Self::ConnectionAborted,
165 Errno::INVAL => Self::InvalidArgument,
166 Errno::HOSTUNREACH => Self::RemoteUnreachable,
167 Errno::HOSTDOWN => Self::RemoteUnreachable,
168 Errno::NETDOWN => Self::RemoteUnreachable,
169 Errno::NETUNREACH => Self::RemoteUnreachable,
170 #[cfg(target_os = "linux")]
171 Errno::NONET => Self::RemoteUnreachable,
172 Errno::ISCONN => Self::InvalidState,
173 Errno::NOTCONN => Self::InvalidState,
174 Errno::DESTADDRREQ => Self::InvalidState,
175 Errno::MSGSIZE => Self::DatagramTooLarge,
176 #[cfg(not(windows))]
177 Errno::NOMEM => Self::OutOfMemory,
178 Errno::NOBUFS => Self::OutOfMemory,
179 Errno::OPNOTSUPP => Self::NotSupported,
180 Errno::NOPROTOOPT => Self::NotSupported,
181 Errno::PFNOSUPPORT => Self::NotSupported,
182 Errno::PROTONOSUPPORT => Self::NotSupported,
183 Errno::PROTOTYPE => Self::NotSupported,
184 Errno::SOCKTNOSUPPORT => Self::NotSupported,
185 Errno::AFNOSUPPORT => Self::NotSupported,
186
187 _ => {
189 debug!("unknown I/O error: {value}");
190 Self::Unknown
191 }
192 }
193 }
194}
195
196pub fn get_ip_ttl(fd: impl AsFd) -> Result<u8, ErrorCode> {
197 let v = sockopt::ip_ttl(fd)?;
198 let Ok(v) = v.try_into() else {
199 return Err(ErrorCode::NotSupported);
200 };
201 Ok(v)
202}
203
204pub fn get_ipv6_unicast_hops(fd: impl AsFd) -> Result<u8, ErrorCode> {
205 let v = sockopt::ipv6_unicast_hops(fd)?;
206 Ok(v)
207}
208
209pub fn get_unicast_hop_limit(fd: impl AsFd, family: SocketAddressFamily) -> Result<u8, ErrorCode> {
210 match family {
211 SocketAddressFamily::Ipv4 => get_ip_ttl(fd),
212 SocketAddressFamily::Ipv6 => get_ipv6_unicast_hops(fd),
213 }
214}
215
216pub fn set_unicast_hop_limit(
217 fd: impl AsFd,
218 family: SocketAddressFamily,
219 value: u8,
220) -> Result<(), ErrorCode> {
221 if value == 0 {
222 return Err(ErrorCode::InvalidArgument);
228 }
229 match family {
230 SocketAddressFamily::Ipv4 => {
231 sockopt::set_ip_ttl(fd, value.into())?;
232 }
233 SocketAddressFamily::Ipv6 => {
234 sockopt::set_ipv6_unicast_hops(fd, Some(value))?;
235 }
236 }
237 Ok(())
238}
239
240pub fn receive_buffer_size(fd: impl AsFd) -> Result<u64, ErrorCode> {
241 let v = sockopt::socket_recv_buffer_size(fd)?;
242 Ok(normalize_get_buffer_size(v).try_into().unwrap_or(u64::MAX))
243}
244
245pub fn set_receive_buffer_size(fd: impl AsFd, value: u64) -> Result<usize, ErrorCode> {
246 if value == 0 {
247 return Err(ErrorCode::InvalidArgument);
249 }
250 let value = value.try_into().unwrap_or(usize::MAX);
251 let value = normalize_set_buffer_size(value);
252 match sockopt::set_socket_recv_buffer_size(fd, value) {
253 Err(Errno::NOBUFS) => {}
264 Err(err) => return Err(err.into()),
265 _ => {}
266 };
267 Ok(value)
268}
269
270pub fn send_buffer_size(fd: impl AsFd) -> Result<u64, ErrorCode> {
271 let v = sockopt::socket_send_buffer_size(fd)?;
272 Ok(normalize_get_buffer_size(v).try_into().unwrap_or(u64::MAX))
273}
274
275pub fn set_send_buffer_size(fd: impl AsFd, value: u64) -> Result<usize, ErrorCode> {
276 if value == 0 {
277 return Err(ErrorCode::InvalidArgument);
279 }
280 let value = value.try_into().unwrap_or(usize::MAX);
281 let value = normalize_set_buffer_size(value);
282 match sockopt::set_socket_send_buffer_size(fd, value) {
283 Err(Errno::NOBUFS) => {}
284 Err(err) => return Err(err.into()),
285 _ => {}
286 };
287 Ok(value)
288}
289
290pub fn set_keep_alive_idle_time(fd: impl AsFd, value: u64) -> Result<u64, ErrorCode> {
291 const NANOS_PER_SEC: u64 = 1_000_000_000;
292
293 const MIN: u64 = NANOS_PER_SEC;
295
296 const MAX: u64 = (i16::MAX as u64) * NANOS_PER_SEC;
298
299 if value <= 0 {
300 return Err(ErrorCode::InvalidArgument);
302 }
303 let value = value.clamp(MIN, MAX);
304 sockopt::set_tcp_keepidle(fd, Duration::from_nanos(value))?;
305 Ok(value)
306}
307
308pub fn set_keep_alive_interval(fd: impl AsFd, value: Duration) -> Result<(), ErrorCode> {
309 const MIN: Duration = Duration::from_secs(1);
311
312 const MAX: Duration = Duration::from_secs(i16::MAX as u64);
314
315 if value <= Duration::ZERO {
316 return Err(ErrorCode::InvalidArgument);
318 }
319 sockopt::set_tcp_keepintvl(fd, value.clamp(MIN, MAX))?;
320 Ok(())
321}
322
323pub fn set_keep_alive_count(fd: impl AsFd, value: u32) -> Result<(), ErrorCode> {
324 const MIN_CNT: u32 = 1;
325 const MAX_CNT: u32 = i8::MAX as u32;
327
328 if value == 0 {
329 return Err(ErrorCode::InvalidArgument);
331 }
332 sockopt::set_tcp_keepcnt(fd, value.clamp(MIN_CNT, MAX_CNT))?;
333 Ok(())
334}
335
336pub fn tcp_bind(
337 socket: &tokio::net::TcpSocket,
338 local_address: SocketAddr,
339) -> Result<(), ErrorCode> {
340 #[cfg(not(windows))]
345 if let Err(err) = sockopt::set_socket_reuseaddr(&socket, local_address.port() > 0) {
346 return Err(err.into());
347 }
348
349 socket
351 .bind(local_address)
352 .map_err(|err| match Errno::from_io_error(&err) {
353 Some(Errno::AFNOSUPPORT) => ErrorCode::InvalidArgument,
361 #[cfg(windows)]
364 Some(Errno::NOBUFS) => ErrorCode::AddressInUse,
365 _ => err.into(),
366 })
367}
368
369pub fn udp_socket(family: AddressFamily) -> std::io::Result<cap_std::net::UdpSocket> {
370 let socket = cap_std::net::UdpSocket::new(family, Blocking::No)?;
376 Ok(socket)
377}
378
379pub fn udp_bind(sockfd: impl AsFd, addr: SocketAddr) -> Result<(), ErrorCode> {
380 bind(sockfd, &addr).map_err(|err| match err {
381 #[cfg(windows)]
384 Errno::NOBUFS => ErrorCode::AddressInUse,
385 Errno::AFNOSUPPORT => ErrorCode::InvalidArgument,
393 _ => err.into(),
394 })
395}
396
397pub fn udp_disconnect(sockfd: impl AsFd) -> Result<(), ErrorCode> {
398 match connect_unspec(sockfd) {
399 #[cfg(target_os = "macos")]
411 Err(Errno::INVAL | Errno::AFNOSUPPORT) => Ok(()),
412 Err(err) => Err(err.into()),
413 Ok(()) => Ok(()),
414 }
415}
416
417pub fn parse_host(name: &str) -> Result<url::Host, ErrorCode> {
418 match url::Host::parse(&name) {
422 Ok(host) => Ok(host),
423
424 Err(_) => {
426 if let Ok(addr) = Ipv6Addr::from_str(name) {
427 Ok(url::Host::Ipv6(addr))
428 } else {
429 Err(ErrorCode::InvalidArgument)
430 }
431 }
432 }
433}