wasmcloud_provider_http_client/
lib.rs

1use core::convert::Infallible;
2use core::pin::pin;
3use core::time::Duration;
4
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use anyhow::Context as _;
9use bytes::Bytes;
10use futures::StreamExt as _;
11use http::uri::Scheme;
12use http_body::Frame;
13use http_body_util::{BodyExt as _, StreamBody};
14use tokio::task::JoinSet;
15use tokio::time::sleep;
16use tokio::{select, spawn};
17use tracing::{debug, error, info, trace, warn, Instrument as _};
18
19use wasmcloud_provider_sdk::core::tls;
20use wasmcloud_provider_sdk::{
21    get_connection, initialize_observability, load_host_data, propagate_trace_for_ctx,
22    run_provider, Context, Provider,
23};
24use wrpc_interface_http::{
25    bindings::wrpc::http::types::{ErrorCode, RequestOptions},
26    split_outgoing_http_body, try_fields_to_header_map, ServeHttp, ServeOutgoingHandlerHttp,
27};
28
29// Import shared connection pooling infrastructure from the internal provider
30use wasmcloud_core::http_client::{
31    hyper_request_error, Cacheable, ConnPool, DEFAULT_CONNECT_TIMEOUT, DEFAULT_FIRST_BYTE_TIMEOUT,
32    DEFAULT_IDLE_TIMEOUT, DEFAULT_USER_AGENT, LOAD_NATIVE_CERTS, LOAD_WEBPKI_CERTS, SSL_CERTS_FILE,
33};
34
35/// HTTP client capability provider implementation struct
36#[derive(Clone)]
37pub struct HttpClientProvider {
38    /// TLS connector for establishing secure HTTPS connections
39    tls: tokio_rustls::TlsConnector,
40    /// Connection pools for HTTP and HTTPS connections
41    conns: ConnPool<wrpc_interface_http::HttpBody>,
42    /// Background tasks for connection management
43    #[allow(unused)]
44    tasks: Arc<JoinSet<()>>,
45}
46
47pub async fn run() -> anyhow::Result<()> {
48    info!("Starting HTTP client provider");
49    initialize_observability!(
50        "http-client-provider",
51        std::env::var_os("PROVIDER_HTTP_CLIENT_FLAMEGRAPH_PATH")
52    );
53
54    let host_data = load_host_data()?;
55    let provider = HttpClientProvider::new(&host_data.config, DEFAULT_IDLE_TIMEOUT).await?;
56
57    debug!("Initializing provider runtime");
58    let shutdown = run_provider(provider.clone(), "http-client-provider")
59        .await
60        .context("failed to run provider")?;
61
62    let connection = get_connection();
63    let wrpc = connection
64        .get_wrpc_client(connection.provider_key())
65        .await?;
66
67    debug!("Setting up wrpc interface");
68    let [(_, _, mut invocations)] =
69        wrpc_interface_http::bindings::exports::wrpc::http::outgoing_handler::serve_interface(
70            &wrpc,
71            ServeHttp(provider),
72        )
73        .await
74        .context("failed to serve exports")?;
75
76    info!("HTTP client provider ready to handle requests");
77    let mut shutdown = pin!(shutdown);
78    let mut tasks = JoinSet::new();
79
80    loop {
81        select! {
82            Some(res) = invocations.next() => {
83                match res {
84                    Ok(fut) => {
85                        tasks.spawn(async move {
86                            if let Err(err) = fut.await {
87                                warn!(?err, "failed to serve invocation");
88                            }
89                        });
90                    },
91                    Err(err) => {
92                        warn!(?err, "failed to accept invocation");
93                    }
94                }
95            },
96            () = &mut shutdown => {
97                info!("Received shutdown signal");
98                return Ok(())
99            }
100        }
101    }
102}
103
104impl HttpClientProvider {
105    pub async fn new(
106        config: &HashMap<String, String>,
107        idle_timeout: Duration,
108    ) -> anyhow::Result<Self> {
109        debug!("Creating new HTTP client provider");
110
111        // Initialize TLS configuration
112        let tls = if config.is_empty() {
113            debug!("Using default TLS connector");
114            tls::DEFAULT_RUSTLS_CONNECTOR.clone()
115        } else {
116            debug!("Configuring custom TLS connector");
117            let mut ca = rustls::RootCertStore::empty();
118
119            // Load native certificates
120            if config
121                .get(LOAD_NATIVE_CERTS)
122                .map(|v| v.eq_ignore_ascii_case("true"))
123                .unwrap_or(true)
124            {
125                let (added, ignored) =
126                    ca.add_parsable_certificates(tls::NATIVE_ROOTS.iter().cloned());
127                debug!(added, ignored, "loaded native root certificate store");
128            }
129
130            // Load Mozilla trusted root certificates
131            if config
132                .get(LOAD_WEBPKI_CERTS)
133                .map(|v| v.eq_ignore_ascii_case("true"))
134                .unwrap_or(true)
135            {
136                ca.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
137                debug!("loaded webpki root certificate store");
138            }
139
140            // Load root certificates from a file
141            if let Some(file_path) = config.get(SSL_CERTS_FILE) {
142                let f = std::fs::File::open(file_path)?;
143                let mut reader = std::io::BufReader::new(f);
144                let certs = rustls_pemfile::certs(&mut reader).collect::<Result<Vec<_>, _>>()?;
145                let (added, ignored) = ca.add_parsable_certificates(certs);
146                debug!(
147                    added,
148                    ignored, "added additional root certificates from file"
149                );
150            }
151            tokio_rustls::TlsConnector::from(Arc::new(
152                rustls::ClientConfig::builder()
153                    .with_root_certificates(ca)
154                    .with_no_client_auth(),
155            ))
156        };
157
158        // Initialize connection pool and eviction task
159        let conns = ConnPool::default();
160        let mut tasks = JoinSet::new();
161
162        debug!(
163            "Starting connection eviction task with timeout: {:?}",
164            idle_timeout
165        );
166        tasks.spawn({
167            let conns = conns.clone();
168            async move {
169                loop {
170                    sleep(idle_timeout).await;
171                    trace!("Evicting idle connections");
172                    conns.evict(idle_timeout).await;
173                }
174            }
175        });
176
177        debug!("HTTP client provider initialization complete");
178        Ok(Self {
179            tls,
180            conns,
181            tasks: Arc::new(tasks),
182        })
183    }
184}
185
186impl ServeOutgoingHandlerHttp<Option<Context>> for HttpClientProvider {
187    #[tracing::instrument(level = "debug", skip_all)]
188    async fn handle(
189        &self,
190        cx: Option<Context>,
191        mut request: http::Request<wrpc_interface_http::HttpBody>,
192        options: Option<RequestOptions>,
193    ) -> anyhow::Result<
194        Result<
195            http::Response<impl http_body::Body<Data = Bytes, Error = Infallible> + Send + 'static>,
196            ErrorCode,
197        >,
198    > {
199        info!(
200            method = %request.method(),
201            uri = %request.uri(),
202            "Handling outgoing HTTP request"
203        );
204
205        propagate_trace_for_ctx!(cx);
206        wasmcloud_provider_sdk::wasmcloud_tracing::http::HeaderInjector(request.headers_mut())
207            .inject_context();
208
209        debug!(headers = ?request.headers(), "Request headers");
210
211        let connect_timeout = options
212            .and_then(|opts| opts.connect_timeout.map(Duration::from_nanos))
213            .unwrap_or(DEFAULT_CONNECT_TIMEOUT);
214
215        let first_byte_timeout = options
216            .and_then(|opts| opts.first_byte_timeout.map(Duration::from_nanos))
217            .unwrap_or(DEFAULT_FIRST_BYTE_TIMEOUT);
218
219        debug!(
220            ?connect_timeout,
221            ?first_byte_timeout,
222            "Request timeouts configured"
223        );
224
225        Ok(async {
226            let authority = request
227                .uri()
228                .authority()
229                .ok_or(ErrorCode::HttpRequestUriInvalid)?;
230
231            debug!(%authority, "Request authority extracted");
232
233            let use_tls = match request.uri().scheme() {
234                None => true,
235                Some(scheme) if *scheme == Scheme::HTTPS => true,
236                Some(..) => false,
237            };
238            let authority = if authority.port().is_some() {
239                authority.to_string()
240            } else {
241                let port = if use_tls { 443 } else { 80 };
242                format!("{authority}:{port}")
243            };
244
245            debug!(%authority, use_tls, "Using authority with TLS setting");
246
247            // Remove scheme and authority from request URI
248            *request.uri_mut() = http::Uri::builder()
249                .path_and_query(
250                    request
251                        .uri()
252                        .path_and_query()
253                        .map(|p| p.as_str())
254                        .unwrap_or("/"),
255                )
256                .build()
257                .map_err(|err| ErrorCode::InternalError(Some(err.to_string())))?;
258
259            // Ensure User-Agent header is set
260            request
261                .headers_mut()
262                .entry(http::header::USER_AGENT)
263                .or_insert(http::header::HeaderValue::from_static(DEFAULT_USER_AGENT));
264
265            debug!(path = %request.uri().path(), "Request URI prepared for sending");
266
267            loop {
268                let mut sender = if use_tls {
269                    debug!(%authority, "Establishing HTTPS connection");
270                    tokio::time::timeout(
271                        connect_timeout,
272                        self.conns.connect_https(&self.tls, &authority),
273                    )
274                    .await
275                } else {
276                    debug!(%authority, "Establishing HTTP connection");
277                    tokio::time::timeout(connect_timeout, self.conns.connect_http(&authority)).await
278                }
279                .map_err(|_| ErrorCode::ConnectionTimeout)??;
280
281                debug!(
282                    uri = ?request.uri(),
283                    method = %request.method(),
284                    connection_type = if use_tls { "HTTPS" } else { "HTTP" },
285                    is_cached = matches!(sender, Cacheable::Hit(..)),
286                    "Sending HTTP request"
287                );
288
289                match tokio::time::timeout(first_byte_timeout, sender.try_send_request(request))
290                    .instrument(tracing::debug_span!("http_request"))
291                    .await
292                    .map_err(|_| ErrorCode::ConnectionReadTimeout)?
293                {
294                    Err(mut err) => {
295                        let req = err.take_message();
296                        let err = err.into_error();
297                        if let Some(req) = req {
298                            if err.is_closed() && matches!(sender, Cacheable::Hit(..)) {
299                                debug!(%authority, "Cached connection closed, retrying with a different connection");
300                                request = req;
301                                continue;
302                            }
303                        }
304                        warn!(?err, %authority, "HTTP request error");
305                        return Err(hyper_request_error(err));
306                    }
307                    Ok(res) => {
308                        debug!(%authority, status = %res.status(), "HTTP response received");
309
310                        let authority = authority.into_boxed_str();
311                        let mut sender = sender.unwrap();
312                        if use_tls {
313                            let mut https = self.conns.https.write().await;
314                            sender.last_seen = std::time::Instant::now();
315                            if let Ok(conns) = https.entry(authority).or_default().get_mut() {
316                                debug!("Caching HTTPS connection for future use");
317                                conns.push_front(sender);
318                            }
319                        } else {
320                            let mut http = self.conns.http.write().await;
321                            sender.last_seen = std::time::Instant::now();
322                            if let Ok(conns) = http.entry(authority).or_default().get_mut() {
323                                debug!("Caching HTTP connection for future use");
324                                conns.push_front(sender);
325                            }
326                        }
327
328                        return Ok(res.map(|body| {
329                            let (data, trailers, mut errs) = split_outgoing_http_body(body);
330                            spawn(
331                                async move {
332                                    while let Some(err) = errs.next().await {
333                                        error!(?err, "Body error encountered");
334                                    }
335                                    trace!("Body processing finished");
336                                }
337                                .in_current_span(),
338                            );
339                            StreamBody::new(data.map(Frame::data).map(Ok)).with_trailers(async {
340                                trace!("Awaiting trailers");
341                                if let Some(trailers) = trailers.await {
342                                    trace!("Trailers received");
343                                    match try_fields_to_header_map(trailers) {
344                                        Ok(headers) => Some(Ok(headers)),
345                                        Err(err) => {
346                                            error!(?err, "Failed to parse trailers");
347                                            None
348                                        }
349                                    }
350                                } else {
351                                    trace!("No trailers received");
352                                    None
353                                }
354                            })
355                        }));
356                    }
357                }
358            }
359        }
360        .await)
361    }
362}
363
364impl Provider for HttpClientProvider {}
365
366#[cfg(test)]
367mod tests {
368    use core::net::{Ipv4Addr, SocketAddr};
369    use core::sync::atomic::{AtomicUsize, Ordering};
370
371    use std::collections::HashMap;
372
373    use anyhow::{ensure, Context as _};
374    use bytes::Bytes;
375    use hyper_util::rt::TokioIo;
376    use tokio::net::TcpListener;
377    use tokio::spawn;
378    use tokio::try_join;
379    use tracing::info;
380
381    use super::*;
382    use wrpc_interface_http::HttpBody;
383
384    const N: usize = 20;
385
386    fn new_request(addr: SocketAddr) -> http::Request<HttpBody> {
387        http::Request::builder()
388            .method(http::Method::POST)
389            .uri(format!("http://{addr}"))
390            .body(HttpBody {
391                body: Box::pin(futures::stream::empty()),
392                trailers: Box::pin(async { None }),
393            })
394            .expect("failed to construct HTTP POST request")
395    }
396
397    /// Tests connection reuse by verifying that multiple requests use the same connection
398    #[test_log::test(tokio::test(flavor = "multi_thread"))]
399    #[test_log(default_log_filter = "trace")]
400    async fn test_reuse_conn() -> anyhow::Result<()> {
401        let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await?;
402        let addr = listener.local_addr()?;
403        let requests = AtomicUsize::default();
404        try_join!(
405            async {
406                let mut conns: usize = 0;
407                while requests.load(Ordering::Relaxed) != N {
408                    info!("accepting stream...");
409                    let (stream, _) = listener
410                        .accept()
411                        .await
412                        .context("failed to accept connection")?;
413                    info!(i = conns, "serving connection...");
414                    hyper::server::conn::http1::Builder::new()
415                        .serve_connection(
416                            TokioIo::new(stream),
417                            hyper::service::service_fn(move |_| async {
418                                anyhow::Ok(http::Response::new(
419                                    http_body_util::Empty::<Bytes>::new(),
420                                ))
421                            }),
422                        )
423                        .await
424                        .context("failed to serve connection")?;
425                    info!(i = conns, "done serving connection");
426                    conns = conns.saturating_add(1);
427                }
428                let reqs = requests.load(Ordering::Relaxed);
429                info!(connections = conns, requests = reqs, "server finished");
430                ensure!(conns < reqs, "connections: {conns}, requests: {reqs}");
431                anyhow::Ok(())
432            },
433            async {
434                let provider =
435                    HttpClientProvider::new(&HashMap::default(), DEFAULT_IDLE_TIMEOUT).await?;
436                for i in 0..N {
437                    info!(i, "sending request...");
438                    let res =
439                        provider
440                            .handle(
441                                None,
442                                new_request(addr),
443                                Some(RequestOptions {
444                                    connect_timeout: Some(Duration::from_secs(10).as_nanos() as _),
445                                    first_byte_timeout: Some(
446                                        Duration::from_secs(10).as_nanos() as _
447                                    ),
448                                    between_bytes_timeout: Some(
449                                        Duration::from_secs(10).as_nanos() as _
450                                    ),
451                                }),
452                            )
453                            .await
454                            .with_context(|| format!("failed to invoke `handle` for request {i}"))?
455                            .with_context(|| format!("failed to handle request {i}"))?;
456                    requests.store(i.saturating_add(1), Ordering::Relaxed);
457                    info!(i, "reading response body...");
458                    let body = res.collect().await?;
459                    assert_eq!(body.to_bytes(), Bytes::default());
460                }
461                Ok(())
462            }
463        )?;
464        Ok(())
465    }
466
467    /// Tests handling of concurrent connections by verifying multiple simultaneous requests
468    #[test_log::test(tokio::test(flavor = "multi_thread"))]
469    async fn test_concurrent_conn() -> anyhow::Result<()> {
470        let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await?;
471        let addr = listener.local_addr()?;
472        let provider = HttpClientProvider::new(&HashMap::default(), DEFAULT_IDLE_TIMEOUT).await?;
473        let mut clt = JoinSet::new();
474        for i in 0..N {
475            clt.spawn({
476                let provider = provider.clone();
477                async move {
478                    info!(i, "sending request...");
479                    let res = provider
480                        .handle(None, new_request(addr), None)
481                        .await
482                        .with_context(|| format!("failed to invoke `handle` for request {i}"))?
483                        .with_context(|| format!("failed to handle request {i}"))?;
484                    info!(i, "reading response body...");
485                    let body = res.collect().await?;
486                    assert_eq!(body.to_bytes(), Bytes::default());
487                    anyhow::Ok(())
488                }
489            });
490        }
491        let mut streams = Vec::with_capacity(N);
492        for i in 0..N {
493            info!(i, "accepting stream...");
494            let (stream, _) = listener
495                .accept()
496                .await
497                .with_context(|| format!("failed to accept connection {i}"))?;
498            streams.push(stream);
499        }
500
501        let mut srv = JoinSet::new();
502        for stream in streams {
503            srv.spawn(async {
504                info!("serving connection...");
505                hyper::server::conn::http1::Builder::new()
506                    .serve_connection(
507                        TokioIo::new(stream),
508                        hyper::service::service_fn(move |_| async {
509                            anyhow::Ok(http::Response::new(http_body_util::Empty::<Bytes>::new()))
510                        }),
511                    )
512                    .await
513                    .context("failed to serve connection")
514            });
515        }
516        while let Some(res) = clt.join_next().await {
517            res??;
518        }
519        Ok(())
520    }
521
522    /// Tests error handling by verifying proper handling of HTTP error responses
523    #[test_log::test(tokio::test(flavor = "multi_thread"))]
524    async fn test_http_error_handling() -> anyhow::Result<()> {
525        let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await?;
526        let addr = listener.local_addr()?;
527        let provider = HttpClientProvider::new(&HashMap::default(), DEFAULT_IDLE_TIMEOUT).await?;
528        let request = new_request(addr);
529
530        // Spawn server that returns error responses
531        spawn(async move {
532            let (stream, _) = listener.accept().await?;
533            hyper::server::conn::http1::Builder::new()
534                .serve_connection(
535                    TokioIo::new(stream),
536                    hyper::service::service_fn(move |_| async {
537                        anyhow::Ok(
538                            http::Response::builder()
539                                .status(http::StatusCode::INTERNAL_SERVER_ERROR)
540                                .body(http_body_util::Empty::<Bytes>::new())?,
541                        )
542                    }),
543                )
544                .await?;
545            Ok::<_, anyhow::Error>(())
546        });
547
548        // Send request and verify error handling
549        let result = provider.handle(None, request, None).await?;
550        assert!(result.is_ok());
551        let response = result?;
552        assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
553
554        Ok(())
555    }
556}