wasmcloud_provider_http_server/
address.rs

1//! Implementation of the `wrpc:http/incoming-handler` provider in address mode
2//!
3//! This provider listens on a new address for each component that it links to.
4
5use core::str::FromStr as _;
6use core::time::Duration;
7
8use std::collections::HashMap;
9use std::net::SocketAddr;
10use std::sync::Arc;
11
12use anyhow::{bail, Context as _};
13use axum::extract;
14use axum::handler::Handler;
15use axum_server::tls_rustls::RustlsConfig;
16use tokio::sync::RwLock;
17use tracing::{debug, error, info, instrument};
18use wasmcloud_core::http::{default_listen_address, load_settings, ServiceSettings};
19use wasmcloud_provider_sdk::core::LinkName;
20use wasmcloud_provider_sdk::provider::WrpcClient;
21use wasmcloud_provider_sdk::{get_connection, HostData, LinkConfig, LinkDeleteInfo, Provider};
22
23use crate::{build_request, get_cors_layer, get_tcp_listener, invoke_component};
24
25/// Lookup for handlers by socket
26///
27/// Indexed first by socket address to more easily detect duplicates,
28/// with the http server stored, along with a list (order matters) of components that were registered
29type HandlerLookup =
30    HashMap<SocketAddr, (Arc<HttpServerCore>, Vec<(Arc<str>, Arc<str>, WrpcClient)>)>;
31
32/// `wrpc:http/incoming-handler` provider implementation in address mode
33#[derive(Clone)]
34pub struct HttpServerProvider {
35    default_address: SocketAddr,
36
37    /// Lookup of components that handle requests {addr -> (server, (component id, link name))}
38    handlers_by_socket: Arc<RwLock<HandlerLookup>>,
39
40    /// Sockets that are relevant to a given link name
41    ///
42    /// This structure is generally used as a look up into `handlers_by_socket`
43    sockets_by_link_name: Arc<RwLock<HashMap<LinkName, SocketAddr>>>,
44}
45
46impl Default for HttpServerProvider {
47    fn default() -> Self {
48        Self {
49            default_address: default_listen_address(),
50            handlers_by_socket: Arc::default(),
51            sockets_by_link_name: Arc::default(),
52        }
53    }
54}
55
56impl HttpServerProvider {
57    /// Create a new instance of the HTTP server provider
58    pub fn new(host_data: &HostData) -> anyhow::Result<Self> {
59        let default_address = host_data
60            .config
61            .get("default_address")
62            .map(|s| SocketAddr::from_str(s))
63            .transpose()
64            .context("failed to parse default_address")?
65            .unwrap_or_else(default_listen_address);
66
67        Ok(Self {
68            default_address,
69            handlers_by_socket: Arc::default(),
70            sockets_by_link_name: Arc::default(),
71        })
72    }
73}
74
75impl Provider for HttpServerProvider {
76    /// This is called when the HTTP server provider is linked to a component
77    ///
78    /// This HTTP server mode will listen on a new address for each component that it links to.
79    async fn receive_link_config_as_source(
80        &self,
81        link_config: LinkConfig<'_>,
82    ) -> anyhow::Result<()> {
83        let settings = match load_settings(Some(self.default_address), link_config.config)
84            .context("httpserver failed to load settings for component")
85        {
86            Ok(settings) => settings,
87            Err(e) => {
88                error!(
89                    config = ?link_config.config,
90                    "httpserver failed to load settings for component: {}", e.to_string()
91                );
92                bail!(e);
93            }
94        };
95
96        let wrpc = get_connection()
97            .get_wrpc_client(link_config.target_id)
98            .await
99            .context("failed to construct wRPC client")?;
100        let component_meta = (
101            Arc::from(link_config.target_id),
102            Arc::from(link_config.link_name),
103            wrpc,
104        );
105        let mut sockets_by_link_name = self.sockets_by_link_name.write().await;
106        let mut handlers_by_socket = self.handlers_by_socket.write().await;
107
108        match sockets_by_link_name.entry(link_config.link_name.to_string()) {
109            // If a mapping already exists, and the stored address is different, disallow overwriting
110            std::collections::hash_map::Entry::Occupied(v) => {
111                bail!(
112                    "an address mapping for address [{}] the link [{}] already exists, overwriting links is not currently supported",
113                    v.get().ip(),
114                    link_config.link_name,
115                )
116            }
117            // If a mapping does exist, we can create a new mapping for the address
118            std::collections::hash_map::Entry::Vacant(v) => {
119                v.insert(settings.address);
120            }
121        }
122
123        match handlers_by_socket.entry(settings.address) {
124            // If handlers already exist for the address, add the newly linked component
125            //
126            // NOTE: only components at the head of the list are served requests
127            std::collections::hash_map::Entry::Occupied(mut v) => {
128                v.get_mut().1.push(component_meta);
129            }
130            // If a handler does not already exist, make a new server and insert
131            std::collections::hash_map::Entry::Vacant(v) => {
132                // Start a server instance that calls the given component
133                let http_server = match HttpServerCore::new(
134                    Arc::new(settings),
135                    link_config.target_id,
136                    self.handlers_by_socket.clone(),
137                )
138                .await
139                {
140                    Ok(s) => s,
141                    Err(e) => {
142                        error!("failed to start listener for component: {e:?}");
143                        bail!(e);
144                    }
145                };
146                v.insert((Arc::new(http_server), vec![component_meta]));
147            }
148        }
149
150        Ok(())
151    }
152
153    /// Handle notification that a link is dropped - stop the http listener
154    #[instrument(level = "info", skip_all, fields(target_id = info.get_target_id()))]
155    async fn delete_link_as_source(&self, info: impl LinkDeleteInfo) -> anyhow::Result<()> {
156        let component_id = info.get_target_id();
157        let link_name = info.get_link_name();
158
159        // Retrieve the thing by link name
160        let mut sockets_by_link_name = self.sockets_by_link_name.write().await;
161        if let Some(addr) = sockets_by_link_name.get(link_name) {
162            let mut handlers_by_socket = self.handlers_by_socket.write().await;
163            if let Some((server, component_metas)) = handlers_by_socket.get_mut(addr) {
164                // If the component id & link name pair is present, remove it
165                if let Some(idx) = component_metas
166                    .iter()
167                    .position(|(c, l, ..)| c.as_ref() == component_id && l.as_ref() == link_name)
168                {
169                    component_metas.remove(idx);
170                }
171
172                // If the component was the last one, we can remove the server
173                if component_metas.is_empty() {
174                    info!(
175                        address = addr.to_string(),
176                        "last component removed for address, shutting down server"
177                    );
178                    server.handle.shutdown();
179                    handlers_by_socket.remove(addr);
180                    sockets_by_link_name.remove(link_name);
181                }
182            }
183        }
184
185        Ok(())
186    }
187
188    /// Handle shutdown request by shutting down all the http server threads
189    async fn shutdown(&self) -> anyhow::Result<()> {
190        // Empty the component link data and stop all servers
191        self.sockets_by_link_name.write().await.clear();
192        self.handlers_by_socket.write().await.clear();
193        Ok(())
194    }
195}
196
197#[derive(Clone)]
198struct RequestContext {
199    /// Address of the server, used for handler lookup
200    server_address: SocketAddr,
201    /// Settings that can be
202    settings: Arc<ServiceSettings>,
203    /// HTTP scheme
204    scheme: http::uri::Scheme,
205    /// Handlers for components
206    handlers_by_socket: Arc<RwLock<HandlerLookup>>,
207}
208
209/// Handle an HTTP request by invoking the target component as configured in the listener
210#[instrument(level = "debug", skip(settings, handlers_by_socket))]
211async fn handle_request(
212    extract::State(RequestContext {
213        server_address,
214        settings,
215        scheme,
216        handlers_by_socket,
217    }): extract::State<RequestContext>,
218    axum_extra::extract::Host(authority): axum_extra::extract::Host,
219    request: extract::Request,
220) -> impl axum::response::IntoResponse {
221    let (component_id, wrpc) = {
222        let Some((component_id, wrpc)) = handlers_by_socket
223            .read()
224            .await
225            .get(&server_address)
226            .and_then(|v| v.1.first())
227            .map(|(component_id, _, wrpc)| (Arc::clone(component_id), wrpc.clone()))
228        else {
229            return Err((
230                http::StatusCode::INTERNAL_SERVER_ERROR,
231                "no targets for HTTP request",
232            ))?;
233        };
234        (component_id, wrpc)
235    };
236
237    let timeout = settings.timeout_ms.map(Duration::from_millis);
238    let req = build_request(request, scheme, authority, &settings).map_err(|err| *err)?;
239    axum::response::Result::<_, axum::response::ErrorResponse>::Ok(
240        invoke_component(
241            &wrpc,
242            &component_id,
243            req,
244            timeout,
245            settings.cache_control.as_ref(),
246        )
247        .await,
248    )
249}
250
251/// An asynchronous `wrpc:http/incoming-handler` with support for CORS and TLS
252#[derive(Debug)]
253pub struct HttpServerCore {
254    /// The handle to the server handling incoming requests
255    handle: axum_server::Handle,
256    /// The asynchronous task running the server
257    task: tokio::task::JoinHandle<()>,
258}
259
260impl HttpServerCore {
261    #[instrument(skip(handlers_by_socket))]
262    pub async fn new(
263        settings: Arc<ServiceSettings>,
264        target: &str,
265        handlers_by_socket: Arc<RwLock<HandlerLookup>>,
266    ) -> anyhow::Result<Self> {
267        let addr = settings.address;
268        info!(
269            %addr,
270            component_id = target,
271            "httpserver starting listener for target",
272        );
273        let cors = get_cors_layer(&settings)?;
274        let service = handle_request.layer(cors);
275        let handle = axum_server::Handle::new();
276        let listener = get_tcp_listener(&settings)
277            .with_context(|| format!("failed to create listener (is [{addr}] already in use?)"))?;
278
279        let target = target.to_owned();
280        let task_handle = handle.clone();
281        let task = if let (Some(crt), Some(key)) =
282            (&settings.tls_cert_file, &settings.tls_priv_key_file)
283        {
284            debug!(?addr, "bind HTTPS listener");
285            let tls = RustlsConfig::from_pem_file(crt, key)
286                .await
287                .context("failed to construct TLS config")?;
288
289            let srv = axum_server::from_tcp_rustls(listener, tls);
290            tokio::spawn(async move {
291                if let Err(e) = srv
292                    .handle(task_handle)
293                    .serve(
294                        service
295                            .with_state(RequestContext {
296                                server_address: addr,
297                                settings,
298                                scheme: http::uri::Scheme::HTTPS,
299                                handlers_by_socket,
300                            })
301                            .into_make_service(),
302                    )
303                    .await
304                {
305                    error!(error = %e, component_id = target, "failed to serve HTTPS for component");
306                }
307            })
308        } else {
309            debug!(?addr, "bind HTTP listener");
310
311            let mut srv = axum_server::from_tcp(listener);
312            srv.http_builder().http1().keep_alive(false);
313            tokio::spawn(async move {
314                if let Err(e) = srv
315                    .handle(task_handle)
316                    .serve(
317                        service
318                            .with_state(RequestContext {
319                                server_address: addr,
320                                settings,
321                                scheme: http::uri::Scheme::HTTP,
322                                handlers_by_socket,
323                            })
324                            .into_make_service(),
325                    )
326                    .await
327                {
328                    error!(error = %e, component_id = target, "failed to serve HTTP for component");
329                }
330            })
331        };
332
333        Ok(Self { handle, task })
334    }
335}
336
337impl Drop for HttpServerCore {
338    /// Drop the client connection. Does not block or fail if the client has already been closed.
339    fn drop(&mut self) {
340        self.handle.shutdown();
341        self.task.abort();
342    }
343}