wasmcloud_provider_http_server/
host.rs1use core::time::Duration;
7
8use std::collections::HashMap;
9use std::net::SocketAddr;
10use std::str::FromStr;
11use std::sync::Arc;
12
13use anyhow::{bail, Context as _};
14use axum::extract;
15use axum::handler::Handler;
16use axum_server::tls_rustls::RustlsConfig;
17use axum_server::Handle;
18use tokio::sync::RwLock;
19use tokio::task::JoinHandle;
20use tracing::{debug, error, info, instrument};
21use wasmcloud_provider_sdk::provider::WrpcClient;
22use wasmcloud_provider_sdk::{get_connection, HostData, LinkConfig, LinkDeleteInfo, Provider};
23
24use crate::{
25 build_request, get_cors_layer, get_tcp_listener, invoke_component, load_settings,
26 ServiceSettings,
27};
28
29#[derive(Default)]
32struct Router {
33 hosts: HashMap<Arc<str>, (Arc<str>, WrpcClient)>,
35 components: HashMap<(Arc<str>, Arc<str>), Arc<str>>,
37 header: String,
39}
40
41#[derive(Clone)]
43pub struct HttpServerProvider {
44 router: Arc<RwLock<Router>>,
46 handle: Handle,
48 task: Arc<JoinHandle<()>>,
50}
51
52impl Drop for HttpServerProvider {
53 fn drop(&mut self) {
54 self.handle.shutdown();
55 self.task.abort();
56 }
57}
58
59impl HttpServerProvider {
60 pub(crate) async fn new(host_data: &HostData) -> anyhow::Result<Self> {
61 let default_address = host_data
62 .config
63 .get("default_address")
64 .map(|s| SocketAddr::from_str(s))
65 .transpose()
66 .context("failed to parse default_address")?;
67
68 let header = host_data
69 .config
70 .get("header")
71 .map(String::as_str)
72 .unwrap_or("host")
73 .to_lowercase();
74
75 let settings = load_settings(default_address, &host_data.config)
76 .context("failed to load settings in host mode")?;
77 let settings = Arc::new(settings);
78
79 let router = Arc::new(RwLock::new(Router {
80 header: header.to_string(),
81 ..Default::default()
82 }));
83
84 let addr = settings.address;
85 info!(
86 %addr,
87 "httpserver starting listener in host-based mode",
88 );
89 let cors = get_cors_layer(&settings)?;
90 let listener = get_tcp_listener(&settings)?;
91 let service = handle_request.layer(cors);
92
93 let handle = axum_server::Handle::new();
94 let task_handle = handle.clone();
95 let task_router = Arc::clone(&router);
96 let task = if let (Some(crt), Some(key)) =
97 (&settings.tls_cert_file, &settings.tls_priv_key_file)
98 {
99 debug!(?addr, "bind HTTPS listener");
100 let tls = RustlsConfig::from_pem_file(crt, key)
101 .await
102 .context("failed to construct TLS config")?;
103
104 tokio::spawn(async move {
105 if let Err(e) = axum_server::from_tcp_rustls(listener, tls)
106 .handle(task_handle)
107 .serve(
108 service
109 .with_state(RequestContext {
110 router: task_router,
111 scheme: http::uri::Scheme::HTTPS,
112 settings: Arc::clone(&settings),
113 })
114 .into_make_service(),
115 )
116 .await
117 {
118 error!(error = %e, "failed to serve HTTPS for host-based mode");
119 }
120 })
121 } else {
122 debug!(?addr, "bind HTTP listener");
123
124 tokio::spawn(async move {
125 if let Err(e) = axum_server::from_tcp(listener)
126 .handle(task_handle)
127 .serve(
128 service
129 .with_state(RequestContext {
130 router: task_router,
131 scheme: http::uri::Scheme::HTTP,
132 settings: Arc::clone(&settings),
133 })
134 .into_make_service(),
135 )
136 .await
137 {
138 error!(error = %e, "failed to serve HTTP for host-based mode");
139 }
140 })
141 };
142
143 Ok(Self {
144 router,
145 handle,
146 task: Arc::new(task),
147 })
148 }
149}
150
151impl Provider for HttpServerProvider {
152 async fn receive_link_config_as_source(
157 &self,
158 link_config: LinkConfig<'_>,
159 ) -> anyhow::Result<()> {
160 let Some(host) = link_config.config.get("host") else {
161 error!(?link_config.config, ?link_config.target_id, "host not found in link config, cannot register host");
162 bail!(
163 "host not found in link config, cannot register host for component {}",
164 link_config.target_id
165 );
166 };
167
168 let target = Arc::from(link_config.target_id);
169 let name = Arc::from(link_config.link_name);
170
171 let key = (Arc::clone(&target), Arc::clone(&name));
172
173 let mut router = self.router.write().await;
174 if router.components.contains_key(&key) {
175 bail!("Component {target} already has a host registered with link name {name}");
177 }
178 if router.hosts.contains_key(host.as_str()) {
179 bail!("Host {host} already in use by a different component");
181 }
182
183 let wrpc = get_connection()
184 .get_wrpc_client(link_config.target_id)
185 .await
186 .context("failed to construct wRPC client")?;
187
188 let host = Arc::from(host.clone());
189 router.components.insert(key, Arc::clone(&host));
191 router.hosts.insert(host, (target, wrpc));
192
193 Ok(())
194 }
195
196 #[instrument(level = "debug", skip_all, fields(target_id = info.get_target_id()))]
198 async fn delete_link_as_source(&self, info: impl LinkDeleteInfo) -> anyhow::Result<()> {
199 debug!(
200 source = info.get_source_id(),
201 target = info.get_target_id(),
202 link = info.get_link_name(),
203 "deleting http host link"
204 );
205 let component_id = info.get_target_id();
206 let link_name = info.get_link_name();
207
208 let mut router = self.router.write().await;
209 let host = router
210 .components
211 .remove(&(Arc::from(component_id), Arc::from(link_name)));
212 if let Some(host) = host {
213 router.hosts.remove(&host);
214 }
215
216 Ok(())
217 }
218
219 async fn shutdown(&self) -> anyhow::Result<()> {
221 self.handle.shutdown();
222 self.task.abort();
223
224 Ok(())
225 }
226}
227
228#[derive(Clone)]
229struct RequestContext {
230 router: Arc<RwLock<Router>>,
231 scheme: http::uri::Scheme,
232 settings: Arc<ServiceSettings>,
233}
234
235#[instrument(level = "debug", skip(router, settings))]
237async fn handle_request(
238 extract::State(RequestContext {
239 router,
240 scheme,
241 settings,
242 }): extract::State<RequestContext>,
243 axum_extra::extract::Host(authority): axum_extra::extract::Host,
244 request: extract::Request,
245) -> impl axum::response::IntoResponse {
246 let timeout = settings.timeout_ms.map(Duration::from_millis);
247 let req = build_request(request, scheme, authority, &settings).map_err(|err| *err)?;
248
249 let Some(host_header) = req.headers().get(router.read().await.header.as_str()) else {
250 Err((http::StatusCode::BAD_REQUEST, "missing host header"))?
251 };
252
253 let lookup_host = host_header
254 .to_str()
255 .map_err(|_| (http::StatusCode::BAD_REQUEST, "invalid host header"))?;
256
257 let Some((target_component, wrpc)) = router.read().await.hosts.get(lookup_host).cloned() else {
258 Err((http::StatusCode::NOT_FOUND, "host not found"))?
259 };
260
261 axum::response::Result::<_, axum::response::ErrorResponse>::Ok(
262 invoke_component(
263 &wrpc,
264 &target_component,
265 req,
266 timeout,
267 settings.cache_control.as_ref(),
268 )
269 .await,
270 )
271}