wasmcloud_provider_keyvalue_vault/
lib.rs

1pub(crate) mod config;
2
3use core::str;
4use core::time::Duration;
5
6use std::collections::{hash_map, HashMap};
7use std::string::ToString;
8use std::sync::Arc;
9
10use anyhow::{anyhow, bail, Context as _};
11use base64::Engine as _;
12use bytes::Bytes;
13use tokio::sync::{Mutex, RwLock};
14use tokio::task::JoinHandle;
15use tracing::{debug, error, info, instrument, warn};
16use vaultrs::client::{Client as _, VaultClient, VaultClientSettings};
17use wasmcloud_provider_sdk::{
18    get_connection, load_host_data, propagate_trace_for_ctx, run_provider, Context, LinkConfig,
19    LinkDeleteInfo, Provider,
20};
21use wasmcloud_provider_sdk::{initialize_observability, serve_provider_exports};
22
23use crate::config::Config;
24
25mod bindings {
26    wit_bindgen_wrpc::generate!({
27        with: {
28            "wrpc:keyvalue/store@0.2.0-draft": generate,
29        }
30    });
31}
32use bindings::exports::wrpc::keyvalue;
33
34type Result<T, E = keyvalue::store::Error> = core::result::Result<T, E>;
35
36/// Vault HTTP api version. As of Vault 1.9.x (Feb 2022), all http api calls use version 1
37const API_VERSION: u8 = 1;
38
39/// Default TTL for tokens used by this provider. Defaults to 72 hours.
40pub const TOKEN_INCREMENT_TTL: &str = "72h";
41pub const TOKEN_REFRESH_INTERVAL: Duration = Duration::from_secs(60 * 60 * 12); // 12 hours
42
43pub async fn run() -> anyhow::Result<()> {
44    KvVaultProvider::run().await
45}
46
47/// Vault client connection information.
48#[derive(Clone)]
49pub struct Client {
50    inner: Arc<vaultrs::client::VaultClient>,
51    namespace: String,
52    token_increment_ttl: String,
53    token_refresh_interval: Duration,
54    renew_task: Arc<Mutex<Option<JoinHandle<()>>>>,
55}
56
57impl Client {
58    /// Creates a new Vault client. See [config](./config.rs) for explanation of parameters.
59    ///
60    /// Note that this constructor does not attempt to connect to the vault server,
61    /// so the vault server does not need to be running at the time a `LinkDefinition` to this provider is created.
62    pub fn new(config: Config) -> Result<Self, vaultrs::error::ClientError> {
63        let client = VaultClient::new(VaultClientSettings {
64            token: config.token,
65            address: config.addr,
66            ca_certs: config.certs,
67            verify: false,
68            version: API_VERSION,
69            wrapping: false,
70            timeout: None,
71            namespace: None,
72            identity: None,
73        })?;
74        Ok(Self {
75            inner: Arc::new(client),
76            namespace: config.mount,
77            token_increment_ttl: config
78                .token_increment_ttl
79                .unwrap_or(TOKEN_INCREMENT_TTL.into()),
80            token_refresh_interval: config
81                .token_refresh_interval
82                .unwrap_or(TOKEN_REFRESH_INTERVAL),
83            renew_task: Arc::default(),
84        })
85    }
86
87    /// Reads value of secret using namespace and key path
88    pub async fn read_secret(&self, path: &str) -> Result<Option<HashMap<String, String>>> {
89        match vaultrs::kv2::read(self.inner.as_ref(), &self.namespace, path).await {
90            Err(vaultrs::error::ClientError::APIError {
91                code: 404,
92                errors: _,
93            }) => Ok(None),
94            Err(err) => {
95                error!(error = %err, "failed to read secret");
96                Err(keyvalue::store::Error::Other(format!(
97                    "{:#}",
98                    anyhow!(err).context("failed to read secret")
99                )))
100            }
101            Ok(val) => Ok(val),
102        }
103    }
104
105    /// Writes value of secret using namespace and key path
106    pub async fn write_secret(&self, path: &str, data: &HashMap<String, String>) -> Result<()> {
107        let md = vaultrs::kv2::set(self.inner.as_ref(), &self.namespace, path, data)
108            .await
109            .map_err(|err| {
110                error!(error = %err, "failed to write secret");
111                keyvalue::store::Error::Other(format!(
112                    "{:#}",
113                    anyhow!(err).context("failed to write secret")
114                ))
115            })?;
116        debug!(?md, "set returned metadata");
117        Ok(())
118    }
119
120    /// Sets up a background task to renew the token at the configured interval. This function
121    /// attempts to lock the `renew_task` mutex and will deadlock if called without first ensuring
122    /// the lock is available.
123    pub async fn set_renewal(&self) {
124        let mut renew_task = self.renew_task.lock().await;
125        if let Some(handle) = renew_task.take() {
126            handle.abort();
127        }
128        let client = self.inner.clone();
129        let interval = self.token_refresh_interval;
130        let ttl = self.token_increment_ttl.clone();
131
132        *renew_task = Some(tokio::spawn(async move {
133            let mut next_interval = tokio::time::interval(interval);
134            loop {
135                next_interval.tick().await;
136                // NOTE(brooksmtownsend): Errors are appropriately logged in the function
137                let _ = renew_self(&client, ttl.as_str()).await;
138            }
139        }));
140    }
141}
142
143impl Drop for Client {
144    fn drop(&mut self) {
145        // NOTE(brooksmtownsend): We're trying to lock here so we don't deadlock on dropping.
146        if let Ok(mut renew_task) = self.renew_task.try_lock() {
147            if let Some(handle) = renew_task.take() {
148                handle.abort();
149            }
150        }
151    }
152}
153
154/// Helper function to renew a client's token, incrementing the validity by `increment`
155async fn renew_self(
156    client: &VaultClient,
157    increment: &str,
158) -> Result<(), vaultrs::error::ClientError> {
159    debug!("renewing token");
160    client.renew(Some(increment)).await.map_err(|e| {
161        error!("error renewing self token: {}", e);
162        e
163    })?;
164
165    let info = client.lookup().await.map_err(|e| {
166        error!("error looking up self token: {}", e);
167        e
168    })?;
169
170    let expire_time = info.expire_time.unwrap_or_else(|| "None".to_string());
171    info!(%expire_time, accessor = %info.accessor, "renewed token");
172    Ok(())
173}
174
175/// Redis KV provider implementation which utilizes [Hashicorp Vault](https://developer.hashicorp.com/vault/docs)
176#[derive(Default, Clone)]
177pub struct KvVaultProvider {
178    // store vault connection per component
179    components: Arc<RwLock<HashMap<String, Arc<Client>>>>,
180}
181
182impl KvVaultProvider {
183    pub fn name() -> &'static str {
184        "keyvalue-vault-provider"
185    }
186
187    pub async fn run() -> anyhow::Result<()> {
188        let host_data = load_host_data().context("failed to load host data")?;
189        let flamegraph_path = host_data
190            .config
191            .get("FLAMEGRAPH_PATH")
192            .map(String::from)
193            .or_else(|| std::env::var("PROVIDER_KEYVALUE_VAULT_FLAMEGRAPH_PATH").ok());
194        initialize_observability!(Self::name(), flamegraph_path);
195        let provider = Self::default();
196        let shutdown = run_provider(provider.clone(), KvVaultProvider::name())
197            .await
198            .context("failed to run provider")?;
199        let connection = get_connection();
200        let wrpc = connection
201            .get_wrpc_client(connection.provider_key())
202            .await?;
203        serve_provider_exports(&wrpc, provider, shutdown, bindings::serve)
204            .await
205            .context("failed to serve provider exports")
206    }
207
208    /// Retrieve a client for a given context (determined by `source_id`)
209    async fn get_client(&self, ctx: Option<Context>) -> Result<Arc<Client>> {
210        let ctx = ctx.ok_or_else(|| {
211            warn!("invocation context missing");
212            keyvalue::store::Error::Other("invocation context missing".into())
213        })?;
214        let source_id = ctx.component.as_ref().ok_or_else(|| {
215            warn!("source ID missing");
216            keyvalue::store::Error::Other("source ID missing".into())
217        })?;
218        let links = self.components.read().await;
219        links.get(source_id).cloned().ok_or_else(|| {
220            warn!(source_id, "source ID not linked");
221            keyvalue::store::Error::Other("source ID not linked".into())
222        })
223    }
224
225    /// Gets a value for a specified key. Deserialize the value as json
226    /// If it's any other map, the entire map is returned as a serialized json string
227    /// If the stored value is a plain string, returns the plain value
228    /// All other values are returned as serialized json
229    #[instrument(level = "debug", skip(ctx, self))]
230    async fn get(&self, ctx: Option<Context>, path: String, key: String) -> Result<Option<Bytes>> {
231        propagate_trace_for_ctx!(ctx);
232        let client = self.get_client(ctx).await?;
233        if let Some(mut secret) = client.read_secret(&path).await? {
234            match secret.remove(&key) {
235                Some(value) => {
236                    let value = base64::engine::general_purpose::STANDARD_NO_PAD
237                        .decode(value)
238                        .map_err(|err| {
239                            error!(?err, "failed to decode secret value");
240                            keyvalue::store::Error::Other(format!(
241                                "{:#}",
242                                anyhow!(err).context("failed to decode secret value")
243                            ))
244                        })?;
245                    Ok(Some(value.into()))
246                }
247                None => Ok(None),
248            }
249        } else {
250            Ok(None)
251        }
252    }
253
254    /// Returns true if the store contains the key
255    #[instrument(level = "debug", skip(ctx, self))]
256    async fn contains(&self, ctx: Option<Context>, path: String, key: String) -> Result<bool> {
257        propagate_trace_for_ctx!(ctx);
258        let client = self.get_client(ctx).await?;
259        let secret = client.read_secret(&path).await?;
260        Ok(secret.is_some_and(|secret| secret.contains_key(&key)))
261    }
262
263    /// Deletes a key from a secret
264    #[instrument(level = "debug", skip(ctx, self))]
265    async fn del(&self, ctx: Option<Context>, path: String, key: String) -> Result<()> {
266        propagate_trace_for_ctx!(ctx);
267        let client = self.get_client(ctx).await?;
268        let secret = client.read_secret(&path).await?;
269        let secret = if let Some(mut secret) = secret {
270            if secret.remove(&key).is_none() {
271                debug!("key does not exist in the secret");
272                return Ok(());
273            }
274            secret
275        } else {
276            debug!("secret not found");
277            return Ok(());
278        };
279        client.write_secret(&path, &secret).await
280    }
281
282    /// Sets the value of a key.
283    #[instrument(level = "debug", skip(ctx, self))]
284    async fn set(
285        &self,
286        ctx: Option<Context>,
287        path: String,
288        key: String,
289        value: Bytes,
290    ) -> Result<()> {
291        propagate_trace_for_ctx!(ctx);
292        let client = self.get_client(ctx).await?;
293        let value = base64::engine::general_purpose::STANDARD_NO_PAD.encode(value);
294        let secret = client.read_secret(&path).await?;
295        let secret = if let Some(mut secret) = secret {
296            match secret.entry(key) {
297                hash_map::Entry::Vacant(e) => {
298                    e.insert(value);
299                }
300                hash_map::Entry::Occupied(mut e) => {
301                    if *e.get() == value {
302                        return Ok(());
303                    }
304                    e.insert(value);
305                }
306            }
307            secret
308        } else {
309            HashMap::from([(key, value)])
310        };
311        client.write_secret(&path, &secret).await
312    }
313
314    #[instrument(level = "debug", skip(ctx, self))]
315    async fn list_keys(
316        &self,
317        ctx: Option<Context>,
318        path: String,
319        skip: u64,
320    ) -> Result<keyvalue::store::KeyResponse> {
321        propagate_trace_for_ctx!(ctx);
322        let client = self.get_client(ctx).await?;
323        let secret = client.read_secret(&path).await?;
324        Ok(keyvalue::store::KeyResponse {
325            cursor: None,
326            keys: secret
327                .map(|secret| {
328                    secret
329                        .into_keys()
330                        .skip(skip.try_into().unwrap_or(usize::MAX))
331                        .collect()
332                })
333                .unwrap_or_default(),
334        })
335    }
336}
337
338impl keyvalue::store::Handler<Option<Context>> for KvVaultProvider {
339    #[instrument(level = "debug", skip(self))]
340    async fn delete(
341        &self,
342        context: Option<Context>,
343        bucket: String,
344        key: String,
345    ) -> anyhow::Result<Result<()>> {
346        propagate_trace_for_ctx!(context);
347        Ok(self.del(context, bucket, key).await)
348    }
349
350    #[instrument(level = "debug", skip(self))]
351    async fn exists(
352        &self,
353        context: Option<Context>,
354        bucket: String,
355        key: String,
356    ) -> anyhow::Result<Result<bool>> {
357        propagate_trace_for_ctx!(context);
358        Ok(self.contains(context, bucket, key).await)
359    }
360
361    #[instrument(level = "debug", skip(self))]
362    async fn get(
363        &self,
364        context: Option<Context>,
365        bucket: String,
366        key: String,
367    ) -> anyhow::Result<Result<Option<Bytes>>> {
368        propagate_trace_for_ctx!(context);
369        Ok(self.get(context, bucket, key).await)
370    }
371
372    #[instrument(level = "debug", skip(self))]
373    async fn set(
374        &self,
375        context: Option<Context>,
376        bucket: String,
377        key: String,
378        value: Bytes,
379    ) -> anyhow::Result<Result<()>> {
380        propagate_trace_for_ctx!(context);
381        Ok(self.set(context, bucket, key, value).await)
382    }
383
384    #[instrument(level = "debug", skip(self))]
385    async fn list_keys(
386        &self,
387        context: Option<Context>,
388        bucket: String,
389        cursor: Option<u64>,
390    ) -> anyhow::Result<Result<keyvalue::store::KeyResponse>> {
391        propagate_trace_for_ctx!(context);
392        Ok(self
393            .list_keys(context, bucket, cursor.unwrap_or_default())
394            .await)
395    }
396}
397
398/// Handle provider control commands, the minimum required of any provider on
399/// a wasmcloud lattice
400impl Provider for KvVaultProvider {
401    /// Provider should perform any operations needed for a new link,
402    /// including setting up per-component resources, and checking authorization.
403    /// If the link is allowed, return true, otherwise return false to deny the link.
404    #[instrument(level = "debug", skip_all, fields(source_id))]
405    async fn receive_link_config_as_target(
406        &self,
407        link_config: LinkConfig<'_>,
408    ) -> anyhow::Result<()> {
409        let LinkConfig {
410            source_id,
411            link_name,
412            ..
413        } = link_config;
414        debug!(
415           %source_id,
416           %link_name,
417            "adding link for component",
418        );
419
420        let config = match Config::from_link_config(&link_config) {
421            Ok(config) => config,
422            Err(e) => {
423                error!(
424                    %source_id,
425                    %link_name,
426                    "failed to parse config: {e}",
427                );
428                bail!(anyhow!(e).context("failed to parse config"))
429            }
430        };
431
432        let client = match Client::new(config.clone()) {
433            Ok(client) => client,
434            Err(e) => {
435                error!(
436                    %source_id,
437                    %link_name,
438                    "failed to create new client config: {e}",
439                );
440                return Err(anyhow!(e).context("failed to create new client config"));
441            }
442        };
443        client.set_renewal().await;
444
445        let mut update_map = self.components.write().await;
446        update_map.insert(source_id.to_string(), Arc::new(client));
447
448        Ok(())
449    }
450
451    /// Handle notification that a link is dropped - close the connection
452    #[instrument(level = "info", skip_all, fields(source_id = info.get_source_id()))]
453    async fn delete_link_as_target(&self, info: impl LinkDeleteInfo) -> anyhow::Result<()> {
454        let component_id = info.get_source_id();
455        let mut aw = self.components.write().await;
456        if let Some(client) = aw.remove(component_id) {
457            debug!(component_id, "deleting link for component");
458            drop(client);
459        }
460        Ok(())
461    }
462
463    /// Handle shutdown request by closing all connections
464    async fn shutdown(&self) -> anyhow::Result<()> {
465        let mut aw = self.components.write().await;
466        // Empty the component link data and stop all servers
467        for (_, client) in aw.drain() {
468            drop(client);
469        }
470        Ok(())
471    }
472}