1use std::collections::HashMap;
11use std::sync::Arc;
12
13use anyhow::{anyhow, bail, Context as _};
14use bytes::Bytes;
15use futures::{StreamExt as _, TryStreamExt as _};
16use tokio::fs;
17use tokio::sync::RwLock;
18use tracing::{debug, error, info, instrument, warn};
19use wascap::prelude::KeyPair;
20use wasmcloud_provider_sdk::core::HostData;
21use wasmcloud_provider_sdk::{
22 get_connection, initialize_observability, load_host_data, propagate_trace_for_ctx,
23 run_provider, serve_provider_exports, Context, LinkConfig, LinkDeleteInfo, Provider,
24};
25
26mod config;
27use config::NatsConnectionConfig;
28
29mod bindings {
30 wit_bindgen_wrpc::generate!({
31 with: {
32 "wrpc:keyvalue/atomics@0.2.0-draft": generate,
33 "wrpc:keyvalue/batch@0.2.0-draft": generate,
34 "wrpc:keyvalue/store@0.2.0-draft": generate,
35 }
36 });
37}
38use bindings::exports::wrpc::keyvalue;
39
40type Result<T, E = keyvalue::store::Error> = core::result::Result<T, E>;
41
42pub async fn run() -> anyhow::Result<()> {
43 KvNatsProvider::run().await
44}
45
46const EXPONENTIAL_BACKOFF_BASE_INTERVAL: u64 = 5; type NatsKvStores = HashMap<String, async_nats::jetstream::kv::Store>;
51
52#[derive(Default, Clone)]
54pub struct KvNatsProvider {
55 consumer_components: Arc<RwLock<HashMap<String, NatsKvStores>>>,
56 default_config: NatsConnectionConfig,
57}
58impl KvNatsProvider {
60 pub async fn run() -> anyhow::Result<()> {
61 let host_data = load_host_data().context("failed to load host data")?;
62 let flamegraph_path = host_data
63 .config
64 .get("FLAMEGRAPH_PATH")
65 .map(String::from)
66 .or_else(|| std::env::var("PROVIDER_KEYVALUE_NATS_FLAMEGRAPH_PATH").ok());
67 initialize_observability!("keyvalue-nats-provider", flamegraph_path);
68 let provider = Self::from_host_data(host_data);
69 let shutdown = run_provider(provider.clone(), "keyvalue-nats-provider")
70 .await
71 .context("failed to run provider")?;
72 let connection = get_connection();
73 let wrpc = connection
74 .get_wrpc_client(connection.provider_key())
75 .await?;
76 serve_provider_exports(&wrpc, provider, shutdown, bindings::serve)
77 .await
78 .context("failed to serve provider exports")
79 }
80
81 pub fn from_host_data(host_data: &HostData) -> KvNatsProvider {
83 let config =
84 NatsConnectionConfig::from_config_and_secrets(&host_data.config, &host_data.secrets);
85 if let Ok(config) = config {
86 KvNatsProvider {
87 default_config: config,
88 ..Default::default()
89 }
90 } else {
91 warn!("Failed to build NATS connection configuration, falling back to default");
92 KvNatsProvider::default()
93 }
94 }
95
96 async fn connect(
98 &self,
99 cfg: NatsConnectionConfig,
100 link_cfg: &LinkConfig<'_>,
101 ) -> anyhow::Result<async_nats::jetstream::kv::Store> {
102 let mut opts = match (cfg.auth_jwt, cfg.auth_seed) {
103 (Some(jwt), Some(seed)) => {
104 let seed = KeyPair::from_seed(&seed).context("failed to parse seed key pair")?;
105 let seed = Arc::new(seed);
106 async_nats::ConnectOptions::with_jwt(jwt, move |nonce| {
107 let seed = seed.clone();
108 async move { seed.sign(&nonce).map_err(async_nats::AuthError::new) }
109 })
110 }
111 (None, None) => async_nats::ConnectOptions::default(),
112 _ => bail!("must provide both jwt and seed for jwt authentication"),
113 };
114 if let Some(tls_ca) = &cfg.tls_ca {
115 opts = add_tls_ca(tls_ca, opts)?;
116 } else if let Some(tls_ca_file) = &cfg.tls_ca_file {
117 let ca = fs::read_to_string(tls_ca_file)
118 .await
119 .context("failed to read TLS CA file")?;
120 opts = add_tls_ca(&ca, opts)?;
121 }
122
123 let uri = cfg.cluster_uri.unwrap_or_default();
125
126 let client = opts
128 .name("NATS Key-Value Provider") .connect(uri.clone())
130 .await?;
131
132 let js_context = if let Some(domain) = &cfg.js_domain {
134 async_nats::jetstream::with_domain(client.clone(), domain.clone())
135 } else {
136 async_nats::jetstream::new(client.clone())
137 };
138
139 if link_cfg
142 .config
143 .get("enable_bucket_auto_create")
144 .is_some_and(|v| v.to_lowercase() == "true")
145 {
146 if let Err(e) = js_context
148 .create_key_value(async_nats::jetstream::kv::Config {
149 bucket: cfg.bucket.clone(),
150 ..Default::default()
151 })
152 .await
153 {
154 warn!("failed to auto create bucket [{}]: {e}", cfg.bucket);
155 }
156 };
157
158 let store = js_context.get_key_value(&cfg.bucket).await?;
160 info!(%cfg.bucket, "NATS Kv store opened");
161
162 Ok(store)
164 }
165
166 async fn get_kv_store(
168 &self,
169 context: Option<Context>,
170 bucket_id: String,
171 ) -> Result<async_nats::jetstream::kv::Store, keyvalue::store::Error> {
172 if let Some(ref source_id) = context
173 .as_ref()
174 .and_then(|Context { component, .. }| component.clone())
175 {
176 let components = self.consumer_components.read().await;
177 let kv_stores = match components.get(source_id) {
178 Some(kv_stores) => kv_stores,
179 None => {
180 return Err(keyvalue::store::Error::Other(format!(
181 "consumer component not linked: {source_id}"
182 )));
183 }
184 };
185 kv_stores.get(&bucket_id).cloned().ok_or_else(|| {
186 keyvalue::store::Error::Other(format!(
187 "No NATS Kv store found for bucket id (link name): {bucket_id}"
188 ))
189 })
190 } else {
191 Err(keyvalue::store::Error::Other(
192 "no consumer component in the request".to_string(),
193 ))
194 }
195 }
196
197 #[instrument(level = "debug", skip_all)]
199 async fn get(
200 &self,
201 context: Option<Context>,
202 bucket: String,
203 key: String,
204 ) -> anyhow::Result<Result<Option<Bytes>>> {
205 keyvalue::store::Handler::get(self, context, bucket, key).await
206 }
207
208 async fn set(
210 &self,
211 context: Option<Context>,
212 bucket: String,
213 key: String,
214 value: Bytes,
215 ) -> anyhow::Result<Result<()>> {
216 keyvalue::store::Handler::set(self, context, bucket, key, value).await
217 }
218
219 async fn delete(
221 &self,
222 context: Option<Context>,
223 bucket: String,
224 key: String,
225 ) -> anyhow::Result<Result<()>> {
226 keyvalue::store::Handler::delete(self, context, bucket, key).await
227 }
228}
229
230impl Provider for KvNatsProvider {
232 #[instrument(level = "debug", skip_all, fields(source_id))]
236 async fn receive_link_config_as_target(
237 &self,
238 link_config: LinkConfig<'_>,
239 ) -> anyhow::Result<()> {
240 let nats_config = if link_config.config.is_empty() {
241 self.default_config.clone()
242 } else {
243 match NatsConnectionConfig::from_config_and_secrets(
246 link_config.config,
247 link_config.secrets,
248 ) {
249 Ok(ncc) => self.default_config.merge(&ncc),
250 Err(e) => {
251 error!("Failed to build NATS connection configuration: {e:?}");
252 return Err(anyhow!(e).context("failed to build NATS connection configuration"));
253 }
254 }
255 };
256 println!("NATS Kv configuration: {nats_config:?}");
257
258 let LinkConfig {
259 source_id,
260 link_name,
261 ..
262 }: LinkConfig<'_> = link_config;
263
264 let kv_store = match self.connect(nats_config, &link_config).await {
265 Ok(b) => b,
266 Err(e) => {
267 error!("Failed to connect to NATS: {e:?}");
268 bail!(anyhow!(e).context("failed to connect to NATS"))
269 }
270 };
271
272 let mut consumer_components = self.consumer_components.write().await;
273 if let Some(existing_kv_stores) = consumer_components.get_mut(&source_id.to_string()) {
275 existing_kv_stores.insert(link_name.into(), kv_store);
277 } else {
278 consumer_components.insert(
280 source_id.into(),
281 HashMap::from([(link_name.into(), kv_store)]),
282 );
283 }
284
285 Ok(())
286 }
287
288 #[instrument(level = "info", skip_all, fields(source_id = info.get_source_id()))]
291 async fn delete_link_as_target(&self, info: impl LinkDeleteInfo) -> anyhow::Result<()> {
292 let component_id = info.get_source_id();
293 let mut links = self.consumer_components.write().await;
294 if let Some(kv_store) = links.remove(component_id) {
295 debug!(
296 component_id,
297 "dropping NATS Kv store [{kv_store:?}] for (consumer) component...",
298 );
299 }
300
301 debug!(component_id, "finished processing link deletion");
302
303 Ok(())
304 }
305
306 async fn shutdown(&self) -> anyhow::Result<()> {
308 let mut consumers = self.consumer_components.write().await;
310 consumers.clear();
311
312 Ok(())
313 }
314}
315
316impl keyvalue::store::Handler<Option<Context>> for KvNatsProvider {
318 #[instrument(level = "debug", skip(self))]
320 async fn get(
321 &self,
322 context: Option<Context>,
323 bucket: String,
324 key: String,
325 ) -> anyhow::Result<Result<Option<Bytes>>> {
326 propagate_trace_for_ctx!(context);
327
328 match self.get_kv_store(context, bucket).await {
329 Ok(store) => match store.get(key.clone()).await {
330 Ok(Some(bytes)) => Ok(Ok(Some(bytes))),
331 Ok(None) => Ok(Ok(None)),
332 Err(err) => {
333 error!(%key, "failed to get key value: {err:?}");
334 Ok(Err(keyvalue::store::Error::Other(err.to_string())))
335 }
336 },
337 Err(err) => Ok(Err(err)),
338 }
339 }
340
341 #[instrument(level = "debug", skip(self))]
343 async fn set(
344 &self,
345 context: Option<Context>,
346 bucket: String,
347 key: String,
348 value: Bytes,
349 ) -> anyhow::Result<Result<()>> {
350 propagate_trace_for_ctx!(context);
351
352 match self.get_kv_store(context, bucket).await {
353 Ok(store) => match store.put(key.clone(), value).await {
354 Ok(_) => Ok(Ok(())),
355 Err(err) => {
356 error!(%key, "failed to set key value: {err:?}");
357 Ok(Err(keyvalue::store::Error::Other(err.to_string())))
358 }
359 },
360 Err(err) => Ok(Err(err)),
361 }
362 }
363
364 #[instrument(level = "debug", skip(self))]
366 async fn delete(
367 &self,
368 context: Option<Context>,
369 bucket: String,
370 key: String,
371 ) -> anyhow::Result<Result<()>> {
372 propagate_trace_for_ctx!(context);
373
374 match self.get_kv_store(context, bucket).await {
375 Ok(store) => match store.purge(key.clone()).await {
376 Ok(_) => Ok(Ok(())),
377 Err(err) => {
378 error!(%key, "failed to delete key: {err:?}");
379 Ok(Err(keyvalue::store::Error::Other(err.to_string())))
380 }
381 },
382 Err(err) => Ok(Err(err)),
383 }
384 }
385
386 #[instrument(level = "debug", skip(self))]
388 async fn exists(
389 &self,
390 context: Option<Context>,
391 bucket: String,
392 key: String,
393 ) -> anyhow::Result<Result<bool>> {
394 propagate_trace_for_ctx!(context);
395
396 match self.get(context, bucket, key).await {
397 Ok(Ok(Some(_))) => Ok(Ok(true)),
398 Ok(Ok(None)) => Ok(Ok(false)),
399 Ok(Err(err)) => Ok(Err(err)),
400 Err(err) => Ok(Err(keyvalue::store::Error::Other(err.to_string()))),
401 }
402 }
403
404 #[instrument(level = "debug", skip(self))]
406 async fn list_keys(
407 &self,
408 context: Option<Context>,
409 bucket: String,
410 cursor: Option<u64>,
411 ) -> anyhow::Result<Result<keyvalue::store::KeyResponse>> {
412 propagate_trace_for_ctx!(context);
413
414 match self.get_kv_store(context, bucket).await {
415 Ok(store) => match store.keys().await {
416 Ok(keys) => {
417 match keys
418 .skip(cursor.unwrap_or(0) as usize)
419 .take(usize::MAX)
420 .try_collect()
421 .await
422 {
423 Ok(keys) => Ok(Ok(keyvalue::store::KeyResponse { keys, cursor: None })),
424 Err(err) => {
425 error!("failed to list keys: {err:?}");
426 Ok(Err(keyvalue::store::Error::Other(err.to_string())))
427 }
428 }
429 }
430 Err(err) => {
431 error!("failed to list keys: {err:?}");
432 Ok(Err(keyvalue::store::Error::Other(err.to_string())))
433 }
434 },
435 Err(err) => Ok(Err(err)),
436 }
437 }
438}
439
440impl keyvalue::atomics::Handler<Option<Context>> for KvNatsProvider {
442 #[instrument(level = "debug", skip(self))]
444 async fn increment(
445 &self,
446 context: Option<Context>,
447 bucket: String,
448 key: String,
449 delta: u64,
450 ) -> anyhow::Result<Result<u64, keyvalue::store::Error>> {
451 propagate_trace_for_ctx!(context);
452
453 let kv_store = self.get_kv_store(context.clone(), bucket.clone()).await?;
455
456 let mut new_value = 0;
457 let mut success = false;
458 for attempt in 0..5 {
459 let entry = kv_store.entry(key.clone()).await?;
461
462 let (current_value, revision) = match &entry {
464 Some(entry) if !entry.value.is_empty() => {
465 let value_str = std::str::from_utf8(&entry.value)?;
466 match value_str.parse::<u64>() {
467 Ok(num) => (num, entry.revision),
468 Err(_) => {
469 return Err(keyvalue::store::Error::Other(
470 "Cannot increment a non-numerical value".to_string(),
471 )
472 .into())
473 }
474 }
475 }
476 _ => (0, entry.as_ref().map_or(0, |e| e.revision)),
477 };
478
479 new_value = current_value + delta;
480
481 match kv_store
483 .update(key.clone(), new_value.to_string().into(), revision)
484 .await
485 {
486 Ok(_) => {
487 success = true;
488 break; }
490 Err(_) => {
491 if attempt > 0 {
493 let wait_time = EXPONENTIAL_BACKOFF_BASE_INTERVAL * 2u64.pow(attempt - 1);
494 tokio::time::sleep(std::time::Duration::from_millis(wait_time)).await;
495 }
496 }
497 }
498 }
499
500 if success {
501 Ok(Ok(new_value))
502 } else {
503 Ok(Err(keyvalue::store::Error::Other(
505 "Failed to increment the value after 5 attempts".to_string(),
506 )))
507 }
508 }
509}
510
511type KvResult = Vec<Option<(String, Bytes)>>;
513
514impl keyvalue::batch::Handler<Option<Context>> for KvNatsProvider {
516 #[instrument(level = "debug", skip(self))]
518 async fn get_many(
519 &self,
520 ctx: Option<Context>,
521 bucket: String,
522 keys: Vec<String>,
523 ) -> anyhow::Result<Result<KvResult>> {
524 let ctx = ctx.clone();
525 let bucket = bucket.clone();
526
527 let results: Result<Vec<_>, _> = keys
529 .into_iter()
530 .map(|key| {
531 let ctx = ctx.clone();
532 let bucket = bucket.clone();
533 async move {
534 self.get(ctx, bucket, key.clone())
535 .await
536 .map(|value| (key, value))
537 }
538 })
539 .collect::<futures::stream::FuturesUnordered<_>>()
540 .try_collect()
541 .await;
542
543 match results {
544 Ok(values) => {
545 let values: Result<Vec<_>, _> = values
546 .into_iter()
547 .map(|(k, res)| match res {
548 Ok(Some(v)) => Ok(Some((k, v))),
549 Ok(None) => Ok(None),
550 Err(err) => {
551 error!("failed to parse key-value pairs: {err:?}");
552 Err(keyvalue::store::Error::Other(err.to_string()))
553 }
554 })
555 .collect();
556 Ok(values)
557 }
558 Err(err) => {
559 error!("failed to get many keys: {err:?}");
560 Ok(Err(keyvalue::store::Error::Other(err.to_string())))
561 }
562 }
563 }
564
565 #[instrument(level = "debug", skip(self))]
567 async fn set_many(
568 &self,
569 ctx: Option<Context>,
570 bucket: String,
571 items: Vec<(String, Bytes)>,
572 ) -> anyhow::Result<Result<()>> {
573 let ctx = ctx.clone();
574 let bucket = bucket.clone();
575
576 let results: Result<Vec<_>, _> = items
578 .into_iter()
579 .map(|(key, value)| {
580 let ctx = ctx.clone();
581 let bucket = bucket.clone();
582 async move { self.set(ctx, bucket, key, value).await }
583 })
584 .collect::<futures::stream::FuturesUnordered<_>>()
585 .try_collect()
586 .await;
587
588 results.map(|_| Ok(()))
590 }
591
592 #[instrument(level = "debug", skip(self))]
594 async fn delete_many(
595 &self,
596 ctx: Option<Context>,
597 bucket: String,
598 keys: Vec<String>,
599 ) -> anyhow::Result<Result<()>> {
600 let ctx = ctx.clone();
601 let bucket = bucket.clone();
602
603 let results: Result<Vec<_>, _> = keys
605 .into_iter()
606 .map(|key| {
607 let ctx = ctx.clone();
608 let bucket = bucket.clone();
609 async move { self.delete(ctx, bucket, key).await }
610 })
611 .collect::<futures::stream::FuturesUnordered<_>>()
612 .try_collect()
613 .await;
614
615 results.map(|_| Ok(()))
617 }
618}
619
620fn add_tls_ca(
622 tls_ca: &str,
623 opts: async_nats::ConnectOptions,
624) -> anyhow::Result<async_nats::ConnectOptions> {
625 let ca = rustls_pemfile::read_one(&mut tls_ca.as_bytes()).context("failed to read CA")?;
626 let mut roots = async_nats::rustls::RootCertStore::empty();
627 if let Some(rustls_pemfile::Item::X509Certificate(ca)) = ca {
628 roots.add_parsable_certificates([ca]);
629 } else {
630 bail!("tls ca: invalid certificate type, must be a DER encoded PEM file")
631 };
632 let tls_client = async_nats::rustls::ClientConfig::builder()
633 .with_root_certificates(roots)
634 .with_no_client_auth();
635 Ok(opts.tls_client_config(tls_client).require_tls(true))
636}
637
638#[cfg(test)]
640mod test {
641 use super::*;
642
643 #[test]
645 fn test_add_tls_ca() {
646 let tls_ca = "-----BEGIN CERTIFICATE-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAwJwz\n-----END CERTIFICATE-----";
647 let opts = async_nats::ConnectOptions::new();
648 let opts = add_tls_ca(tls_ca, opts);
649 assert!(opts.is_ok())
650 }
651}