wasmcloud_provider_http_client/
lib.rs1use core::convert::Infallible;
2use core::pin::pin;
3use core::time::Duration;
4
5use std::collections::HashMap;
6use std::sync::Arc;
7
8use anyhow::Context as _;
9use bytes::Bytes;
10use futures::StreamExt as _;
11use http::uri::Scheme;
12use http_body::Frame;
13use http_body_util::{BodyExt as _, StreamBody};
14use tokio::task::JoinSet;
15use tokio::time::sleep;
16use tokio::{select, spawn};
17use tracing::{debug, error, info, trace, warn, Instrument as _};
18
19use wasmcloud_provider_sdk::core::tls;
20use wasmcloud_provider_sdk::{
21 get_connection, initialize_observability, load_host_data, propagate_trace_for_ctx,
22 run_provider, Context, Provider,
23};
24use wrpc_interface_http::{
25 bindings::wrpc::http::types::{ErrorCode, RequestOptions},
26 split_outgoing_http_body, try_fields_to_header_map, ServeHttp, ServeOutgoingHandlerHttp,
27};
28
29use wasmcloud_core::http_client::{
31 hyper_request_error, Cacheable, ConnPool, DEFAULT_CONNECT_TIMEOUT, DEFAULT_FIRST_BYTE_TIMEOUT,
32 DEFAULT_IDLE_TIMEOUT, DEFAULT_USER_AGENT, LOAD_NATIVE_CERTS, LOAD_WEBPKI_CERTS, SSL_CERTS_FILE,
33};
34
35#[derive(Clone)]
37pub struct HttpClientProvider {
38 tls: tokio_rustls::TlsConnector,
40 conns: ConnPool<wrpc_interface_http::HttpBody>,
42 #[allow(unused)]
44 tasks: Arc<JoinSet<()>>,
45}
46
47pub async fn run() -> anyhow::Result<()> {
48 info!("Starting HTTP client provider");
49 initialize_observability!(
50 "http-client-provider",
51 std::env::var_os("PROVIDER_HTTP_CLIENT_FLAMEGRAPH_PATH")
52 );
53
54 let host_data = load_host_data()?;
55 let provider = HttpClientProvider::new(&host_data.config, DEFAULT_IDLE_TIMEOUT).await?;
56
57 debug!("Initializing provider runtime");
58 let shutdown = run_provider(provider.clone(), "http-client-provider")
59 .await
60 .context("failed to run provider")?;
61
62 let connection = get_connection();
63 let wrpc = connection
64 .get_wrpc_client(connection.provider_key())
65 .await?;
66
67 debug!("Setting up wrpc interface");
68 let [(_, _, mut invocations)] =
69 wrpc_interface_http::bindings::exports::wrpc::http::outgoing_handler::serve_interface(
70 &wrpc,
71 ServeHttp(provider),
72 )
73 .await
74 .context("failed to serve exports")?;
75
76 info!("HTTP client provider ready to handle requests");
77 let mut shutdown = pin!(shutdown);
78 let mut tasks = JoinSet::new();
79
80 loop {
81 select! {
82 Some(res) = invocations.next() => {
83 match res {
84 Ok(fut) => {
85 tasks.spawn(async move {
86 if let Err(err) = fut.await {
87 warn!(?err, "failed to serve invocation");
88 }
89 });
90 },
91 Err(err) => {
92 warn!(?err, "failed to accept invocation");
93 }
94 }
95 },
96 () = &mut shutdown => {
97 info!("Received shutdown signal");
98 return Ok(())
99 }
100 }
101 }
102}
103
104impl HttpClientProvider {
105 pub async fn new(
106 config: &HashMap<String, String>,
107 idle_timeout: Duration,
108 ) -> anyhow::Result<Self> {
109 debug!("Creating new HTTP client provider");
110
111 let tls = if config.is_empty() {
113 debug!("Using default TLS connector");
114 tls::DEFAULT_RUSTLS_CONNECTOR.clone()
115 } else {
116 debug!("Configuring custom TLS connector");
117 let mut ca = rustls::RootCertStore::empty();
118
119 if config
121 .get(LOAD_NATIVE_CERTS)
122 .map(|v| v.eq_ignore_ascii_case("true"))
123 .unwrap_or(true)
124 {
125 let (added, ignored) =
126 ca.add_parsable_certificates(tls::NATIVE_ROOTS.iter().cloned());
127 debug!(added, ignored, "loaded native root certificate store");
128 }
129
130 if config
132 .get(LOAD_WEBPKI_CERTS)
133 .map(|v| v.eq_ignore_ascii_case("true"))
134 .unwrap_or(true)
135 {
136 ca.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
137 debug!("loaded webpki root certificate store");
138 }
139
140 if let Some(file_path) = config.get(SSL_CERTS_FILE) {
142 let f = std::fs::File::open(file_path)?;
143 let mut reader = std::io::BufReader::new(f);
144 let certs = rustls_pemfile::certs(&mut reader).collect::<Result<Vec<_>, _>>()?;
145 let (added, ignored) = ca.add_parsable_certificates(certs);
146 debug!(
147 added,
148 ignored, "added additional root certificates from file"
149 );
150 }
151 tokio_rustls::TlsConnector::from(Arc::new(
152 rustls::ClientConfig::builder()
153 .with_root_certificates(ca)
154 .with_no_client_auth(),
155 ))
156 };
157
158 let conns = ConnPool::default();
160 let mut tasks = JoinSet::new();
161
162 debug!(
163 "Starting connection eviction task with timeout: {:?}",
164 idle_timeout
165 );
166 tasks.spawn({
167 let conns = conns.clone();
168 async move {
169 loop {
170 sleep(idle_timeout).await;
171 trace!("Evicting idle connections");
172 conns.evict(idle_timeout).await;
173 }
174 }
175 });
176
177 debug!("HTTP client provider initialization complete");
178 Ok(Self {
179 tls,
180 conns,
181 tasks: Arc::new(tasks),
182 })
183 }
184}
185
186impl ServeOutgoingHandlerHttp<Option<Context>> for HttpClientProvider {
187 #[tracing::instrument(level = "debug", skip_all)]
188 async fn handle(
189 &self,
190 cx: Option<Context>,
191 mut request: http::Request<wrpc_interface_http::HttpBody>,
192 options: Option<RequestOptions>,
193 ) -> anyhow::Result<
194 Result<
195 http::Response<impl http_body::Body<Data = Bytes, Error = Infallible> + Send + 'static>,
196 ErrorCode,
197 >,
198 > {
199 info!(
200 method = %request.method(),
201 uri = %request.uri(),
202 "Handling outgoing HTTP request"
203 );
204
205 propagate_trace_for_ctx!(cx);
206 wasmcloud_provider_sdk::wasmcloud_tracing::http::HeaderInjector(request.headers_mut())
207 .inject_context();
208
209 debug!(headers = ?request.headers(), "Request headers");
210
211 let connect_timeout = options
212 .and_then(|opts| opts.connect_timeout.map(Duration::from_nanos))
213 .unwrap_or(DEFAULT_CONNECT_TIMEOUT);
214
215 let first_byte_timeout = options
216 .and_then(|opts| opts.first_byte_timeout.map(Duration::from_nanos))
217 .unwrap_or(DEFAULT_FIRST_BYTE_TIMEOUT);
218
219 debug!(
220 ?connect_timeout,
221 ?first_byte_timeout,
222 "Request timeouts configured"
223 );
224
225 Ok(async {
226 let authority = request
227 .uri()
228 .authority()
229 .ok_or(ErrorCode::HttpRequestUriInvalid)?;
230
231 debug!(%authority, "Request authority extracted");
232
233 let use_tls = match request.uri().scheme() {
234 None => true,
235 Some(scheme) if *scheme == Scheme::HTTPS => true,
236 Some(..) => false,
237 };
238 let authority = if authority.port().is_some() {
239 authority.to_string()
240 } else {
241 let port = if use_tls { 443 } else { 80 };
242 format!("{authority}:{port}")
243 };
244
245 debug!(%authority, use_tls, "Using authority with TLS setting");
246
247 *request.uri_mut() = http::Uri::builder()
249 .path_and_query(
250 request
251 .uri()
252 .path_and_query()
253 .map(|p| p.as_str())
254 .unwrap_or("/"),
255 )
256 .build()
257 .map_err(|err| ErrorCode::InternalError(Some(err.to_string())))?;
258
259 request
261 .headers_mut()
262 .entry(http::header::USER_AGENT)
263 .or_insert(http::header::HeaderValue::from_static(DEFAULT_USER_AGENT));
264
265 debug!(path = %request.uri().path(), "Request URI prepared for sending");
266
267 loop {
268 let mut sender = if use_tls {
269 debug!(%authority, "Establishing HTTPS connection");
270 tokio::time::timeout(
271 connect_timeout,
272 self.conns.connect_https(&self.tls, &authority),
273 )
274 .await
275 } else {
276 debug!(%authority, "Establishing HTTP connection");
277 tokio::time::timeout(connect_timeout, self.conns.connect_http(&authority)).await
278 }
279 .map_err(|_| ErrorCode::ConnectionTimeout)??;
280
281 debug!(
282 uri = ?request.uri(),
283 method = %request.method(),
284 connection_type = if use_tls { "HTTPS" } else { "HTTP" },
285 is_cached = matches!(sender, Cacheable::Hit(..)),
286 "Sending HTTP request"
287 );
288
289 match tokio::time::timeout(first_byte_timeout, sender.try_send_request(request))
290 .instrument(tracing::debug_span!("http_request"))
291 .await
292 .map_err(|_| ErrorCode::ConnectionReadTimeout)?
293 {
294 Err(mut err) => {
295 let req = err.take_message();
296 let err = err.into_error();
297 if let Some(req) = req {
298 if err.is_closed() && matches!(sender, Cacheable::Hit(..)) {
299 debug!(%authority, "Cached connection closed, retrying with a different connection");
300 request = req;
301 continue;
302 }
303 }
304 warn!(?err, %authority, "HTTP request error");
305 return Err(hyper_request_error(err));
306 }
307 Ok(res) => {
308 debug!(%authority, status = %res.status(), "HTTP response received");
309
310 let authority = authority.into_boxed_str();
311 let mut sender = sender.unwrap();
312 if use_tls {
313 let mut https = self.conns.https.write().await;
314 sender.last_seen = std::time::Instant::now();
315 if let Ok(conns) = https.entry(authority).or_default().get_mut() {
316 debug!("Caching HTTPS connection for future use");
317 conns.push_front(sender);
318 }
319 } else {
320 let mut http = self.conns.http.write().await;
321 sender.last_seen = std::time::Instant::now();
322 if let Ok(conns) = http.entry(authority).or_default().get_mut() {
323 debug!("Caching HTTP connection for future use");
324 conns.push_front(sender);
325 }
326 }
327
328 return Ok(res.map(|body| {
329 let (data, trailers, mut errs) = split_outgoing_http_body(body);
330 spawn(
331 async move {
332 while let Some(err) = errs.next().await {
333 error!(?err, "Body error encountered");
334 }
335 trace!("Body processing finished");
336 }
337 .in_current_span(),
338 );
339 StreamBody::new(data.map(Frame::data).map(Ok)).with_trailers(async {
340 trace!("Awaiting trailers");
341 if let Some(trailers) = trailers.await {
342 trace!("Trailers received");
343 match try_fields_to_header_map(trailers) {
344 Ok(headers) => Some(Ok(headers)),
345 Err(err) => {
346 error!(?err, "Failed to parse trailers");
347 None
348 }
349 }
350 } else {
351 trace!("No trailers received");
352 None
353 }
354 })
355 }));
356 }
357 }
358 }
359 }
360 .await)
361 }
362}
363
364impl Provider for HttpClientProvider {}
365
366#[cfg(test)]
367mod tests {
368 use core::net::{Ipv4Addr, SocketAddr};
369 use core::sync::atomic::{AtomicUsize, Ordering};
370
371 use std::collections::HashMap;
372
373 use anyhow::{ensure, Context as _};
374 use bytes::Bytes;
375 use hyper_util::rt::TokioIo;
376 use tokio::net::TcpListener;
377 use tokio::spawn;
378 use tokio::try_join;
379 use tracing::info;
380
381 use super::*;
382 use wrpc_interface_http::HttpBody;
383
384 const N: usize = 20;
385
386 fn new_request(addr: SocketAddr) -> http::Request<HttpBody> {
387 http::Request::builder()
388 .method(http::Method::POST)
389 .uri(format!("http://{addr}"))
390 .body(HttpBody {
391 body: Box::pin(futures::stream::empty()),
392 trailers: Box::pin(async { None }),
393 })
394 .expect("failed to construct HTTP POST request")
395 }
396
397 #[test_log::test(tokio::test(flavor = "multi_thread"))]
399 #[test_log(default_log_filter = "trace")]
400 async fn test_reuse_conn() -> anyhow::Result<()> {
401 let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await?;
402 let addr = listener.local_addr()?;
403 let requests = AtomicUsize::default();
404 try_join!(
405 async {
406 let mut conns: usize = 0;
407 while requests.load(Ordering::Relaxed) != N {
408 info!("accepting stream...");
409 let (stream, _) = listener
410 .accept()
411 .await
412 .context("failed to accept connection")?;
413 info!(i = conns, "serving connection...");
414 hyper::server::conn::http1::Builder::new()
415 .serve_connection(
416 TokioIo::new(stream),
417 hyper::service::service_fn(move |_| async {
418 anyhow::Ok(http::Response::new(
419 http_body_util::Empty::<Bytes>::new(),
420 ))
421 }),
422 )
423 .await
424 .context("failed to serve connection")?;
425 info!(i = conns, "done serving connection");
426 conns = conns.saturating_add(1);
427 }
428 let reqs = requests.load(Ordering::Relaxed);
429 info!(connections = conns, requests = reqs, "server finished");
430 ensure!(conns < reqs, "connections: {conns}, requests: {reqs}");
431 anyhow::Ok(())
432 },
433 async {
434 let provider =
435 HttpClientProvider::new(&HashMap::default(), DEFAULT_IDLE_TIMEOUT).await?;
436 for i in 0..N {
437 info!(i, "sending request...");
438 let res =
439 provider
440 .handle(
441 None,
442 new_request(addr),
443 Some(RequestOptions {
444 connect_timeout: Some(Duration::from_secs(10).as_nanos() as _),
445 first_byte_timeout: Some(
446 Duration::from_secs(10).as_nanos() as _
447 ),
448 between_bytes_timeout: Some(
449 Duration::from_secs(10).as_nanos() as _
450 ),
451 }),
452 )
453 .await
454 .with_context(|| format!("failed to invoke `handle` for request {i}"))?
455 .with_context(|| format!("failed to handle request {i}"))?;
456 requests.store(i.saturating_add(1), Ordering::Relaxed);
457 info!(i, "reading response body...");
458 let body = res.collect().await?;
459 assert_eq!(body.to_bytes(), Bytes::default());
460 }
461 Ok(())
462 }
463 )?;
464 Ok(())
465 }
466
467 #[test_log::test(tokio::test(flavor = "multi_thread"))]
469 async fn test_concurrent_conn() -> anyhow::Result<()> {
470 let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await?;
471 let addr = listener.local_addr()?;
472 let provider = HttpClientProvider::new(&HashMap::default(), DEFAULT_IDLE_TIMEOUT).await?;
473 let mut clt = JoinSet::new();
474 for i in 0..N {
475 clt.spawn({
476 let provider = provider.clone();
477 async move {
478 info!(i, "sending request...");
479 let res = provider
480 .handle(None, new_request(addr), None)
481 .await
482 .with_context(|| format!("failed to invoke `handle` for request {i}"))?
483 .with_context(|| format!("failed to handle request {i}"))?;
484 info!(i, "reading response body...");
485 let body = res.collect().await?;
486 assert_eq!(body.to_bytes(), Bytes::default());
487 anyhow::Ok(())
488 }
489 });
490 }
491 let mut streams = Vec::with_capacity(N);
492 for i in 0..N {
493 info!(i, "accepting stream...");
494 let (stream, _) = listener
495 .accept()
496 .await
497 .with_context(|| format!("failed to accept connection {i}"))?;
498 streams.push(stream);
499 }
500
501 let mut srv = JoinSet::new();
502 for stream in streams {
503 srv.spawn(async {
504 info!("serving connection...");
505 hyper::server::conn::http1::Builder::new()
506 .serve_connection(
507 TokioIo::new(stream),
508 hyper::service::service_fn(move |_| async {
509 anyhow::Ok(http::Response::new(http_body_util::Empty::<Bytes>::new()))
510 }),
511 )
512 .await
513 .context("failed to serve connection")
514 });
515 }
516 while let Some(res) = clt.join_next().await {
517 res??;
518 }
519 Ok(())
520 }
521
522 #[test_log::test(tokio::test(flavor = "multi_thread"))]
524 async fn test_http_error_handling() -> anyhow::Result<()> {
525 let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0)).await?;
526 let addr = listener.local_addr()?;
527 let provider = HttpClientProvider::new(&HashMap::default(), DEFAULT_IDLE_TIMEOUT).await?;
528 let request = new_request(addr);
529
530 spawn(async move {
532 let (stream, _) = listener.accept().await?;
533 hyper::server::conn::http1::Builder::new()
534 .serve_connection(
535 TokioIo::new(stream),
536 hyper::service::service_fn(move |_| async {
537 anyhow::Ok(
538 http::Response::builder()
539 .status(http::StatusCode::INTERNAL_SERVER_ERROR)
540 .body(http_body_util::Empty::<Bytes>::new())?,
541 )
542 }),
543 )
544 .await?;
545 Ok::<_, anyhow::Error>(())
546 });
547
548 let result = provider.handle(None, request, None).await?;
550 assert!(result.is_ok());
551 let response = result?;
552 assert_eq!(response.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
553
554 Ok(())
555 }
556}