1use core::time::Duration;
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use anyhow::{anyhow, bail, ensure, Context as _};
7use async_nats::subject::ToSubject;
8use bytes::Bytes;
9use futures::StreamExt as _;
10use opentelemetry_nats::{attach_span_context, NatsHeaderInjector};
11use tokio::fs;
12use tokio::sync::RwLock;
13use tokio::task::JoinHandle;
14use tracing::{debug, error, instrument, warn};
15use tracing_futures::Instrument;
16use wascap::prelude::KeyPair;
17use wasmcloud_core::messaging::ConnectionConfig;
18use wasmcloud_provider_sdk::core::HostData;
19use wasmcloud_provider_sdk::provider::WrpcClient;
20use wasmcloud_provider_sdk::wasmcloud_tracing::context::TraceContextInjector;
21use wasmcloud_provider_sdk::{
22 get_connection, initialize_observability, load_host_data, propagate_trace_for_ctx,
23 run_provider, serve_provider_exports, Context, LinkConfig, LinkDeleteInfo, Provider,
24};
25
26mod connection;
27
28mod bindings {
29 wit_bindgen_wrpc::generate!({
30 with: {
31 "wasmcloud:messaging/consumer@0.2.0": generate,
32 "wasmcloud:messaging/handler@0.2.0": generate,
33 "wasmcloud:messaging/types@0.2.0": generate,
34 },
35 });
36}
37use bindings::wasmcloud::messaging::types::BrokerMessage;
38
39pub async fn run() -> anyhow::Result<()> {
40 NatsMessagingProvider::run().await
41}
42
43#[derive(Debug)]
49struct NatsClientBundle {
50 pub client: async_nats::Client,
51 pub sub_handles: Vec<(String, JoinHandle<()>)>,
52}
53
54impl Drop for NatsClientBundle {
55 fn drop(&mut self) {
56 for handle in &self.sub_handles {
57 handle.1.abort();
58 }
59 }
60}
61
62#[derive(Default, Clone)]
64pub struct NatsMessagingProvider {
65 handler_components: Arc<RwLock<HashMap<String, NatsClientBundle>>>,
66 consumer_components: Arc<RwLock<HashMap<String, NatsClientBundle>>>,
67 default_config: ConnectionConfig,
68}
69
70impl NatsMessagingProvider {
71 pub async fn run() -> anyhow::Result<()> {
72 initialize_observability!(
73 "nats-messaging-provider",
74 std::env::var_os("PROVIDER_NATS_MESSAGING_FLAMEGRAPH_PATH")
75 );
76
77 let host_data = load_host_data().context("failed to load host data")?;
78 let provider = Self::from_host_data(host_data);
79 let shutdown = run_provider(provider.clone(), "messaging-nats-provider")
80 .await
81 .context("failed to run provider")?;
82 let connection = get_connection();
83 let wrpc = connection
84 .get_wrpc_client(connection.provider_key())
85 .await?;
86 serve_provider_exports(&wrpc, provider, shutdown, bindings::serve)
87 .await
88 .context("failed to serve provider exports")
89 }
90
91 pub fn from_host_data(host_data: &HostData) -> NatsMessagingProvider {
93 let config = ConnectionConfig::from_map(&host_data.config);
94 if let Ok(config) = config {
95 NatsMessagingProvider {
96 default_config: config,
97 ..Default::default()
98 }
99 } else {
100 warn!("Failed to build connection configuration, falling back to default");
101 NatsMessagingProvider::default()
102 }
103 }
104
105 async fn connect(
107 &self,
108 cfg: ConnectionConfig,
109 component_id: &str,
110 ) -> anyhow::Result<NatsClientBundle> {
111 ensure!(
112 cfg.consumers.is_empty(),
113 "JetStream consumers not supported by this provider"
114 );
115 let mut opts = match (cfg.auth_jwt, cfg.auth_seed) {
116 (Some(jwt), Some(seed)) => {
117 let seed = KeyPair::from_seed(&seed).context("failed to parse seed key pair")?;
118 let seed = Arc::new(seed);
119 async_nats::ConnectOptions::with_jwt(jwt.into_string(), move |nonce| {
120 let seed = seed.clone();
121 async move { seed.sign(&nonce).map_err(async_nats::AuthError::new) }
122 })
123 }
124 (None, None) => async_nats::ConnectOptions::default(),
125 _ => bail!("must provide both jwt and seed for jwt authentication"),
126 };
127 if let Some(tls_ca) = cfg.tls_ca.as_deref() {
128 opts = add_tls_ca(tls_ca, opts)?;
129 } else if let Some(tls_ca_file) = cfg.tls_ca_file.as_deref() {
130 let ca = fs::read_to_string(tls_ca_file)
131 .await
132 .context("failed to read TLS CA file")?;
133 opts = add_tls_ca(&ca, opts)?;
134 }
135
136 let url = cfg.cluster_uris.first().unwrap();
138
139 if let Some(prefix) = cfg.custom_inbox_prefix {
141 opts = opts.custom_inbox_prefix(prefix);
142 }
143
144 let client = opts
145 .name("NATS Messaging Provider") .connect(url.as_ref())
147 .await?;
148
149 let mut sub_handles = Vec::new();
151 for sub in cfg.subscriptions.iter().filter(|s| !s.is_empty()) {
152 let (sub, queue) = match sub.split_once('|') {
153 Some((sub, queue)) => (sub, Some(queue.into())),
154 None => (sub.as_str(), None),
155 };
156
157 sub_handles.push((
158 sub.into(),
159 self.subscribe(&client, component_id, sub.to_string(), queue)
160 .await?,
161 ));
162 }
163
164 Ok(NatsClientBundle {
165 client,
166 sub_handles,
167 })
168 }
169
170 async fn subscribe(
172 &self,
173 client: &async_nats::Client,
174 component_id: &str,
175 sub: impl ToSubject,
176 queue: Option<String>,
177 ) -> anyhow::Result<JoinHandle<()>> {
178 let mut subscriber = match queue {
179 Some(queue) => client.queue_subscribe(sub, queue).await,
180 None => client.subscribe(sub).await,
181 }?;
182
183 debug!(?component_id, "spawning listener for component");
184
185 let component_id = Arc::from(component_id);
186 let join_handle = tokio::spawn(async move {
189 let wrpc = match get_connection()
190 .get_wrpc_client_custom(&component_id, None)
191 .await
192 {
193 Ok(wrpc) => Arc::new(wrpc),
194 Err(err) => {
195 error!(?err, "failed to construct wRPC client");
196 return;
197 }
198 };
199 while let Some(msg) = subscriber.next().await {
201 debug!(?msg, ?component_id, "received message");
202 let span = tracing::debug_span!("handle_message", ?component_id);
204
205 let component_id = Arc::clone(&component_id);
206 let wrpc = Arc::clone(&wrpc);
207 tokio::spawn(async move {
208 dispatch_msg(&wrpc, &component_id, msg)
209 .instrument(span)
210 .await;
211 });
212 }
213 });
214
215 Ok(join_handle)
216 }
217}
218
219#[instrument(level = "debug", skip_all, fields(component_id = %component_id, subject = %nats_msg.subject, reply_to = ?nats_msg.reply))]
220async fn dispatch_msg(wrpc: &WrpcClient, component_id: &str, nats_msg: async_nats::Message) {
221 match nats_msg.headers {
222 Some(ref h) if !h.is_empty() => {
225 attach_span_context(&nats_msg);
226 }
227 _ => (),
229 };
230
231 let msg = BrokerMessage {
232 body: nats_msg.payload,
233 reply_to: nats_msg.reply.map(|s| s.into_string()),
234 subject: nats_msg.subject.into_string(),
235 };
236 debug!(
237 subject = msg.subject,
238 reply_to = ?msg.reply_to,
239 component_id = component_id,
240 "sending message to component",
241 );
242 let mut cx = async_nats::HeaderMap::new();
243 for (k, v) in TraceContextInjector::default_with_span().iter() {
244 cx.insert(k.as_str(), v.as_str())
245 }
246 if let Err(e) =
247 bindings::wasmcloud::messaging::handler::handle_message(wrpc, Some(cx), &msg).await
248 {
249 error!(
250 error = %e,
251 "Unable to send message"
252 );
253 }
254}
255
256impl Provider for NatsMessagingProvider {
259 #[instrument(level = "debug", skip_all, fields(source_id))]
263 async fn receive_link_config_as_target(
264 &self,
265 link_config: LinkConfig<'_>,
266 ) -> anyhow::Result<()> {
267 let LinkConfig { source_id, .. } = link_config;
268 let config = if link_config.config.is_empty() {
269 self.default_config.clone()
270 } else {
271 match connection::from_link_config(&link_config) {
273 Ok(cc) => self.default_config.merge(&ConnectionConfig {
274 subscriptions: Box::default(),
275 ..cc
276 }),
277 Err(e) => {
278 error!("Failed to build connection configuration: {e:?}");
279 return Err(anyhow!(e).context("failed to build connection config"));
280 }
281 }
282 };
283
284 let mut update_map = self.consumer_components.write().await;
285 let bundle = match self.connect(config, source_id).await {
286 Ok(b) => b,
287 Err(e) => {
288 error!("Failed to connect to NATS: {e:?}");
289 bail!(anyhow!(e).context("failed to connect to NATS"))
290 }
291 };
292 update_map.insert(source_id.into(), bundle);
293
294 Ok(())
295 }
296
297 #[instrument(level = "debug", skip_all, fields(target_id))]
298 async fn receive_link_config_as_source(
299 &self,
300 link_config: LinkConfig<'_>,
301 ) -> anyhow::Result<()> {
302 let target_id = link_config.target_id;
303 let config = if link_config.config.is_empty() {
304 self.default_config.clone()
305 } else {
306 match connection::from_link_config(&link_config) {
308 Ok(cc) => self.default_config.merge(&cc),
309 Err(e) => {
310 error!("Failed to build connection configuration: {e:?}");
311 return Err(anyhow!(e).context("failed to build connection config"));
312 }
313 }
314 };
315
316 let mut update_map = self.handler_components.write().await;
317 let bundle = match self.connect(config, target_id).await {
318 Ok(b) => b,
319 Err(e) => {
320 error!("Failed to connect to NATS: {e:?}");
321 bail!(anyhow!(e).context("failed to connect to NATS"))
322 }
323 };
324 update_map.insert(target_id.into(), bundle);
325
326 Ok(())
327 }
328
329 #[instrument(level = "info", skip_all, fields(source_id = info.get_source_id()))]
331 async fn delete_link_as_target(&self, info: impl LinkDeleteInfo) -> anyhow::Result<()> {
332 let component_id = info.get_source_id();
333 let mut links = self.consumer_components.write().await;
334 if let Some(bundle) = links.remove(component_id) {
335 let client = &bundle.client;
336 debug!(
337 component_id,
338 "dropping NATS client [{}] for (consumer) component",
339 format!(
340 "{}:{}",
341 client.server_info().server_id,
342 client.server_info().client_id
343 ),
344 );
345 }
346
347 debug!(
348 component_id,
349 "finished processing (consumer) link deletion for component",
350 );
351
352 Ok(())
353 }
354
355 #[instrument(level = "info", skip_all, fields(target_id = info.get_target_id()))]
356 async fn delete_link_as_source(&self, info: impl LinkDeleteInfo) -> anyhow::Result<()> {
357 let component_id = info.get_target_id();
359 let mut links = self.handler_components.write().await;
360 if let Some(bundle) = links.remove(component_id) {
361 let client = &bundle.client;
363 debug!(
364 component_id,
365 "dropping NATS client [{}] and associated subscriptions [{}] for (handler) component",
366 format!(
367 "{}:{}",
368 client.server_info().server_id,
369 client.server_info().client_id
370 ),
371 &bundle.sub_handles.len(),
372 );
373 }
374
375 debug!(
376 component_id,
377 "finished processing (handler) link deletion for component",
378 );
379
380 Ok(())
381 }
382
383 async fn shutdown(&self) -> anyhow::Result<()> {
385 let mut handlers = self.handler_components.write().await;
387 handlers.clear();
388
389 let mut consumers = self.consumer_components.write().await;
391 consumers.clear();
392
393 Ok(())
396 }
397}
398
399impl bindings::exports::wasmcloud::messaging::consumer::Handler<Option<Context>>
401 for NatsMessagingProvider
402{
403 #[instrument(level = "debug", skip(self, ctx, msg), fields(subject = %msg.subject, reply_to = ?msg.reply_to, body_len = %msg.body.len()))]
404 async fn publish(
405 &self,
406 ctx: Option<Context>,
407 msg: BrokerMessage,
408 ) -> anyhow::Result<Result<(), String>> {
409 propagate_trace_for_ctx!(ctx);
410
411 let nats_client =
412 if let Some(ref source_id) = ctx.and_then(|Context { component, .. }| component) {
413 let actors = self.consumer_components.read().await;
414 let nats_bundle = match actors.get(source_id) {
415 Some(nats_bundle) => nats_bundle,
416 None => {
417 error!("component not linked: {source_id}");
418 bail!("component not linked: {source_id}")
419 }
420 };
421 nats_bundle.client.clone()
422 } else {
423 error!("no component in request");
424 bail!("no component in request")
425 };
426
427 let headers = NatsHeaderInjector::default_with_span().into();
428
429 let body = msg.body;
430 let res = match msg.reply_to.clone() {
431 Some(reply_to) => if should_strip_headers(&msg.subject) {
432 nats_client
433 .publish_with_reply(msg.subject, reply_to, body)
434 .await
435 } else {
436 nats_client
437 .publish_with_reply_and_headers(msg.subject, reply_to, headers, body)
438 .await
439 }
440 .map_err(|e| e.to_string()),
441 None => nats_client
442 .publish_with_headers(msg.subject, headers, body)
443 .await
444 .map_err(|e| e.to_string()),
445 };
446 let _ = nats_client.flush().await;
447 Ok(res)
448 }
449
450 #[instrument(level = "debug", skip(self, ctx), fields(subject = %subject))]
451 async fn request(
452 &self,
453 ctx: Option<Context>,
454 subject: String,
455 body: Bytes,
456 timeout_ms: u32,
457 ) -> anyhow::Result<Result<BrokerMessage, String>> {
458 let nats_client =
459 if let Some(ref source_id) = ctx.and_then(|Context { component, .. }| component) {
460 let actors = self.consumer_components.read().await;
461 let nats_bundle = match actors.get(source_id) {
462 Some(nats_bundle) => nats_bundle,
463 None => {
464 error!("component not linked: {source_id}");
465 bail!("component not linked: {source_id}")
466 }
467 };
468 nats_bundle.client.clone()
469 } else {
470 error!("no component in request");
471 bail!("no component in request")
472 };
473
474 let headers = NatsHeaderInjector::default_with_span().into();
476
477 let timeout = Duration::from_millis(timeout_ms.into());
478 let request_with_timeout = if should_strip_headers(&subject) {
480 tokio::time::timeout(timeout, nats_client.request(subject, body)).await
481 } else {
482 tokio::time::timeout(
483 timeout,
484 nats_client.request_with_headers(subject, headers, body),
485 )
486 .await
487 };
488
489 match request_with_timeout {
491 Err(timeout_err) => {
492 error!("nats request timed out: {timeout_err}");
493 return Ok(Err(format!("nats request timed out: {timeout_err}")));
494 }
495 Ok(Err(send_err)) => {
496 error!("nats send error: {send_err}");
497 return Ok(Err(format!("nats send error: {send_err}")));
498 }
499 Ok(Ok(resp)) => Ok(Ok(BrokerMessage {
500 body: resp.payload,
501 reply_to: resp.reply.map(|s| s.into_string()),
502 subject: resp.subject.into_string(),
503 })),
504 }
505 }
506}
507
508fn should_strip_headers(topic: &str) -> bool {
511 topic.starts_with("$SYS")
512}
513
514pub fn add_tls_ca(
515 tls_ca: &str,
516 opts: async_nats::ConnectOptions,
517) -> anyhow::Result<async_nats::ConnectOptions> {
518 let ca = rustls_pemfile::read_one(&mut tls_ca.as_bytes()).context("failed to read CA")?;
519 let mut roots = async_nats::rustls::RootCertStore::empty();
520 if let Some(rustls_pemfile::Item::X509Certificate(ca)) = ca {
521 roots.add_parsable_certificates([ca]);
522 } else {
523 bail!("tls ca: invalid certificate type, must be a DER encoded PEM file")
524 };
525 let tls_client = async_nats::rustls::ClientConfig::builder()
526 .with_root_certificates(roots)
527 .with_no_client_auth();
528 Ok(opts.tls_client_config(tls_client).require_tls(true))
529}
530
531#[cfg(test)]
532mod test {
533 use super::*;
534 use std::collections::HashMap;
535
536 #[test]
537 fn test_default_connection_serialize() {
538 let input = r#"
540{
541 "cluster_uris": ["nats://soyvuh"],
542 "auth_jwt": "authy",
543 "auth_seed": "seedy"
544}
545"#;
546
547 let config: ConnectionConfig = serde_json::from_str(input).unwrap();
548 assert_eq!(config.auth_jwt.unwrap().as_ref(), "authy");
549 assert_eq!(config.auth_seed.unwrap().as_ref(), "seedy");
550 assert_eq!(config.cluster_uris, [Box::from("nats://soyvuh")].into());
551 assert_eq!(config.custom_inbox_prefix, None);
552 assert!(config.subscriptions.is_empty());
553 assert!(config.ping_interval_sec.is_none());
554 }
555
556 #[test]
557 fn test_connectionconfig_merge() {
558 let cc1 = ConnectionConfig {
560 cluster_uris: ["old_server".into()].into(),
561 subscriptions: ["topic1".into()].into(),
562 custom_inbox_prefix: Some("_NOPE.>".into()),
563 ..Default::default()
564 };
565 let cc2 = ConnectionConfig {
566 cluster_uris: ["server1".into(), "server2".into()].into(),
567 auth_jwt: Some("jawty".into()),
568 ..Default::default()
569 };
570 let cc3 = cc1.merge(&cc2);
571 assert_eq!(cc3.cluster_uris, cc2.cluster_uris);
572 assert_eq!(cc3.subscriptions, cc1.subscriptions);
573 assert_eq!(cc3.auth_jwt, Some("jawty".into()));
574 assert_eq!(cc3.custom_inbox_prefix, Some("_NOPE.>".into()));
575 }
576
577 #[test]
578 fn test_from_map() -> anyhow::Result<()> {
579 let cc = ConnectionConfig::from_map(&HashMap::from([(
580 "custom_inbox_prefix".into(),
581 "_TEST.>".into(),
582 )]))?;
583 assert_eq!(cc.custom_inbox_prefix, Some("_TEST.>".into()));
584 Ok(())
585 }
586}