wasmcloud_provider_messaging_nats/
lib.rs

1use core::time::Duration;
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use anyhow::{anyhow, bail, ensure, Context as _};
7use async_nats::subject::ToSubject;
8use bytes::Bytes;
9use futures::StreamExt as _;
10use opentelemetry_nats::{attach_span_context, NatsHeaderInjector};
11use tokio::fs;
12use tokio::sync::RwLock;
13use tokio::task::JoinHandle;
14use tracing::{debug, error, instrument, warn};
15use tracing_futures::Instrument;
16use wascap::prelude::KeyPair;
17use wasmcloud_core::messaging::ConnectionConfig;
18use wasmcloud_provider_sdk::core::HostData;
19use wasmcloud_provider_sdk::provider::WrpcClient;
20use wasmcloud_provider_sdk::wasmcloud_tracing::context::TraceContextInjector;
21use wasmcloud_provider_sdk::{
22    get_connection, initialize_observability, load_host_data, propagate_trace_for_ctx,
23    run_provider, serve_provider_exports, Context, LinkConfig, LinkDeleteInfo, Provider,
24};
25
26mod connection;
27
28mod bindings {
29    wit_bindgen_wrpc::generate!({
30        with: {
31            "wasmcloud:messaging/consumer@0.2.0": generate,
32            "wasmcloud:messaging/handler@0.2.0": generate,
33            "wasmcloud:messaging/types@0.2.0": generate,
34        },
35    });
36}
37use bindings::wasmcloud::messaging::types::BrokerMessage;
38
39pub async fn run() -> anyhow::Result<()> {
40    NatsMessagingProvider::run().await
41}
42
43/// [`NatsClientBundle`]s hold a NATS client and information (subscriptions)
44/// related to it.
45///
46/// This struct is necessary because subscriptions are *not* automatically removed on client drop,
47/// meaning that we must keep track of all subscriptions to close once the client is done
48#[derive(Debug)]
49struct NatsClientBundle {
50    pub client: async_nats::Client,
51    pub sub_handles: Vec<(String, JoinHandle<()>)>,
52}
53
54impl Drop for NatsClientBundle {
55    fn drop(&mut self) {
56        for handle in &self.sub_handles {
57            handle.1.abort();
58        }
59    }
60}
61
62/// Nats implementation for wasmcloud:messaging
63#[derive(Default, Clone)]
64pub struct NatsMessagingProvider {
65    handler_components: Arc<RwLock<HashMap<String, NatsClientBundle>>>,
66    consumer_components: Arc<RwLock<HashMap<String, NatsClientBundle>>>,
67    default_config: ConnectionConfig,
68}
69
70impl NatsMessagingProvider {
71    pub async fn run() -> anyhow::Result<()> {
72        initialize_observability!(
73            "nats-messaging-provider",
74            std::env::var_os("PROVIDER_NATS_MESSAGING_FLAMEGRAPH_PATH")
75        );
76
77        let host_data = load_host_data().context("failed to load host data")?;
78        let provider = Self::from_host_data(host_data);
79        let shutdown = run_provider(provider.clone(), "messaging-nats-provider")
80            .await
81            .context("failed to run provider")?;
82        let connection = get_connection();
83        let wrpc = connection
84            .get_wrpc_client(connection.provider_key())
85            .await?;
86        serve_provider_exports(&wrpc, provider, shutdown, bindings::serve)
87            .await
88            .context("failed to serve provider exports")
89    }
90
91    /// Build a [`NatsMessagingProvider`] from [`HostData`]
92    pub fn from_host_data(host_data: &HostData) -> NatsMessagingProvider {
93        let config = ConnectionConfig::from_map(&host_data.config);
94        if let Ok(config) = config {
95            NatsMessagingProvider {
96                default_config: config,
97                ..Default::default()
98            }
99        } else {
100            warn!("Failed to build connection configuration, falling back to default");
101            NatsMessagingProvider::default()
102        }
103    }
104
105    /// Attempt to connect to nats url (with jwt credentials, if provided)
106    async fn connect(
107        &self,
108        cfg: ConnectionConfig,
109        component_id: &str,
110    ) -> anyhow::Result<NatsClientBundle> {
111        ensure!(
112            cfg.consumers.is_empty(),
113            "JetStream consumers not supported by this provider"
114        );
115        let mut opts = match (cfg.auth_jwt, cfg.auth_seed) {
116            (Some(jwt), Some(seed)) => {
117                let seed = KeyPair::from_seed(&seed).context("failed to parse seed key pair")?;
118                let seed = Arc::new(seed);
119                async_nats::ConnectOptions::with_jwt(jwt.into_string(), move |nonce| {
120                    let seed = seed.clone();
121                    async move { seed.sign(&nonce).map_err(async_nats::AuthError::new) }
122                })
123            }
124            (None, None) => async_nats::ConnectOptions::default(),
125            _ => bail!("must provide both jwt and seed for jwt authentication"),
126        };
127        if let Some(tls_ca) = cfg.tls_ca.as_deref() {
128            opts = add_tls_ca(tls_ca, opts)?;
129        } else if let Some(tls_ca_file) = cfg.tls_ca_file.as_deref() {
130            let ca = fs::read_to_string(tls_ca_file)
131                .await
132                .context("failed to read TLS CA file")?;
133            opts = add_tls_ca(&ca, opts)?;
134        }
135
136        // Use the first visible cluster_uri
137        let url = cfg.cluster_uris.first().unwrap();
138
139        // Override inbox prefix if specified
140        if let Some(prefix) = cfg.custom_inbox_prefix {
141            opts = opts.custom_inbox_prefix(prefix);
142        }
143
144        let client = opts
145            .name("NATS Messaging Provider") // allow this to show up uniquely in a NATS connection list
146            .connect(url.as_ref())
147            .await?;
148
149        // Connections
150        let mut sub_handles = Vec::new();
151        for sub in cfg.subscriptions.iter().filter(|s| !s.is_empty()) {
152            let (sub, queue) = match sub.split_once('|') {
153                Some((sub, queue)) => (sub, Some(queue.into())),
154                None => (sub.as_str(), None),
155            };
156
157            sub_handles.push((
158                sub.into(),
159                self.subscribe(&client, component_id, sub.to_string(), queue)
160                    .await?,
161            ));
162        }
163
164        Ok(NatsClientBundle {
165            client,
166            sub_handles,
167        })
168    }
169
170    /// Add a regular or queue subscription
171    async fn subscribe(
172        &self,
173        client: &async_nats::Client,
174        component_id: &str,
175        sub: impl ToSubject,
176        queue: Option<String>,
177    ) -> anyhow::Result<JoinHandle<()>> {
178        let mut subscriber = match queue {
179            Some(queue) => client.queue_subscribe(sub, queue).await,
180            None => client.subscribe(sub).await,
181        }?;
182
183        debug!(?component_id, "spawning listener for component");
184
185        let component_id = Arc::from(component_id);
186        // Spawn a thread that listens for messages coming from NATS
187        // this thread is expected to run the full duration that the provider is available
188        let join_handle = tokio::spawn(async move {
189            let wrpc = match get_connection()
190                .get_wrpc_client_custom(&component_id, None)
191                .await
192            {
193                Ok(wrpc) => Arc::new(wrpc),
194                Err(err) => {
195                    error!(?err, "failed to construct wRPC client");
196                    return;
197                }
198            };
199            // Listen for NATS message(s)
200            while let Some(msg) = subscriber.next().await {
201                debug!(?msg, ?component_id, "received message");
202                // Set up tracing context for the NATS message
203                let span = tracing::debug_span!("handle_message", ?component_id);
204
205                let component_id = Arc::clone(&component_id);
206                let wrpc = Arc::clone(&wrpc);
207                tokio::spawn(async move {
208                    dispatch_msg(&wrpc, &component_id, msg)
209                        .instrument(span)
210                        .await;
211                });
212            }
213        });
214
215        Ok(join_handle)
216    }
217}
218
219#[instrument(level = "debug", skip_all, fields(component_id = %component_id, subject = %nats_msg.subject, reply_to = ?nats_msg.reply))]
220async fn dispatch_msg(wrpc: &WrpcClient, component_id: &str, nats_msg: async_nats::Message) {
221    match nats_msg.headers {
222        // If there are some headers on the message they might contain a span context
223        // so attempt to attach them.
224        Some(ref h) if !h.is_empty() => {
225            attach_span_context(&nats_msg);
226        }
227        // Otherwise, we'll use the existing span context starting with this message
228        _ => (),
229    };
230
231    let msg = BrokerMessage {
232        body: nats_msg.payload,
233        reply_to: nats_msg.reply.map(|s| s.into_string()),
234        subject: nats_msg.subject.into_string(),
235    };
236    debug!(
237        subject = msg.subject,
238        reply_to = ?msg.reply_to,
239        component_id = component_id,
240        "sending message to component",
241    );
242    let mut cx = async_nats::HeaderMap::new();
243    for (k, v) in TraceContextInjector::default_with_span().iter() {
244        cx.insert(k.as_str(), v.as_str())
245    }
246    if let Err(e) =
247        bindings::wasmcloud::messaging::handler::handle_message(wrpc, Some(cx), &msg).await
248    {
249        error!(
250            error = %e,
251            "Unable to send message"
252        );
253    }
254}
255
256/// Handle provider control commands
257/// `put_link` (new component link command), `del_link` (remove link command), and shutdown
258impl Provider for NatsMessagingProvider {
259    /// Provider should perform any operations needed for a new link,
260    /// including setting up per-component resources, and checking authorization.
261    /// If the link is allowed, return true, otherwise return false to deny the link.
262    #[instrument(level = "debug", skip_all, fields(source_id))]
263    async fn receive_link_config_as_target(
264        &self,
265        link_config: LinkConfig<'_>,
266    ) -> anyhow::Result<()> {
267        let LinkConfig { source_id, .. } = link_config;
268        let config = if link_config.config.is_empty() {
269            self.default_config.clone()
270        } else {
271            // create a config from the supplied values and merge that with the existing default
272            match connection::from_link_config(&link_config) {
273                Ok(cc) => self.default_config.merge(&ConnectionConfig {
274                    subscriptions: Box::default(),
275                    ..cc
276                }),
277                Err(e) => {
278                    error!("Failed to build connection configuration: {e:?}");
279                    return Err(anyhow!(e).context("failed to build connection config"));
280                }
281            }
282        };
283
284        let mut update_map = self.consumer_components.write().await;
285        let bundle = match self.connect(config, source_id).await {
286            Ok(b) => b,
287            Err(e) => {
288                error!("Failed to connect to NATS: {e:?}");
289                bail!(anyhow!(e).context("failed to connect to NATS"))
290            }
291        };
292        update_map.insert(source_id.into(), bundle);
293
294        Ok(())
295    }
296
297    #[instrument(level = "debug", skip_all, fields(target_id))]
298    async fn receive_link_config_as_source(
299        &self,
300        link_config: LinkConfig<'_>,
301    ) -> anyhow::Result<()> {
302        let target_id = link_config.target_id;
303        let config = if link_config.config.is_empty() {
304            self.default_config.clone()
305        } else {
306            // create a config from the supplied values and merge that with the existing default
307            match connection::from_link_config(&link_config) {
308                Ok(cc) => self.default_config.merge(&cc),
309                Err(e) => {
310                    error!("Failed to build connection configuration: {e:?}");
311                    return Err(anyhow!(e).context("failed to build connection config"));
312                }
313            }
314        };
315
316        let mut update_map = self.handler_components.write().await;
317        let bundle = match self.connect(config, target_id).await {
318            Ok(b) => b,
319            Err(e) => {
320                error!("Failed to connect to NATS: {e:?}");
321                bail!(anyhow!(e).context("failed to connect to NATS"))
322            }
323        };
324        update_map.insert(target_id.into(), bundle);
325
326        Ok(())
327    }
328
329    /// Handle notification that a link is dropped: close the connection
330    #[instrument(level = "info", skip_all, fields(source_id = info.get_source_id()))]
331    async fn delete_link_as_target(&self, info: impl LinkDeleteInfo) -> anyhow::Result<()> {
332        let component_id = info.get_source_id();
333        let mut links = self.consumer_components.write().await;
334        if let Some(bundle) = links.remove(component_id) {
335            let client = &bundle.client;
336            debug!(
337                component_id,
338                "dropping NATS client [{}] for (consumer) component",
339                format!(
340                    "{}:{}",
341                    client.server_info().server_id,
342                    client.server_info().client_id
343                ),
344            );
345        }
346
347        debug!(
348            component_id,
349            "finished processing (consumer) link deletion for component",
350        );
351
352        Ok(())
353    }
354
355    #[instrument(level = "info", skip_all, fields(target_id = info.get_target_id()))]
356    async fn delete_link_as_source(&self, info: impl LinkDeleteInfo) -> anyhow::Result<()> {
357        // If we were the source, then the component we're invoking is the target
358        let component_id = info.get_target_id();
359        let mut links = self.handler_components.write().await;
360        if let Some(bundle) = links.remove(component_id) {
361            // Note: subscriptions will be closed via Drop on the NatsClientBundle
362            let client = &bundle.client;
363            debug!(
364                component_id,
365                "dropping NATS client [{}] and associated subscriptions [{}] for (handler) component",
366                format!(
367                    "{}:{}",
368                    client.server_info().server_id,
369                    client.server_info().client_id
370                ),
371                &bundle.sub_handles.len(),
372            );
373        }
374
375        debug!(
376            component_id,
377            "finished processing (handler) link deletion for component",
378        );
379
380        Ok(())
381    }
382
383    /// Handle shutdown request by closing all connections
384    async fn shutdown(&self) -> anyhow::Result<()> {
385        // clear the handler components
386        let mut handlers = self.handler_components.write().await;
387        handlers.clear();
388
389        // clear the consumer components
390        let mut consumers = self.consumer_components.write().await;
391        consumers.clear();
392
393        // dropping all connections should send unsubscribes and close the connections, so no need
394        // to handle that here
395        Ok(())
396    }
397}
398
399/// Implement the 'wasmcloud:messaging' capability provider interface
400impl bindings::exports::wasmcloud::messaging::consumer::Handler<Option<Context>>
401    for NatsMessagingProvider
402{
403    #[instrument(level = "debug", skip(self, ctx, msg), fields(subject = %msg.subject, reply_to = ?msg.reply_to, body_len = %msg.body.len()))]
404    async fn publish(
405        &self,
406        ctx: Option<Context>,
407        msg: BrokerMessage,
408    ) -> anyhow::Result<Result<(), String>> {
409        propagate_trace_for_ctx!(ctx);
410
411        let nats_client =
412            if let Some(ref source_id) = ctx.and_then(|Context { component, .. }| component) {
413                let actors = self.consumer_components.read().await;
414                let nats_bundle = match actors.get(source_id) {
415                    Some(nats_bundle) => nats_bundle,
416                    None => {
417                        error!("component not linked: {source_id}");
418                        bail!("component not linked: {source_id}")
419                    }
420                };
421                nats_bundle.client.clone()
422            } else {
423                error!("no component in request");
424                bail!("no component in request")
425            };
426
427        let headers = NatsHeaderInjector::default_with_span().into();
428
429        let body = msg.body;
430        let res = match msg.reply_to.clone() {
431            Some(reply_to) => if should_strip_headers(&msg.subject) {
432                nats_client
433                    .publish_with_reply(msg.subject, reply_to, body)
434                    .await
435            } else {
436                nats_client
437                    .publish_with_reply_and_headers(msg.subject, reply_to, headers, body)
438                    .await
439            }
440            .map_err(|e| e.to_string()),
441            None => nats_client
442                .publish_with_headers(msg.subject, headers, body)
443                .await
444                .map_err(|e| e.to_string()),
445        };
446        let _ = nats_client.flush().await;
447        Ok(res)
448    }
449
450    #[instrument(level = "debug", skip(self, ctx), fields(subject = %subject))]
451    async fn request(
452        &self,
453        ctx: Option<Context>,
454        subject: String,
455        body: Bytes,
456        timeout_ms: u32,
457    ) -> anyhow::Result<Result<BrokerMessage, String>> {
458        let nats_client =
459            if let Some(ref source_id) = ctx.and_then(|Context { component, .. }| component) {
460                let actors = self.consumer_components.read().await;
461                let nats_bundle = match actors.get(source_id) {
462                    Some(nats_bundle) => nats_bundle,
463                    None => {
464                        error!("component not linked: {source_id}");
465                        bail!("component not linked: {source_id}")
466                    }
467                };
468                nats_bundle.client.clone()
469            } else {
470                error!("no component in request");
471                bail!("no component in request")
472            };
473
474        // Inject OTEL headers
475        let headers = NatsHeaderInjector::default_with_span().into();
476
477        let timeout = Duration::from_millis(timeout_ms.into());
478        // Perform the request with a timeout
479        let request_with_timeout = if should_strip_headers(&subject) {
480            tokio::time::timeout(timeout, nats_client.request(subject, body)).await
481        } else {
482            tokio::time::timeout(
483                timeout,
484                nats_client.request_with_headers(subject, headers, body),
485            )
486            .await
487        };
488
489        // Process results of request
490        match request_with_timeout {
491            Err(timeout_err) => {
492                error!("nats request timed out: {timeout_err}");
493                return Ok(Err(format!("nats request timed out: {timeout_err}")));
494            }
495            Ok(Err(send_err)) => {
496                error!("nats send error: {send_err}");
497                return Ok(Err(format!("nats send error: {send_err}")));
498            }
499            Ok(Ok(resp)) => Ok(Ok(BrokerMessage {
500                body: resp.payload,
501                reply_to: resp.reply.map(|s| s.into_string()),
502                subject: resp.subject.into_string(),
503            })),
504        }
505    }
506}
507
508// In the current version of the NATS server, using headers on certain $SYS.REQ topics will cause server-side
509// parse failures
510fn should_strip_headers(topic: &str) -> bool {
511    topic.starts_with("$SYS")
512}
513
514pub fn add_tls_ca(
515    tls_ca: &str,
516    opts: async_nats::ConnectOptions,
517) -> anyhow::Result<async_nats::ConnectOptions> {
518    let ca = rustls_pemfile::read_one(&mut tls_ca.as_bytes()).context("failed to read CA")?;
519    let mut roots = async_nats::rustls::RootCertStore::empty();
520    if let Some(rustls_pemfile::Item::X509Certificate(ca)) = ca {
521        roots.add_parsable_certificates([ca]);
522    } else {
523        bail!("tls ca: invalid certificate type, must be a DER encoded PEM file")
524    };
525    let tls_client = async_nats::rustls::ClientConfig::builder()
526        .with_root_certificates(roots)
527        .with_no_client_auth();
528    Ok(opts.tls_client_config(tls_client).require_tls(true))
529}
530
531#[cfg(test)]
532mod test {
533    use super::*;
534    use std::collections::HashMap;
535
536    #[test]
537    fn test_default_connection_serialize() {
538        // test to verify that we can default a config with partial input
539        let input = r#"
540{
541    "cluster_uris": ["nats://soyvuh"],
542    "auth_jwt": "authy",
543    "auth_seed": "seedy"
544}
545"#;
546
547        let config: ConnectionConfig = serde_json::from_str(input).unwrap();
548        assert_eq!(config.auth_jwt.unwrap().as_ref(), "authy");
549        assert_eq!(config.auth_seed.unwrap().as_ref(), "seedy");
550        assert_eq!(config.cluster_uris, [Box::from("nats://soyvuh")].into());
551        assert_eq!(config.custom_inbox_prefix, None);
552        assert!(config.subscriptions.is_empty());
553        assert!(config.ping_interval_sec.is_none());
554    }
555
556    #[test]
557    fn test_connectionconfig_merge() {
558        // second > original, individual vec fields are replace not extend
559        let cc1 = ConnectionConfig {
560            cluster_uris: ["old_server".into()].into(),
561            subscriptions: ["topic1".into()].into(),
562            custom_inbox_prefix: Some("_NOPE.>".into()),
563            ..Default::default()
564        };
565        let cc2 = ConnectionConfig {
566            cluster_uris: ["server1".into(), "server2".into()].into(),
567            auth_jwt: Some("jawty".into()),
568            ..Default::default()
569        };
570        let cc3 = cc1.merge(&cc2);
571        assert_eq!(cc3.cluster_uris, cc2.cluster_uris);
572        assert_eq!(cc3.subscriptions, cc1.subscriptions);
573        assert_eq!(cc3.auth_jwt, Some("jawty".into()));
574        assert_eq!(cc3.custom_inbox_prefix, Some("_NOPE.>".into()));
575    }
576
577    #[test]
578    fn test_from_map() -> anyhow::Result<()> {
579        let cc = ConnectionConfig::from_map(&HashMap::from([(
580            "custom_inbox_prefix".into(),
581            "_TEST.>".into(),
582        )]))?;
583        assert_eq!(cc.custom_inbox_prefix, Some("_TEST.>".into()));
584        Ok(())
585    }
586}