wasmcloud_provider_keyvalue_nats/
lib.rs

1//! NATS implementation for wrpc:keyvalue.
2//!
3//! This implementation is multi-threaded and operations between different consumer/client
4//! components use different connections and can run in parallel.
5//!
6//! A single connection is shared by all instances of the same consumer component, identified
7//! by its id (public key), so there may be some brief lock contention if several instances of
8//! the same component are simultaneously attempting to communicate with NATS.
9
10use std::collections::HashMap;
11use std::sync::Arc;
12
13use anyhow::{anyhow, bail, Context as _};
14use bytes::Bytes;
15use futures::{StreamExt as _, TryStreamExt as _};
16use tokio::fs;
17use tokio::sync::RwLock;
18use tracing::{debug, error, info, instrument, warn};
19use wascap::prelude::KeyPair;
20use wasmcloud_provider_sdk::core::HostData;
21use wasmcloud_provider_sdk::{
22    get_connection, initialize_observability, load_host_data, propagate_trace_for_ctx,
23    run_provider, serve_provider_exports, Context, LinkConfig, LinkDeleteInfo, Provider,
24};
25
26mod config;
27use config::NatsConnectionConfig;
28
29mod bindings {
30    wit_bindgen_wrpc::generate!({
31        with: {
32            "wrpc:keyvalue/atomics@0.2.0-draft": generate,
33            "wrpc:keyvalue/batch@0.2.0-draft": generate,
34            "wrpc:keyvalue/store@0.2.0-draft": generate,
35        }
36    });
37}
38use bindings::exports::wrpc::keyvalue;
39
40type Result<T, E = keyvalue::store::Error> = core::result::Result<T, E>;
41
42pub async fn run() -> anyhow::Result<()> {
43    KvNatsProvider::run().await
44}
45
46/// The `atomic::increment` function's exponential backoff base interval
47const EXPONENTIAL_BACKOFF_BASE_INTERVAL: u64 = 5; // milliseconds
48
49/// [`NatsKvStores`] holds the handles to opened NATS Kv Stores, and their respective identifiers.
50type NatsKvStores = HashMap<String, async_nats::jetstream::kv::Store>;
51
52/// NATS implementation for wasi:keyvalue (via wrpc:keyvalue)
53#[derive(Default, Clone)]
54pub struct KvNatsProvider {
55    consumer_components: Arc<RwLock<HashMap<String, NatsKvStores>>>,
56    default_config: NatsConnectionConfig,
57}
58/// Implement the [`KvNatsProvider`] and [`Provider`] traits
59impl KvNatsProvider {
60    pub async fn run() -> anyhow::Result<()> {
61        let host_data = load_host_data().context("failed to load host data")?;
62        let flamegraph_path = host_data
63            .config
64            .get("FLAMEGRAPH_PATH")
65            .map(String::from)
66            .or_else(|| std::env::var("PROVIDER_KEYVALUE_NATS_FLAMEGRAPH_PATH").ok());
67        initialize_observability!("keyvalue-nats-provider", flamegraph_path);
68        let provider = Self::from_host_data(host_data);
69        let shutdown = run_provider(provider.clone(), "keyvalue-nats-provider")
70            .await
71            .context("failed to run provider")?;
72        let connection = get_connection();
73        let wrpc = connection
74            .get_wrpc_client(connection.provider_key())
75            .await?;
76        serve_provider_exports(&wrpc, provider, shutdown, bindings::serve)
77            .await
78            .context("failed to serve provider exports")
79    }
80
81    /// Build a [`KvNatsProvider`] from [`HostData`]
82    pub fn from_host_data(host_data: &HostData) -> KvNatsProvider {
83        let config =
84            NatsConnectionConfig::from_config_and_secrets(&host_data.config, &host_data.secrets);
85        if let Ok(config) = config {
86            KvNatsProvider {
87                default_config: config,
88                ..Default::default()
89            }
90        } else {
91            warn!("Failed to build NATS connection configuration, falling back to default");
92            KvNatsProvider::default()
93        }
94    }
95
96    /// Attempt to connect to NATS url (with JWT credentials, if provided)
97    async fn connect(
98        &self,
99        cfg: NatsConnectionConfig,
100        link_cfg: &LinkConfig<'_>,
101    ) -> anyhow::Result<async_nats::jetstream::kv::Store> {
102        let mut opts = match (cfg.auth_jwt, cfg.auth_seed) {
103            (Some(jwt), Some(seed)) => {
104                let seed = KeyPair::from_seed(&seed).context("failed to parse seed key pair")?;
105                let seed = Arc::new(seed);
106                async_nats::ConnectOptions::with_jwt(jwt, move |nonce| {
107                    let seed = seed.clone();
108                    async move { seed.sign(&nonce).map_err(async_nats::AuthError::new) }
109                })
110            }
111            (None, None) => async_nats::ConnectOptions::default(),
112            _ => bail!("must provide both jwt and seed for jwt authentication"),
113        };
114        if let Some(tls_ca) = &cfg.tls_ca {
115            opts = add_tls_ca(tls_ca, opts)?;
116        } else if let Some(tls_ca_file) = &cfg.tls_ca_file {
117            let ca = fs::read_to_string(tls_ca_file)
118                .await
119                .context("failed to read TLS CA file")?;
120            opts = add_tls_ca(&ca, opts)?;
121        }
122
123        // Get the cluster_uri
124        let uri = cfg.cluster_uri.unwrap_or_default();
125
126        // Connect to the NATS server
127        let client = opts
128            .name("NATS Key-Value Provider") // allow this to show up uniquely in a NATS connection list
129            .connect(uri.clone())
130            .await?;
131
132        // Get the JetStream context based on js_domain
133        let js_context = if let Some(domain) = &cfg.js_domain {
134            async_nats::jetstream::with_domain(client.clone(), domain.clone())
135        } else {
136            async_nats::jetstream::new(client.clone())
137        };
138
139        // If bucket auto-creation was specified in the link configuration,
140        // create a bucket
141        if link_cfg
142            .config
143            .get("enable_bucket_auto_create")
144            .is_some_and(|v| v.to_lowercase() == "true")
145        {
146            // Get the JetStream context based on js_domain
147            if let Err(e) = js_context
148                .create_key_value(async_nats::jetstream::kv::Config {
149                    bucket: cfg.bucket.clone(),
150                    ..Default::default()
151                })
152                .await
153            {
154                warn!("failed to auto create bucket [{}]: {e}", cfg.bucket);
155            }
156        };
157
158        // Open the key-value store
159        let store = js_context.get_key_value(&cfg.bucket).await?;
160        info!(%cfg.bucket, "NATS Kv store opened");
161
162        // Return the handle to the opened NATS Kv store
163        Ok(store)
164    }
165
166    /// Helper function to lookup and return the NATS Kv store handle, from the client component's context
167    async fn get_kv_store(
168        &self,
169        context: Option<Context>,
170        bucket_id: String,
171    ) -> Result<async_nats::jetstream::kv::Store, keyvalue::store::Error> {
172        if let Some(ref source_id) = context
173            .as_ref()
174            .and_then(|Context { component, .. }| component.clone())
175        {
176            let components = self.consumer_components.read().await;
177            let kv_stores = match components.get(source_id) {
178                Some(kv_stores) => kv_stores,
179                None => {
180                    return Err(keyvalue::store::Error::Other(format!(
181                        "consumer component not linked: {source_id}"
182                    )));
183                }
184            };
185            kv_stores.get(&bucket_id).cloned().ok_or_else(|| {
186                keyvalue::store::Error::Other(format!(
187                    "No NATS Kv store found for bucket id (link name): {bucket_id}"
188                ))
189            })
190        } else {
191            Err(keyvalue::store::Error::Other(
192                "no consumer component in the request".to_string(),
193            ))
194        }
195    }
196
197    /// Helper function to get a value from the key-value store
198    #[instrument(level = "debug", skip_all)]
199    async fn get(
200        &self,
201        context: Option<Context>,
202        bucket: String,
203        key: String,
204    ) -> anyhow::Result<Result<Option<Bytes>>> {
205        keyvalue::store::Handler::get(self, context, bucket, key).await
206    }
207
208    /// Helper function to set a value in the key-value store
209    async fn set(
210        &self,
211        context: Option<Context>,
212        bucket: String,
213        key: String,
214        value: Bytes,
215    ) -> anyhow::Result<Result<()>> {
216        keyvalue::store::Handler::set(self, context, bucket, key, value).await
217    }
218
219    /// Helper function to delete a key-value pair from the key-value store
220    async fn delete(
221        &self,
222        context: Option<Context>,
223        bucket: String,
224        key: String,
225    ) -> anyhow::Result<Result<()>> {
226        keyvalue::store::Handler::delete(self, context, bucket, key).await
227    }
228}
229
230/// Handle provider control commands
231impl Provider for KvNatsProvider {
232    /// Provider should perform any operations needed for a new link,
233    /// including setting up per-component resources, and checking authorization.
234    /// If the link is allowed, return true, otherwise return false to deny the link.
235    #[instrument(level = "debug", skip_all, fields(source_id))]
236    async fn receive_link_config_as_target(
237        &self,
238        link_config: LinkConfig<'_>,
239    ) -> anyhow::Result<()> {
240        let nats_config = if link_config.config.is_empty() {
241            self.default_config.clone()
242        } else {
243            // create a config from the supplied values and merge that with the existing default
244            // NATS connection configuration
245            match NatsConnectionConfig::from_config_and_secrets(
246                link_config.config,
247                link_config.secrets,
248            ) {
249                Ok(ncc) => self.default_config.merge(&ncc),
250                Err(e) => {
251                    error!("Failed to build NATS connection configuration: {e:?}");
252                    return Err(anyhow!(e).context("failed to build NATS connection configuration"));
253                }
254            }
255        };
256        println!("NATS Kv configuration: {nats_config:?}");
257
258        let LinkConfig {
259            source_id,
260            link_name,
261            ..
262        }: LinkConfig<'_> = link_config;
263
264        let kv_store = match self.connect(nats_config, &link_config).await {
265            Ok(b) => b,
266            Err(e) => {
267                error!("Failed to connect to NATS: {e:?}");
268                bail!(anyhow!(e).context("failed to connect to NATS"))
269            }
270        };
271
272        let mut consumer_components = self.consumer_components.write().await;
273        // Check if there's an existing hashmap for the source_id
274        if let Some(existing_kv_stores) = consumer_components.get_mut(&source_id.to_string()) {
275            // If so, insert the new kv_store into it
276            existing_kv_stores.insert(link_name.into(), kv_store);
277        } else {
278            // Otherwise, create a new hashmap and insert it
279            consumer_components.insert(
280                source_id.into(),
281                HashMap::from([(link_name.into(), kv_store)]),
282            );
283        }
284
285        Ok(())
286    }
287
288    /// Provider should perform any operations needed for a link deletion, including cleaning up
289    /// per-component resources.
290    #[instrument(level = "info", skip_all, fields(source_id = info.get_source_id()))]
291    async fn delete_link_as_target(&self, info: impl LinkDeleteInfo) -> anyhow::Result<()> {
292        let component_id = info.get_source_id();
293        let mut links = self.consumer_components.write().await;
294        if let Some(kv_store) = links.remove(component_id) {
295            debug!(
296                component_id,
297                "dropping NATS Kv store [{kv_store:?}] for (consumer) component...",
298            );
299        }
300
301        debug!(component_id, "finished processing link deletion");
302
303        Ok(())
304    }
305
306    /// Handle shutdown request by closing all connections
307    async fn shutdown(&self) -> anyhow::Result<()> {
308        // clear the consumer components
309        let mut consumers = self.consumer_components.write().await;
310        consumers.clear();
311
312        Ok(())
313    }
314}
315
316/// Implement the 'wasi:keyvalue/store' capability provider interface
317impl keyvalue::store::Handler<Option<Context>> for KvNatsProvider {
318    // Get the last revision of a value, for a given key, from the key-value store
319    #[instrument(level = "debug", skip(self))]
320    async fn get(
321        &self,
322        context: Option<Context>,
323        bucket: String,
324        key: String,
325    ) -> anyhow::Result<Result<Option<Bytes>>> {
326        propagate_trace_for_ctx!(context);
327
328        match self.get_kv_store(context, bucket).await {
329            Ok(store) => match store.get(key.clone()).await {
330                Ok(Some(bytes)) => Ok(Ok(Some(bytes))),
331                Ok(None) => Ok(Ok(None)),
332                Err(err) => {
333                    error!(%key, "failed to get key value: {err:?}");
334                    Ok(Err(keyvalue::store::Error::Other(err.to_string())))
335                }
336            },
337            Err(err) => Ok(Err(err)),
338        }
339    }
340
341    // Set new key-value pair in the key-value store. If key didn’t exist, it is created. If it did exist, a new value with a new version is added
342    #[instrument(level = "debug", skip(self))]
343    async fn set(
344        &self,
345        context: Option<Context>,
346        bucket: String,
347        key: String,
348        value: Bytes,
349    ) -> anyhow::Result<Result<()>> {
350        propagate_trace_for_ctx!(context);
351
352        match self.get_kv_store(context, bucket).await {
353            Ok(store) => match store.put(key.clone(), value).await {
354                Ok(_) => Ok(Ok(())),
355                Err(err) => {
356                    error!(%key, "failed to set key value: {err:?}");
357                    Ok(Err(keyvalue::store::Error::Other(err.to_string())))
358                }
359            },
360            Err(err) => Ok(Err(err)),
361        }
362    }
363
364    // Purge all the revisions of a key destructively,  from the key-value store, leaving behind a single purge entry in-place.
365    #[instrument(level = "debug", skip(self))]
366    async fn delete(
367        &self,
368        context: Option<Context>,
369        bucket: String,
370        key: String,
371    ) -> anyhow::Result<Result<()>> {
372        propagate_trace_for_ctx!(context);
373
374        match self.get_kv_store(context, bucket).await {
375            Ok(store) => match store.purge(key.clone()).await {
376                Ok(_) => Ok(Ok(())),
377                Err(err) => {
378                    error!(%key, "failed to delete key: {err:?}");
379                    Ok(Err(keyvalue::store::Error::Other(err.to_string())))
380                }
381            },
382            Err(err) => Ok(Err(err)),
383        }
384    }
385
386    // Check if a key exists in the key-value store
387    #[instrument(level = "debug", skip(self))]
388    async fn exists(
389        &self,
390        context: Option<Context>,
391        bucket: String,
392        key: String,
393    ) -> anyhow::Result<Result<bool>> {
394        propagate_trace_for_ctx!(context);
395
396        match self.get(context, bucket, key).await {
397            Ok(Ok(Some(_))) => Ok(Ok(true)),
398            Ok(Ok(None)) => Ok(Ok(false)),
399            Ok(Err(err)) => Ok(Err(err)),
400            Err(err) => Ok(Err(keyvalue::store::Error::Other(err.to_string()))),
401        }
402    }
403
404    // List all keys in the key-value store
405    #[instrument(level = "debug", skip(self))]
406    async fn list_keys(
407        &self,
408        context: Option<Context>,
409        bucket: String,
410        cursor: Option<u64>,
411    ) -> anyhow::Result<Result<keyvalue::store::KeyResponse>> {
412        propagate_trace_for_ctx!(context);
413
414        match self.get_kv_store(context, bucket).await {
415            Ok(store) => match store.keys().await {
416                Ok(keys) => {
417                    match keys
418                        .skip(cursor.unwrap_or(0) as usize)
419                        .take(usize::MAX)
420                        .try_collect()
421                        .await
422                    {
423                        Ok(keys) => Ok(Ok(keyvalue::store::KeyResponse { keys, cursor: None })),
424                        Err(err) => {
425                            error!("failed to list keys: {err:?}");
426                            Ok(Err(keyvalue::store::Error::Other(err.to_string())))
427                        }
428                    }
429                }
430                Err(err) => {
431                    error!("failed to list keys: {err:?}");
432                    Ok(Err(keyvalue::store::Error::Other(err.to_string())))
433                }
434            },
435            Err(err) => Ok(Err(err)),
436        }
437    }
438}
439
440/// Implement the 'wasi:keyvalue/atomic' capability provider interface
441impl keyvalue::atomics::Handler<Option<Context>> for KvNatsProvider {
442    /// Increments a numeric value, returning the new value
443    #[instrument(level = "debug", skip(self))]
444    async fn increment(
445        &self,
446        context: Option<Context>,
447        bucket: String,
448        key: String,
449        delta: u64,
450    ) -> anyhow::Result<Result<u64, keyvalue::store::Error>> {
451        propagate_trace_for_ctx!(context);
452
453        // Try to increment the value up to 5 times with exponential backoff
454        let kv_store = self.get_kv_store(context.clone(), bucket.clone()).await?;
455
456        let mut new_value = 0;
457        let mut success = false;
458        for attempt in 0..5 {
459            // Get the latest entry from the key-value store
460            let entry = kv_store.entry(key.clone()).await?;
461
462            // Get the current value and revision
463            let (current_value, revision) = match &entry {
464                Some(entry) if !entry.value.is_empty() => {
465                    let value_str = std::str::from_utf8(&entry.value)?;
466                    match value_str.parse::<u64>() {
467                        Ok(num) => (num, entry.revision),
468                        Err(_) => {
469                            return Err(keyvalue::store::Error::Other(
470                                "Cannot increment a non-numerical value".to_string(),
471                            )
472                            .into())
473                        }
474                    }
475                }
476                _ => (0, entry.as_ref().map_or(0, |e| e.revision)),
477            };
478
479            new_value = current_value + delta;
480
481            // Increment the value of the key
482            match kv_store
483                .update(key.clone(), new_value.to_string().into(), revision)
484                .await
485            {
486                Ok(_) => {
487                    success = true;
488                    break; // Exit the loop on success
489                }
490                Err(_) => {
491                    // Apply exponential backoff delay if the revision has changed (i.e. the key has been updated since the last read)
492                    if attempt > 0 {
493                        let wait_time = EXPONENTIAL_BACKOFF_BASE_INTERVAL * 2u64.pow(attempt - 1);
494                        tokio::time::sleep(std::time::Duration::from_millis(wait_time)).await;
495                    }
496                }
497            }
498        }
499
500        if success {
501            Ok(Ok(new_value))
502        } else {
503            // If all attempts fail, let user know
504            Ok(Err(keyvalue::store::Error::Other(
505                "Failed to increment the value after 5 attempts".to_string(),
506            )))
507        }
508    }
509}
510
511/// Reducing type complexity for the `get_many` function of wasi:keyvalue/batch
512type KvResult = Vec<Option<(String, Bytes)>>;
513
514/// Implement the 'wasi:keyvalue/batch' capability provider interface
515impl keyvalue::batch::Handler<Option<Context>> for KvNatsProvider {
516    // Get multiple values from the key-value store
517    #[instrument(level = "debug", skip(self))]
518    async fn get_many(
519        &self,
520        ctx: Option<Context>,
521        bucket: String,
522        keys: Vec<String>,
523    ) -> anyhow::Result<Result<KvResult>> {
524        let ctx = ctx.clone();
525        let bucket = bucket.clone();
526
527        // Get the values for the keys
528        let results: Result<Vec<_>, _> = keys
529            .into_iter()
530            .map(|key| {
531                let ctx = ctx.clone();
532                let bucket = bucket.clone();
533                async move {
534                    self.get(ctx, bucket, key.clone())
535                        .await
536                        .map(|value| (key, value))
537                }
538            })
539            .collect::<futures::stream::FuturesUnordered<_>>()
540            .try_collect()
541            .await;
542
543        match results {
544            Ok(values) => {
545                let values: Result<Vec<_>, _> = values
546                    .into_iter()
547                    .map(|(k, res)| match res {
548                        Ok(Some(v)) => Ok(Some((k, v))),
549                        Ok(None) => Ok(None),
550                        Err(err) => {
551                            error!("failed to parse key-value pairs: {err:?}");
552                            Err(keyvalue::store::Error::Other(err.to_string()))
553                        }
554                    })
555                    .collect();
556                Ok(values)
557            }
558            Err(err) => {
559                error!("failed to get many keys: {err:?}");
560                Ok(Err(keyvalue::store::Error::Other(err.to_string())))
561            }
562        }
563    }
564
565    // Set multiple values in the key-value store
566    #[instrument(level = "debug", skip(self))]
567    async fn set_many(
568        &self,
569        ctx: Option<Context>,
570        bucket: String,
571        items: Vec<(String, Bytes)>,
572    ) -> anyhow::Result<Result<()>> {
573        let ctx = ctx.clone();
574        let bucket = bucket.clone();
575
576        // Set the values for the keys
577        let results: Result<Vec<_>, _> = items
578            .into_iter()
579            .map(|(key, value)| {
580                let ctx = ctx.clone();
581                let bucket = bucket.clone();
582                async move { self.set(ctx, bucket, key, value).await }
583            })
584            .collect::<futures::stream::FuturesUnordered<_>>()
585            .try_collect()
586            .await;
587
588        // If all set operations were successful, return Ok(())
589        results.map(|_| Ok(()))
590    }
591
592    // Delete multiple keys from the key-value store
593    #[instrument(level = "debug", skip(self))]
594    async fn delete_many(
595        &self,
596        ctx: Option<Context>,
597        bucket: String,
598        keys: Vec<String>,
599    ) -> anyhow::Result<Result<()>> {
600        let ctx = ctx.clone();
601        let bucket = bucket.clone();
602
603        // Delete the keys
604        let results: Result<Vec<_>, _> = keys
605            .into_iter()
606            .map(|key| {
607                let ctx = ctx.clone();
608                let bucket = bucket.clone();
609                async move { self.delete(ctx, bucket, key).await }
610            })
611            .collect::<futures::stream::FuturesUnordered<_>>()
612            .try_collect()
613            .await;
614
615        // If all delete operations were successful, return Ok(())
616        results.map(|_| Ok(()))
617    }
618}
619
620/// Helper function for adding the TLS CA to the NATS connection options
621fn add_tls_ca(
622    tls_ca: &str,
623    opts: async_nats::ConnectOptions,
624) -> anyhow::Result<async_nats::ConnectOptions> {
625    let ca = rustls_pemfile::read_one(&mut tls_ca.as_bytes()).context("failed to read CA")?;
626    let mut roots = async_nats::rustls::RootCertStore::empty();
627    if let Some(rustls_pemfile::Item::X509Certificate(ca)) = ca {
628        roots.add_parsable_certificates([ca]);
629    } else {
630        bail!("tls ca: invalid certificate type, must be a DER encoded PEM file")
631    };
632    let tls_client = async_nats::rustls::ClientConfig::builder()
633        .with_root_certificates(roots)
634        .with_no_client_auth();
635    Ok(opts.tls_client_config(tls_client).require_tls(true))
636}
637
638// Performing various provider configuration tests
639#[cfg(test)]
640mod test {
641    use super::*;
642
643    // Verify that tls_ca is set
644    #[test]
645    fn test_add_tls_ca() {
646        let tls_ca = "-----BEGIN CERTIFICATE-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAwJwz\n-----END CERTIFICATE-----";
647        let opts = async_nats::ConnectOptions::new();
648        let opts = add_tls_ca(tls_ca, opts);
649        assert!(opts.is_ok())
650    }
651}