wasmcloud_provider_http_server/
host.rs

1//! This module contains the implementation of the `wrpc:http/incoming-handler` provider in host-based mode.
2//!
3//! In host-based mode, the HTTP server listens on a single address and routes requests to different components
4//! based on the host of the request.
5
6use core::time::Duration;
7
8use std::collections::HashMap;
9use std::net::SocketAddr;
10use std::str::FromStr;
11use std::sync::Arc;
12
13use anyhow::{bail, Context as _};
14use axum::extract;
15use axum::handler::Handler;
16use axum_server::tls_rustls::RustlsConfig;
17use axum_server::Handle;
18use tokio::sync::RwLock;
19use tokio::task::JoinHandle;
20use tracing::{debug, error, info, instrument};
21use wasmcloud_provider_sdk::provider::WrpcClient;
22use wasmcloud_provider_sdk::{get_connection, HostData, LinkConfig, LinkDeleteInfo, Provider};
23
24use crate::{
25    build_request, get_cors_layer, get_tcp_listener, invoke_component, load_settings,
26    ServiceSettings,
27};
28
29/// This struct holds both the forward and reverse mappings for host-based routing
30/// so that they can be modified by just acquiring a single lock in the [`HttpServerProvider`]
31#[derive(Default)]
32struct Router {
33    /// Lookup from a host to the component ID that is handling that host
34    hosts: HashMap<Arc<str>, (Arc<str>, WrpcClient)>,
35    /// Reverse lookup to find the host for a (component,link_name) pair
36    components: HashMap<(Arc<str>, Arc<str>), Arc<str>>,
37    /// Header to match for host-based routing
38    header: String,
39}
40
41/// `wrpc:http/incoming-handler` provider implementation with host-based routing
42#[derive(Clone)]
43pub struct HttpServerProvider {
44    /// Struct that holds the routing information based on host/component_id
45    router: Arc<RwLock<Router>>,
46    /// [`Handle`] to the server task
47    handle: Handle,
48    /// Task handle for the server task
49    task: Arc<JoinHandle<()>>,
50}
51
52impl Drop for HttpServerProvider {
53    fn drop(&mut self) {
54        self.handle.shutdown();
55        self.task.abort();
56    }
57}
58
59impl HttpServerProvider {
60    pub(crate) async fn new(host_data: &HostData) -> anyhow::Result<Self> {
61        let default_address = host_data
62            .config
63            .get("default_address")
64            .map(|s| SocketAddr::from_str(s))
65            .transpose()
66            .context("failed to parse default_address")?;
67
68        let header = host_data
69            .config
70            .get("header")
71            .map(String::as_str)
72            .unwrap_or("host")
73            .to_lowercase();
74
75        let settings = load_settings(default_address, &host_data.config)
76            .context("failed to load settings in host mode")?;
77        let settings = Arc::new(settings);
78
79        let router = Arc::new(RwLock::new(Router {
80            header: header.to_string(),
81            ..Default::default()
82        }));
83
84        let addr = settings.address;
85        info!(
86            %addr,
87            "httpserver starting listener in host-based mode",
88        );
89        let cors = get_cors_layer(&settings)?;
90        let listener = get_tcp_listener(&settings)?;
91        let service = handle_request.layer(cors);
92
93        let handle = axum_server::Handle::new();
94        let task_handle = handle.clone();
95        let task_router = Arc::clone(&router);
96        let task = if let (Some(crt), Some(key)) =
97            (&settings.tls_cert_file, &settings.tls_priv_key_file)
98        {
99            debug!(?addr, "bind HTTPS listener");
100            let tls = RustlsConfig::from_pem_file(crt, key)
101                .await
102                .context("failed to construct TLS config")?;
103
104            tokio::spawn(async move {
105                if let Err(e) = axum_server::from_tcp_rustls(listener, tls)
106                    .handle(task_handle)
107                    .serve(
108                        service
109                            .with_state(RequestContext {
110                                router: task_router,
111                                scheme: http::uri::Scheme::HTTPS,
112                                settings: Arc::clone(&settings),
113                            })
114                            .into_make_service(),
115                    )
116                    .await
117                {
118                    error!(error = %e, "failed to serve HTTPS for host-based mode");
119                }
120            })
121        } else {
122            debug!(?addr, "bind HTTP listener");
123
124            tokio::spawn(async move {
125                if let Err(e) = axum_server::from_tcp(listener)
126                    .handle(task_handle)
127                    .serve(
128                        service
129                            .with_state(RequestContext {
130                                router: task_router,
131                                scheme: http::uri::Scheme::HTTP,
132                                settings: Arc::clone(&settings),
133                            })
134                            .into_make_service(),
135                    )
136                    .await
137                {
138                    error!(error = %e, "failed to serve HTTP for host-based mode");
139                }
140            })
141        };
142
143        Ok(Self {
144            router,
145            handle,
146            task: Arc::new(task),
147        })
148    }
149}
150
151impl Provider for HttpServerProvider {
152    /// This is called when the HTTP server provider is linked to a component
153    ///
154    /// This HTTP server mode will register the host in the link for routing to the target
155    /// component when a request is received on the listen address.
156    async fn receive_link_config_as_source(
157        &self,
158        link_config: LinkConfig<'_>,
159    ) -> anyhow::Result<()> {
160        let Some(host) = link_config.config.get("host") else {
161            error!(?link_config.config, ?link_config.target_id, "host not found in link config, cannot register host");
162            bail!(
163                "host not found in link config, cannot register host for component {}",
164                link_config.target_id
165            );
166        };
167
168        let target = Arc::from(link_config.target_id);
169        let name = Arc::from(link_config.link_name);
170
171        let key = (Arc::clone(&target), Arc::clone(&name));
172
173        let mut router = self.router.write().await;
174        if router.components.contains_key(&key) {
175            // When we can return errors from links, tell the host this was invalid
176            bail!("Component {target} already has a host registered with link name {name}");
177        }
178        if router.hosts.contains_key(host.as_str()) {
179            // When we can return errors from links, tell the host this was invalid
180            bail!("Host {host} already in use by a different component");
181        }
182
183        let wrpc = get_connection()
184            .get_wrpc_client(link_config.target_id)
185            .await
186            .context("failed to construct wRPC client")?;
187
188        let host = Arc::from(host.clone());
189        // Insert the host into the hosts map for future lookups
190        router.components.insert(key, Arc::clone(&host));
191        router.hosts.insert(host, (target, wrpc));
192
193        Ok(())
194    }
195
196    /// Remove the host for a particular component/link_name pair
197    #[instrument(level = "debug", skip_all, fields(target_id = info.get_target_id()))]
198    async fn delete_link_as_source(&self, info: impl LinkDeleteInfo) -> anyhow::Result<()> {
199        debug!(
200            source = info.get_source_id(),
201            target = info.get_target_id(),
202            link = info.get_link_name(),
203            "deleting http host link"
204        );
205        let component_id = info.get_target_id();
206        let link_name = info.get_link_name();
207
208        let mut router = self.router.write().await;
209        let host = router
210            .components
211            .remove(&(Arc::from(component_id), Arc::from(link_name)));
212        if let Some(host) = host {
213            router.hosts.remove(&host);
214        }
215
216        Ok(())
217    }
218
219    /// Handle shutdown request by shutting down the http server task
220    async fn shutdown(&self) -> anyhow::Result<()> {
221        self.handle.shutdown();
222        self.task.abort();
223
224        Ok(())
225    }
226}
227
228#[derive(Clone)]
229struct RequestContext {
230    router: Arc<RwLock<Router>>,
231    scheme: http::uri::Scheme,
232    settings: Arc<ServiceSettings>,
233}
234
235/// Handle an HTTP request by looking up the component ID for the host and invoking the component
236#[instrument(level = "debug", skip(router, settings))]
237async fn handle_request(
238    extract::State(RequestContext {
239        router,
240        scheme,
241        settings,
242    }): extract::State<RequestContext>,
243    axum_extra::extract::Host(authority): axum_extra::extract::Host,
244    request: extract::Request,
245) -> impl axum::response::IntoResponse {
246    let timeout = settings.timeout_ms.map(Duration::from_millis);
247    let req = build_request(request, scheme, authority, &settings).map_err(|err| *err)?;
248
249    let Some(host_header) = req.headers().get(router.read().await.header.as_str()) else {
250        Err((http::StatusCode::BAD_REQUEST, "missing host header"))?
251    };
252
253    let lookup_host = host_header
254        .to_str()
255        .map_err(|_| (http::StatusCode::BAD_REQUEST, "invalid host header"))?;
256
257    let Some((target_component, wrpc)) = router.read().await.hosts.get(lookup_host).cloned() else {
258        Err((http::StatusCode::NOT_FOUND, "host not found"))?
259    };
260
261    axum::response::Result::<_, axum::response::ErrorResponse>::Ok(
262        invoke_component(
263            &wrpc,
264            &target_component,
265            req,
266            timeout,
267            settings.cache_control.as_ref(),
268        )
269        .await,
270    )
271}