wasmcloud_host/wasmbus/providers/
messaging_nats.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use anyhow::{anyhow, bail, Context as _};
5use async_nats::jetstream;
6use futures::StreamExt;
7use nkeys::{KeyPair, XKey};
8use tokio::fs;
9use tokio::sync::{broadcast, Mutex, RwLock};
10use tokio::task::JoinSet;
11use tokio::time::Instant;
12use tracing::{debug, error, instrument, trace_span, warn, Instrument as _, Span};
13use wasmcloud_core::messaging::{add_tls_ca, ConnectionConfig, ConsumerConfig};
14use wasmcloud_core::HostData;
15use wasmcloud_provider_sdk::provider::{
16    handle_provider_commands, receive_link_for_provider, ProviderCommandReceivers,
17};
18use wasmcloud_provider_sdk::{LinkConfig, LinkDeleteInfo, ProviderConnection};
19use wasmcloud_runtime::capability::wrpc;
20use wasmcloud_tracing::KeyValue;
21
22use crate::wasmbus::{Component, InvocationContext};
23
24struct Provider {
25    config: ConnectionConfig,
26    components: Arc<RwLock<HashMap<String, Arc<Component>>>>,
27    messaging_links:
28        Arc<RwLock<HashMap<Arc<str>, Arc<RwLock<HashMap<Box<str>, async_nats::Client>>>>>>,
29    subscriptions: Mutex<HashMap<Arc<str>, HashMap<Box<str>, JoinSet<()>>>>,
30    lattice_id: Arc<str>,
31    host_id: Arc<str>,
32}
33
34impl Provider {
35    async fn connect(
36        &self,
37        config: &HashMap<String, String>,
38    ) -> anyhow::Result<(async_nats::Client, ConnectionConfig)> {
39        // NOTE: Big part of this is copy-pasted from `provider-messaging-nats`
40        let config = if config.is_empty() {
41            self.config.clone()
42        } else {
43            match ConnectionConfig::from_map(config) {
44                Ok(cc) => self.config.merge(&cc),
45                Err(err) => {
46                    error!(?err, "failed to build connection configuration");
47                    return Err(anyhow!(err).context("failed to build connection config"));
48                }
49            }
50        };
51        let mut opts = match (&config.auth_jwt, &config.auth_seed) {
52            (Some(jwt), Some(seed)) => {
53                let seed = KeyPair::from_seed(seed).context("failed to parse seed key pair")?;
54                let seed = Arc::new(seed);
55                async_nats::ConnectOptions::with_jwt(jwt.to_string(), move |nonce| {
56                    let seed = seed.clone();
57                    async move { seed.sign(&nonce).map_err(async_nats::AuthError::new) }
58                })
59            }
60            (None, None) => async_nats::ConnectOptions::default(),
61            _ => bail!("must provide both jwt and seed for jwt authentication"),
62        };
63        if let Some(tls_ca) = config.tls_ca.as_deref() {
64            opts = add_tls_ca(tls_ca, opts)?;
65        } else if let Some(tls_ca_file) = config.tls_ca_file.as_deref() {
66            let ca = fs::read_to_string(tls_ca_file)
67                .await
68                .context("failed to read TLS CA file")?;
69            opts = add_tls_ca(&ca, opts)?;
70        }
71
72        // Use the first visible cluster_uri
73        let url = config.cluster_uris.first().context("invalid address")?;
74
75        // Override inbox prefix if specified
76        if let Some(ref prefix) = config.custom_inbox_prefix {
77            opts = opts.custom_inbox_prefix(prefix);
78        }
79        let nats = opts
80            .name("builtin NATS Messaging Provider")
81            .connect(url.as_ref())
82            .await
83            .context("failed to connect to NATS")?;
84        Ok((nats, config))
85    }
86}
87
88#[instrument(skip_all)]
89async fn handle_message(
90    components: Arc<RwLock<HashMap<String, Arc<Component>>>>,
91    lattice_id: Arc<str>,
92    host_id: Arc<str>,
93    target_id: Arc<str>,
94    msg: async_nats::Message,
95) {
96    use wrpc::exports::wasmcloud::messaging0_2_0::handler::Handler as _;
97
98    opentelemetry_nats::attach_span_context(&msg);
99    let component = {
100        let components = components.read().await;
101        let Some(component) = components.get(target_id.as_ref()) else {
102            warn!(?target_id, "linked component not found");
103            return;
104        };
105        Arc::clone(component)
106    };
107    let _permit = match component
108        .permits
109        .acquire()
110        .instrument(trace_span!("acquire_message_permit"))
111        .await
112    {
113        Ok(permit) => permit,
114        Err(err) => {
115            error!(?err, "failed to acquire execution permit");
116            return;
117        }
118    };
119    match component
120        .instantiate(component.handler.copy_for_new(), component.events.clone())
121        .handle_message(
122            InvocationContext {
123                span: Span::current(),
124                start_at: Instant::now(),
125                attributes: vec![
126                    KeyValue::new("component.ref", Arc::clone(&component.image_reference)),
127                    KeyValue::new("lattice", lattice_id),
128                    KeyValue::new("host", host_id),
129                ],
130            },
131            wrpc::wasmcloud::messaging0_2_0::types::BrokerMessage {
132                subject: msg.subject.into_string(),
133                body: msg.payload,
134                reply_to: msg.reply.map(async_nats::Subject::into_string),
135            },
136        )
137        .await
138    {
139        Ok(Ok(())) => {}
140        Ok(Err(err)) => {
141            warn!(?err, "component failed to handle message")
142        }
143        Err(err) => {
144            warn!(?err, "failed to call component")
145        }
146    }
147}
148
149impl wasmcloud_provider_sdk::Provider for Provider {
150    #[instrument(level = "debug", skip_all)]
151    async fn receive_link_config_as_target(
152        &self,
153        LinkConfig {
154            source_id,
155            link_name,
156            config,
157            ..
158        }: LinkConfig<'_>,
159    ) -> anyhow::Result<()> {
160        let (nats, _) = self.connect(config).await?;
161        let mut links = self.messaging_links.write().await;
162        let mut links = links.entry(source_id.into()).or_default().write().await;
163        links.insert(link_name.into(), nats);
164        Ok(())
165    }
166
167    #[instrument(level = "debug", skip_all)]
168    async fn receive_link_config_as_source(
169        &self,
170        LinkConfig {
171            target_id,
172            config,
173            link_name,
174            ..
175        }: LinkConfig<'_>,
176    ) -> anyhow::Result<()> {
177        let (nats, config) = self.connect(config).await?;
178        let mut tasks = JoinSet::new();
179        let target_id: Arc<str> = Arc::from(target_id);
180        for ConsumerConfig {
181            stream,
182            consumer,
183            max_messages,
184            max_bytes,
185        } in config.consumers
186        {
187            let js = jetstream::new(nats.clone());
188            let stream = js
189                .get_stream(stream)
190                .await
191                .context("failed to get stream")?;
192            let consumer = stream
193                .get_consumer(&consumer)
194                .await
195                .map_err(|err| anyhow!(err).context("failed to get consumer"))?;
196            let sub = consumer.batch();
197            let sub = if let Some(max_messages) = max_messages {
198                sub.max_messages(max_messages)
199            } else {
200                sub
201            };
202            let sub = if let Some(max_bytes) = max_bytes {
203                sub.max_bytes(max_bytes)
204            } else {
205                sub
206            };
207            let mut sub = sub.messages().await.context("failed to subscribe")?;
208
209            let components = Arc::clone(&self.components);
210            let lattice_id = Arc::clone(&self.lattice_id);
211            let host_id = Arc::clone(&self.host_id);
212            let target_id = Arc::clone(&target_id);
213            tasks.spawn(async move {
214                while let Some(msg) = sub.next().await {
215                    let msg = match msg {
216                        Ok(msg) => msg,
217                        Err(err) => {
218                            error!(?err, "failed to receive message");
219                            continue;
220                        }
221                    };
222                    let (msg, ack) = msg.split();
223                    tokio::spawn(async move {
224                        if let Err(err) = ack.ack().await {
225                            error!(?err, "failed to ACK message");
226                        } else {
227                            debug!("successfully ACK'ed message")
228                        }
229                    });
230                    tokio::spawn(handle_message(
231                        Arc::clone(&components),
232                        Arc::clone(&lattice_id),
233                        Arc::clone(&host_id),
234                        Arc::clone(&target_id),
235                        msg,
236                    ));
237                }
238            });
239        }
240        for sub in config.subscriptions {
241            if sub.is_empty() {
242                continue;
243            }
244            let mut sub = if let Some((subject, queue)) = sub.split_once('|') {
245                nats.queue_subscribe(async_nats::Subject::from(subject), queue.into())
246                    .await
247            } else {
248                nats.subscribe(sub).await
249            }
250            .context("failed to subscribe")?;
251            let components = Arc::clone(&self.components);
252            let lattice_id = Arc::clone(&self.lattice_id);
253            let host_id = Arc::clone(&self.host_id);
254            let target_id = Arc::clone(&target_id);
255            tasks.spawn(async move {
256                while let Some(msg) = sub.next().await {
257                    tokio::spawn(handle_message(
258                        Arc::clone(&components),
259                        Arc::clone(&lattice_id),
260                        Arc::clone(&host_id),
261                        Arc::clone(&target_id),
262                        msg,
263                    ));
264                }
265            });
266        }
267        self.subscriptions
268            .lock()
269            .await
270            .entry(target_id)
271            .or_default()
272            .insert(link_name.into(), tasks);
273        Ok(())
274    }
275
276    #[instrument(level = "debug", skip_all)]
277    async fn delete_link_as_source(&self, info: impl LinkDeleteInfo) -> anyhow::Result<()> {
278        let target_id = info.get_target_id();
279        let link_name = info.get_link_name();
280        self.subscriptions
281            .lock()
282            .await
283            .get_mut(target_id)
284            .map(|links| links.remove(link_name));
285        Ok(())
286    }
287}
288
289impl crate::wasmbus::Host {
290    #[instrument(level = "debug", skip_all)]
291    pub(crate) async fn start_messaging_nats_provider(
292        &self,
293        host_data: HostData,
294        provider_xkey: XKey,
295        provider_id: &str,
296    ) -> anyhow::Result<JoinSet<()>> {
297        let host_id = self.host_key.public_key();
298        let config =
299            ConnectionConfig::from_map(&host_data.config).context("failed to parse config")?;
300
301        let (quit_tx, quit_rx) = broadcast::channel(1);
302        let commands = ProviderCommandReceivers::new(
303            Arc::clone(&self.rpc_nats),
304            &quit_tx,
305            &self.host_config.lattice,
306            provider_id,
307            provider_id,
308            &host_id,
309        )
310        .await?;
311        let conn = ProviderConnection::new(
312            Arc::clone(&self.rpc_nats),
313            Arc::from(provider_id),
314            Arc::clone(&self.host_config.lattice),
315            host_id.to_string(),
316            host_data.config,
317            provider_xkey,
318            Arc::clone(&self.secrets_xkey),
319        )
320        .context("failed to establish provider connection")?;
321        let provider = Provider {
322            config,
323            components: Arc::clone(&self.components),
324            messaging_links: Arc::clone(&self.messaging_links),
325            subscriptions: Mutex::default(),
326            host_id: Arc::from(host_id),
327            lattice_id: Arc::clone(&self.host_config.lattice),
328        };
329        for ld in host_data.link_definitions {
330            if let Err(e) = receive_link_for_provider(&provider, &conn, ld).await {
331                error!(
332                    error = %e,
333                    "failed to initialize link during provider startup",
334                );
335            }
336        }
337        let mut tasks = JoinSet::new();
338        tasks.spawn(async move {
339            handle_provider_commands(provider, &conn, quit_rx, quit_tx, commands).await
340        });
341
342        Ok(tasks)
343    }
344}