wasmcloud_host/wasmbus/providers/http_server/
host.rs1use core::net::SocketAddr;
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use anyhow::{bail, Context as _, Result};
7use http::header::HOST;
8use http::uri::Scheme;
9use http::Uri;
10use http_body_util::combinators::BoxBody;
11use http_body_util::BodyExt as _;
12use tokio::time::Instant;
13use tokio::{sync::RwLock, task::JoinSet};
14use tracing::{debug, error, info_span, instrument, trace_span, warn, Instrument as _, Span};
15use wasmcloud_provider_sdk::{LinkConfig, LinkDeleteInfo};
16use wasmcloud_tracing::KeyValue;
17use wasmtime_wasi_http::bindings::http::types::ErrorCode;
18use wrpc_interface_http::ServeIncomingHandlerWasmtime as _;
19
20use crate::wasmbus::{Component, InvocationContext};
21
22use super::listen;
23
24#[derive(Default)]
27pub(crate) struct Router {
28 hosts: HashMap<Arc<str>, Arc<str>>,
30 components: HashMap<(Arc<str>, Arc<str>), Arc<str>>,
32 header: String,
34}
35
36pub(crate) struct Provider {
37 #[allow(unused)]
40 pub(crate) handle: JoinSet<()>,
41 pub(crate) host_router: Arc<RwLock<Router>>,
43}
44
45impl wasmcloud_provider_sdk::Provider for Provider {
47 #[instrument(level = "debug", skip_all)]
48 async fn receive_link_config_as_source(&self, link: LinkConfig<'_>) -> Result<()> {
49 self.put_link(link.target_id, link.link_name, link.config)
50 .await
51 }
52
53 #[instrument(level = "debug", skip_all)]
54 async fn delete_link_as_source(&self, info: impl LinkDeleteInfo) -> Result<()> {
55 self.delete_link(
56 info.get_source_id(),
57 info.get_target_id(),
58 info.get_link_name(),
59 )
60 .await
61 }
62}
63
64impl Provider {
65 #[instrument(level = "debug", skip(self))]
66 async fn put_link(
67 &self,
68 target_id: &str,
69 link_name: &str,
70 config: &HashMap<String, String>,
71 ) -> Result<()> {
72 let Some(host) = config.get("host") else {
73 error!(
74 ?config,
75 ?target_id,
76 "host not found in link config, cannot register host"
77 );
78 bail!("host not found in link config, cannot register host for component {target_id}");
79 };
80
81 let target = Arc::from(target_id);
82 let name = Arc::from(link_name);
83 let key = (Arc::clone(&target), Arc::clone(&name));
84
85 let mut router = self.host_router.write().await;
86 if router.components.contains_key(&key) {
87 if router
89 .components
90 .get(&key)
91 .map(|val| **val != *host)
92 .unwrap_or(false)
93 {
94 bail!("Component {target_id} already has a host registered with link name {name}");
96 }
97 }
98 if router.hosts.contains_key(host.as_str()) {
99 if router
101 .hosts
102 .get(host.as_str())
103 .map(|val| *val != target)
104 .unwrap_or(false)
105 {
106 bail!("Host {host} already in use by a different component");
108 }
109 }
110
111 let host = Arc::from(host.clone());
112 router.components.insert(key, Arc::clone(&host));
114 router.hosts.insert(host, target);
115
116 Ok(())
117 }
118
119 #[instrument(level = "debug", skip(self))]
120 async fn delete_link(&self, source_id: &str, target_id: &str, link_name: &str) -> Result<()> {
121 debug!(
122 source = source_id,
123 target = target_id,
124 link = link_name,
125 "deleting http host link"
126 );
127
128 let mut router = self.host_router.write().await;
129 let host = router
130 .components
131 .remove(&(Arc::from(target_id), Arc::from(link_name)));
132 if let Some(host) = host {
133 router.hosts.remove(&host);
134 }
135
136 Ok(())
137 }
138}
139
140impl Provider {
141 pub(crate) async fn new(
142 address: SocketAddr,
143 components: Arc<RwLock<HashMap<String, Arc<Component>>>>,
144 lattice_id: Arc<str>,
145 host_id: Arc<str>,
146 host_header: Option<String>,
147 ) -> Result<Self> {
148 let host_router = Arc::new(RwLock::new(Router {
149 hosts: HashMap::new(),
150 components: HashMap::new(),
151 header: host_header.unwrap_or_else(|| HOST.to_string()),
152 }));
153 let handle = listen(address, {
154 let host_router = Arc::clone(&host_router);
155 move |req: hyper::Request<hyper::body::Incoming>| {
156 let lattice_id = Arc::clone(&lattice_id);
157 let host_id = Arc::clone(&host_id);
158 let components = Arc::clone(&components);
159 let host_router = Arc::clone(&host_router);
160 async move {
161 let (
162 http::request::Parts {
163 method,
164 uri,
165 headers,
166 ..
167 },
168 body,
169 ) = req.into_parts();
170 let http::uri::Parts {
171 scheme,
172 authority,
173 path_and_query,
174 ..
175 } = uri.into_parts();
176
177 let Some(host_header) = headers.get(host_router.read().await.header.as_str())
178 else {
179 warn!("received request with no host header");
180 return build_bad_request_error("missing host header");
181 };
182
183 let Ok(lookup_host) = host_header.to_str() else {
184 warn!("received request with invalid host header");
185 return build_bad_request_error("invalid host header");
186 };
187
188 let mut uri = Uri::builder().scheme(scheme.unwrap_or(Scheme::HTTP));
190 let component = {
191 let component_id = {
192 let router = host_router.read().await;
193 let Some(component_id) = router.hosts.get(lookup_host) else {
194 warn!(host = lookup_host, "received request for unregistered host");
195 return http::Response::builder()
196 .status(404)
197 .body(wasmtime_wasi_http::body::HyperOutgoingBody::new(
198 BoxBody::new(
199 http_body_util::Empty::new()
200 .map_err(|_| ErrorCode::InternalError(None)),
201 ),
202 ))
203 .context("failed to construct missing host error response");
204 };
205 component_id.to_string()
206 };
207
208 let components = components.read().await;
209 let component = components
210 .get(&component_id)
211 .context("linked component not found")?;
212 Arc::clone(component)
213 };
214
215 if let Some(path_and_query) = path_and_query {
216 uri = uri.path_and_query(path_and_query);
217 }
218
219 if let Some(authority) = authority {
220 uri = uri.authority(authority);
221 } else if let Some(authority) = headers.get("X-Forwarded-Host") {
222 uri = uri.authority(authority.as_bytes());
223 } else if let Some(authority) = headers.get(HOST) {
224 uri = uri.authority(authority.as_bytes());
225 }
226
227 let uri = uri.build().context("invalid URI")?;
228 let mut req = http::Request::builder().method(method);
229 *req.headers_mut().expect("headers missing") = headers;
230 let req = req
231 .uri(uri)
232 .body(
233 body.map_err(wasmtime_wasi_http::hyper_response_error)
234 .boxed(),
235 )
236 .context("invalid request")?;
237 let _permit = component
238 .permits
239 .acquire()
240 .instrument(trace_span!("acquire_permit"))
241 .await
242 .context("failed to acquire execution permit")?;
243 let res = component
244 .instantiate(component.handler.copy_for_new(), component.events.clone())
245 .handle(
246 InvocationContext {
247 span: Span::current(),
248 start_at: Instant::now(),
249 attributes: vec![
250 KeyValue::new(
251 "component.ref",
252 Arc::clone(&component.image_reference),
253 ),
254 KeyValue::new("lattice", Arc::clone(&lattice_id)),
255 KeyValue::new("host", Arc::clone(&host_id)),
256 ],
257 },
258 req,
259 )
260 .await?;
261 let res = res?;
262 Ok(res)
263 }
264 .instrument(info_span!("handle"))
265 }
266 })
267 .await
268 .context("failed to listen on address for host based http server")?;
269
270 Ok(Provider {
271 handle,
272 host_router,
273 })
274 }
275}
276
277fn build_bad_request_error(
279 message: &str,
280) -> Result<http::Response<wasmtime_wasi_http::body::HyperOutgoingBody>> {
281 http::Response::builder()
282 .status(http::StatusCode::BAD_REQUEST)
283 .body(wasmtime_wasi_http::body::HyperOutgoingBody::new(
284 BoxBody::new(
285 http_body_util::Full::new(bytes::Bytes::copy_from_slice(message.as_bytes()))
286 .map_err(|_| ErrorCode::InternalError(None)),
287 ),
288 ))
289 .with_context(|| format!("failed to construct host error response: {message}"))
290}
291
292#[cfg(test)]
293mod test {
294 use std::{collections::HashMap, sync::Arc};
295
296 use anyhow::Context as _;
297 use tokio::task::JoinSet;
298
299 #[tokio::test]
301 async fn can_manage_hosts() -> anyhow::Result<()> {
302 let provider = super::Provider {
303 handle: JoinSet::new(),
304 host_router: Arc::default(),
305 };
306
307 provider
312 .put_link(
313 "foo",
314 "default",
315 &HashMap::from([("host".to_string(), "foo.com".to_string())]),
316 )
317 .await
318 .context("should register foo host")?;
319 provider
320 .put_link(
321 "bar",
322 "default",
323 &HashMap::from([("host".to_string(), "bar.com".to_string())]),
324 )
325 .await
326 .context("should register bar host")?;
327 provider
328 .put_link(
329 "baz",
330 "default",
331 &HashMap::from([("host".to_string(), "baz.com".to_string())]),
332 )
333 .await
334 .context("should register baz host")?;
335
336 {
337 let router = provider.host_router.read().await;
338 assert_eq!(router.hosts.len(), 3);
339 assert_eq!(router.components.len(), 3);
340 assert!(router
341 .hosts
342 .get("foo.com")
343 .is_some_and(|target| &target.to_string() == "foo"));
344 assert!(router
345 .components
346 .get(&(Arc::from("foo"), Arc::from("default")))
347 .is_some_and(|h| &h.to_string() == "foo.com"));
348 assert!(router
349 .hosts
350 .get("bar.com")
351 .is_some_and(|target| &target.to_string() == "bar"));
352 assert!(router
353 .components
354 .get(&(Arc::from("bar"), Arc::from("default")))
355 .is_some_and(|h| &h.to_string() == "bar.com"));
356 assert!(router
357 .hosts
358 .get("baz.com")
359 .is_some_and(|target| &target.to_string() == "baz"));
360 assert!(router
361 .components
362 .get(&(Arc::from("baz"), Arc::from("default")))
363 .is_some_and(|h| &h.to_string() == "baz.com"));
364 }
365
366 assert!(
368 provider
369 .put_link(
370 "notbaz",
371 "default",
372 &HashMap::from([("host".to_string(), "baz.com".to_string())]),
373 )
374 .await
375 .is_err(),
376 "should fail to register a host that's already registered"
377 );
378 assert!(
379 provider
380 .put_link(
381 "baz",
382 "default",
383 &HashMap::from([("host".to_string(), "notbaz.com".to_string())]),
384 )
385 .await
386 .is_err(),
387 "should fail to register a host to a component that already has a host"
388 );
389
390 provider
392 .delete_link("builtin", "foo", "default")
393 .await
394 .context("should delete link")?;
395 provider
396 .delete_link("builtin", "bar", "default")
397 .await
398 .context("should delete link")?;
399 provider
400 .delete_link("builtin", "baz", "default")
401 .await
402 .context("should delete link")?;
403 {
404 let router = provider.host_router.read().await;
405 assert!(router.hosts.is_empty());
406 assert!(router.components.is_empty());
407 }
408
409 Ok(())
410 }
411}