wasmcloud_provider_http_server/
lib.rs

1//! The httpserver capability provider allows wasmcloud components to receive
2//! and process http(s) messages from web browsers, command-line tools
3//! such as curl, and other http clients. The server is fully asynchronous,
4//! and built on Rust's high-performance axum library, which is in turn based
5//! on hyper, and can process a large number of simultaneous connections.
6//!
7//! ## Features:
8//!
9//! - HTTP/1 and HTTP/2
10//! - TLS
11//! - CORS support (select `allowed_origins`, `allowed_methods`,
12//!   `allowed_headers`.) Cors has sensible defaults so it should
13//!   work as-is for development purposes, and may need refinement
14//!   for production if a more secure configuration is required.
15//! - All settings can be specified at runtime, using per-component link settings:
16//!   - bind path/address
17//!   - TLS
18//!   - Cors
19//! - Flexible configuration loading: from host, or from local toml or json file.
20//! - Fully asynchronous, using tokio lightweight "green" threads
21//! - Thread pool (for managing a pool of OS threads). The default
22//!   thread pool has one thread per cpu core.
23//!
24
25use core::future::Future;
26use core::pin::Pin;
27use core::str::FromStr as _;
28use core::task::{ready, Context, Poll};
29use core::time::Duration;
30
31use std::net::{SocketAddr, TcpListener};
32
33use anyhow::{anyhow, bail, Context as _};
34use axum::extract;
35use bytes::Bytes;
36use futures::Stream;
37use pin_project_lite::pin_project;
38use tokio::task::JoinHandle;
39use tokio::{spawn, time};
40use tower_http::cors::{self, CorsLayer};
41use tracing::{debug, info, trace};
42use wasmcloud_core::http::{load_settings, ServiceSettings};
43use wasmcloud_provider_sdk::provider::WrpcClient;
44use wasmcloud_provider_sdk::{initialize_observability, load_host_data, run_provider};
45use wrpc_interface_http::InvokeIncomingHandler as _;
46
47mod address;
48mod host;
49mod path;
50
51pub async fn run() -> anyhow::Result<()> {
52    initialize_observability!(
53        "http-server-provider",
54        std::env::var_os("PROVIDER_HTTP_SERVER_FLAMEGRAPH_PATH")
55    );
56
57    let host_data = load_host_data().context("failed to load host data")?;
58    match host_data.config.get("routing_mode").map(String::as_str) {
59        // Run provider in address mode by default
60        Some("address") | None => run_provider(
61            address::HttpServerProvider::new(host_data).context(
62                "failed to create address-mode HTTP server provider from hostdata configuration",
63            )?,
64            "http-server-provider",
65        )
66        .await?
67        .await,
68        // Run provider in path mode
69        Some("path") => {
70            run_provider(
71                path::HttpServerProvider::new(host_data).await.context(
72                    "failed to create path-mode HTTP server provider from hostdata configuration",
73                )?,
74                "http-server-provider",
75            )
76            .await?
77            .await;
78        }
79        Some("host") => {
80            run_provider(
81                host::HttpServerProvider::new(host_data).await.context(
82                    "failed to create host-mode HTTP server provider from hostdata configuration",
83                )?,
84                "http-server-provider",
85            )
86            .await?
87            .await;
88        }
89        Some(other) => bail!("unknown routing_mode: {other}"),
90    };
91
92    Ok(())
93}
94
95/// Build a request to send to the component from the incoming request
96pub(crate) fn build_request(
97    request: extract::Request,
98    scheme: http::uri::Scheme,
99    authority: String,
100    settings: &ServiceSettings,
101) -> Result<http::Request<axum::body::Body>, Box<axum::response::ErrorResponse>> {
102    let method = request.method();
103    if let Some(readonly_mode) = settings.readonly_mode {
104        if readonly_mode
105            && method != http::method::Method::GET
106            && method != http::method::Method::HEAD
107        {
108            debug!("only GET and HEAD allowed in read-only mode");
109            Err(axum::response::ErrorResponse::from((
110                http::StatusCode::METHOD_NOT_ALLOWED,
111                "only GET and HEAD allowed in read-only mode",
112            )))?;
113        }
114    }
115    let (
116        http::request::Parts {
117            method,
118            uri,
119            headers,
120            ..
121        },
122        body,
123    ) = request.into_parts();
124    let http::uri::Parts { path_and_query, .. } = uri.into_parts();
125
126    let mut uri = http::Uri::builder().scheme(scheme);
127    if !authority.is_empty() {
128        uri = uri.authority(authority);
129    }
130    if let Some(path_and_query) = path_and_query {
131        uri = uri.path_and_query(path_and_query);
132    }
133    let uri = uri.build().map_err(|err| {
134        axum::response::ErrorResponse::from((
135            http::StatusCode::INTERNAL_SERVER_ERROR,
136            err.to_string(),
137        ))
138    })?;
139    let mut req = http::Request::builder();
140    *req.headers_mut().ok_or_else(|| {
141        axum::response::ErrorResponse::from((
142            http::StatusCode::INTERNAL_SERVER_ERROR,
143            "invalid request generated",
144        ))
145    })? = headers;
146    let req = req.uri(uri).method(method).body(body).map_err(|err| {
147        axum::response::ErrorResponse::from((
148            http::StatusCode::INTERNAL_SERVER_ERROR,
149            err.to_string(),
150        ))
151    })?;
152
153    Ok(req)
154}
155
156/// Invoke a component with the given request
157pub(crate) async fn invoke_component(
158    wrpc: &WrpcClient,
159    target: &str,
160    req: http::Request<axum::body::Body>,
161    timeout: Option<Duration>,
162    cache_control: Option<&String>,
163) -> impl axum::response::IntoResponse {
164    // Create a new wRPC client with all headers from the current span injected
165    let mut cx = async_nats::HeaderMap::new();
166    for (k, v) in
167        wasmcloud_provider_sdk::wasmcloud_tracing::context::TraceContextInjector::new_with_extractor(
168            &wasmcloud_provider_sdk::wasmcloud_tracing::http::HeaderExtractor(req.headers()),
169        )
170        .iter()
171    {
172        cx.insert(k.as_str(), v.as_str());
173    }
174
175    trace!(?req, component_id = target, "httpserver calling component");
176    let fut = wrpc.invoke_handle_http(Some(cx), req);
177    let res = if let Some(timeout) = timeout {
178        let Ok(res) = time::timeout(timeout, fut).await else {
179            Err(http::StatusCode::REQUEST_TIMEOUT)?
180        };
181        res
182    } else {
183        fut.await
184    };
185    let (res, errors, io) =
186        res.map_err(|err| (http::StatusCode::INTERNAL_SERVER_ERROR, format!("{err:#}")))?;
187    let io = io.map(spawn);
188    let errors: Box<dyn Stream<Item = _> + Send + Unpin> = Box::new(errors);
189    // TODO: Convert this to http status code
190    let mut res =
191        res.map_err(|err| (http::StatusCode::INTERNAL_SERVER_ERROR, format!("{err:?}")))?;
192    if let Some(cache_control) = cache_control {
193        let cache_control = http::HeaderValue::from_str(cache_control)
194            .map_err(|err| (http::StatusCode::INTERNAL_SERVER_ERROR, err.to_string()))?;
195        res.headers_mut().append("Cache-Control", cache_control);
196    };
197    axum::response::Result::<_, axum::response::ErrorResponse>::Ok(res.map(|body| ResponseBody {
198        body,
199        errors,
200        io,
201    }))
202}
203
204/// Helper function to construct a [`CorsLayer`] according to the [`ServiceSettings`].
205pub(crate) fn get_cors_layer(settings: &ServiceSettings) -> anyhow::Result<CorsLayer> {
206    let allow_origin = settings.cors_allowed_origins.as_ref();
207    let allow_origin: Vec<_> = allow_origin
208        .map(|origins| {
209            origins
210                .iter()
211                .map(AsRef::as_ref)
212                .map(http::HeaderValue::from_str)
213                .collect::<Result<_, _>>()
214                .context("failed to parse allowed origins")
215        })
216        .transpose()?
217        .unwrap_or_default();
218    let allow_origin = if allow_origin.is_empty() {
219        cors::AllowOrigin::any()
220    } else {
221        cors::AllowOrigin::list(allow_origin)
222    };
223    let allow_headers = settings.cors_allowed_headers.as_ref();
224    let allow_headers: Vec<_> = allow_headers
225        .map(|headers| {
226            headers
227                .iter()
228                .map(AsRef::as_ref)
229                .map(http::HeaderName::from_str)
230                .collect::<Result<_, _>>()
231                .context("failed to parse allowed header names")
232        })
233        .transpose()?
234        .unwrap_or_default();
235    let allow_headers = if allow_headers.is_empty() {
236        cors::AllowHeaders::any()
237    } else {
238        cors::AllowHeaders::list(allow_headers)
239    };
240    let allow_methods = settings.cors_allowed_methods.as_ref();
241    let allow_methods: Vec<_> = allow_methods
242        .map(|methods| {
243            methods
244                .iter()
245                .map(AsRef::as_ref)
246                .map(http::Method::from_str)
247                .collect::<Result<_, _>>()
248                .context("failed to parse allowed methods")
249        })
250        .transpose()?
251        .unwrap_or_default();
252    let allow_methods = if allow_methods.is_empty() {
253        cors::AllowMethods::any()
254    } else {
255        cors::AllowMethods::list(allow_methods)
256    };
257    let expose_headers = settings.cors_exposed_headers.as_ref();
258    let expose_headers: Vec<_> = expose_headers
259        .map(|headers| {
260            headers
261                .iter()
262                .map(AsRef::as_ref)
263                .map(http::HeaderName::from_str)
264                .collect::<Result<_, _>>()
265                .context("failed to parse exposeed header names")
266        })
267        .transpose()?
268        .unwrap_or_default();
269    let expose_headers = if expose_headers.is_empty() {
270        cors::ExposeHeaders::any()
271    } else {
272        cors::ExposeHeaders::list(expose_headers)
273    };
274    let mut cors = CorsLayer::new()
275        .allow_origin(allow_origin)
276        .allow_headers(allow_headers)
277        .allow_methods(allow_methods)
278        .expose_headers(expose_headers);
279    if let Some(max_age) = settings.cors_max_age_secs {
280        cors = cors.max_age(Duration::from_secs(max_age));
281    }
282
283    Ok(cors)
284}
285
286/// Helper function to create and listen on a [`TcpListener`] from the given [`ServiceSettings`].
287///
288/// Note that this function actually calls the `bind` method on the [`TcpSocket`], it's up to the
289/// caller to ensure that the address is not already in use (or to handle the error if it is).
290pub(crate) fn get_tcp_listener(settings: &ServiceSettings) -> anyhow::Result<TcpListener> {
291    let socket = match &settings.address {
292        SocketAddr::V4(_) => tokio::net::TcpSocket::new_v4(),
293        SocketAddr::V6(_) => tokio::net::TcpSocket::new_v6(),
294    }
295    .context("Unable to open socket")?;
296    // Copied this option from
297    // https://github.com/bytecodealliance/wasmtime/blob/05095c18680927ce0cf6c7b468f9569ec4d11bd7/src/commands/serve.rs#L319.
298    // This does increase throughput by 10-15% which is why we're creating the socket. We're
299    // using the tokio one because it exposes the `reuseaddr` option.
300    socket
301        .set_reuseaddr(!cfg!(windows))
302        .context("Error when setting socket to reuseaddr")?;
303    socket
304        .set_nodelay(true)
305        .context("failed to set `TCP_NODELAY`")?;
306
307    match settings.disable_keepalive {
308        Some(false) => {
309            info!("disabling TCP keepalive");
310            socket
311                .set_keepalive(false)
312                .context("failed to disable TCP keepalive")?
313        }
314        None | Some(true) => socket
315            .set_keepalive(true)
316            .context("failed to enable TCP keepalive")?,
317    }
318
319    socket
320        .bind(settings.address)
321        .context("Unable to bind to address")?;
322    let listener = socket.listen(1024).context("unable to listen on socket")?;
323    let listener = listener.into_std().context("Unable to get listener")?;
324
325    Ok(listener)
326}
327
328pin_project! {
329    struct ResponseBody {
330        #[pin]
331        body: wrpc_interface_http::HttpBody,
332        #[pin]
333        errors: Box<dyn Stream<Item = wrpc_interface_http::HttpBodyError<axum::Error>> + Send + Unpin>,
334        #[pin]
335        io: Option<JoinHandle<anyhow::Result<()>>>,
336    }
337}
338
339impl http_body::Body for ResponseBody {
340    type Data = Bytes;
341    type Error = anyhow::Error;
342
343    fn poll_frame(
344        mut self: Pin<&mut Self>,
345        cx: &mut Context<'_>,
346    ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
347        let mut this = self.as_mut().project();
348        if let Some(io) = this.io.as_mut().as_pin_mut() {
349            match io.poll(cx) {
350                Poll::Ready(Ok(Ok(()))) => {
351                    this.io.take();
352                }
353                Poll::Ready(Ok(Err(err))) => {
354                    return Poll::Ready(Some(Err(
355                        anyhow!(err).context("failed to complete async I/O")
356                    )))
357                }
358                Poll::Ready(Err(err)) => {
359                    return Poll::Ready(Some(Err(anyhow!(err).context("I/O task failed"))))
360                }
361                Poll::Pending => {}
362            }
363        }
364        match this.errors.poll_next(cx) {
365            Poll::Ready(Some(err)) => {
366                if let Some(io) = this.io.as_pin_mut() {
367                    io.abort();
368                }
369                return Poll::Ready(Some(Err(anyhow!(err).context("failed to process body"))));
370            }
371            Poll::Ready(None) | Poll::Pending => {}
372        }
373        match ready!(this.body.poll_frame(cx)) {
374            Some(Ok(frame)) => Poll::Ready(Some(Ok(frame))),
375            Some(Err(err)) => {
376                if let Some(io) = this.io.as_pin_mut() {
377                    io.abort();
378                }
379                Poll::Ready(Some(Err(err)))
380            }
381            None => {
382                if let Some(io) = this.io.as_pin_mut() {
383                    io.abort();
384                }
385                Poll::Ready(None)
386            }
387        }
388    }
389}
390
391#[cfg(test)]
392mod test {
393    use std::collections::HashMap;
394
395    use anyhow::Result;
396    use futures::StreamExt;
397    use wasmcloud_provider_sdk::{
398        provider::initialize_host_data, run_provider, HostData, InterfaceLinkDefinition,
399    };
400    use wasmcloud_test_util::testcontainers::{AsyncRunner, NatsServer};
401
402    use crate::{address, path};
403
404    // This test is ignored by default as it requires a container runtime to be installed
405    // to run the testcontainer. In GitHub Actions CI, this is only works on `linux`
406    #[ignore]
407    #[tokio::test]
408    async fn can_listen_and_invoke_with_timeout() -> Result<()> {
409        let nats_container = NatsServer::default()
410            .start()
411            .await
412            .expect("failed to start nats-server container");
413        let nats_port = nats_container
414            .get_host_port_ipv4(4222)
415            .await
416            .expect("should be able to find the NATS port");
417        let nats_address = format!("nats://127.0.0.1:{nats_port}");
418
419        let default_address = "0.0.0.0:8080";
420        let host_data = HostData {
421            lattice_rpc_url: nats_address.clone(),
422            lattice_rpc_prefix: "lattice".to_string(),
423            provider_key: "http-server-provider-test".to_string(),
424            config: std::collections::HashMap::from([
425                ("default_address".to_string(), default_address.to_string()),
426                ("routing_mode".to_string(), "address".to_string()),
427            ]),
428            link_definitions: vec![InterfaceLinkDefinition {
429                source_id: "http-server-provider-test".to_string(),
430                target: "test-component".to_string(),
431                name: "default".to_string(),
432                wit_namespace: "wasi".to_string(),
433                wit_package: "http".to_string(),
434                interfaces: vec!["incoming-handler".to_string()],
435                source_config: std::collections::HashMap::from([(
436                    "timeout_ms".to_string(),
437                    "100".to_string(),
438                )]),
439                target_config: HashMap::new(),
440                source_secrets: None,
441                target_secrets: None,
442            }],
443            ..Default::default()
444        };
445        initialize_host_data(host_data.clone()).expect("should be able to initialize host data");
446
447        let provider = run_provider(
448            address::HttpServerProvider::new(&host_data)
449                .expect("should be able to create provider"),
450            "http-server-provider-test",
451        )
452        .await
453        .expect("should be able to run provider");
454
455        // Use a separate task to listen for the component message
456        let conn = async_nats::connect(nats_address)
457            .await
458            .expect("should be able to connect");
459        let mut subscriber = conn
460            .subscribe("lattice.test-component.wrpc.>")
461            .await
462            .expect("should be able to subscribe");
463
464        let provider_handle = tokio::spawn(provider);
465
466        // Let the provider have a second to setup the listener
467        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
468        let resp = reqwest::get("http://127.0.0.1:8080")
469            .await
470            .expect("should be able to make request");
471
472        // Should have timed out
473        assert_eq!(resp.status(), 408);
474        // Ensure component received the message
475        let msg = subscriber
476            .next()
477            .await
478            .expect("should be able to get a message");
479        assert!(msg.subject.contains("test-component"));
480        provider_handle.abort();
481        let _ = nats_container.stop().await;
482
483        Ok(())
484    }
485
486    // This test is ignored by default as it requires a container runtime to be installed
487    // to run the testcontainer. In GitHub Actions CI, this is only works on `linux`
488    #[ignore]
489    #[tokio::test]
490    async fn can_support_path_based_routing() -> Result<()> {
491        let nats_container = NatsServer::default()
492            .start()
493            .await
494            .expect("failed to start nats-server container");
495        let nats_port = nats_container
496            .get_host_port_ipv4(4222)
497            .await
498            .expect("should be able to find the NATS port");
499        let nats_address = format!("nats://127.0.0.1:{nats_port}");
500
501        let default_address = "0.0.0.0:8081";
502        let host_data = HostData {
503            lattice_rpc_url: nats_address.clone(),
504            lattice_rpc_prefix: "lattice".to_string(),
505            provider_key: "http-server-provider-test".to_string(),
506            config: std::collections::HashMap::from([
507                ("default_address".to_string(), default_address.to_string()),
508                ("routing_mode".to_string(), "path".to_string()),
509                ("timeout_ms".to_string(), "100".to_string()),
510            ]),
511            link_definitions: vec![
512                InterfaceLinkDefinition {
513                    source_id: "http-server-provider-test".to_string(),
514                    target: "test-component-one".to_string(),
515                    name: "default".to_string(),
516                    wit_namespace: "wasi".to_string(),
517                    wit_package: "http".to_string(),
518                    interfaces: vec!["incoming-handler".to_string()],
519                    source_config: std::collections::HashMap::from([(
520                        "path".to_string(),
521                        "/foo".to_string(),
522                    )]),
523                    target_config: HashMap::new(),
524                    source_secrets: None,
525                    target_secrets: None,
526                },
527                InterfaceLinkDefinition {
528                    source_id: "http-server-provider-test".to_string(),
529                    target: "test-component-two".to_string(),
530                    name: "default".to_string(),
531                    wit_namespace: "wasi".to_string(),
532                    wit_package: "http".to_string(),
533                    interfaces: vec!["incoming-handler".to_string()],
534                    source_config: std::collections::HashMap::from([(
535                        "path".to_string(),
536                        "/bar".to_string(),
537                    )]),
538                    target_config: HashMap::new(),
539                    source_secrets: None,
540                    target_secrets: None,
541                },
542            ],
543            ..Default::default()
544        };
545        initialize_host_data(host_data.clone()).expect("should be able to initialize host data");
546
547        let provider = run_provider(
548            path::HttpServerProvider::new(&host_data)
549                .await
550                .expect("should be able to create provider"),
551            "http-server-provider-test",
552        )
553        .await
554        .expect("should be able to run provider");
555
556        // Use a separate task to listen for the component message
557        let conn = async_nats::connect(nats_address)
558            .await
559            .expect("should be able to connect");
560        let mut subscriber_one = conn
561            .subscribe("lattice.test-component-one.wrpc.>")
562            .await
563            .expect("should be able to subscribe");
564        let mut subscriber_two = conn
565            .subscribe("lattice.test-component-two.wrpc.>")
566            .await
567            .expect("should be able to subscribe");
568
569        let provider_handle = tokio::spawn(provider);
570        // Let the provider have a second to setup the listeners
571        tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
572
573        // Invoke component one
574        let resp = reqwest::get("http://127.0.0.1:8081/foo")
575            .await
576            .expect("should be able to make request");
577        // Should have timed out
578        assert_eq!(resp.status(), 408);
579        let msg = subscriber_one
580            .next()
581            .await
582            .expect("should be able to get a message");
583        assert!(msg.subject.contains("test-component-one"));
584
585        // Invoke component two
586        let resp = reqwest::get("http://127.0.0.1:8081/bar")
587            .await
588            .expect("should be able to make request");
589        // Should have timed out
590        assert_eq!(resp.status(), 408);
591        let msg = subscriber_two
592            .next()
593            .await
594            .expect("should be able to get a message");
595        assert!(msg.subject.contains("test-component-two"));
596
597        // Invoke component two with a query parameter
598        let resp = reqwest::get("http://127.0.0.1:8081/bar?someparam=foo")
599            .await
600            .expect("should be able to make request");
601        // Should have timed out
602        assert_eq!(resp.status(), 408);
603        let msg = subscriber_two
604            .next()
605            .await
606            .expect("should be able to get a message");
607        assert!(msg.subject.contains("test-component-two"));
608
609        // Unknown path should return 404
610        let resp = reqwest::get("http://127.0.0.1:8081/some/other/route/idk")
611            .await
612            .expect("should be able to make request");
613        assert_eq!(resp.status(), 404);
614
615        // No other messages should have been received
616        // (the assertion is that the operation timed out)
617        assert!(
618            tokio::time::timeout(tokio::time::Duration::from_secs(1), subscriber_one.next())
619                .await
620                .is_err(),
621        );
622        assert!(
623            tokio::time::timeout(tokio::time::Duration::from_secs(1), subscriber_two.next())
624                .await
625                .is_err(),
626        );
627
628        provider_handle.abort();
629        let _ = nats_container.stop().await;
630
631        Ok(())
632    }
633}