1use core::error::Error;
8use core::ops::{Deref, DerefMut};
9use core::time::Duration;
10use hyper::client::conn::http1;
11use hyper_util::rt::TokioIo;
12use std::collections::{HashMap, VecDeque};
13use std::sync::{Arc, LazyLock};
14use std::time::Instant;
15use tokio::join;
16use tokio::net::{TcpStream, ToSocketAddrs};
17use tokio::sync::{Mutex, RwLock};
18use tokio::task::{AbortHandle, JoinSet};
19use tracing::{trace, warn};
20
21use wrpc_interface_http::bindings::{
22 wasi::http::types::DnsErrorPayload, wrpc::http::types::ErrorCode,
23};
24
25pub const DEFAULT_IDLE_TIMEOUT: Duration = Duration::from_secs(90);
28
29pub const DEFAULT_USER_AGENT: &str =
31 concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
32
33pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(600);
35
36pub const DEFAULT_FIRST_BYTE_TIMEOUT: Duration = Duration::from_secs(600);
38
39pub const LOAD_NATIVE_CERTS: &str = "load_native_certs";
42
43pub const LOAD_WEBPKI_CERTS: &str = "load_webpki_certs";
46
47pub const SSL_CERTS_FILE: &str = "ssl_certs_file";
50
51pub static ZERO_INSTANT: LazyLock<Instant> = LazyLock::new(Instant::now);
54
55#[derive(Clone, Debug)]
59pub struct PooledConn<T> {
60 pub sender: T,
62 pub abort: AbortHandle,
64 pub last_seen: Instant,
66}
67
68impl<T> Deref for PooledConn<T> {
69 type Target = T;
70
71 fn deref(&self) -> &Self::Target {
72 &self.sender
73 }
74}
75
76impl<T> DerefMut for PooledConn<T> {
77 fn deref_mut(&mut self) -> &mut Self::Target {
78 &mut self.sender
79 }
80}
81
82impl<T: PartialEq> PartialEq for PooledConn<T> {
83 fn eq(
84 &self,
85 Self {
86 sender,
87 abort,
88 last_seen,
89 }: &Self,
90 ) -> bool {
91 self.sender == *sender && self.abort.id() == abort.id() && self.last_seen == *last_seen
92 }
93}
94
95impl<T> Drop for PooledConn<T> {
96 fn drop(&mut self) {
97 self.abort.abort();
98 }
99}
100
101impl<T> PooledConn<T> {
102 pub fn new(sender: T, abort: AbortHandle) -> Self {
109 Self {
110 sender,
111 abort,
112 last_seen: *ZERO_INSTANT,
113 }
114 }
115}
116
117pub type ConnPoolTable<T> =
121 RwLock<HashMap<Box<str>, std::sync::Mutex<VecDeque<PooledConn<http1::SendRequest<T>>>>>>;
122
123#[derive(Debug)]
130pub struct ConnPool<T> {
131 pub http: Arc<ConnPoolTable<T>>,
133 pub https: Arc<ConnPoolTable<T>>,
135 pub tasks: Arc<Mutex<JoinSet<()>>>,
137}
138
139impl<T> Default for ConnPool<T> {
142 fn default() -> Self {
143 Self {
144 http: Arc::default(),
145 https: Arc::default(),
146 tasks: Arc::default(),
147 }
148 }
149}
150
151impl<T> Clone for ConnPool<T> {
152 fn clone(&self) -> Self {
153 Self {
154 http: self.http.clone(),
155 https: self.https.clone(),
156 tasks: self.tasks.clone(),
157 }
158 }
159}
160
161pub fn evict_conns<T>(
168 cutoff: Instant,
169 conns: &mut HashMap<Box<str>, std::sync::Mutex<VecDeque<PooledConn<T>>>>,
170) {
171 trace!(target: "http_client::evict", ?cutoff, total_authorities=conns.len(), "evicting connections older than cutoff");
172 let mut total_evicted = 0;
173 conns.retain(|authority, conns| {
174 let Ok(conns) = conns.get_mut() else {
175 trace!(target: "http_client::evict", %authority, "skipping locked connection pool");
176 return true;
177 };
178 let total_conns = conns.len();
179 let idx = conns.partition_point(|&PooledConn { last_seen, .. }| last_seen <= cutoff);
180 if idx == conns.len() {
181 trace!(target: "http_client::evict", %authority, evicted=total_conns, "evicting all connections");
182 total_evicted += total_conns;
183 false
184 } else if idx == 0 {
185 trace!(target: "http_client::evict", %authority, total=total_conns, "no connections to evict");
186 true
187 } else {
188 trace!(target: "http_client::evict", %authority, evicted=idx, remaining=(total_conns - idx), "partially evicting connections");
189 conns.rotate_left(idx);
190 conns.truncate(total_conns - idx);
191 total_evicted += idx;
192 true
193 }
194 });
195 trace!(target: "http_client::evict", total_evicted, remaining_authorities=conns.len(), "connection eviction complete");
196}
197
198impl<T> ConnPool<T> {
199 pub async fn evict(&self, timeout: Duration) {
206 let Some(cutoff) = Instant::now().checked_sub(timeout) else {
207 return;
208 };
209 join!(
210 async {
211 let mut conns = self.http.write().await;
212 evict_conns(cutoff, &mut conns);
213 },
214 async {
215 let mut conns = self.https.write().await;
216 evict_conns(cutoff, &mut conns);
217 }
218 );
219 }
220
221 #[allow(dead_code)]
233 pub async fn connect_http(
234 &self,
235 authority: &str,
236 ) -> Result<Cacheable<PooledConn<http1::SendRequest<T>>>, ErrorCode>
237 where
238 T: http_body::Body + Send + 'static,
239 T::Data: Send,
240 T::Error: Into<Box<dyn Error + Send + Sync>>,
241 {
242 trace!(target: "http_client::connect_http", authority, "attempting HTTP connection");
243 {
244 let http = self.http.read().await;
245 if let Some(conns) = http.get(authority) {
246 if let Ok(mut conns) = conns.lock() {
247 trace!(target: "http_client::connect_http", authority, cached_connections=conns.len(), "checking cached HTTP connections");
248 while let Some(conn) = conns.pop_front() {
249 trace!(target: "http_client::connect_http", authority, "found cached HTTP connection");
250 if !conn.is_closed() && conn.is_ready() {
251 trace!(target: "http_client::connect_http", authority, "returning HTTP connection cache hit");
252 return Ok(Cacheable::Hit(conn));
253 } else {
254 trace!(target: "http_client::connect_http", authority, is_closed=conn.is_closed(), is_ready=conn.is_ready(), "discarding unusable cached HTTP connection");
255 }
256 }
257 }
258 }
259 }
260 trace!(target: "http_client::connect_http", authority, "establishing new TCP connection");
261 let stream = connect(authority).await?;
262 trace!(target: "http_client::connect_http", authority, "starting HTTP handshake");
263 let (sender, conn) = http1::handshake(TokioIo::new(stream))
264 .await
265 .map_err(|err| {
266 warn!(target: "http_client::connect_http", error=?err, authority, "HTTP handshake failed");
267 hyper_request_error(err)
268 })?;
269 let tasks = Arc::clone(&self.tasks);
270 let authority_clone = authority.to_string();
271 let abort = tasks.lock().await.spawn(async move {
272 match conn.await {
273 Ok(()) => trace!(target: "http_client::connect_http", authority=authority_clone, "HTTP connection closed successfully"),
274 Err(err) => warn!(target: "http_client::connect_http", ?err, authority=authority_clone, "HTTP connection closed with error"),
275 }
276 });
277 trace!(target: "http_client::connect_http", authority, "returning HTTP connection cache miss");
278 Ok(Cacheable::Miss(PooledConn::new(sender, abort)))
279 }
280
281 #[cfg(any(target_arch = "riscv64", target_arch = "s390x"))]
282 pub async fn connect_https(
283 &self,
284 _tls: &tokio_rustls::TlsConnector,
285 _authority: &str,
286 ) -> Result<Cacheable<PooledConn<http1::SendRequest<T>>>, ErrorCode>
287 where
288 T: http_body::Body + Send + 'static,
289 T::Data: Send,
290 T::Error: Into<Box<dyn Error + Send + Sync>>,
291 {
292 Err(ErrorCode::InternalError(Some(
293 "HTTPS connections are not supported on this architecture".to_string(),
294 )))
295 }
296
297 #[cfg(not(any(target_arch = "riscv64", target_arch = "s390x")))]
310 pub async fn connect_https(
311 &self,
312 tls: &tokio_rustls::TlsConnector,
313 authority: &str,
314 ) -> Result<Cacheable<PooledConn<http1::SendRequest<T>>>, ErrorCode>
315 where
316 T: http_body::Body + Send + 'static,
317 T::Data: Send,
318 T::Error: Into<Box<dyn Error + Send + Sync>>,
319 {
320 use rustls::pki_types::ServerName;
321
322 trace!(target: "http_client::connect_https", authority, "attempting HTTPS connection");
323 {
324 let https = self.https.read().await;
325 if let Some(conns) = https.get(authority) {
326 if let Ok(mut conns) = conns.lock() {
327 trace!(target: "http_client::connect_https", authority, cached_connections=conns.len(), "checking cached HTTPS connections");
328 while let Some(conn) = conns.pop_front() {
329 trace!(target: "http_client::connect_https", authority, "found cached HTTPS connection");
330 if !conn.is_closed() && conn.is_ready() {
331 trace!(target: "http_client::connect_https", authority, "returning HTTPS connection cache hit");
332 return Ok(Cacheable::Hit(conn));
333 } else {
334 trace!(target: "http_client::connect_https", authority, is_closed=conn.is_closed(), is_ready=conn.is_ready(), "discarding unusable cached HTTPS connection");
335 }
336 }
337 }
338 }
339 }
340 trace!(target: "http_client::connect_https", authority, "establishing new TCP connection");
341 let stream = connect(authority).await?;
342
343 let mut parts = authority.split(":");
344 let host = parts.next().unwrap_or(authority);
345 trace!(target: "http_client::connect_https", authority, host, "resolving server name for TLS");
346 let domain = ServerName::try_from(host)
347 .map_err(|err| {
348 warn!(target: "http_client::connect_https", ?err, authority, host, "invalid DNS name for TLS");
349 dns_error("invalid DNS name".to_string(), 0)
350 })?
351 .to_owned();
352 trace!(target: "http_client::connect_https", authority, host, "starting TLS handshake");
353 let stream = tls.connect(domain, stream).await.map_err(|err| {
354 warn!(target: "http_client::connect_https", ?err, authority, host, "TLS handshake failed");
355 ErrorCode::TlsProtocolError
356 })?;
357 trace!(target: "http_client::connect_https", authority, "starting HTTP handshake over TLS");
358 let (sender, conn) = http1::handshake(TokioIo::new(stream))
359 .await
360 .map_err(|err| {
361 warn!(target: "http_client::connect_https", error=?err, authority, "HTTP handshake failed over TLS");
362 hyper_request_error(err)
363 })?;
364 let tasks = Arc::clone(&self.tasks);
365 let authority_clone = authority.to_string();
366 let abort = tasks.lock().await.spawn(async move {
367 match conn.await {
368 Ok(()) => trace!(target: "http_client::connect_https", authority=authority_clone, "HTTPS connection closed successfully"),
369 Err(err) => warn!(target: "http_client::connect_https", ?err, authority=authority_clone, "HTTPS connection closed with error"),
370 }
371 });
372 trace!(target: "http_client::connect_https", authority, "returning HTTPS connection cache miss");
373 Ok(Cacheable::Miss(PooledConn::new(sender, abort)))
374 }
375}
376
377pub enum Cacheable<T> {
383 Miss(T),
385 Hit(T),
387}
388
389impl<T> Deref for Cacheable<T> {
390 type Target = T;
391
392 fn deref(&self) -> &Self::Target {
393 match self {
394 Self::Miss(v) | Self::Hit(v) => v,
395 }
396 }
397}
398
399impl<T> DerefMut for Cacheable<T> {
400 fn deref_mut(&mut self) -> &mut Self::Target {
401 match self {
402 Self::Miss(v) | Self::Hit(v) => v,
403 }
404 }
405}
406
407impl<T> Cacheable<T> {
408 #[allow(dead_code)]
410 pub fn unwrap(self) -> T {
411 match self {
412 Self::Miss(v) => {
413 trace!(target: "http_client::cache", "unwrapping cache miss");
414 v
415 }
416 Self::Hit(v) => {
417 trace!(target: "http_client::cache", "unwrapping cache hit");
418 v
419 }
420 }
421 }
422}
423
424fn dns_error(rcode: String, info_code: u16) -> ErrorCode {
435 ErrorCode::DnsError(DnsErrorPayload {
436 rcode: Some(rcode),
437 info_code: Some(info_code),
438 })
439}
440
441async fn connect(addr: impl ToSocketAddrs) -> Result<TcpStream, ErrorCode> {
451 trace!(target: "http_client::connect", "attempting TCP connection");
452 match TcpStream::connect(addr).await {
453 Ok(stream) => {
454 trace!(target: "http_client::connect", "TCP connection established successfully");
455 Ok(stream)
456 }
457 Err(err) if err.kind() == std::io::ErrorKind::AddrNotAvailable => {
458 warn!(target: "http_client::connect", error=?err, "address not available");
459 Err(dns_error("address not available".to_string(), 0))
460 }
461 Err(err) => {
462 if err
463 .to_string()
464 .starts_with("failed to lookup address information")
465 {
466 warn!(target: "http_client::connect", error=?err, "DNS lookup failed");
467 Err(dns_error("address not available".to_string(), 0))
468 } else {
469 warn!(target: "http_client::connect", error=?err, "connection refused");
470 Err(ErrorCode::ConnectionRefused)
471 }
472 }
473 }
474}
475
476pub fn hyper_request_error(err: hyper::Error) -> ErrorCode {
486 if let Some(cause) = err.source() {
488 if let Some(io_err) = cause.downcast_ref::<std::io::Error>() {
490 match io_err.kind() {
491 std::io::ErrorKind::ConnectionRefused => return ErrorCode::ConnectionRefused,
492 std::io::ErrorKind::ConnectionReset => return ErrorCode::ConnectionTerminated,
493 std::io::ErrorKind::TimedOut => return ErrorCode::ConnectionTimeout,
494 _ => {}
495 }
496 }
497
498 warn!(
500 target: "http_client::error",
501 error=?err,
502 cause=?cause,
503 "HTTP request failed with underlying cause"
504 );
505 return ErrorCode::HttpProtocolError;
506 }
507
508 warn!(
510 target: "http_client::error",
511 error=?err,
512 "HTTP request failed"
513 );
514
515 ErrorCode::HttpProtocolError
516}
517
518#[cfg(test)]
519mod tests {
520 use core::net::Ipv4Addr;
521
522 use std::collections::{HashMap, VecDeque};
523 use std::time::Instant;
524
525 use anyhow::Context as _;
526 use bytes::Bytes;
527 use hyper_util::rt::TokioIo;
528 use tokio::net::TcpListener;
529 use tokio::spawn;
530 use tokio::try_join;
531
532 use super::*;
533 use wrpc_interface_http::HttpBody;
534
535 const N: usize = 20;
536
537 #[test_log::test(tokio::test(flavor = "multi_thread"))]
539 async fn test_conn_evict() -> anyhow::Result<()> {
540 let now = Instant::now();
541
542 let mut foo = VecDeque::from([
543 PooledConn {
544 sender: (),
545 abort: spawn(async {}).abort_handle(),
546 last_seen: now
547 .checked_sub(Duration::from_secs(10))
548 .expect("time subtraction should not overflow"),
549 },
550 PooledConn {
551 sender: (),
552 abort: spawn(async {}).abort_handle(),
553 last_seen: now
554 .checked_sub(Duration::from_secs(1))
555 .expect("time subtraction should not overflow"),
556 },
557 PooledConn {
558 sender: (),
559 abort: spawn(async {}).abort_handle(),
560 last_seen: now,
561 },
562 PooledConn {
563 sender: (),
564 abort: spawn(async {}).abort_handle(),
565 last_seen: now
566 .checked_add(Duration::from_secs(1))
567 .expect("time addition should not overflow"),
568 },
569 PooledConn {
570 sender: (),
571 abort: spawn(async {}).abort_handle(),
572 last_seen: now
573 .checked_add(Duration::from_secs(1))
574 .expect("time addition should not overflow"),
575 },
576 PooledConn {
577 sender: (),
578 abort: spawn(async {}).abort_handle(),
579 last_seen: now
580 .checked_add(Duration::from_secs(3))
581 .expect("time addition should not overflow"),
582 },
583 ]);
584 let qux = VecDeque::from([
585 PooledConn {
586 sender: (),
587 abort: spawn(async {}).abort_handle(),
588 last_seen: now
589 .checked_add(Duration::from_secs(10))
590 .expect("time addition should not overflow"),
591 },
592 PooledConn {
593 sender: (),
594 abort: spawn(async {}).abort_handle(),
595 last_seen: now
596 .checked_add(Duration::from_secs(12))
597 .expect("time addition should not overflow"),
598 },
599 ]);
600 let mut conns = HashMap::from([
601 ("foo".into(), std::sync::Mutex::new(foo.clone())),
602 ("bar".into(), std::sync::Mutex::default()),
603 (
604 "baz".into(),
605 std::sync::Mutex::new(VecDeque::from([
606 PooledConn {
607 sender: (),
608 abort: spawn(async {}).abort_handle(),
609 last_seen: now
610 .checked_sub(Duration::from_secs(10))
611 .expect("time subtraction should not overflow"),
612 },
613 PooledConn {
614 sender: (),
615 abort: spawn(async {}).abort_handle(),
616 last_seen: now
617 .checked_sub(Duration::from_secs(1))
618 .expect("time subtraction should not overflow"),
619 },
620 ])),
621 ),
622 ("qux".into(), std::sync::Mutex::new(qux.clone())),
623 ]);
624 evict_conns(now, &mut conns);
625 assert_eq!(
626 conns
627 .remove("foo")
628 .expect("foo should exist")
629 .into_inner()
630 .expect("mutex should be unlocked"),
631 foo.split_off(3)
632 );
633 assert_eq!(
634 conns
635 .remove("qux")
636 .expect("qux should exist")
637 .into_inner()
638 .expect("mutex should be unlocked"),
639 qux
640 );
641 assert!(conns.is_empty());
642 evict_conns(now, &mut conns);
643 assert!(conns.is_empty());
644 Ok(())
645 }
646
647 #[cfg(feature = "http")]
648 #[test_log::test(tokio::test(flavor = "multi_thread"))]
650 async fn test_pool_evict() -> anyhow::Result<()> {
651 eprintln!("Starting test_pool_evict");
652 const IDLE_TIMEOUT: Duration = Duration::from_millis(10);
653 let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await?;
654 let addr = listener.local_addr()?;
655 eprintln!("Test server bound to {addr}");
656
657 try_join!(
658 async {
659 eprintln!("Server task starting, will accept {N} connections");
660 for i in 0..N {
661 eprintln!("[{}/{}] Waiting to accept connection...", i + 1, N);
662 let (stream, _) = listener
663 .accept()
664 .await
665 .with_context(|| format!("failed to accept connection {i}"))?;
666 eprintln!("[{}/{}] Connection accepted, serving...", i + 1, N);
667 hyper::server::conn::http1::Builder::new()
668 .serve_connection(
669 TokioIo::new(stream),
670 hyper::service::service_fn(move |_| async {
671 anyhow::Ok(http::Response::new(
672 http_body_util::Empty::<Bytes>::new(),
673 ))
674 }),
675 )
676 .await
677 .with_context(|| format!("failed to serve connection {i}"))?;
678 eprintln!("[{}/{}] Connection served and closed", i + 1, N);
679 }
680 eprintln!("Server task completed all {N} connections");
681 anyhow::Ok(())
682 },
683 async {
684 eprintln!("Client task starting");
685 let pool = ConnPool::<HttpBody>::default();
686 let now = Instant::now();
687
688 eprintln!(" Creating {N} connections to server at {addr}");
689 {
691 let mut http_conns = pool.http.write().await;
692 let mut connections = VecDeque::new();
693
694 for i in 0..N {
695 eprintln!("[{}/{}] Establishing handshake...", i + 1, N);
696 let (sender, _) =
697 http1::handshake(TokioIo::new(TcpStream::connect(addr).await?)).await?;
698 eprintln!("[{}/{}] Handshake completed", i + 1, N);
699
700 connections.push_back(PooledConn {
701 sender,
702 abort: spawn(async {}).abort_handle(),
703 last_seen: now
704 .checked_sub(Duration::from_secs(10))
705 .expect("time subtraction should not overflow"),
706 });
707 }
708
709 http_conns.insert(addr.to_string().into(), std::sync::Mutex::new(connections));
710 eprintln!("All {N} connections added to pool");
711 } eprintln!("Sleeping for a bit to let connections age...");
714 tokio::time::sleep(Duration::from_millis(20)).await;
715
716 eprintln!("Starting eviction process...");
717 pool.evict(IDLE_TIMEOUT).await;
719 eprintln!("Eviction completed");
720
721 eprintln!("Verifying connections were evicted...");
722 let http_conns = pool.http.read().await;
724 let result = http_conns
726 .get(addr.to_string().into_boxed_str().as_ref())
727 .is_none();
728 eprintln!("Eviction verification result: authority removed = {result}");
729 assert!(result);
730
731 eprintln!("Client task completed successfully");
732 Ok(())
733 }
734 )?;
735 Ok(())
736 }
737}