wasmcloud_host/wasmbus/providers/http_server/
host.rs

1use core::net::SocketAddr;
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use anyhow::{bail, Context as _, Result};
7use http::header::HOST;
8use http::uri::Scheme;
9use http::Uri;
10use http_body_util::combinators::BoxBody;
11use http_body_util::BodyExt as _;
12use tokio::time::Instant;
13use tokio::{sync::RwLock, task::JoinSet};
14use tracing::{debug, error, info_span, instrument, trace_span, warn, Instrument as _, Span};
15use wasmcloud_provider_sdk::{LinkConfig, LinkDeleteInfo};
16use wasmcloud_tracing::KeyValue;
17use wasmtime_wasi_http::bindings::http::types::ErrorCode;
18use wrpc_interface_http::ServeIncomingHandlerWasmtime as _;
19
20use crate::wasmbus::{Component, InvocationContext};
21
22use super::listen;
23
24/// This struct holds both the forward and reverse mappings for host-based routing
25/// so that they can be modified by just acquiring a single lock in the [`HttpServerProvider`]
26#[derive(Default)]
27pub(crate) struct Router {
28    /// Lookup from a host to the component ID that is handling that host
29    hosts: HashMap<Arc<str>, Arc<str>>,
30    /// Reverse lookup to find the host for a (component,link_name) pair
31    components: HashMap<(Arc<str>, Arc<str>), Arc<str>>,
32    /// Header to match for host-based routing
33    header: String,
34}
35
36pub(crate) struct Provider {
37    /// Handle to the server task. The use of the [`JoinSet`] allows for the server to be
38    /// gracefully shutdown when the provider is shutdown
39    #[allow(unused)]
40    pub(crate) handle: JoinSet<()>,
41    /// Struct that holds the routing information based on host/component_id
42    pub(crate) host_router: Arc<RwLock<Router>>,
43}
44
45// Implementations of put and delete link are done in the `impl Provider` block to aid in testing
46impl wasmcloud_provider_sdk::Provider for Provider {
47    #[instrument(level = "debug", skip_all)]
48    async fn receive_link_config_as_source(&self, link: LinkConfig<'_>) -> Result<()> {
49        self.put_link(link.target_id, link.link_name, link.config)
50            .await
51    }
52
53    #[instrument(level = "debug", skip_all)]
54    async fn delete_link_as_source(&self, info: impl LinkDeleteInfo) -> Result<()> {
55        self.delete_link(
56            info.get_source_id(),
57            info.get_target_id(),
58            info.get_link_name(),
59        )
60        .await
61    }
62}
63
64impl Provider {
65    #[instrument(level = "debug", skip(self))]
66    async fn put_link(
67        &self,
68        target_id: &str,
69        link_name: &str,
70        config: &HashMap<String, String>,
71    ) -> Result<()> {
72        let Some(host) = config.get("host") else {
73            error!(
74                ?config,
75                ?target_id,
76                "host not found in link config, cannot register host"
77            );
78            bail!("host not found in link config, cannot register host for component {target_id}");
79        };
80
81        let target = Arc::from(target_id);
82        let name = Arc::from(link_name);
83        let key = (Arc::clone(&target), Arc::clone(&name));
84
85        let mut router = self.host_router.write().await;
86        if router.components.contains_key(&key) {
87            // Ensure the current host doesn't differ for the given component
88            if router
89                .components
90                .get(&key)
91                .map(|val| **val != *host)
92                .unwrap_or(false)
93            {
94                // When we can return errors from links, tell the host this was invalid
95                bail!("Component {target_id} already has a host registered with link name {name}");
96            }
97        }
98        if router.hosts.contains_key(host.as_str()) {
99            // Ensure the current component doesn't differ for the given host
100            if router
101                .hosts
102                .get(host.as_str())
103                .map(|val| *val != target)
104                .unwrap_or(false)
105            {
106                // When we can return errors from links, tell the host this was invalid
107                bail!("Host {host} already in use by a different component");
108            }
109        }
110
111        let host = Arc::from(host.clone());
112        // Insert the host into the hosts map for future lookups
113        router.components.insert(key, Arc::clone(&host));
114        router.hosts.insert(host, target);
115
116        Ok(())
117    }
118
119    #[instrument(level = "debug", skip(self))]
120    async fn delete_link(&self, source_id: &str, target_id: &str, link_name: &str) -> Result<()> {
121        debug!(
122            source = source_id,
123            target = target_id,
124            link = link_name,
125            "deleting http host link"
126        );
127
128        let mut router = self.host_router.write().await;
129        let host = router
130            .components
131            .remove(&(Arc::from(target_id), Arc::from(link_name)));
132        if let Some(host) = host {
133            router.hosts.remove(&host);
134        }
135
136        Ok(())
137    }
138}
139
140impl Provider {
141    pub(crate) async fn new(
142        address: SocketAddr,
143        components: Arc<RwLock<HashMap<String, Arc<Component>>>>,
144        lattice_id: Arc<str>,
145        host_id: Arc<str>,
146        host_header: Option<String>,
147    ) -> Result<Self> {
148        let host_router = Arc::new(RwLock::new(Router {
149            hosts: HashMap::new(),
150            components: HashMap::new(),
151            header: host_header.unwrap_or_else(|| HOST.to_string()),
152        }));
153        let handle = listen(address, {
154            let host_router = Arc::clone(&host_router);
155            move |req: hyper::Request<hyper::body::Incoming>| {
156                let lattice_id = Arc::clone(&lattice_id);
157                let host_id = Arc::clone(&host_id);
158                let components = Arc::clone(&components);
159                let host_router = Arc::clone(&host_router);
160                async move {
161                    let (
162                        http::request::Parts {
163                            method,
164                            uri,
165                            headers,
166                            ..
167                        },
168                        body,
169                    ) = req.into_parts();
170                    let http::uri::Parts {
171                        scheme,
172                        authority,
173                        path_and_query,
174                        ..
175                    } = uri.into_parts();
176
177                    let Some(host_header) = headers.get(host_router.read().await.header.as_str())
178                    else {
179                        warn!("received request with no host header");
180                        return build_bad_request_error("missing host header");
181                    };
182
183                    let Ok(lookup_host) = host_header.to_str() else {
184                        warn!("received request with invalid host header");
185                        return build_bad_request_error("invalid host header");
186                    };
187
188                    // TODO(#3705): Propagate trace context from headers
189                    let mut uri = Uri::builder().scheme(scheme.unwrap_or(Scheme::HTTP));
190                    let component = {
191                        let component_id = {
192                            let router = host_router.read().await;
193                            let Some(component_id) = router.hosts.get(lookup_host) else {
194                                warn!(host = lookup_host, "received request for unregistered host");
195                                return http::Response::builder()
196                                    .status(404)
197                                    .body(wasmtime_wasi_http::body::HyperOutgoingBody::new(
198                                        BoxBody::new(
199                                            http_body_util::Empty::new()
200                                                .map_err(|_| ErrorCode::InternalError(None)),
201                                        ),
202                                    ))
203                                    .context("failed to construct missing host error response");
204                            };
205                            component_id.to_string()
206                        };
207
208                        let components = components.read().await;
209                        let component = components
210                            .get(&component_id)
211                            .context("linked component not found")?;
212                        Arc::clone(component)
213                    };
214
215                    if let Some(path_and_query) = path_and_query {
216                        uri = uri.path_and_query(path_and_query);
217                    }
218
219                    if let Some(authority) = authority {
220                        uri = uri.authority(authority);
221                    } else if let Some(authority) = headers.get("X-Forwarded-Host") {
222                        uri = uri.authority(authority.as_bytes());
223                    } else if let Some(authority) = headers.get(HOST) {
224                        uri = uri.authority(authority.as_bytes());
225                    }
226
227                    let uri = uri.build().context("invalid URI")?;
228                    let mut req = http::Request::builder().method(method);
229                    *req.headers_mut().expect("headers missing") = headers;
230                    let req = req
231                        .uri(uri)
232                        .body(
233                            body.map_err(wasmtime_wasi_http::hyper_response_error)
234                                .boxed(),
235                        )
236                        .context("invalid request")?;
237                    let _permit = component
238                        .permits
239                        .acquire()
240                        .instrument(trace_span!("acquire_permit"))
241                        .await
242                        .context("failed to acquire execution permit")?;
243                    let res = component
244                        .instantiate(component.handler.copy_for_new(), component.events.clone())
245                        .handle(
246                            InvocationContext {
247                                span: Span::current(),
248                                start_at: Instant::now(),
249                                attributes: vec![
250                                    KeyValue::new(
251                                        "component.ref",
252                                        Arc::clone(&component.image_reference),
253                                    ),
254                                    KeyValue::new("lattice", Arc::clone(&lattice_id)),
255                                    KeyValue::new("host", Arc::clone(&host_id)),
256                                ],
257                            },
258                            req,
259                        )
260                        .await?;
261                    let res = res?;
262                    Ok(res)
263                }
264                .instrument(info_span!("handle"))
265            }
266        })
267        .await
268        .context("failed to listen on address for host based http server")?;
269
270        Ok(Provider {
271            handle,
272            host_router,
273        })
274    }
275}
276
277/// Build a bad request error
278fn build_bad_request_error(
279    message: &str,
280) -> Result<http::Response<wasmtime_wasi_http::body::HyperOutgoingBody>> {
281    http::Response::builder()
282        .status(http::StatusCode::BAD_REQUEST)
283        .body(wasmtime_wasi_http::body::HyperOutgoingBody::new(
284            BoxBody::new(
285                http_body_util::Full::new(bytes::Bytes::copy_from_slice(message.as_bytes()))
286                    .map_err(|_| ErrorCode::InternalError(None)),
287            ),
288        ))
289        .with_context(|| format!("failed to construct host error response: {message}"))
290}
291
292#[cfg(test)]
293mod test {
294    use std::{collections::HashMap, sync::Arc};
295
296    use anyhow::Context as _;
297    use tokio::task::JoinSet;
298
299    /// Ensure we can register and deregister a bunch of hosts properly
300    #[tokio::test]
301    async fn can_manage_hosts() -> anyhow::Result<()> {
302        let provider = super::Provider {
303            handle: JoinSet::new(),
304            host_router: Arc::default(),
305        };
306
307        // Put host registrations:
308        // foo.com -> foo
309        // bar.com -> bar
310        // baz.com -> baz
311        provider
312            .put_link(
313                "foo",
314                "default",
315                &HashMap::from([("host".to_string(), "foo.com".to_string())]),
316            )
317            .await
318            .context("should register foo host")?;
319        provider
320            .put_link(
321                "bar",
322                "default",
323                &HashMap::from([("host".to_string(), "bar.com".to_string())]),
324            )
325            .await
326            .context("should register bar host")?;
327        provider
328            .put_link(
329                "baz",
330                "default",
331                &HashMap::from([("host".to_string(), "baz.com".to_string())]),
332            )
333            .await
334            .context("should register baz host")?;
335
336        {
337            let router = provider.host_router.read().await;
338            assert_eq!(router.hosts.len(), 3);
339            assert_eq!(router.components.len(), 3);
340            assert!(router
341                .hosts
342                .get("foo.com")
343                .is_some_and(|target| &target.to_string() == "foo"));
344            assert!(router
345                .components
346                .get(&(Arc::from("foo"), Arc::from("default")))
347                .is_some_and(|h| &h.to_string() == "foo.com"));
348            assert!(router
349                .hosts
350                .get("bar.com")
351                .is_some_and(|target| &target.to_string() == "bar"));
352            assert!(router
353                .components
354                .get(&(Arc::from("bar"), Arc::from("default")))
355                .is_some_and(|h| &h.to_string() == "bar.com"));
356            assert!(router
357                .hosts
358                .get("baz.com")
359                .is_some_and(|target| &target.to_string() == "baz"));
360            assert!(router
361                .components
362                .get(&(Arc::from("baz"), Arc::from("default")))
363                .is_some_and(|h| &h.to_string() == "baz.com"));
364        }
365
366        // Rejecting reserved hosts / linked components
367        assert!(
368            provider
369                .put_link(
370                    "notbaz",
371                    "default",
372                    &HashMap::from([("host".to_string(), "baz.com".to_string())]),
373                )
374                .await
375                .is_err(),
376            "should fail to register a host that's already registered"
377        );
378        assert!(
379            provider
380                .put_link(
381                    "baz",
382                    "default",
383                    &HashMap::from([("host".to_string(), "notbaz.com".to_string())]),
384                )
385                .await
386                .is_err(),
387            "should fail to register a host to a component that already has a host"
388        );
389
390        // Delete host registrations
391        provider
392            .delete_link("builtin", "foo", "default")
393            .await
394            .context("should delete link")?;
395        provider
396            .delete_link("builtin", "bar", "default")
397            .await
398            .context("should delete link")?;
399        provider
400            .delete_link("builtin", "baz", "default")
401            .await
402            .context("should delete link")?;
403        {
404            let router = provider.host_router.read().await;
405            assert!(router.hosts.is_empty());
406            assert!(router.components.is_empty());
407        }
408
409        Ok(())
410    }
411}