wasmcloud_provider_sqldb_postgres/
lib.rs

1#![cfg(not(doctest))]
2
3//! SQL-powered database access provider implementing `wasmcloud:postgres` for connecting
4//! to Postgres clusters.
5//!
6//! This implementation is multi-threaded and operations between different actors
7//! use different connections and can run in parallel.
8//!
9
10use std::collections::HashMap;
11use std::sync::Arc;
12
13use anyhow::{Context as _, Result};
14use deadpool_postgres::Pool;
15use futures::TryStreamExt as _;
16use sha2::{Digest as _, Sha256};
17use tokio::sync::RwLock;
18use tokio_postgres::types::Type as PgType;
19use tracing::{error, instrument, warn};
20use ulid::Ulid;
21
22use wasmcloud_provider_sdk::{
23    get_connection, propagate_trace_for_ctx, run_provider, LinkConfig, LinkDeleteInfo, Provider,
24};
25use wasmcloud_provider_sdk::{initialize_observability, serve_provider_exports};
26
27mod bindings;
28use bindings::{
29    into_result_row, PgValue, PreparedStatementExecError, PreparedStatementToken, QueryError,
30    ResultRow, StatementPrepareError,
31};
32
33mod config;
34use config::{extract_prefixed_conn_config, ConnectionCreateOptions};
35
36use wasmcloud_provider_sdk::Context;
37
38/// Whether to share connections by URL
39///
40/// This option indicates that URLs with identical connection configurations will be shared/reused by
41/// components that are linked with the same configurations
42const CONFIG_SHARE_CONNECTIONS_BY_URL_KEY: &str = "POSTGRES_SHARE_CONNECTIONS_BY_URL";
43
44/// A unique identifier for a created connection
45type SourceId = String;
46
47/// A query used in the process of creating a prepared statement
48type PreparedStatementQuery = String;
49
50/// Parameters determined to be used in a statement
51///
52/// This value is usually constructed after running a prepare against a given
53/// client from a given pool, and saving the relevant type information.
54type StatementParams = Vec<PgType>;
55
56/// Information about a given prepared statement
57type PreparedStatementInfo = (PreparedStatementQuery, StatementParams, SourceId);
58
59/// Shared connection keys are keys that identify shared connections
60///
61/// This is the hash of the connection configuration, to avoid printing credentials inadvertently.
62type SharedConnectionKey = String;
63
64/// Type of postgres connection - either direct or shared
65#[derive(Clone)]
66enum PostgresConnection {
67    /// Direct connection to a pool
68    Direct(Pool),
69    /// Shared connection, identified by the hash of the connection configuration
70    Shared(String),
71}
72
73#[derive(Clone, Default)]
74pub struct PostgresProvider {
75    /// Database connections indexed by source ID name
76    connections: Arc<RwLock<HashMap<SourceId, PostgresConnection>>>,
77    /// Shared connection pools indexed by configuration hash
78    shared_connections: Arc<RwLock<HashMap<SharedConnectionKey, Pool>>>,
79    /// Lookup of prepared statements to the statement and the source ID that prepared them
80    prepared_statements: Arc<RwLock<HashMap<PreparedStatementToken, PreparedStatementInfo>>>,
81}
82
83impl PostgresProvider {
84    fn name() -> &'static str {
85        "sqldb-postgres-provider"
86    }
87
88    /// Generate a connection string from ConnectionCreateOptions for hashing
89    fn connection_string_for_hashing(opts: &ConnectionCreateOptions) -> String {
90        format!(
91            "postgres://{}:{}@{}:{}/{}?tls_required={}&pool_size={:?}",
92            opts.username,
93            opts.password,
94            opts.host,
95            opts.port,
96            opts.database,
97            opts.tls_required,
98            opts.pool_size
99        )
100    }
101
102    /// Get a pool for the given source_id, resolving shared connections if necessary
103    async fn get_pool(&self, source_id: &str) -> Result<Pool, String> {
104        let connections = self.connections.read().await;
105        let connection = connections
106            .get(source_id)
107            .ok_or_else(|| format!("missing connection pool for source [{source_id}]"))?;
108
109        match connection {
110            PostgresConnection::Direct(pool) => Ok(pool.clone()),
111            PostgresConnection::Shared(key) => {
112                let shared = self.shared_connections.read().await;
113                shared
114                    .get(key)
115                    .cloned()
116                    .ok_or_else(|| format!("no shared connection found with key [{key}]"))
117            }
118        }
119    }
120
121    /// Run [`PostgresProvider`] as a wasmCloud provider
122    pub async fn run() -> anyhow::Result<()> {
123        initialize_observability!(
124            PostgresProvider::name(),
125            std::env::var_os("PROVIDER_SQLDB_POSTGRES_FLAMEGRAPH_PATH")
126        );
127        let provider = PostgresProvider::default();
128        let shutdown = run_provider(provider.clone(), PostgresProvider::name())
129            .await
130            .context("failed to run provider")?;
131        let connection = get_connection();
132        let wrpc = connection
133            .get_wrpc_client(connection.provider_key())
134            .await?;
135        serve_provider_exports(&wrpc, provider, shutdown, bindings::serve)
136            .await
137            .context("failed to serve provider exports")
138    }
139
140    /// Create and store a connection pool, if not already present
141    async fn ensure_pool(
142        &self,
143        source_id: &str,
144        create_opts: ConnectionCreateOptions,
145        share_connections: bool,
146    ) -> Result<()> {
147        // If sharing is enabled, check if we already have a shared connection for this configuration
148        if share_connections {
149            let connection_string = Self::connection_string_for_hashing(&create_opts);
150            let shared_key = format!("{:X}", Sha256::digest(&connection_string));
151
152            // Check if we already have this shared connection
153            {
154                let shared_connections = self.shared_connections.read().await;
155                if shared_connections.contains_key(&shared_key) {
156                    let mut connections = self.connections.write().await;
157                    connections.insert(source_id.into(), PostgresConnection::Shared(shared_key));
158                    return Ok(());
159                }
160            }
161        }
162
163        // Exit early if a pool with the given source ID is already present
164        {
165            let connections = self.connections.read().await;
166            if connections.get(source_id).is_some() {
167                return Ok(());
168            }
169        }
170
171        // Build the new connection pool
172        let runtime = Some(deadpool_postgres::Runtime::Tokio1);
173        let tls_required = create_opts.tls_required;
174        let cfg = deadpool_postgres::Config::from(create_opts.clone());
175        let pool = if tls_required {
176            create_tls_pool(cfg, runtime)
177        } else {
178            cfg.create_pool(runtime, tokio_postgres::NoTls)
179                .context("failed to create non-TLS postgres pool")
180        }?;
181
182        if share_connections {
183            // Store as shared connection
184            let connection_string = Self::connection_string_for_hashing(&create_opts);
185            let shared_key = format!("{:X}", Sha256::digest(&connection_string));
186
187            // Store the shared connection first, then reference it
188            let mut shared_connections = self.shared_connections.write().await;
189            shared_connections.insert(shared_key.clone(), pool);
190            drop(shared_connections);
191
192            let mut connections = self.connections.write().await;
193            connections.insert(source_id.into(), PostgresConnection::Shared(shared_key));
194        } else {
195            // Store as direct connection
196            let mut connections = self.connections.write().await;
197            connections.insert(source_id.into(), PostgresConnection::Direct(pool));
198        }
199
200        Ok(())
201    }
202
203    /// Perform a query
204    async fn do_query(
205        &self,
206        source_id: &str,
207        query: &str,
208        params: Vec<PgValue>,
209    ) -> Result<Vec<ResultRow>, QueryError> {
210        let pool = self.get_pool(source_id).await.map_err(|e| {
211            QueryError::Unexpected(format!(
212                "missing connection pool for source [{source_id}] while querying: {e}"
213            ))
214        })?;
215
216        let client = pool.get().await.map_err(|e| {
217            QueryError::Unexpected(format!("failed to build client from pool: {e}"))
218        })?;
219
220        let rows = client
221            .query_raw(query, params)
222            .await
223            .map_err(|e| QueryError::Unexpected(format!("failed to perform query: {e}")))?;
224
225        // todo(fix): once async stream support is available & in contract
226        // replace this with a mapped stream
227        rows.map_ok(into_result_row)
228            .try_collect::<Vec<_>>()
229            .await
230            .map_err(|e| QueryError::Unexpected(format!("failed to evaluate full row: {e}")))
231    }
232
233    /// Perform a raw query
234    async fn do_query_batch(&self, source_id: &str, query: &str) -> Result<(), QueryError> {
235        let pool = self.get_pool(source_id).await.map_err(|e| {
236            QueryError::Unexpected(format!(
237                "missing connection pool for source [{source_id}] while querying: {e}"
238            ))
239        })?;
240
241        let client = pool.get().await.map_err(|e| {
242            QueryError::Unexpected(format!("failed to build client from pool: {e}"))
243        })?;
244
245        client
246            .batch_execute(query)
247            .await
248            .map_err(|e| QueryError::Unexpected(format!("failed to perform query: {e}")))?;
249
250        Ok(())
251    }
252
253    /// Prepare a statement
254    async fn do_statement_prepare(
255        &self,
256        source_id: &str,
257        query: &str,
258    ) -> Result<PreparedStatementToken, StatementPrepareError> {
259        let pool = self.get_pool(source_id).await.map_err(|e| {
260            StatementPrepareError::Unexpected(format!(
261                "failed to find connection pool for token [{source_id}]: {e}"
262            ))
263        })?;
264
265        let client = pool.get().await.map_err(|e| {
266            StatementPrepareError::Unexpected(format!("failed to build client from pool: {e}"))
267        })?;
268
269        let statement = client.prepare(query).await.map_err(|e| {
270            StatementPrepareError::Unexpected(format!("failed to prepare query: {e}"))
271        })?;
272
273        let statement_token = format!("prepared-statement-{}", Ulid::new().to_string());
274
275        let mut prepared_statements = self.prepared_statements.write().await;
276        prepared_statements.insert(
277            statement_token.clone(),
278            (query.into(), statement.params().into(), source_id.into()),
279        );
280
281        Ok(statement_token)
282    }
283
284    /// Execute a prepared statement, returning the number of rows affected
285    async fn do_statement_execute(
286        &self,
287        statement_token: &str,
288        params: Vec<PgValue>,
289    ) -> Result<u64, PreparedStatementExecError> {
290        let statements = self.prepared_statements.read().await;
291        let (query, types, source_id) = statements.get(statement_token).ok_or_else(|| {
292            PreparedStatementExecError::Unexpected(format!(
293                "missing prepared statement with statement ID [{statement_token}]"
294            ))
295        })?;
296
297        let pool = self.get_pool(source_id).await.map_err(|e| {
298            PreparedStatementExecError::Unexpected(format!(
299                "missing connection pool for token [{source_id}], statement ID [{statement_token}]: {e}"
300            ))
301        })?;
302        let client = pool.get().await.map_err(|e| {
303            PreparedStatementExecError::Unexpected(format!("failed to build client from pool: {e}"))
304        })?;
305
306        // Since the pool is not aware of already created statements managed by tokio_postgres,
307        // we may have pulled a client that has not already has this statement prepared,
308        // so we must prepare, just in case.
309        let statement = client
310            .statement_cache
311            .prepare_typed(&client, query, types)
312            .await
313            .map_err(|e| {
314                PreparedStatementExecError::Unexpected(format!(
315                    "failed to prepare statement for client in pool: {e}"
316                ))
317            })?;
318
319        let rows_affected = client.execute_raw(&statement, params).await.map_err(|e| {
320            PreparedStatementExecError::Unexpected(format!(
321                "failed to execute prepared statement with token [{statement_token}]: {e}"
322            ))
323        })?;
324
325        Ok(rows_affected)
326    }
327}
328
329impl Provider for PostgresProvider {
330    /// Handle being linked to a source (likely a component) as a target
331    ///
332    /// Components are expected to provide references to named configuration via link definitions
333    /// which contain keys named `POSTGRES_*` detailing configuration for connecting to Postgres.
334    #[instrument(level = "debug", skip_all, fields(source_id))]
335    async fn receive_link_config_as_target(
336        &self,
337        link_config @ LinkConfig { source_id, .. }: LinkConfig<'_>,
338    ) -> anyhow::Result<()> {
339        // Attempt to parse a configuration from the map with the prefix POSTGRES_
340        let Some(db_cfg) = extract_prefixed_conn_config("POSTGRES_", &link_config) else {
341            // If we failed to find a config on the link, then we
342            warn!(source_id, "no link-level DB configuration");
343            return Ok(());
344        };
345
346        // Check if connection sharing is enabled
347        let share_connections = if let Some(value) =
348            link_config.config.get(CONFIG_SHARE_CONNECTIONS_BY_URL_KEY)
349        {
350            matches!(value.to_lowercase().as_str(), "true" | "yes")
351        } else if let Some(secret) = link_config.secrets.get(CONFIG_SHARE_CONNECTIONS_BY_URL_KEY) {
352            if let Some(value) = secret.as_string() {
353                matches!(value.to_lowercase().as_str(), "true" | "yes")
354            } else {
355                false
356            }
357        } else {
358            false
359        };
360
361        // Create a pool if one isn't already present for this particular source
362        if let Err(error) = self.ensure_pool(source_id, db_cfg, share_connections).await {
363            error!(?error, source_id, "failed to create connection");
364        };
365
366        Ok(())
367    }
368
369    /// Handle notification that a link is dropped
370    ///
371    /// Generally we can release the resources (connections) associated with the source
372    #[instrument(level = "info", skip_all, fields(source_id = info.get_source_id()))]
373    async fn delete_link_as_target(&self, info: impl LinkDeleteInfo) -> anyhow::Result<()> {
374        let source_id = info.get_source_id();
375        let mut prepared_statements = self.prepared_statements.write().await;
376        prepared_statements.retain(|_stmt_token, (_query, _statement, src_id)| src_id != source_id);
377        drop(prepared_statements);
378        let mut connections = self.connections.write().await;
379        connections.remove(source_id);
380        drop(connections);
381        Ok(())
382    }
383
384    /// Handle shutdown request by closing all connections
385    #[instrument(level = "debug", skip_all)]
386    async fn shutdown(&self) -> anyhow::Result<()> {
387        let mut prepared_statements = self.prepared_statements.write().await;
388        prepared_statements.drain();
389        let mut connections = self.connections.write().await;
390        connections.drain();
391        Ok(())
392    }
393}
394
395/// Implement the `wasmcloud:postgres/query` interface for [`PostgresProvider`]
396impl bindings::query::Handler<Option<Context>> for PostgresProvider {
397    #[instrument(level = "debug", skip_all, fields(query))]
398    async fn query(
399        &self,
400        ctx: Option<Context>,
401        query: String,
402        params: Vec<PgValue>,
403    ) -> Result<Result<Vec<ResultRow>, QueryError>> {
404        propagate_trace_for_ctx!(ctx);
405        let Some(Context {
406            component: Some(source_id),
407            ..
408        }) = ctx
409        else {
410            return Ok(Err(QueryError::Unexpected(
411                "unexpectedly missing source ID".into(),
412            )));
413        };
414
415        Ok(self.do_query(&source_id, &query, params).await)
416    }
417
418    #[instrument(level = "debug", skip_all, fields(query))]
419    async fn query_batch(
420        &self,
421        ctx: Option<Context>,
422        query: String,
423    ) -> Result<Result<(), QueryError>> {
424        propagate_trace_for_ctx!(ctx);
425        let Some(Context {
426            component: Some(source_id),
427            ..
428        }) = ctx
429        else {
430            return Ok(Err(QueryError::Unexpected(
431                "unexpectedly missing source ID".into(),
432            )));
433        };
434
435        Ok(self.do_query_batch(&source_id, &query).await)
436    }
437}
438
439/// Implement the `wasmcloud:postgres/prepared` interface for [`PostgresProvider`]
440impl bindings::prepared::Handler<Option<Context>> for PostgresProvider {
441    #[instrument(level = "debug", skip_all, fields(query))]
442    async fn prepare(
443        &self,
444        ctx: Option<Context>,
445        query: String,
446    ) -> Result<Result<PreparedStatementToken, StatementPrepareError>> {
447        propagate_trace_for_ctx!(ctx);
448        let Some(Context {
449            component: Some(source_id),
450            ..
451        }) = ctx
452        else {
453            return Ok(Err(StatementPrepareError::Unexpected(
454                "unexpectedly missing source ID".into(),
455            )));
456        };
457        Ok(self.do_statement_prepare(&source_id, &query).await)
458    }
459
460    #[instrument(level = "debug", skip_all, fields(statement_token))]
461    async fn exec(
462        &self,
463        ctx: Option<Context>,
464        statement_token: PreparedStatementToken,
465        params: Vec<PgValue>,
466    ) -> Result<Result<u64, PreparedStatementExecError>> {
467        propagate_trace_for_ctx!(ctx);
468        Ok(self.do_statement_execute(&statement_token, params).await)
469    }
470}
471
472fn create_tls_pool(
473    cfg: deadpool_postgres::Config,
474    runtime: Option<deadpool_postgres::Runtime>,
475) -> Result<Pool> {
476    let mut store = rustls::RootCertStore::empty();
477    store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
478    cfg.create_pool(
479        runtime,
480        tokio_postgres_rustls::MakeRustlsConnect::new(
481            rustls::ClientConfig::builder()
482                .with_root_certificates(store)
483                .with_no_client_auth(),
484        ),
485    )
486    .context("failed to create TLS-enabled connection pool")
487}