wasmcloud_provider_sqldb_postgres/
lib.rs1#![cfg(not(doctest))]
2
3use std::collections::HashMap;
11use std::sync::Arc;
12
13use anyhow::{Context as _, Result};
14use deadpool_postgres::Pool;
15use futures::TryStreamExt as _;
16use sha2::{Digest as _, Sha256};
17use tokio::sync::RwLock;
18use tokio_postgres::types::Type as PgType;
19use tracing::{error, instrument, warn};
20use ulid::Ulid;
21
22use wasmcloud_provider_sdk::{
23 get_connection, propagate_trace_for_ctx, run_provider, LinkConfig, LinkDeleteInfo, Provider,
24};
25use wasmcloud_provider_sdk::{initialize_observability, serve_provider_exports};
26
27mod bindings;
28use bindings::{
29 into_result_row, PgValue, PreparedStatementExecError, PreparedStatementToken, QueryError,
30 ResultRow, StatementPrepareError,
31};
32
33mod config;
34use config::{extract_prefixed_conn_config, ConnectionCreateOptions};
35
36use wasmcloud_provider_sdk::Context;
37
38const CONFIG_SHARE_CONNECTIONS_BY_URL_KEY: &str = "POSTGRES_SHARE_CONNECTIONS_BY_URL";
43
44type SourceId = String;
46
47type PreparedStatementQuery = String;
49
50type StatementParams = Vec<PgType>;
55
56type PreparedStatementInfo = (PreparedStatementQuery, StatementParams, SourceId);
58
59type SharedConnectionKey = String;
63
64#[derive(Clone)]
66enum PostgresConnection {
67 Direct(Pool),
69 Shared(String),
71}
72
73#[derive(Clone, Default)]
74pub struct PostgresProvider {
75 connections: Arc<RwLock<HashMap<SourceId, PostgresConnection>>>,
77 shared_connections: Arc<RwLock<HashMap<SharedConnectionKey, Pool>>>,
79 prepared_statements: Arc<RwLock<HashMap<PreparedStatementToken, PreparedStatementInfo>>>,
81}
82
83impl PostgresProvider {
84 fn name() -> &'static str {
85 "sqldb-postgres-provider"
86 }
87
88 fn connection_string_for_hashing(opts: &ConnectionCreateOptions) -> String {
90 format!(
91 "postgres://{}:{}@{}:{}/{}?tls_required={}&pool_size={:?}",
92 opts.username,
93 opts.password,
94 opts.host,
95 opts.port,
96 opts.database,
97 opts.tls_required,
98 opts.pool_size
99 )
100 }
101
102 async fn get_pool(&self, source_id: &str) -> Result<Pool, String> {
104 let connections = self.connections.read().await;
105 let connection = connections
106 .get(source_id)
107 .ok_or_else(|| format!("missing connection pool for source [{source_id}]"))?;
108
109 match connection {
110 PostgresConnection::Direct(pool) => Ok(pool.clone()),
111 PostgresConnection::Shared(key) => {
112 let shared = self.shared_connections.read().await;
113 shared
114 .get(key)
115 .cloned()
116 .ok_or_else(|| format!("no shared connection found with key [{key}]"))
117 }
118 }
119 }
120
121 pub async fn run() -> anyhow::Result<()> {
123 initialize_observability!(
124 PostgresProvider::name(),
125 std::env::var_os("PROVIDER_SQLDB_POSTGRES_FLAMEGRAPH_PATH")
126 );
127 let provider = PostgresProvider::default();
128 let shutdown = run_provider(provider.clone(), PostgresProvider::name())
129 .await
130 .context("failed to run provider")?;
131 let connection = get_connection();
132 let wrpc = connection
133 .get_wrpc_client(connection.provider_key())
134 .await?;
135 serve_provider_exports(&wrpc, provider, shutdown, bindings::serve)
136 .await
137 .context("failed to serve provider exports")
138 }
139
140 async fn ensure_pool(
142 &self,
143 source_id: &str,
144 create_opts: ConnectionCreateOptions,
145 share_connections: bool,
146 ) -> Result<()> {
147 if share_connections {
149 let connection_string = Self::connection_string_for_hashing(&create_opts);
150 let shared_key = format!("{:X}", Sha256::digest(&connection_string));
151
152 {
154 let shared_connections = self.shared_connections.read().await;
155 if shared_connections.contains_key(&shared_key) {
156 let mut connections = self.connections.write().await;
157 connections.insert(source_id.into(), PostgresConnection::Shared(shared_key));
158 return Ok(());
159 }
160 }
161 }
162
163 {
165 let connections = self.connections.read().await;
166 if connections.get(source_id).is_some() {
167 return Ok(());
168 }
169 }
170
171 let runtime = Some(deadpool_postgres::Runtime::Tokio1);
173 let tls_required = create_opts.tls_required;
174 let cfg = deadpool_postgres::Config::from(create_opts.clone());
175 let pool = if tls_required {
176 create_tls_pool(cfg, runtime)
177 } else {
178 cfg.create_pool(runtime, tokio_postgres::NoTls)
179 .context("failed to create non-TLS postgres pool")
180 }?;
181
182 if share_connections {
183 let connection_string = Self::connection_string_for_hashing(&create_opts);
185 let shared_key = format!("{:X}", Sha256::digest(&connection_string));
186
187 let mut shared_connections = self.shared_connections.write().await;
189 shared_connections.insert(shared_key.clone(), pool);
190 drop(shared_connections);
191
192 let mut connections = self.connections.write().await;
193 connections.insert(source_id.into(), PostgresConnection::Shared(shared_key));
194 } else {
195 let mut connections = self.connections.write().await;
197 connections.insert(source_id.into(), PostgresConnection::Direct(pool));
198 }
199
200 Ok(())
201 }
202
203 async fn do_query(
205 &self,
206 source_id: &str,
207 query: &str,
208 params: Vec<PgValue>,
209 ) -> Result<Vec<ResultRow>, QueryError> {
210 let pool = self.get_pool(source_id).await.map_err(|e| {
211 QueryError::Unexpected(format!(
212 "missing connection pool for source [{source_id}] while querying: {e}"
213 ))
214 })?;
215
216 let client = pool.get().await.map_err(|e| {
217 QueryError::Unexpected(format!("failed to build client from pool: {e}"))
218 })?;
219
220 let rows = client
221 .query_raw(query, params)
222 .await
223 .map_err(|e| QueryError::Unexpected(format!("failed to perform query: {e}")))?;
224
225 rows.map_ok(into_result_row)
228 .try_collect::<Vec<_>>()
229 .await
230 .map_err(|e| QueryError::Unexpected(format!("failed to evaluate full row: {e}")))
231 }
232
233 async fn do_query_batch(&self, source_id: &str, query: &str) -> Result<(), QueryError> {
235 let pool = self.get_pool(source_id).await.map_err(|e| {
236 QueryError::Unexpected(format!(
237 "missing connection pool for source [{source_id}] while querying: {e}"
238 ))
239 })?;
240
241 let client = pool.get().await.map_err(|e| {
242 QueryError::Unexpected(format!("failed to build client from pool: {e}"))
243 })?;
244
245 client
246 .batch_execute(query)
247 .await
248 .map_err(|e| QueryError::Unexpected(format!("failed to perform query: {e}")))?;
249
250 Ok(())
251 }
252
253 async fn do_statement_prepare(
255 &self,
256 source_id: &str,
257 query: &str,
258 ) -> Result<PreparedStatementToken, StatementPrepareError> {
259 let pool = self.get_pool(source_id).await.map_err(|e| {
260 StatementPrepareError::Unexpected(format!(
261 "failed to find connection pool for token [{source_id}]: {e}"
262 ))
263 })?;
264
265 let client = pool.get().await.map_err(|e| {
266 StatementPrepareError::Unexpected(format!("failed to build client from pool: {e}"))
267 })?;
268
269 let statement = client.prepare(query).await.map_err(|e| {
270 StatementPrepareError::Unexpected(format!("failed to prepare query: {e}"))
271 })?;
272
273 let statement_token = format!("prepared-statement-{}", Ulid::new().to_string());
274
275 let mut prepared_statements = self.prepared_statements.write().await;
276 prepared_statements.insert(
277 statement_token.clone(),
278 (query.into(), statement.params().into(), source_id.into()),
279 );
280
281 Ok(statement_token)
282 }
283
284 async fn do_statement_execute(
286 &self,
287 statement_token: &str,
288 params: Vec<PgValue>,
289 ) -> Result<u64, PreparedStatementExecError> {
290 let statements = self.prepared_statements.read().await;
291 let (query, types, source_id) = statements.get(statement_token).ok_or_else(|| {
292 PreparedStatementExecError::Unexpected(format!(
293 "missing prepared statement with statement ID [{statement_token}]"
294 ))
295 })?;
296
297 let pool = self.get_pool(source_id).await.map_err(|e| {
298 PreparedStatementExecError::Unexpected(format!(
299 "missing connection pool for token [{source_id}], statement ID [{statement_token}]: {e}"
300 ))
301 })?;
302 let client = pool.get().await.map_err(|e| {
303 PreparedStatementExecError::Unexpected(format!("failed to build client from pool: {e}"))
304 })?;
305
306 let statement = client
310 .statement_cache
311 .prepare_typed(&client, query, types)
312 .await
313 .map_err(|e| {
314 PreparedStatementExecError::Unexpected(format!(
315 "failed to prepare statement for client in pool: {e}"
316 ))
317 })?;
318
319 let rows_affected = client.execute_raw(&statement, params).await.map_err(|e| {
320 PreparedStatementExecError::Unexpected(format!(
321 "failed to execute prepared statement with token [{statement_token}]: {e}"
322 ))
323 })?;
324
325 Ok(rows_affected)
326 }
327}
328
329impl Provider for PostgresProvider {
330 #[instrument(level = "debug", skip_all, fields(source_id))]
335 async fn receive_link_config_as_target(
336 &self,
337 link_config @ LinkConfig { source_id, .. }: LinkConfig<'_>,
338 ) -> anyhow::Result<()> {
339 let Some(db_cfg) = extract_prefixed_conn_config("POSTGRES_", &link_config) else {
341 warn!(source_id, "no link-level DB configuration");
343 return Ok(());
344 };
345
346 let share_connections = if let Some(value) =
348 link_config.config.get(CONFIG_SHARE_CONNECTIONS_BY_URL_KEY)
349 {
350 matches!(value.to_lowercase().as_str(), "true" | "yes")
351 } else if let Some(secret) = link_config.secrets.get(CONFIG_SHARE_CONNECTIONS_BY_URL_KEY) {
352 if let Some(value) = secret.as_string() {
353 matches!(value.to_lowercase().as_str(), "true" | "yes")
354 } else {
355 false
356 }
357 } else {
358 false
359 };
360
361 if let Err(error) = self.ensure_pool(source_id, db_cfg, share_connections).await {
363 error!(?error, source_id, "failed to create connection");
364 };
365
366 Ok(())
367 }
368
369 #[instrument(level = "info", skip_all, fields(source_id = info.get_source_id()))]
373 async fn delete_link_as_target(&self, info: impl LinkDeleteInfo) -> anyhow::Result<()> {
374 let source_id = info.get_source_id();
375 let mut prepared_statements = self.prepared_statements.write().await;
376 prepared_statements.retain(|_stmt_token, (_query, _statement, src_id)| src_id != source_id);
377 drop(prepared_statements);
378 let mut connections = self.connections.write().await;
379 connections.remove(source_id);
380 drop(connections);
381 Ok(())
382 }
383
384 #[instrument(level = "debug", skip_all)]
386 async fn shutdown(&self) -> anyhow::Result<()> {
387 let mut prepared_statements = self.prepared_statements.write().await;
388 prepared_statements.drain();
389 let mut connections = self.connections.write().await;
390 connections.drain();
391 Ok(())
392 }
393}
394
395impl bindings::query::Handler<Option<Context>> for PostgresProvider {
397 #[instrument(level = "debug", skip_all, fields(query))]
398 async fn query(
399 &self,
400 ctx: Option<Context>,
401 query: String,
402 params: Vec<PgValue>,
403 ) -> Result<Result<Vec<ResultRow>, QueryError>> {
404 propagate_trace_for_ctx!(ctx);
405 let Some(Context {
406 component: Some(source_id),
407 ..
408 }) = ctx
409 else {
410 return Ok(Err(QueryError::Unexpected(
411 "unexpectedly missing source ID".into(),
412 )));
413 };
414
415 Ok(self.do_query(&source_id, &query, params).await)
416 }
417
418 #[instrument(level = "debug", skip_all, fields(query))]
419 async fn query_batch(
420 &self,
421 ctx: Option<Context>,
422 query: String,
423 ) -> Result<Result<(), QueryError>> {
424 propagate_trace_for_ctx!(ctx);
425 let Some(Context {
426 component: Some(source_id),
427 ..
428 }) = ctx
429 else {
430 return Ok(Err(QueryError::Unexpected(
431 "unexpectedly missing source ID".into(),
432 )));
433 };
434
435 Ok(self.do_query_batch(&source_id, &query).await)
436 }
437}
438
439impl bindings::prepared::Handler<Option<Context>> for PostgresProvider {
441 #[instrument(level = "debug", skip_all, fields(query))]
442 async fn prepare(
443 &self,
444 ctx: Option<Context>,
445 query: String,
446 ) -> Result<Result<PreparedStatementToken, StatementPrepareError>> {
447 propagate_trace_for_ctx!(ctx);
448 let Some(Context {
449 component: Some(source_id),
450 ..
451 }) = ctx
452 else {
453 return Ok(Err(StatementPrepareError::Unexpected(
454 "unexpectedly missing source ID".into(),
455 )));
456 };
457 Ok(self.do_statement_prepare(&source_id, &query).await)
458 }
459
460 #[instrument(level = "debug", skip_all, fields(statement_token))]
461 async fn exec(
462 &self,
463 ctx: Option<Context>,
464 statement_token: PreparedStatementToken,
465 params: Vec<PgValue>,
466 ) -> Result<Result<u64, PreparedStatementExecError>> {
467 propagate_trace_for_ctx!(ctx);
468 Ok(self.do_statement_execute(&statement_token, params).await)
469 }
470}
471
472fn create_tls_pool(
473 cfg: deadpool_postgres::Config,
474 runtime: Option<deadpool_postgres::Runtime>,
475) -> Result<Pool> {
476 let mut store = rustls::RootCertStore::empty();
477 store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
478 cfg.create_pool(
479 runtime,
480 tokio_postgres_rustls::MakeRustlsConnect::new(
481 rustls::ClientConfig::builder()
482 .with_root_certificates(store)
483 .with_no_client_auth(),
484 ),
485 )
486 .context("failed to create TLS-enabled connection pool")
487}