wasmcloud_host/wasmbus/providers/http_client/
provider.rs

1use bytes::Bytes;
2use core::convert::Infallible;
3use core::time::Duration;
4use futures::StreamExt as _;
5use http::uri::Scheme;
6use http_body::Frame;
7use http_body_util::{BodyExt as _, StreamBody};
8use std::sync::Arc;
9use tokio::spawn;
10use tokio::task::JoinSet;
11use tracing::{debug, error, info, trace, warn, Instrument as _};
12use wasmcloud_provider_sdk::Context;
13
14use wrpc_interface_http::{
15    bindings::wrpc::http::types::{ErrorCode, RequestOptions},
16    split_outgoing_http_body, try_fields_to_header_map, ServeOutgoingHandlerHttp,
17};
18
19// Import shared connection pooling infrastructure
20use wasmcloud_core::http_client::{
21    hyper_request_error, Cacheable, ConnPool, DEFAULT_CONNECT_TIMEOUT, DEFAULT_FIRST_BYTE_TIMEOUT,
22    DEFAULT_USER_AGENT,
23};
24
25/// Internal HTTP client provider implementation that handles outgoing HTTP requests
26/// from components. Maintains connection pools for both HTTP and HTTPS connections
27/// and manages TLS connections for secure requests.
28///
29/// This provider is built into the wasmCloud host and provides HTTP client capabilities
30/// to components without requiring an external provider.
31#[derive(Clone)]
32pub struct HttpClientProvider {
33    /// TLS connector for establishing secure HTTPS connections
34    tls: tokio_rustls::TlsConnector,
35    /// Connection pools for HTTP and HTTPS connections
36    conns: ConnPool<wrpc_interface_http::HttpBody>,
37    /// Background tasks for connection management
38    #[allow(unused)]
39    tasks: Arc<JoinSet<()>>,
40}
41
42impl HttpClientProvider {
43    /// Creates a new HTTP client provider with the specified configuration
44    ///
45    /// # Arguments
46    ///
47    /// * `tls` - TLS connector for HTTPS connections
48    /// * `idle_timeout` - Duration after which idle connections are closed
49    ///
50    /// # Returns
51    ///
52    /// A new HTTP client provider or an error if initialization fails
53    pub(crate) async fn new(
54        tls: tokio_rustls::TlsConnector,
55        idle_timeout: Duration,
56    ) -> anyhow::Result<Self> {
57        debug!(
58            target: "http_client::handle",
59            "Creating new HTTP client provider"
60        );
61
62        let conns = ConnPool::<wrpc_interface_http::HttpBody>::default();
63        let mut tasks = JoinSet::new();
64
65        debug!(
66            target: "http_client::handle",
67            "Starting connection eviction task with timeout: {:?}",
68            idle_timeout
69        );
70        tasks.spawn({
71            let conns = conns.clone();
72            async move {
73                loop {
74                    tokio::time::sleep(idle_timeout).await;
75                    trace!("Evicting idle connections");
76                    conns.evict(idle_timeout).await;
77                }
78            }
79        });
80
81        let provider = Self {
82            tls,
83            conns,
84            tasks: Arc::new(tasks),
85        };
86
87        debug!(
88            target: "http_client::handle",
89            "HTTP client provider created successfully"
90        );
91        Ok(provider)
92    }
93}
94
95// Leverages the default implementation of the `wasmcloud_provider_sdk::Provider` trait.
96/// This trait is required by the wasmCloud runtime to interact with the provider.
97impl wasmcloud_provider_sdk::Provider for HttpClientProvider {
98    /// Provider should perform any operations needed for a new link,
99    /// including setting up per-component resources, and checking authorization.
100    /// If the link is allowed, return true, otherwise return false to deny the link.
101    async fn receive_link_config_as_target(
102        &self,
103        link_config: wasmcloud_provider_sdk::LinkConfig<'_>,
104    ) -> anyhow::Result<()> {
105        debug!(
106            target: "http_client::handle",
107            target_id = %link_config.target_id,
108            source_id = %link_config.source_id,
109            link_name = %link_config.link_name,
110            wit_namespace = %link_config.wit_metadata.0,
111            wit_package = %link_config.wit_metadata.1,
112            interfaces = ?link_config.wit_metadata.2,
113            "Received link config as target"
114        );
115        Ok(())
116    }
117}
118
119/// `ServeOutgoingHandlerHttp` trait implementation for the HTTP client provider.
120///
121/// This trait is implemented for the HTTP client provider to handle outgoing HTTP requests.
122/// It provides a method to handle HTTP requests with optional context and request options.
123impl ServeOutgoingHandlerHttp<Option<Context>> for HttpClientProvider {
124    /// Handles an outgoing HTTP request with optional context and request options.
125    ///
126    /// # Arguments
127    ///
128    /// * `cx` - Optional context for the request
129    /// * `request` - The HTTP request to handle
130    /// * `options` - Optional request options
131    ///
132    /// # Returns
133    ///
134    /// A result indicating the success or failure of the operation
135    #[tracing::instrument(level = "debug", skip_all)]
136    async fn handle(
137        &self,
138        cx: Option<Context>,
139        mut request: http::Request<wrpc_interface_http::HttpBody>,
140        options: Option<RequestOptions>,
141    ) -> anyhow::Result<
142        Result<
143            http::Response<impl http_body::Body<Data = Bytes, Error = Infallible> + Send + 'static>,
144            ErrorCode,
145        >,
146    > {
147        // Extract tracing context if available
148        if let Some(ctx) = &cx {
149            if let Some(traceparent) = ctx.tracing.get("traceparent") {
150                // Add traceparent to request headers to propagate tracing context
151                request.headers_mut().insert(
152                    "traceparent",
153                    http::HeaderValue::from_str(traceparent)
154                        .map_err(|e| ErrorCode::InternalError(Some(e.to_string())))
155                        .expect("Failed to propagate trace context"),
156                );
157            }
158        }
159
160        info!(
161            method = %request.method(),
162            uri = %request.uri(),
163            "Handling outgoing HTTP request"
164        );
165
166        debug!(headers = ?request.headers(), "Request headers");
167
168        let connect_timeout = options
169            .and_then(|opts| opts.connect_timeout.map(Duration::from_nanos))
170            .unwrap_or(DEFAULT_CONNECT_TIMEOUT);
171
172        let first_byte_timeout = options
173            .and_then(|opts| opts.first_byte_timeout.map(Duration::from_nanos))
174            .unwrap_or(DEFAULT_FIRST_BYTE_TIMEOUT);
175
176        debug!(
177            ?connect_timeout,
178            ?first_byte_timeout,
179            "Request timeouts configured"
180        );
181
182        Ok(async {
183            let authority = request
184                .uri()
185                .authority()
186                .ok_or(ErrorCode::HttpRequestUriInvalid)?;
187
188            debug!(%authority, "Request authority extracted");
189
190            let use_tls = match request.uri().scheme() {
191                None => true,
192                Some(scheme) if *scheme == Scheme::HTTPS => true,
193                Some(..) => false,
194            };
195            let authority = if authority.port().is_some() {
196                authority.to_string()
197            } else {
198                let port = if use_tls { 443 } else { 80 };
199                format!("{authority}:{port}")
200            };
201
202            debug!(%authority, use_tls, "Using authority with TLS setting");
203
204            // Remove scheme and authority from request URI
205            *request.uri_mut() = http::Uri::builder()
206                .path_and_query(
207                    request
208                        .uri()
209                        .path_and_query()
210                        .map(|p| p.as_str())
211                        .unwrap_or("/"),
212                )
213                .build()
214                .map_err(|err| ErrorCode::InternalError(Some(err.to_string())))?;
215
216            // Ensure User-Agent header is set
217            request
218                .headers_mut()
219                .entry(http::header::USER_AGENT)
220                .or_insert(http::header::HeaderValue::from_static(DEFAULT_USER_AGENT));
221
222            debug!(path = %request.uri().path(), "Request URI prepared for sending");
223
224            loop {
225                let mut sender = if use_tls {
226                    debug!(%authority, "Establishing HTTPS connection");
227                    tokio::time::timeout(
228                        connect_timeout,
229                        self.conns.connect_https(&self.tls, &authority),
230                    )
231                    .await
232                } else {
233                    debug!(%authority, "Establishing HTTP connection");
234                    tokio::time::timeout(connect_timeout, self.conns.connect_http(&authority)).await
235                }
236                .map_err(|_| ErrorCode::ConnectionTimeout)??;
237
238                debug!(
239                    uri = ?request.uri(),
240                    method = %request.method(),
241                    connection_type = if use_tls { "HTTPS" } else { "HTTP" },
242                    is_cached = matches!(sender, Cacheable::Hit(..)),
243                    "Sending HTTP request"
244                );
245
246                match tokio::time::timeout(first_byte_timeout, sender.try_send_request(request))
247                    .instrument(tracing::debug_span!("http_request"))
248                    .await
249                    .map_err(|_| ErrorCode::ConnectionReadTimeout)?
250                {
251                    Err(mut err) => {
252                        let req = err.take_message();
253                        let err = err.into_error();
254                        if let Some(req) = req {
255                            if err.is_closed() && matches!(sender, Cacheable::Hit(..)) {
256                                debug!(%authority, "Cached connection closed, retrying with a different connection");
257                                request = req;
258                                continue;
259                            }
260                        }
261                        warn!(?err, %authority, "HTTP request error");
262                        return Err(hyper_request_error(err));
263                    }
264                    Ok(res) => {
265                        debug!(%authority, status = %res.status(), "HTTP response received");
266
267                        let authority = authority.into_boxed_str();
268                        let mut sender = sender.unwrap();
269                        if use_tls {
270                            let mut https = self.conns.https.write().await;
271                            sender.last_seen = std::time::Instant::now();
272                            if let Ok(conns) = https.entry(authority).or_default().get_mut() {
273                                debug!("Caching HTTPS connection for future use");
274                                conns.push_front(sender);
275                            }
276                        } else {
277                            let mut http = self.conns.http.write().await;
278                            sender.last_seen = std::time::Instant::now();
279                            if let Ok(conns) = http.entry(authority).or_default().get_mut() {
280                                debug!("Caching HTTP connection for future use");
281                                conns.push_front(sender);
282                            }
283                        }
284
285                        return Ok(res.map(|body| {
286                            let (data, trailers, mut errs) = split_outgoing_http_body(body);
287                            spawn(
288                                async move {
289                                    while let Some(err) = errs.next().await {
290                                        error!(?err, "Body error encountered");
291                                    }
292                                    trace!("Body processing finished");
293                                }
294                                .in_current_span(),
295                            );
296                            StreamBody::new(data.map(Frame::data).map(Ok)).with_trailers(async {
297                                trace!("Awaiting trailers");
298                                if let Some(trailers) = trailers.await {
299                                    trace!("Trailers received");
300                                    match try_fields_to_header_map(trailers) {
301                                        Ok(headers) => Some(Ok(headers)),
302                                        Err(err) => {
303                                            error!(?err, "Failed to parse trailers");
304                                            None
305                                        }
306                                    }
307                                } else {
308                                    trace!("No trailers received");
309                                    None
310                                }
311                            })
312                        }));
313                    }
314                }
315            }
316        }
317        .await)
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use core::net::{Ipv4Addr, SocketAddr};
324    use core::sync::atomic::{AtomicUsize, Ordering};
325
326    use std::time::Duration;
327
328    use anyhow::{ensure, Context as _};
329    use bytes::Bytes;
330    use hyper_util::rt::TokioIo;
331    use tokio::net::TcpListener;
332    use tokio::spawn;
333    use tokio::try_join;
334    use tracing::info;
335
336    use super::*;
337    use wasmcloud_core::http_client::DEFAULT_IDLE_TIMEOUT;
338    use wasmcloud_provider_sdk::core::tls::DEFAULT_RUSTLS_CONNECTOR;
339    use wrpc_interface_http::HttpBody;
340
341    const N: usize = 20;
342
343    fn new_request(addr: SocketAddr) -> http::Request<HttpBody> {
344        http::Request::builder()
345            .method(http::Method::POST)
346            .uri(format!("http://{addr}"))
347            .body(HttpBody {
348                body: Box::pin(futures::stream::empty()),
349                trailers: Box::pin(async { None }),
350            })
351            .expect("failed to construct HTTP POST request")
352    }
353
354    /// Tests connection reuse by verifying that multiple requests use the same connection
355    #[test_log::test(tokio::test(flavor = "multi_thread"))]
356    #[test_log(default_log_filter = "trace")]
357    async fn test_reuse_conn() -> anyhow::Result<()> {
358        let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await?;
359        let addr = listener.local_addr()?;
360        let requests = AtomicUsize::default();
361        try_join!(
362            async {
363                let mut conns: usize = 0;
364                while requests.load(Ordering::Relaxed) != N {
365                    info!("accepting stream...");
366                    let (stream, _) = listener
367                        .accept()
368                        .await
369                        .context("failed to accept connection")?;
370                    info!(i = conns, "serving connection...");
371                    hyper::server::conn::http1::Builder::new()
372                        .serve_connection(
373                            TokioIo::new(stream),
374                            hyper::service::service_fn(move |_| async {
375                                anyhow::Ok(http::Response::new(
376                                    http_body_util::Empty::<Bytes>::new(),
377                                ))
378                            }),
379                        )
380                        .await
381                        .context("failed to serve connection")?;
382                    info!(i = conns, "done serving connection");
383                    conns = conns.saturating_add(1);
384                }
385                let reqs = requests.load(Ordering::Relaxed);
386                info!(connections = conns, requests = reqs, "server finished");
387                ensure!(conns < reqs, "connections: {conns}, requests: {reqs}");
388                anyhow::Ok(())
389            },
390            async {
391                let provider =
392                    HttpClientProvider::new(DEFAULT_RUSTLS_CONNECTOR.clone(), DEFAULT_IDLE_TIMEOUT)
393                        .await?;
394                for i in 0..N {
395                    info!(i, "sending request...");
396                    let res =
397                        provider
398                            .handle(
399                                None,
400                                new_request(addr),
401                                Some(RequestOptions {
402                                    connect_timeout: Some(Duration::from_secs(10).as_nanos() as _),
403                                    first_byte_timeout: Some(
404                                        Duration::from_secs(10).as_nanos() as _
405                                    ),
406                                    between_bytes_timeout: Some(
407                                        Duration::from_secs(10).as_nanos() as _
408                                    ),
409                                }),
410                            )
411                            .await
412                            .with_context(|| format!("failed to invoke `handle` for request {i}"))?
413                            .with_context(|| format!("failed to handle request {i}"))?;
414                    requests.store(i.saturating_add(1), Ordering::Relaxed);
415                    info!(i, "reading response body...");
416                    let body = res.collect().await?;
417                    assert_eq!(body.to_bytes(), Bytes::default());
418                }
419                Ok(())
420            }
421        )?;
422        Ok(())
423    }
424
425    /// Tests handling of concurrent connections by verifying multiple simultaneous requests
426    #[test_log::test(tokio::test(flavor = "multi_thread"))]
427    async fn test_concurrent_conn() -> anyhow::Result<()> {
428        let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await?;
429        let addr = listener.local_addr()?;
430        let provider =
431            HttpClientProvider::new(DEFAULT_RUSTLS_CONNECTOR.clone(), DEFAULT_IDLE_TIMEOUT).await?;
432        let mut clt = JoinSet::new();
433        for i in 0..N {
434            clt.spawn({
435                let provider = provider.clone();
436                async move {
437                    info!(i, "sending request...");
438                    let res = provider
439                        .handle(None, new_request(addr), None)
440                        .await
441                        .with_context(|| format!("failed to invoke `handle` for request {i}"))?
442                        .with_context(|| format!("failed to handle request {i}"))?;
443                    info!(i, "reading response body...");
444                    let body = res.collect().await?;
445                    assert_eq!(body.to_bytes(), Bytes::default());
446                    anyhow::Ok(())
447                }
448            });
449        }
450        let mut streams = Vec::with_capacity(N);
451        for i in 0..N {
452            info!(i, "accepting stream...");
453            let (stream, _) = listener
454                .accept()
455                .await
456                .with_context(|| format!("failed to accept connection {i}"))?;
457            streams.push(stream);
458        }
459
460        let mut srv = JoinSet::new();
461        for stream in streams {
462            srv.spawn(async {
463                info!("serving connection...");
464                hyper::server::conn::http1::Builder::new()
465                    .serve_connection(
466                        TokioIo::new(stream),
467                        hyper::service::service_fn(move |_| async {
468                            anyhow::Ok(http::Response::new(http_body_util::Empty::<Bytes>::new()))
469                        }),
470                    )
471                    .await
472                    .context("failed to serve connection")
473            });
474        }
475        while let Some(res) = clt.join_next().await {
476            res??;
477        }
478        Ok(())
479    }
480
481    /// Tests error handling by verifying proper handling of HTTP error responses
482    #[test_log::test(tokio::test(flavor = "multi_thread"))]
483    async fn test_http_error_handling() -> anyhow::Result<()> {
484        let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await?;
485        let addr = listener.local_addr()?;
486        let provider =
487            HttpClientProvider::new(DEFAULT_RUSTLS_CONNECTOR.clone(), DEFAULT_IDLE_TIMEOUT).await?;
488        let request = new_request(addr);
489
490        // Spawn server that returns error responses
491        spawn(async move {
492            let (stream, _) = listener.accept().await?;
493            hyper::server::conn::http1::Builder::new()
494                .serve_connection(
495                    TokioIo::new(stream),
496                    hyper::service::service_fn(move |_| async {
497                        anyhow::Ok(
498                            http::Response::builder()
499                                .status(http::StatusCode::INTERNAL_SERVER_ERROR)
500                                .body(http_body_util::Empty::<Bytes>::new())?,
501                        )
502                    }),
503                )
504                .await?;
505            Ok::<_, anyhow::Error>(())
506        });
507
508        // Send request and verify error handling
509        let result = provider.handle(None, request, None).await?;
510        assert!(result.is_ok());
511        let response = result?;
512        assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
513
514        Ok(())
515    }
516}