pub(crate) mod config;
use core::str;
use core::time::Duration;
use std::collections::{hash_map, HashMap};
use std::string::ToString;
use std::sync::Arc;
use anyhow::{anyhow, bail, Context as _};
use base64::Engine as _;
use bytes::Bytes;
use tokio::sync::{Mutex, RwLock};
use tokio::task::JoinHandle;
use tracing::{debug, error, info, instrument, warn};
use vaultrs::client::{Client as _, VaultClient, VaultClientSettings};
use wasmcloud_provider_sdk::{
get_connection, load_host_data, propagate_trace_for_ctx, run_provider, Context, LinkConfig,
LinkDeleteInfo, Provider,
};
use wasmcloud_provider_sdk::{initialize_observability, serve_provider_exports};
use crate::config::Config;
mod bindings {
wit_bindgen_wrpc::generate!({
with: {
"wrpc:keyvalue/store@0.2.0-draft": generate,
}
});
}
use bindings::exports::wrpc::keyvalue;
type Result<T, E = keyvalue::store::Error> = core::result::Result<T, E>;
const API_VERSION: u8 = 1;
pub const TOKEN_INCREMENT_TTL: &str = "72h";
pub const TOKEN_REFRESH_INTERVAL: Duration = Duration::from_secs(60 * 60 * 12); pub async fn run() -> anyhow::Result<()> {
KvVaultProvider::run().await
}
#[derive(Clone)]
pub struct Client {
inner: Arc<vaultrs::client::VaultClient>,
namespace: String,
token_increment_ttl: String,
token_refresh_interval: Duration,
renew_task: Arc<Mutex<Option<JoinHandle<()>>>>,
}
impl Client {
pub fn new(config: Config) -> Result<Self, vaultrs::error::ClientError> {
let client = VaultClient::new(VaultClientSettings {
token: config.token,
address: config.addr,
ca_certs: config.certs,
verify: false,
version: API_VERSION,
wrapping: false,
timeout: None,
namespace: None,
identity: None,
})?;
Ok(Self {
inner: Arc::new(client),
namespace: config.mount,
token_increment_ttl: config
.token_increment_ttl
.unwrap_or(TOKEN_INCREMENT_TTL.into()),
token_refresh_interval: config
.token_refresh_interval
.unwrap_or(TOKEN_REFRESH_INTERVAL),
renew_task: Arc::default(),
})
}
pub async fn read_secret(&self, path: &str) -> Result<Option<HashMap<String, String>>> {
match vaultrs::kv2::read(self.inner.as_ref(), &self.namespace, path).await {
Err(vaultrs::error::ClientError::APIError {
code: 404,
errors: _,
}) => Ok(None),
Err(err) => {
error!(error = %err, "failed to read secret");
Err(keyvalue::store::Error::Other(format!(
"{:#}",
anyhow!(err).context("failed to read secret")
)))
}
Ok(val) => Ok(val),
}
}
pub async fn write_secret(&self, path: &str, data: &HashMap<String, String>) -> Result<()> {
let md = vaultrs::kv2::set(self.inner.as_ref(), &self.namespace, path, data)
.await
.map_err(|err| {
error!(error = %err, "failed to write secret");
keyvalue::store::Error::Other(format!(
"{:#}",
anyhow!(err).context("failed to write secret")
))
})?;
debug!(?md, "set returned metadata");
Ok(())
}
pub async fn set_renewal(&self) {
let mut renew_task = self.renew_task.lock().await;
if let Some(handle) = renew_task.take() {
handle.abort();
}
let client = self.inner.clone();
let interval = self.token_refresh_interval;
let ttl = self.token_increment_ttl.clone();
*renew_task = Some(tokio::spawn(async move {
let mut next_interval = tokio::time::interval(interval);
loop {
next_interval.tick().await;
let _ = renew_self(&client, ttl.as_str()).await;
}
}));
}
}
impl Drop for Client {
fn drop(&mut self) {
if let Ok(mut renew_task) = self.renew_task.try_lock() {
if let Some(handle) = renew_task.take() {
handle.abort();
}
}
}
}
async fn renew_self(
client: &VaultClient,
increment: &str,
) -> Result<(), vaultrs::error::ClientError> {
debug!("renewing token");
client.renew(Some(increment)).await.map_err(|e| {
error!("error renewing self token: {}", e);
e
})?;
let info = client.lookup().await.map_err(|e| {
error!("error looking up self token: {}", e);
e
})?;
let expire_time = info.expire_time.unwrap_or_else(|| "None".to_string());
info!(%expire_time, accessor = %info.accessor, "renewed token");
Ok(())
}
#[derive(Default, Clone)]
pub struct KvVaultProvider {
components: Arc<RwLock<HashMap<String, Arc<Client>>>>,
}
impl KvVaultProvider {
pub fn name() -> &'static str {
"keyvalue-vault-provider"
}
pub async fn run() -> anyhow::Result<()> {
let host_data = load_host_data().context("failed to load host data")?;
let flamegraph_path = host_data
.config
.get("FLAMEGRAPH_PATH")
.map(String::from)
.or_else(|| std::env::var("PROVIDER_KEYVALUE_VAULT_FLAMEGRAPH_PATH").ok());
initialize_observability!(Self::name(), flamegraph_path);
let provider = Self::default();
let shutdown = run_provider(provider.clone(), KvVaultProvider::name())
.await
.context("failed to run provider")?;
let connection = get_connection();
let wrpc = connection
.get_wrpc_client(connection.provider_key())
.await?;
serve_provider_exports(&wrpc, provider, shutdown, bindings::serve)
.await
.context("failed to serve provider exports")
}
async fn get_client(&self, ctx: Option<Context>) -> Result<Arc<Client>> {
let ctx = ctx.ok_or_else(|| {
warn!("invocation context missing");
keyvalue::store::Error::Other("invocation context missing".into())
})?;
let source_id = ctx.component.as_ref().ok_or_else(|| {
warn!("source ID missing");
keyvalue::store::Error::Other("source ID missing".into())
})?;
let links = self.components.read().await;
links.get(source_id).cloned().ok_or_else(|| {
warn!(source_id, "source ID not linked");
keyvalue::store::Error::Other("source ID not linked".into())
})
}
#[instrument(level = "debug", skip(ctx, self))]
async fn get(&self, ctx: Option<Context>, path: String, key: String) -> Result<Option<Bytes>> {
propagate_trace_for_ctx!(ctx);
let client = self.get_client(ctx).await?;
if let Some(mut secret) = client.read_secret(&path).await? {
match secret.remove(&key) {
Some(value) => {
let value = base64::engine::general_purpose::STANDARD_NO_PAD
.decode(value)
.map_err(|err| {
error!(?err, "failed to decode secret value");
keyvalue::store::Error::Other(format!(
"{:#}",
anyhow!(err).context("failed to decode secret value")
))
})?;
Ok(Some(value.into()))
}
None => Ok(None),
}
} else {
Ok(None)
}
}
#[instrument(level = "debug", skip(ctx, self))]
async fn contains(&self, ctx: Option<Context>, path: String, key: String) -> Result<bool> {
propagate_trace_for_ctx!(ctx);
let client = self.get_client(ctx).await?;
let secret = client.read_secret(&path).await?;
Ok(secret.is_some_and(|secret| secret.contains_key(&key)))
}
#[instrument(level = "debug", skip(ctx, self))]
async fn del(&self, ctx: Option<Context>, path: String, key: String) -> Result<()> {
propagate_trace_for_ctx!(ctx);
let client = self.get_client(ctx).await?;
let secret = client.read_secret(&path).await?;
let secret = if let Some(mut secret) = secret {
if secret.remove(&key).is_none() {
debug!("key does not exist in the secret");
return Ok(());
}
secret
} else {
debug!("secret not found");
return Ok(());
};
client.write_secret(&path, &secret).await
}
#[instrument(level = "debug", skip(ctx, self))]
async fn set(
&self,
ctx: Option<Context>,
path: String,
key: String,
value: Bytes,
) -> Result<()> {
propagate_trace_for_ctx!(ctx);
let client = self.get_client(ctx).await?;
let value = base64::engine::general_purpose::STANDARD_NO_PAD.encode(value);
let secret = client.read_secret(&path).await?;
let secret = if let Some(mut secret) = secret {
match secret.entry(key) {
hash_map::Entry::Vacant(e) => {
e.insert(value);
}
hash_map::Entry::Occupied(mut e) => {
if *e.get() == value {
return Ok(());
}
e.insert(value);
}
}
secret
} else {
HashMap::from([(key, value)])
};
client.write_secret(&path, &secret).await
}
#[instrument(level = "debug", skip(ctx, self))]
async fn list_keys(
&self,
ctx: Option<Context>,
path: String,
skip: u64,
) -> Result<keyvalue::store::KeyResponse> {
propagate_trace_for_ctx!(ctx);
let client = self.get_client(ctx).await?;
let secret = client.read_secret(&path).await?;
Ok(keyvalue::store::KeyResponse {
cursor: None,
keys: secret
.map(|secret| {
secret
.into_keys()
.skip(skip.try_into().unwrap_or(usize::MAX))
.collect()
})
.unwrap_or_default(),
})
}
}
impl keyvalue::store::Handler<Option<Context>> for KvVaultProvider {
#[instrument(level = "debug", skip(self))]
async fn delete(
&self,
context: Option<Context>,
bucket: String,
key: String,
) -> anyhow::Result<Result<()>> {
propagate_trace_for_ctx!(context);
Ok(self.del(context, bucket, key).await)
}
#[instrument(level = "debug", skip(self))]
async fn exists(
&self,
context: Option<Context>,
bucket: String,
key: String,
) -> anyhow::Result<Result<bool>> {
propagate_trace_for_ctx!(context);
Ok(self.contains(context, bucket, key).await)
}
#[instrument(level = "debug", skip(self))]
async fn get(
&self,
context: Option<Context>,
bucket: String,
key: String,
) -> anyhow::Result<Result<Option<Bytes>>> {
propagate_trace_for_ctx!(context);
Ok(self.get(context, bucket, key).await)
}
#[instrument(level = "debug", skip(self))]
async fn set(
&self,
context: Option<Context>,
bucket: String,
key: String,
value: Bytes,
) -> anyhow::Result<Result<()>> {
propagate_trace_for_ctx!(context);
Ok(self.set(context, bucket, key, value).await)
}
#[instrument(level = "debug", skip(self))]
async fn list_keys(
&self,
context: Option<Context>,
bucket: String,
cursor: Option<u64>,
) -> anyhow::Result<Result<keyvalue::store::KeyResponse>> {
propagate_trace_for_ctx!(context);
Ok(self
.list_keys(context, bucket, cursor.unwrap_or_default())
.await)
}
}
impl Provider for KvVaultProvider {
#[instrument(level = "debug", skip_all, fields(source_id))]
async fn receive_link_config_as_target(
&self,
link_config: LinkConfig<'_>,
) -> anyhow::Result<()> {
let LinkConfig {
source_id,
link_name,
..
} = link_config;
debug!(
%source_id,
%link_name,
"adding link for component",
);
let config = match Config::from_link_config(&link_config) {
Ok(config) => config,
Err(e) => {
error!(
%source_id,
%link_name,
"failed to parse config: {e}",
);
bail!(anyhow!(e).context("failed to parse config"))
}
};
let client = match Client::new(config.clone()) {
Ok(client) => client,
Err(e) => {
error!(
%source_id,
%link_name,
"failed to create new client config: {e}",
);
return Err(anyhow!(e).context("failed to create new client config"));
}
};
client.set_renewal().await;
let mut update_map = self.components.write().await;
update_map.insert(source_id.to_string(), Arc::new(client));
Ok(())
}
#[instrument(level = "info", skip_all, fields(source_id = info.get_source_id()))]
async fn delete_link_as_target(&self, info: impl LinkDeleteInfo) -> anyhow::Result<()> {
let component_id = info.get_source_id();
let mut aw = self.components.write().await;
if let Some(client) = aw.remove(component_id) {
debug!(component_id, "deleting link for component");
drop(client);
}
Ok(())
}
async fn shutdown(&self) -> anyhow::Result<()> {
let mut aw = self.components.write().await;
for (_, client) in aw.drain() {
drop(client);
}
Ok(())
}
}