wasmcloud_provider_blobstore_nats/
provider.rs

1//! This provider implementation is multi-threaded and operations between different consumer/client
2//! components use different connections and can run in parallel.
3//!
4//! A single connection is shared by all instances of the same consumer component, identified
5//! by its id (public key), so there may be some brief lock contention if several instances of
6//! the same component (i.e. replicas) are simultaneously attempting to communicate with NATS.
7
8#![allow(clippy::type_complexity)]
9use std::collections::HashMap;
10use std::sync::Arc;
11
12use anyhow::{anyhow, bail, Context as _};
13use tokio::fs;
14use tracing::{debug, error, info, instrument, warn};
15use wascap::prelude::KeyPair;
16use wasmcloud_provider_sdk::{
17    get_connection, initialize_observability, load_host_data, run_provider, serve_provider_exports,
18    Context, HostData, LinkConfig, LinkDeleteInfo, Provider, ProviderConfigUpdate,
19};
20
21use crate::config::{NatsConnectionConfig, DEFAULT_NATS_URI};
22use crate::{NatsBlobstore, NatsBlobstoreProvider};
23// Import the wrpc interface bindings
24use wrpc_interface_blobstore::bindings;
25
26/// Implement the [`NatsBlobstoreProvider`] and [`Provider`] traits
27impl NatsBlobstoreProvider {
28    pub async fn run() -> anyhow::Result<()> {
29        let host_data = load_host_data().context("failed to load host data")?;
30        let provider = Self::from_host_data(host_data);
31        let shutdown = run_provider(provider.clone(), "nats-bucket-provider")
32            .await
33            .context("failed to run provider")?;
34        let connection = get_connection();
35        let flamegraph_path = host_data
36            .config
37            .get("FLAMEGRAPH_PATH")
38            .map(String::from)
39            .or_else(|| std::env::var("PROVIDER_BLOBSTORE_NATS_FLAMEGRAPH_PATH").ok());
40        initialize_observability!("blobstore-nats-provider", flamegraph_path);
41        serve_provider_exports(
42            &connection
43                .get_wrpc_client(connection.provider_key())
44                .await?,
45            provider,
46            shutdown,
47            bindings::serve,
48        )
49        .await
50        .context("failed to serve provider exports")
51    }
52
53    /// Build a [`NatsBlobstoreProvider`] from [`HostData`]
54    pub fn from_host_data(host_data: &HostData) -> NatsBlobstoreProvider {
55        let config = NatsConnectionConfig::from_link_config(&host_data.config, &host_data.secrets);
56        if let Ok(default_config) = config {
57            NatsBlobstoreProvider {
58                default_config,
59                ..Default::default()
60            }
61        } else {
62            warn!("failed to build NATS connection configuration, falling back to default");
63            NatsBlobstoreProvider::default()
64        }
65    }
66
67    /// Attempt to connect to NATS url (with JWT credentials, if provided)
68    async fn connect(
69        &self,
70        cfg: NatsConnectionConfig,
71    ) -> anyhow::Result<async_nats::jetstream::context::Context> {
72        let mut opts = match (cfg.auth_jwt, cfg.auth_seed) {
73            (Some(jwt), Some(seed)) => {
74                let seed = KeyPair::from_seed(&seed).context("failed to parse seed key pair")?;
75                let seed = Arc::new(seed);
76                async_nats::ConnectOptions::with_jwt(jwt, move |nonce| {
77                    let seed = seed.clone();
78                    async move { seed.sign(&nonce).map_err(async_nats::AuthError::new) }
79                })
80            }
81            (None, None) => async_nats::ConnectOptions::default(),
82            _ => bail!("must provide both jwt and seed for jwt authentication"),
83        };
84        if let Some(tls_ca) = &cfg.tls_ca {
85            opts = add_tls_ca(tls_ca, opts)?;
86        } else if let Some(tls_ca_file) = &cfg.tls_ca_file {
87            let ca = fs::read_to_string(tls_ca_file)
88                .await
89                .context("failed to read TLS CA file")?;
90            opts = add_tls_ca(&ca, opts)?;
91        }
92
93        // Get the cluster_uri with proper default
94        let uri = cfg.cluster_uri.unwrap_or(DEFAULT_NATS_URI.to_string());
95
96        // Connect to the NATS server
97        let client = opts
98            .name("NATS Object Store Provider")
99            .connect(uri.clone())
100            .await?;
101
102        // Get the JetStream context based on js_domain
103        let jetstream = if let Some(domain) = &cfg.js_domain {
104            async_nats::jetstream::with_domain(client.clone(), domain.clone())
105        } else {
106            async_nats::jetstream::new(client.clone())
107        };
108
109        debug!("opened NATS JetStream: {:?}", jetstream);
110        debug!("NATS Connection Configuration: {:?}", client);
111
112        // Return the handle to the opened NATS Object store
113        Ok(jetstream)
114    }
115
116    /// Helper function to lookup and return the NATS JetStream connection handle, and container storage
117    /// configuration, using the client component's context.
118    /// This ensures consistent implementation across all functions that need to get the NATS Blobstore.
119    pub(crate) async fn get_blobstore(
120        &self,
121        context: Option<Context>,
122    ) -> anyhow::Result<NatsBlobstore> {
123        if let Some((component_id, link_name)) =
124            context
125                .as_ref()
126                .and_then(|ctx @ Context { component, .. }| {
127                    component
128                        .clone()
129                        .map(|component_id| (component_id, ctx.link_name().to_string()))
130                })
131        {
132            // Acquire a read lock on the consumer components and attempt to find the specified component_id
133            let components = self.consumer_components.read().await;
134            let nats_stores = components
135                .get(&component_id)
136                .ok_or_else(|| anyhow!("consumer component not linked: {}", component_id))?;
137
138            // Get the NATS Object Store handle and its storage configuration
139            nats_stores
140                .get(&link_name)
141                .cloned()
142                .ok_or_else(|| anyhow!("no NATS Object Store found for link name: {}", &link_name))
143        } else {
144            // If the context is None, return an error indicating no consumer component in the request
145            bail!("no consumer component found in the request")
146        }
147    }
148}
149
150/// Handle provider control commands
151impl Provider for NatsBlobstoreProvider {
152    /// Provider should perform any operations needed for a new link,
153    /// including setting up per-component resources, and checking authorization.
154    /// If the link is allowed, return true, otherwise return false to deny the link.
155    #[instrument(level = "debug", skip_all, fields(source_id))]
156    async fn receive_link_config_as_target(
157        &self,
158        link_config: LinkConfig<'_>,
159    ) -> anyhow::Result<()> {
160        let LinkConfig {
161            source_id,
162            link_name,
163            ..
164        } = link_config;
165
166        let config = if link_config.config.is_empty() {
167            self.default_config.clone()
168        } else {
169            // create a config from the supplied values and merge that with the existing default
170            // NATS connection configuration
171            match NatsConnectionConfig::from_link_config(link_config.config, link_config.secrets) {
172                Ok(ncc) => self.default_config.merge(&ncc),
173                Err(e) => {
174                    error!("failed to build NATS connection configuration: {:?}", e);
175                    return Err(anyhow!(e).context("failed to build NATS connection configuration"));
176                }
177            }
178        };
179        debug!("NATS Blobstore provider configuration: {:?}", config);
180
181        let jetstream = match self.connect(config.clone()).await {
182            Ok(b) => b,
183            Err(e) => {
184                error!("failed to connect to NATS: {:?}", e);
185                bail!(anyhow!(e).context("failed to connect to NATS"))
186            }
187        };
188
189        let mut consumer_components = self.consumer_components.write().await;
190        // Check if there's an existing hashmap for the source_id
191        if let Some(existing_nats_stores) = consumer_components.get_mut(&source_id.to_string()) {
192            // If so, insert the new jetstream into it
193            existing_nats_stores.insert(
194                link_name.into(),
195                NatsBlobstore {
196                    jetstream,
197                    storage_config: config.storage_config.unwrap_or_default(),
198                },
199            );
200        } else {
201            // Otherwise, create a new hashmap and insert it
202            consumer_components.insert(
203                source_id.into(),
204                HashMap::from([(
205                    link_name.into(),
206                    NatsBlobstore {
207                        jetstream,
208                        storage_config: config.storage_config.unwrap_or_default(),
209                    },
210                )]),
211            );
212        }
213
214        Ok(())
215    }
216
217    /// Provider should perform any operations needed for a link deletion, including cleaning up
218    /// per-component resources.
219    #[instrument(level = "info", skip_all, fields(source_id = info.get_source_id(), link_name = info.get_link_name()))]
220    async fn delete_link_as_target(&self, info: impl LinkDeleteInfo) -> anyhow::Result<()> {
221        let source_id = info.get_source_id();
222        let link_name = info.get_link_name();
223        let mut links = self.consumer_components.write().await;
224
225        if let Some(nats_stores) = links.get_mut(source_id) {
226            if nats_stores.remove(link_name).is_some() {
227                debug!(
228                    source_id,
229                    link_name, "removed NATS JetStream connection for link name"
230                );
231            }
232
233            // If the inner hashmap is empty, remove the source_id from the outer hashmap
234            if nats_stores.is_empty() {
235                links.remove(source_id);
236                debug!(
237                    source_id,
238                    "removed source_id from consumer components as it has no more link names"
239                );
240            }
241        } else {
242            debug!(source_id, "source_id not found in consumer components");
243        }
244
245        debug!(source_id, "finished processing link deletion");
246
247        Ok(())
248    }
249
250    /// Provider should perform any operations needed for configuration updates, including cleaning up
251    /// invalidated link resources.
252    #[instrument(level = "debug", skip_all, fields(link_name))]
253    async fn on_config_update(&self, update: impl ProviderConfigUpdate) -> anyhow::Result<()> {
254        let values = update.get_values();
255        debug!("Received config update: {:?}", values);
256
257        // Create a new config from the update values
258        let new_config = match NatsConnectionConfig::from_link_config(values, &HashMap::new()) {
259            Ok(config) => config,
260            Err(e) => {
261                error!("Failed to parse configuration update: {}", e);
262                return Ok(());
263            }
264        };
265
266        // Create new NATS connection with updated config
267        let new_jetstream = match self.connect(new_config.clone()).await {
268            Ok(js) => js,
269            Err(e) => {
270                error!("Failed to connect with new configuration: {}", e);
271                return Ok(());
272            }
273        };
274
275        // Update all existing connections with the new configuration
276        let mut components = self.consumer_components.write().await;
277        for stores in components.values_mut() {
278            for store in stores.values_mut() {
279                // Use existing NatsConnectionConfig merge functionality
280                let merged_config = NatsConnectionConfig {
281                    storage_config: Some(store.storage_config.clone()),
282                    ..Default::default()
283                }
284                .merge(&new_config);
285
286                store.storage_config = merged_config.storage_config.unwrap_or_default();
287                store.jetstream = new_jetstream.clone();
288            }
289        }
290
291        info!("Successfully updated all NATS connections with new configuration");
292        Ok(())
293    }
294
295    /// Handle shutdown request by closing all connections
296    async fn shutdown(&self) -> anyhow::Result<()> {
297        // clear the consumer components
298        let mut consumers = self.consumer_components.write().await;
299        consumers.clear();
300
301        Ok(())
302    }
303}
304
305/// Helper function for adding the TLS CA to the NATS connection options
306fn add_tls_ca(
307    tls_ca: &str,
308    opts: async_nats::ConnectOptions,
309) -> anyhow::Result<async_nats::ConnectOptions> {
310    let ca = rustls_pemfile::read_one(&mut tls_ca.as_bytes()).context("failed to read CA")?;
311    let mut roots = async_nats::rustls::RootCertStore::empty();
312    if let Some(rustls_pemfile::Item::X509Certificate(ca)) = ca {
313        roots.add_parsable_certificates([ca]);
314    } else {
315        bail!("tls ca: invalid certificate type, must be a DER encoded PEM file")
316    };
317    let tls_client = async_nats::rustls::ClientConfig::builder()
318        .with_root_certificates(roots)
319        .with_no_client_auth();
320    Ok(opts.tls_client_config(tls_client).require_tls(true))
321}
322
323// Performing various provider configuration tests
324#[cfg(test)]
325mod test {
326    use super::*;
327
328    // Verify that tls_ca is set
329    #[test]
330    fn test_add_tls_ca() {
331        let tls_ca = "-----BEGIN CERTIFICATE-----\nMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAwJwz\n-----END CERTIFICATE-----";
332        let opts = async_nats::ConnectOptions::new();
333        let opts = add_tls_ca(tls_ca, opts);
334        assert!(opts.is_ok())
335    }
336}