wasmcloud_provider_sqldb_postgres/
lib.rs#![cfg(not(doctest))]
use std::collections::HashMap;
use std::sync::Arc;
use anyhow::{Context as _, Result};
use deadpool_postgres::Pool;
use futures::TryStreamExt as _;
use tokio::sync::RwLock;
use tokio_postgres::Statement;
use tracing::{error, instrument, warn};
use ulid::Ulid;
use wasmcloud_provider_sdk::{
get_connection, propagate_trace_for_ctx, run_provider, LinkConfig, LinkDeleteInfo, Provider,
};
use wasmcloud_provider_sdk::{initialize_observability, serve_provider_exports};
mod bindings;
use bindings::{
into_result_row, PgValue, PreparedStatementExecError, PreparedStatementToken, QueryError,
ResultRow, StatementPrepareError,
};
mod config;
use config::{extract_prefixed_conn_config, ConnectionCreateOptions};
use wasmcloud_provider_sdk::Context;
#[derive(Clone, Default)]
pub struct PostgresProvider {
connections: Arc<RwLock<HashMap<String, Pool>>>,
prepared_statements: Arc<RwLock<HashMap<PreparedStatementToken, (Statement, String)>>>,
}
impl PostgresProvider {
fn name() -> &'static str {
"sqldb-postgres-provider"
}
pub async fn run() -> anyhow::Result<()> {
initialize_observability!(
PostgresProvider::name(),
std::env::var_os("PROVIDER_SQLDB_POSTGRES_FLAMEGRAPH_PATH")
);
let provider = PostgresProvider::default();
let shutdown = run_provider(provider.clone(), PostgresProvider::name())
.await
.context("failed to run provider")?;
let connection = get_connection();
let wrpc = connection
.get_wrpc_client(connection.provider_key())
.await?;
serve_provider_exports(&wrpc, provider, shutdown, bindings::serve)
.await
.context("failed to serve provider exports")
}
async fn ensure_pool(
&self,
source_id: &str,
create_opts: ConnectionCreateOptions,
) -> Result<()> {
{
let connections = self.connections.read().await;
if connections.get(source_id).is_some() {
return Ok(());
}
}
let runtime = Some(deadpool_postgres::Runtime::Tokio1);
let tls_required = create_opts.tls_required;
let cfg = deadpool_postgres::Config::from(create_opts);
let pool = if tls_required {
create_tls_pool(cfg, runtime)
} else {
cfg.create_pool(runtime, tokio_postgres::NoTls)
.context("failed to create non-TLS postgres pool")
}?;
let mut connections = self.connections.write().await;
connections.insert(source_id.into(), pool);
Ok(())
}
async fn do_query(
&self,
source_id: &str,
query: &str,
params: Vec<PgValue>,
) -> Result<Vec<ResultRow>, QueryError> {
let connections = self.connections.read().await;
let pool = connections.get(source_id).ok_or_else(|| {
QueryError::Unexpected(format!(
"missing connection pool for source [{source_id}] while querying"
))
})?;
let client = pool.get().await.map_err(|e| {
QueryError::Unexpected(format!("failed to build client from pool: {e}"))
})?;
let rows = client
.query_raw(query, params)
.await
.map_err(|e| QueryError::Unexpected(format!("failed to perform query: {e}")))?;
rows.map_ok(into_result_row)
.try_collect::<Vec<_>>()
.await
.map_err(|e| QueryError::Unexpected(format!("failed to evaluate full row: {e}")))
}
async fn do_query_batch(&self, source_id: &str, query: &str) -> Result<(), QueryError> {
let connections = self.connections.read().await;
let pool = connections.get(source_id).ok_or_else(|| {
QueryError::Unexpected(format!(
"missing connection pool for source [{source_id}] while querying"
))
})?;
let client = pool.get().await.map_err(|e| {
QueryError::Unexpected(format!("failed to build client from pool: {e}"))
})?;
client
.batch_execute(query)
.await
.map_err(|e| QueryError::Unexpected(format!("failed to perform query: {e}")))?;
Ok(())
}
async fn do_statement_prepare(
&self,
connection_token: &str,
query: &str,
) -> Result<PreparedStatementToken, StatementPrepareError> {
let connections = self.connections.read().await;
let pool = connections.get(connection_token).ok_or_else(|| {
StatementPrepareError::Unexpected(format!(
"failed to find connection pool for token [{connection_token}]"
))
})?;
let client = pool.get().await.map_err(|e| {
StatementPrepareError::Unexpected(format!("failed to build client from pool: {e}"))
})?;
let statement = client.prepare(query).await.map_err(|e| {
StatementPrepareError::Unexpected(format!("failed to prepare query: {e}"))
})?;
let statement_token = format!("prepared-statement-{}", Ulid::new().to_string());
let mut prepared_statements = self.prepared_statements.write().await;
prepared_statements.insert(
statement_token.clone(),
(statement, connection_token.into()),
);
Ok(statement_token)
}
async fn do_statement_execute(
&self,
statement_token: &str,
params: Vec<PgValue>,
) -> Result<u64, PreparedStatementExecError> {
let statements = self.prepared_statements.read().await;
let (statement, connection_token) = statements.get(statement_token).ok_or_else(|| {
PreparedStatementExecError::Unexpected(format!(
"missing prepared statement with statement ID [{statement_token}]"
))
})?;
let connections = self.connections.read().await;
let pool = connections.get(connection_token).ok_or_else(|| {
PreparedStatementExecError::Unexpected(format!(
"missing connection pool for token [{connection_token}], statement ID [{statement_token}]"
))
})?;
let client = pool.get().await.map_err(|e| {
PreparedStatementExecError::Unexpected(format!("failed to build client from pool: {e}"))
})?;
let rows_affected = client.execute_raw(statement, params).await.map_err(|e| {
PreparedStatementExecError::Unexpected(format!(
"failed to execute prepared statement with token [{statement_token}]: {e}"
))
})?;
Ok(rows_affected)
}
}
impl Provider for PostgresProvider {
#[instrument(level = "debug", skip_all, fields(source_id))]
async fn receive_link_config_as_target(
&self,
link_config @ LinkConfig { source_id, .. }: LinkConfig<'_>,
) -> anyhow::Result<()> {
let Some(db_cfg) = extract_prefixed_conn_config("POSTGRES_", &link_config) else {
warn!(source_id, "no link-level DB configuration");
return Ok(());
};
if let Err(error) = self.ensure_pool(source_id, db_cfg).await {
error!(?error, source_id, "failed to create connection");
};
Ok(())
}
#[instrument(level = "info", skip_all, fields(source_id = info.get_source_id()))]
async fn delete_link_as_target(&self, info: impl LinkDeleteInfo) -> anyhow::Result<()> {
let component_id = info.get_source_id();
let mut prepared_statements = self.prepared_statements.write().await;
prepared_statements.retain(|_stmt_token, (_conn, src_id)| component_id != *src_id);
drop(prepared_statements);
let mut connections = self.connections.write().await;
connections.remove(component_id);
drop(connections);
Ok(())
}
#[instrument(level = "debug", skip_all)]
async fn shutdown(&self) -> anyhow::Result<()> {
let mut prepared_statements = self.prepared_statements.write().await;
prepared_statements.drain();
let mut connections = self.connections.write().await;
connections.drain();
Ok(())
}
}
impl bindings::query::Handler<Option<Context>> for PostgresProvider {
#[instrument(level = "debug", skip_all, fields(query))]
async fn query(
&self,
ctx: Option<Context>,
query: String,
params: Vec<PgValue>,
) -> Result<Result<Vec<ResultRow>, QueryError>> {
propagate_trace_for_ctx!(ctx);
let Some(Context {
component: Some(source_id),
..
}) = ctx
else {
return Ok(Err(QueryError::Unexpected(
"unexpectedly missing source ID".into(),
)));
};
Ok(self.do_query(&source_id, &query, params).await)
}
#[instrument(level = "debug", skip_all, fields(query))]
async fn query_batch(
&self,
ctx: Option<Context>,
query: String,
) -> Result<Result<(), QueryError>> {
propagate_trace_for_ctx!(ctx);
let Some(Context {
component: Some(source_id),
..
}) = ctx
else {
return Ok(Err(QueryError::Unexpected(
"unexpectedly missing source ID".into(),
)));
};
Ok(self.do_query_batch(&source_id, &query).await)
}
}
impl bindings::prepared::Handler<Option<Context>> for PostgresProvider {
#[instrument(level = "debug", skip_all, fields(query))]
async fn prepare(
&self,
ctx: Option<Context>,
query: String,
) -> Result<Result<PreparedStatementToken, StatementPrepareError>> {
propagate_trace_for_ctx!(ctx);
let Some(Context {
component: Some(source_id),
..
}) = ctx
else {
return Ok(Err(StatementPrepareError::Unexpected(
"unexpectedly missing source ID".into(),
)));
};
Ok(self.do_statement_prepare(&source_id, &query).await)
}
#[instrument(level = "debug", skip_all, fields(statement_token))]
async fn exec(
&self,
ctx: Option<Context>,
statement_token: PreparedStatementToken,
params: Vec<PgValue>,
) -> Result<Result<u64, PreparedStatementExecError>> {
propagate_trace_for_ctx!(ctx);
Ok(self.do_statement_execute(&statement_token, params).await)
}
}
fn create_tls_pool(
cfg: deadpool_postgres::Config,
runtime: Option<deadpool_postgres::Runtime>,
) -> Result<Pool> {
let mut store = rustls::RootCertStore::empty();
store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
cfg.create_pool(
runtime,
tokio_postgres_rustls::MakeRustlsConnect::new(
rustls::ClientConfig::builder()
.with_root_certificates(store)
.with_no_client_auth(),
),
)
.context("failed to create TLS-enabled connection pool")
}