1pub(crate) mod config;
2
3use core::str;
4use core::time::Duration;
5
6use std::collections::{hash_map, HashMap};
7use std::string::ToString;
8use std::sync::Arc;
9
10use anyhow::{anyhow, bail, Context as _};
11use base64::Engine as _;
12use bytes::Bytes;
13use tokio::sync::{Mutex, RwLock};
14use tokio::task::JoinHandle;
15use tracing::{debug, error, info, instrument, warn};
16use vaultrs::client::{Client as _, VaultClient, VaultClientSettings};
17use wasmcloud_provider_sdk::{
18 get_connection, load_host_data, propagate_trace_for_ctx, run_provider, Context, LinkConfig,
19 LinkDeleteInfo, Provider,
20};
21use wasmcloud_provider_sdk::{initialize_observability, serve_provider_exports};
22
23use crate::config::Config;
24
25mod bindings {
26 wit_bindgen_wrpc::generate!({
27 with: {
28 "wrpc:keyvalue/store@0.2.0-draft": generate,
29 }
30 });
31}
32use bindings::exports::wrpc::keyvalue;
33
34type Result<T, E = keyvalue::store::Error> = core::result::Result<T, E>;
35
36const API_VERSION: u8 = 1;
38
39pub const TOKEN_INCREMENT_TTL: &str = "72h";
41pub const TOKEN_REFRESH_INTERVAL: Duration = Duration::from_secs(60 * 60 * 12); pub async fn run() -> anyhow::Result<()> {
44 KvVaultProvider::run().await
45}
46
47#[derive(Clone)]
49pub struct Client {
50 inner: Arc<vaultrs::client::VaultClient>,
51 namespace: String,
52 token_increment_ttl: String,
53 token_refresh_interval: Duration,
54 renew_task: Arc<Mutex<Option<JoinHandle<()>>>>,
55}
56
57impl Client {
58 pub fn new(config: Config) -> Result<Self, vaultrs::error::ClientError> {
63 let client = VaultClient::new(VaultClientSettings {
64 token: config.token,
65 address: config.addr,
66 ca_certs: config.certs,
67 verify: false,
68 version: API_VERSION,
69 wrapping: false,
70 timeout: None,
71 namespace: None,
72 identity: None,
73 })?;
74 Ok(Self {
75 inner: Arc::new(client),
76 namespace: config.mount,
77 token_increment_ttl: config
78 .token_increment_ttl
79 .unwrap_or(TOKEN_INCREMENT_TTL.into()),
80 token_refresh_interval: config
81 .token_refresh_interval
82 .unwrap_or(TOKEN_REFRESH_INTERVAL),
83 renew_task: Arc::default(),
84 })
85 }
86
87 pub async fn read_secret(&self, path: &str) -> Result<Option<HashMap<String, String>>> {
89 match vaultrs::kv2::read(self.inner.as_ref(), &self.namespace, path).await {
90 Err(vaultrs::error::ClientError::APIError {
91 code: 404,
92 errors: _,
93 }) => Ok(None),
94 Err(err) => {
95 error!(error = %err, "failed to read secret");
96 Err(keyvalue::store::Error::Other(format!(
97 "{:#}",
98 anyhow!(err).context("failed to read secret")
99 )))
100 }
101 Ok(val) => Ok(val),
102 }
103 }
104
105 pub async fn write_secret(&self, path: &str, data: &HashMap<String, String>) -> Result<()> {
107 let md = vaultrs::kv2::set(self.inner.as_ref(), &self.namespace, path, data)
108 .await
109 .map_err(|err| {
110 error!(error = %err, "failed to write secret");
111 keyvalue::store::Error::Other(format!(
112 "{:#}",
113 anyhow!(err).context("failed to write secret")
114 ))
115 })?;
116 debug!(?md, "set returned metadata");
117 Ok(())
118 }
119
120 pub async fn set_renewal(&self) {
124 let mut renew_task = self.renew_task.lock().await;
125 if let Some(handle) = renew_task.take() {
126 handle.abort();
127 }
128 let client = self.inner.clone();
129 let interval = self.token_refresh_interval;
130 let ttl = self.token_increment_ttl.clone();
131
132 *renew_task = Some(tokio::spawn(async move {
133 let mut next_interval = tokio::time::interval(interval);
134 loop {
135 next_interval.tick().await;
136 let _ = renew_self(&client, ttl.as_str()).await;
138 }
139 }));
140 }
141}
142
143impl Drop for Client {
144 fn drop(&mut self) {
145 if let Ok(mut renew_task) = self.renew_task.try_lock() {
147 if let Some(handle) = renew_task.take() {
148 handle.abort();
149 }
150 }
151 }
152}
153
154async fn renew_self(
156 client: &VaultClient,
157 increment: &str,
158) -> Result<(), vaultrs::error::ClientError> {
159 debug!("renewing token");
160 client.renew(Some(increment)).await.map_err(|e| {
161 error!("error renewing self token: {}", e);
162 e
163 })?;
164
165 let info = client.lookup().await.map_err(|e| {
166 error!("error looking up self token: {}", e);
167 e
168 })?;
169
170 let expire_time = info.expire_time.unwrap_or_else(|| "None".to_string());
171 info!(%expire_time, accessor = %info.accessor, "renewed token");
172 Ok(())
173}
174
175#[derive(Default, Clone)]
177pub struct KvVaultProvider {
178 components: Arc<RwLock<HashMap<String, Arc<Client>>>>,
180}
181
182impl KvVaultProvider {
183 pub fn name() -> &'static str {
184 "keyvalue-vault-provider"
185 }
186
187 pub async fn run() -> anyhow::Result<()> {
188 let host_data = load_host_data().context("failed to load host data")?;
189 let flamegraph_path = host_data
190 .config
191 .get("FLAMEGRAPH_PATH")
192 .map(String::from)
193 .or_else(|| std::env::var("PROVIDER_KEYVALUE_VAULT_FLAMEGRAPH_PATH").ok());
194 initialize_observability!(Self::name(), flamegraph_path);
195 let provider = Self::default();
196 let shutdown = run_provider(provider.clone(), KvVaultProvider::name())
197 .await
198 .context("failed to run provider")?;
199 let connection = get_connection();
200 let wrpc = connection
201 .get_wrpc_client(connection.provider_key())
202 .await?;
203 serve_provider_exports(&wrpc, provider, shutdown, bindings::serve)
204 .await
205 .context("failed to serve provider exports")
206 }
207
208 async fn get_client(&self, ctx: Option<Context>) -> Result<Arc<Client>> {
210 let ctx = ctx.ok_or_else(|| {
211 warn!("invocation context missing");
212 keyvalue::store::Error::Other("invocation context missing".into())
213 })?;
214 let source_id = ctx.component.as_ref().ok_or_else(|| {
215 warn!("source ID missing");
216 keyvalue::store::Error::Other("source ID missing".into())
217 })?;
218 let links = self.components.read().await;
219 links.get(source_id).cloned().ok_or_else(|| {
220 warn!(source_id, "source ID not linked");
221 keyvalue::store::Error::Other("source ID not linked".into())
222 })
223 }
224
225 #[instrument(level = "debug", skip(ctx, self))]
230 async fn get(&self, ctx: Option<Context>, path: String, key: String) -> Result<Option<Bytes>> {
231 propagate_trace_for_ctx!(ctx);
232 let client = self.get_client(ctx).await?;
233 if let Some(mut secret) = client.read_secret(&path).await? {
234 match secret.remove(&key) {
235 Some(value) => {
236 let value = base64::engine::general_purpose::STANDARD_NO_PAD
237 .decode(value)
238 .map_err(|err| {
239 error!(?err, "failed to decode secret value");
240 keyvalue::store::Error::Other(format!(
241 "{:#}",
242 anyhow!(err).context("failed to decode secret value")
243 ))
244 })?;
245 Ok(Some(value.into()))
246 }
247 None => Ok(None),
248 }
249 } else {
250 Ok(None)
251 }
252 }
253
254 #[instrument(level = "debug", skip(ctx, self))]
256 async fn contains(&self, ctx: Option<Context>, path: String, key: String) -> Result<bool> {
257 propagate_trace_for_ctx!(ctx);
258 let client = self.get_client(ctx).await?;
259 let secret = client.read_secret(&path).await?;
260 Ok(secret.is_some_and(|secret| secret.contains_key(&key)))
261 }
262
263 #[instrument(level = "debug", skip(ctx, self))]
265 async fn del(&self, ctx: Option<Context>, path: String, key: String) -> Result<()> {
266 propagate_trace_for_ctx!(ctx);
267 let client = self.get_client(ctx).await?;
268 let secret = client.read_secret(&path).await?;
269 let secret = if let Some(mut secret) = secret {
270 if secret.remove(&key).is_none() {
271 debug!("key does not exist in the secret");
272 return Ok(());
273 }
274 secret
275 } else {
276 debug!("secret not found");
277 return Ok(());
278 };
279 client.write_secret(&path, &secret).await
280 }
281
282 #[instrument(level = "debug", skip(ctx, self))]
284 async fn set(
285 &self,
286 ctx: Option<Context>,
287 path: String,
288 key: String,
289 value: Bytes,
290 ) -> Result<()> {
291 propagate_trace_for_ctx!(ctx);
292 let client = self.get_client(ctx).await?;
293 let value = base64::engine::general_purpose::STANDARD_NO_PAD.encode(value);
294 let secret = client.read_secret(&path).await?;
295 let secret = if let Some(mut secret) = secret {
296 match secret.entry(key) {
297 hash_map::Entry::Vacant(e) => {
298 e.insert(value);
299 }
300 hash_map::Entry::Occupied(mut e) => {
301 if *e.get() == value {
302 return Ok(());
303 }
304 e.insert(value);
305 }
306 }
307 secret
308 } else {
309 HashMap::from([(key, value)])
310 };
311 client.write_secret(&path, &secret).await
312 }
313
314 #[instrument(level = "debug", skip(ctx, self))]
315 async fn list_keys(
316 &self,
317 ctx: Option<Context>,
318 path: String,
319 skip: u64,
320 ) -> Result<keyvalue::store::KeyResponse> {
321 propagate_trace_for_ctx!(ctx);
322 let client = self.get_client(ctx).await?;
323 let secret = client.read_secret(&path).await?;
324 Ok(keyvalue::store::KeyResponse {
325 cursor: None,
326 keys: secret
327 .map(|secret| {
328 secret
329 .into_keys()
330 .skip(skip.try_into().unwrap_or(usize::MAX))
331 .collect()
332 })
333 .unwrap_or_default(),
334 })
335 }
336}
337
338impl keyvalue::store::Handler<Option<Context>> for KvVaultProvider {
339 #[instrument(level = "debug", skip(self))]
340 async fn delete(
341 &self,
342 context: Option<Context>,
343 bucket: String,
344 key: String,
345 ) -> anyhow::Result<Result<()>> {
346 propagate_trace_for_ctx!(context);
347 Ok(self.del(context, bucket, key).await)
348 }
349
350 #[instrument(level = "debug", skip(self))]
351 async fn exists(
352 &self,
353 context: Option<Context>,
354 bucket: String,
355 key: String,
356 ) -> anyhow::Result<Result<bool>> {
357 propagate_trace_for_ctx!(context);
358 Ok(self.contains(context, bucket, key).await)
359 }
360
361 #[instrument(level = "debug", skip(self))]
362 async fn get(
363 &self,
364 context: Option<Context>,
365 bucket: String,
366 key: String,
367 ) -> anyhow::Result<Result<Option<Bytes>>> {
368 propagate_trace_for_ctx!(context);
369 Ok(self.get(context, bucket, key).await)
370 }
371
372 #[instrument(level = "debug", skip(self))]
373 async fn set(
374 &self,
375 context: Option<Context>,
376 bucket: String,
377 key: String,
378 value: Bytes,
379 ) -> anyhow::Result<Result<()>> {
380 propagate_trace_for_ctx!(context);
381 Ok(self.set(context, bucket, key, value).await)
382 }
383
384 #[instrument(level = "debug", skip(self))]
385 async fn list_keys(
386 &self,
387 context: Option<Context>,
388 bucket: String,
389 cursor: Option<u64>,
390 ) -> anyhow::Result<Result<keyvalue::store::KeyResponse>> {
391 propagate_trace_for_ctx!(context);
392 Ok(self
393 .list_keys(context, bucket, cursor.unwrap_or_default())
394 .await)
395 }
396}
397
398impl Provider for KvVaultProvider {
401 #[instrument(level = "debug", skip_all, fields(source_id))]
405 async fn receive_link_config_as_target(
406 &self,
407 link_config: LinkConfig<'_>,
408 ) -> anyhow::Result<()> {
409 let LinkConfig {
410 source_id,
411 link_name,
412 ..
413 } = link_config;
414 debug!(
415 %source_id,
416 %link_name,
417 "adding link for component",
418 );
419
420 let config = match Config::from_link_config(&link_config) {
421 Ok(config) => config,
422 Err(e) => {
423 error!(
424 %source_id,
425 %link_name,
426 "failed to parse config: {e}",
427 );
428 bail!(anyhow!(e).context("failed to parse config"))
429 }
430 };
431
432 let client = match Client::new(config.clone()) {
433 Ok(client) => client,
434 Err(e) => {
435 error!(
436 %source_id,
437 %link_name,
438 "failed to create new client config: {e}",
439 );
440 return Err(anyhow!(e).context("failed to create new client config"));
441 }
442 };
443 client.set_renewal().await;
444
445 let mut update_map = self.components.write().await;
446 update_map.insert(source_id.to_string(), Arc::new(client));
447
448 Ok(())
449 }
450
451 #[instrument(level = "info", skip_all, fields(source_id = info.get_source_id()))]
453 async fn delete_link_as_target(&self, info: impl LinkDeleteInfo) -> anyhow::Result<()> {
454 let component_id = info.get_source_id();
455 let mut aw = self.components.write().await;
456 if let Some(client) = aw.remove(component_id) {
457 debug!(component_id, "deleting link for component");
458 drop(client);
459 }
460 Ok(())
461 }
462
463 async fn shutdown(&self) -> anyhow::Result<()> {
465 let mut aw = self.components.write().await;
466 for (_, client) in aw.drain() {
468 drop(client);
469 }
470 Ok(())
471 }
472}