wasmcloud_provider_http_server/
lib.rs1use core::future::Future;
26use core::pin::Pin;
27use core::str::FromStr as _;
28use core::task::{ready, Context, Poll};
29use core::time::Duration;
30
31use std::net::{SocketAddr, TcpListener};
32
33use anyhow::{anyhow, bail, Context as _};
34use axum::extract;
35use bytes::Bytes;
36use futures::Stream;
37use pin_project_lite::pin_project;
38use tokio::task::JoinHandle;
39use tokio::{spawn, time};
40use tower_http::cors::{self, CorsLayer};
41use tracing::{debug, info, trace};
42use wasmcloud_core::http::{load_settings, ServiceSettings};
43use wasmcloud_provider_sdk::provider::WrpcClient;
44use wasmcloud_provider_sdk::{initialize_observability, load_host_data, run_provider};
45use wrpc_interface_http::InvokeIncomingHandler as _;
46
47mod address;
48mod host;
49mod path;
50
51pub async fn run() -> anyhow::Result<()> {
52 initialize_observability!(
53 "http-server-provider",
54 std::env::var_os("PROVIDER_HTTP_SERVER_FLAMEGRAPH_PATH")
55 );
56
57 let host_data = load_host_data().context("failed to load host data")?;
58 match host_data.config.get("routing_mode").map(String::as_str) {
59 Some("address") | None => run_provider(
61 address::HttpServerProvider::new(host_data).context(
62 "failed to create address-mode HTTP server provider from hostdata configuration",
63 )?,
64 "http-server-provider",
65 )
66 .await?
67 .await,
68 Some("path") => {
70 run_provider(
71 path::HttpServerProvider::new(host_data).await.context(
72 "failed to create path-mode HTTP server provider from hostdata configuration",
73 )?,
74 "http-server-provider",
75 )
76 .await?
77 .await;
78 }
79 Some("host") => {
80 run_provider(
81 host::HttpServerProvider::new(host_data).await.context(
82 "failed to create host-mode HTTP server provider from hostdata configuration",
83 )?,
84 "http-server-provider",
85 )
86 .await?
87 .await;
88 }
89 Some(other) => bail!("unknown routing_mode: {other}"),
90 };
91
92 Ok(())
93}
94
95pub(crate) fn build_request(
97 request: extract::Request,
98 scheme: http::uri::Scheme,
99 authority: String,
100 settings: &ServiceSettings,
101) -> Result<http::Request<axum::body::Body>, Box<axum::response::ErrorResponse>> {
102 let method = request.method();
103 if let Some(readonly_mode) = settings.readonly_mode {
104 if readonly_mode
105 && method != http::method::Method::GET
106 && method != http::method::Method::HEAD
107 {
108 debug!("only GET and HEAD allowed in read-only mode");
109 Err(axum::response::ErrorResponse::from((
110 http::StatusCode::METHOD_NOT_ALLOWED,
111 "only GET and HEAD allowed in read-only mode",
112 )))?;
113 }
114 }
115 let (
116 http::request::Parts {
117 method,
118 uri,
119 headers,
120 ..
121 },
122 body,
123 ) = request.into_parts();
124 let http::uri::Parts { path_and_query, .. } = uri.into_parts();
125
126 let mut uri = http::Uri::builder().scheme(scheme);
127 if !authority.is_empty() {
128 uri = uri.authority(authority);
129 }
130 if let Some(path_and_query) = path_and_query {
131 uri = uri.path_and_query(path_and_query);
132 }
133 let uri = uri.build().map_err(|err| {
134 axum::response::ErrorResponse::from((
135 http::StatusCode::INTERNAL_SERVER_ERROR,
136 err.to_string(),
137 ))
138 })?;
139 let mut req = http::Request::builder();
140 *req.headers_mut().ok_or_else(|| {
141 axum::response::ErrorResponse::from((
142 http::StatusCode::INTERNAL_SERVER_ERROR,
143 "invalid request generated",
144 ))
145 })? = headers;
146 let req = req.uri(uri).method(method).body(body).map_err(|err| {
147 axum::response::ErrorResponse::from((
148 http::StatusCode::INTERNAL_SERVER_ERROR,
149 err.to_string(),
150 ))
151 })?;
152
153 Ok(req)
154}
155
156pub(crate) async fn invoke_component(
158 wrpc: &WrpcClient,
159 target: &str,
160 req: http::Request<axum::body::Body>,
161 timeout: Option<Duration>,
162 cache_control: Option<&String>,
163) -> impl axum::response::IntoResponse {
164 let mut cx = async_nats::HeaderMap::new();
166 for (k, v) in
167 wasmcloud_provider_sdk::wasmcloud_tracing::context::TraceContextInjector::new_with_extractor(
168 &wasmcloud_provider_sdk::wasmcloud_tracing::http::HeaderExtractor(req.headers()),
169 )
170 .iter()
171 {
172 cx.insert(k.as_str(), v.as_str());
173 }
174
175 trace!(?req, component_id = target, "httpserver calling component");
176 let fut = wrpc.invoke_handle_http(Some(cx), req);
177 let res = if let Some(timeout) = timeout {
178 let Ok(res) = time::timeout(timeout, fut).await else {
179 Err(http::StatusCode::REQUEST_TIMEOUT)?
180 };
181 res
182 } else {
183 fut.await
184 };
185 let (res, errors, io) =
186 res.map_err(|err| (http::StatusCode::INTERNAL_SERVER_ERROR, format!("{err:#}")))?;
187 let io = io.map(spawn);
188 let errors: Box<dyn Stream<Item = _> + Send + Unpin> = Box::new(errors);
189 let mut res =
191 res.map_err(|err| (http::StatusCode::INTERNAL_SERVER_ERROR, format!("{err:?}")))?;
192 if let Some(cache_control) = cache_control {
193 let cache_control = http::HeaderValue::from_str(cache_control)
194 .map_err(|err| (http::StatusCode::INTERNAL_SERVER_ERROR, err.to_string()))?;
195 res.headers_mut().append("Cache-Control", cache_control);
196 };
197 axum::response::Result::<_, axum::response::ErrorResponse>::Ok(res.map(|body| ResponseBody {
198 body,
199 errors,
200 io,
201 }))
202}
203
204pub(crate) fn get_cors_layer(settings: &ServiceSettings) -> anyhow::Result<CorsLayer> {
206 let allow_origin = settings.cors_allowed_origins.as_ref();
207 let allow_origin: Vec<_> = allow_origin
208 .map(|origins| {
209 origins
210 .iter()
211 .map(AsRef::as_ref)
212 .map(http::HeaderValue::from_str)
213 .collect::<Result<_, _>>()
214 .context("failed to parse allowed origins")
215 })
216 .transpose()?
217 .unwrap_or_default();
218 let allow_origin = if allow_origin.is_empty() {
219 cors::AllowOrigin::any()
220 } else {
221 cors::AllowOrigin::list(allow_origin)
222 };
223 let allow_headers = settings.cors_allowed_headers.as_ref();
224 let allow_headers: Vec<_> = allow_headers
225 .map(|headers| {
226 headers
227 .iter()
228 .map(AsRef::as_ref)
229 .map(http::HeaderName::from_str)
230 .collect::<Result<_, _>>()
231 .context("failed to parse allowed header names")
232 })
233 .transpose()?
234 .unwrap_or_default();
235 let allow_headers = if allow_headers.is_empty() {
236 cors::AllowHeaders::any()
237 } else {
238 cors::AllowHeaders::list(allow_headers)
239 };
240 let allow_methods = settings.cors_allowed_methods.as_ref();
241 let allow_methods: Vec<_> = allow_methods
242 .map(|methods| {
243 methods
244 .iter()
245 .map(AsRef::as_ref)
246 .map(http::Method::from_str)
247 .collect::<Result<_, _>>()
248 .context("failed to parse allowed methods")
249 })
250 .transpose()?
251 .unwrap_or_default();
252 let allow_methods = if allow_methods.is_empty() {
253 cors::AllowMethods::any()
254 } else {
255 cors::AllowMethods::list(allow_methods)
256 };
257 let expose_headers = settings.cors_exposed_headers.as_ref();
258 let expose_headers: Vec<_> = expose_headers
259 .map(|headers| {
260 headers
261 .iter()
262 .map(AsRef::as_ref)
263 .map(http::HeaderName::from_str)
264 .collect::<Result<_, _>>()
265 .context("failed to parse exposeed header names")
266 })
267 .transpose()?
268 .unwrap_or_default();
269 let expose_headers = if expose_headers.is_empty() {
270 cors::ExposeHeaders::any()
271 } else {
272 cors::ExposeHeaders::list(expose_headers)
273 };
274 let mut cors = CorsLayer::new()
275 .allow_origin(allow_origin)
276 .allow_headers(allow_headers)
277 .allow_methods(allow_methods)
278 .expose_headers(expose_headers);
279 if let Some(max_age) = settings.cors_max_age_secs {
280 cors = cors.max_age(Duration::from_secs(max_age));
281 }
282
283 Ok(cors)
284}
285
286pub(crate) fn get_tcp_listener(settings: &ServiceSettings) -> anyhow::Result<TcpListener> {
291 let socket = match &settings.address {
292 SocketAddr::V4(_) => tokio::net::TcpSocket::new_v4(),
293 SocketAddr::V6(_) => tokio::net::TcpSocket::new_v6(),
294 }
295 .context("Unable to open socket")?;
296 socket
301 .set_reuseaddr(!cfg!(windows))
302 .context("Error when setting socket to reuseaddr")?;
303 socket
304 .set_nodelay(true)
305 .context("failed to set `TCP_NODELAY`")?;
306
307 match settings.disable_keepalive {
308 Some(false) => {
309 info!("disabling TCP keepalive");
310 socket
311 .set_keepalive(false)
312 .context("failed to disable TCP keepalive")?
313 }
314 None | Some(true) => socket
315 .set_keepalive(true)
316 .context("failed to enable TCP keepalive")?,
317 }
318
319 socket
320 .bind(settings.address)
321 .context("Unable to bind to address")?;
322 let listener = socket.listen(1024).context("unable to listen on socket")?;
323 let listener = listener.into_std().context("Unable to get listener")?;
324
325 Ok(listener)
326}
327
328pin_project! {
329 struct ResponseBody {
330 #[pin]
331 body: wrpc_interface_http::HttpBody,
332 #[pin]
333 errors: Box<dyn Stream<Item = wrpc_interface_http::HttpBodyError<axum::Error>> + Send + Unpin>,
334 #[pin]
335 io: Option<JoinHandle<anyhow::Result<()>>>,
336 }
337}
338
339impl http_body::Body for ResponseBody {
340 type Data = Bytes;
341 type Error = anyhow::Error;
342
343 fn poll_frame(
344 mut self: Pin<&mut Self>,
345 cx: &mut Context<'_>,
346 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
347 let mut this = self.as_mut().project();
348 if let Some(io) = this.io.as_mut().as_pin_mut() {
349 match io.poll(cx) {
350 Poll::Ready(Ok(Ok(()))) => {
351 this.io.take();
352 }
353 Poll::Ready(Ok(Err(err))) => {
354 return Poll::Ready(Some(Err(
355 anyhow!(err).context("failed to complete async I/O")
356 )))
357 }
358 Poll::Ready(Err(err)) => {
359 return Poll::Ready(Some(Err(anyhow!(err).context("I/O task failed"))))
360 }
361 Poll::Pending => {}
362 }
363 }
364 match this.errors.poll_next(cx) {
365 Poll::Ready(Some(err)) => {
366 if let Some(io) = this.io.as_pin_mut() {
367 io.abort();
368 }
369 return Poll::Ready(Some(Err(anyhow!(err).context("failed to process body"))));
370 }
371 Poll::Ready(None) | Poll::Pending => {}
372 }
373 match ready!(this.body.poll_frame(cx)) {
374 Some(Ok(frame)) => Poll::Ready(Some(Ok(frame))),
375 Some(Err(err)) => {
376 if let Some(io) = this.io.as_pin_mut() {
377 io.abort();
378 }
379 Poll::Ready(Some(Err(err)))
380 }
381 None => {
382 if let Some(io) = this.io.as_pin_mut() {
383 io.abort();
384 }
385 Poll::Ready(None)
386 }
387 }
388 }
389}
390
391#[cfg(test)]
392mod test {
393 use std::collections::HashMap;
394
395 use anyhow::Result;
396 use futures::StreamExt;
397 use wasmcloud_provider_sdk::{
398 provider::initialize_host_data, run_provider, HostData, InterfaceLinkDefinition,
399 };
400 use wasmcloud_test_util::testcontainers::{AsyncRunner, NatsServer};
401
402 use crate::{address, path};
403
404 #[ignore]
407 #[tokio::test]
408 async fn can_listen_and_invoke_with_timeout() -> Result<()> {
409 let nats_container = NatsServer::default()
410 .start()
411 .await
412 .expect("failed to start nats-server container");
413 let nats_port = nats_container
414 .get_host_port_ipv4(4222)
415 .await
416 .expect("should be able to find the NATS port");
417 let nats_address = format!("nats://127.0.0.1:{nats_port}");
418
419 let default_address = "0.0.0.0:8080";
420 let host_data = HostData {
421 lattice_rpc_url: nats_address.clone(),
422 lattice_rpc_prefix: "lattice".to_string(),
423 provider_key: "http-server-provider-test".to_string(),
424 config: std::collections::HashMap::from([
425 ("default_address".to_string(), default_address.to_string()),
426 ("routing_mode".to_string(), "address".to_string()),
427 ]),
428 link_definitions: vec![InterfaceLinkDefinition {
429 source_id: "http-server-provider-test".to_string(),
430 target: "test-component".to_string(),
431 name: "default".to_string(),
432 wit_namespace: "wasi".to_string(),
433 wit_package: "http".to_string(),
434 interfaces: vec!["incoming-handler".to_string()],
435 source_config: std::collections::HashMap::from([(
436 "timeout_ms".to_string(),
437 "100".to_string(),
438 )]),
439 target_config: HashMap::new(),
440 source_secrets: None,
441 target_secrets: None,
442 }],
443 ..Default::default()
444 };
445 initialize_host_data(host_data.clone()).expect("should be able to initialize host data");
446
447 let provider = run_provider(
448 address::HttpServerProvider::new(&host_data)
449 .expect("should be able to create provider"),
450 "http-server-provider-test",
451 )
452 .await
453 .expect("should be able to run provider");
454
455 let conn = async_nats::connect(nats_address)
457 .await
458 .expect("should be able to connect");
459 let mut subscriber = conn
460 .subscribe("lattice.test-component.wrpc.>")
461 .await
462 .expect("should be able to subscribe");
463
464 let provider_handle = tokio::spawn(provider);
465
466 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
468 let resp = reqwest::get("http://127.0.0.1:8080")
469 .await
470 .expect("should be able to make request");
471
472 assert_eq!(resp.status(), 408);
474 let msg = subscriber
476 .next()
477 .await
478 .expect("should be able to get a message");
479 assert!(msg.subject.contains("test-component"));
480 provider_handle.abort();
481 let _ = nats_container.stop().await;
482
483 Ok(())
484 }
485
486 #[ignore]
489 #[tokio::test]
490 async fn can_support_path_based_routing() -> Result<()> {
491 let nats_container = NatsServer::default()
492 .start()
493 .await
494 .expect("failed to start nats-server container");
495 let nats_port = nats_container
496 .get_host_port_ipv4(4222)
497 .await
498 .expect("should be able to find the NATS port");
499 let nats_address = format!("nats://127.0.0.1:{nats_port}");
500
501 let default_address = "0.0.0.0:8081";
502 let host_data = HostData {
503 lattice_rpc_url: nats_address.clone(),
504 lattice_rpc_prefix: "lattice".to_string(),
505 provider_key: "http-server-provider-test".to_string(),
506 config: std::collections::HashMap::from([
507 ("default_address".to_string(), default_address.to_string()),
508 ("routing_mode".to_string(), "path".to_string()),
509 ("timeout_ms".to_string(), "100".to_string()),
510 ]),
511 link_definitions: vec![
512 InterfaceLinkDefinition {
513 source_id: "http-server-provider-test".to_string(),
514 target: "test-component-one".to_string(),
515 name: "default".to_string(),
516 wit_namespace: "wasi".to_string(),
517 wit_package: "http".to_string(),
518 interfaces: vec!["incoming-handler".to_string()],
519 source_config: std::collections::HashMap::from([(
520 "path".to_string(),
521 "/foo".to_string(),
522 )]),
523 target_config: HashMap::new(),
524 source_secrets: None,
525 target_secrets: None,
526 },
527 InterfaceLinkDefinition {
528 source_id: "http-server-provider-test".to_string(),
529 target: "test-component-two".to_string(),
530 name: "default".to_string(),
531 wit_namespace: "wasi".to_string(),
532 wit_package: "http".to_string(),
533 interfaces: vec!["incoming-handler".to_string()],
534 source_config: std::collections::HashMap::from([(
535 "path".to_string(),
536 "/bar".to_string(),
537 )]),
538 target_config: HashMap::new(),
539 source_secrets: None,
540 target_secrets: None,
541 },
542 ],
543 ..Default::default()
544 };
545 initialize_host_data(host_data.clone()).expect("should be able to initialize host data");
546
547 let provider = run_provider(
548 path::HttpServerProvider::new(&host_data)
549 .await
550 .expect("should be able to create provider"),
551 "http-server-provider-test",
552 )
553 .await
554 .expect("should be able to run provider");
555
556 let conn = async_nats::connect(nats_address)
558 .await
559 .expect("should be able to connect");
560 let mut subscriber_one = conn
561 .subscribe("lattice.test-component-one.wrpc.>")
562 .await
563 .expect("should be able to subscribe");
564 let mut subscriber_two = conn
565 .subscribe("lattice.test-component-two.wrpc.>")
566 .await
567 .expect("should be able to subscribe");
568
569 let provider_handle = tokio::spawn(provider);
570 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
572
573 let resp = reqwest::get("http://127.0.0.1:8081/foo")
575 .await
576 .expect("should be able to make request");
577 assert_eq!(resp.status(), 408);
579 let msg = subscriber_one
580 .next()
581 .await
582 .expect("should be able to get a message");
583 assert!(msg.subject.contains("test-component-one"));
584
585 let resp = reqwest::get("http://127.0.0.1:8081/bar")
587 .await
588 .expect("should be able to make request");
589 assert_eq!(resp.status(), 408);
591 let msg = subscriber_two
592 .next()
593 .await
594 .expect("should be able to get a message");
595 assert!(msg.subject.contains("test-component-two"));
596
597 let resp = reqwest::get("http://127.0.0.1:8081/bar?someparam=foo")
599 .await
600 .expect("should be able to make request");
601 assert_eq!(resp.status(), 408);
603 let msg = subscriber_two
604 .next()
605 .await
606 .expect("should be able to get a message");
607 assert!(msg.subject.contains("test-component-two"));
608
609 let resp = reqwest::get("http://127.0.0.1:8081/some/other/route/idk")
611 .await
612 .expect("should be able to make request");
613 assert_eq!(resp.status(), 404);
614
615 assert!(
618 tokio::time::timeout(tokio::time::Duration::from_secs(1), subscriber_one.next())
619 .await
620 .is_err(),
621 );
622 assert!(
623 tokio::time::timeout(tokio::time::Duration::from_secs(1), subscriber_two.next())
624 .await
625 .is_err(),
626 );
627
628 provider_handle.abort();
629 let _ = nats_container.stop().await;
630
631 Ok(())
632 }
633}