wasmcloud_provider_keyvalue_redis/
lib.rs

1//! Redis implementation for wrpc:keyvalue.
2//!
3//! This implementation is multi-threaded and operations between different actors
4//! use different connections and can run in parallel.
5//! A single connection is shared by all instances of the same component id (public key),
6//! so there may be some brief lock contention if several instances of the same component
7//! are simultaneously attempting to communicate with redis. See documentation
8//! on the [exec](#exec) function for more information.
9
10use core::num::NonZeroU64;
11
12use std::collections::{HashMap, HashSet};
13use std::sync::Arc;
14use std::time::Duration;
15
16use anyhow::{bail, Context as _};
17use bytes::Bytes;
18use redis::aio::{ConnectionManager, ConnectionManagerConfig};
19use redis::{Cmd, FromRedisValue};
20use sha2::{Digest as _, Sha256};
21use tokio::sync::RwLock;
22use tokio::task::JoinHandle;
23use tracing::{debug, error, info, instrument, warn};
24use unicase::UniCase;
25use wasmcloud_provider_sdk::core::secrets::SecretValue;
26use wasmcloud_provider_sdk::provider::WrpcClient;
27use wasmcloud_provider_sdk::{
28    get_connection, load_host_data, propagate_trace_for_ctx, run_provider, Context, HostData,
29    LinkConfig, LinkDeleteInfo, Provider,
30};
31use wasmcloud_provider_sdk::{initialize_observability, serve_provider_exports};
32
33mod bindings {
34    wit_bindgen_wrpc::generate!({
35        with: {
36            "wrpc:keyvalue/atomics@0.2.0-draft": generate,
37            "wrpc:keyvalue/batch@0.2.0-draft": generate,
38            "wrpc:keyvalue/store@0.2.0-draft": generate,
39            "wrpc:keyvalue/watcher@0.2.0-draft": generate,
40        }
41    });
42}
43use bindings::exports::wrpc::keyvalue;
44use wit_bindgen_wrpc::futures::StreamExt;
45
46/// Default URL to use to connect to Redis
47const DEFAULT_CONNECT_URL: &str = "redis://127.0.0.1:6379/";
48
49/// Configuration key that will be used to search for Redis config
50const CONFIG_REDIS_URL_KEY: &str = "URL";
51
52/// Key that configures a set number of retries
53const CONFIG_REDIS_BACKEND_RECONNECT_NUM_RETRIES_KEY: &str = "BACKEND_RECONNECT_NUM_RETRIES";
54
55/// Number of retries to perform when connecting to redis
56const DEFAULT_REDIS_BACKEND_RECONNECT_NUM_RETRIES: usize = 3;
57
58/// Key that configures the max amount of of time to wait between reconnection attempts
59const CONFIG_REDIS_BACKEND_RECONNECT_MAX_DELAY_MS_KEY: &str = "BACKEND_RECONNECT_MAX_DELAY_MS";
60
61/// Maximum amount of time (in milliseconds) to wait in between reconnection attempts
62const DEFAULT_REDIS_BACKEND_RECONNECT_MAX_DELAY_MS: u64 = 300;
63
64/// Key that configures the connection timeout amount of of time to wait between reconnection attempts
65const CONFIG_REDIS_BACKEND_CONNECTION_TIMEOUT_MS_KEY: &str = "BACKEND_CONNECTION_TIMEOUT_MS";
66
67/// Maximum amount of time (in milliseconds) to wait for a query to complete
68const DEFAULT_REDIS_BACKEND_CONNECTION_TIMEOUT_MS: u64 = 3000;
69
70/// Key that configures the connection timeout amount of of time to wait between reconnection attempts
71const CONFIG_REDIS_BACKEND_RESPONSE_TIMEOUT_MS_KEY: &str = "BACKEND_RESPONSE_TIMEOUT_MS";
72
73/// Maximum amount of time (in milliseconds) to wait in between reconnection attempts
74const DEFAULT_REDIS_BACKEND_RESPONSE_TIMEOUT_MS: u64 = 1000;
75
76/// Whether to disable default connection
77const CONFIG_DISABLE_DEFAULT_CONNECTION_KEY: &str = "DISABLE_DEFAULT_CONNECTION";
78
79/// Whether to share connections by URL
80///
81/// This option indicates that URLs with identical connection URLs will be shared/reused by
82/// components that are linked with the same URLs
83const CONFIG_SHARE_CONNECTIONS_BY_URL_KEY: &str = "SHARE_CONNECTIONS_BY_URL";
84
85type Result<T, E = keyvalue::store::Error> = core::result::Result<T, E>;
86
87/// The default connection available for the redis client
88///
89/// This enum can be in different states which normally correspond to whether
90/// the provider has started up (and the default connection has been created yet).
91#[derive(Clone)]
92pub enum DefaultConnection {
93    /// Pre-supplied/available client configuration from config
94    ClientConfig {
95        config: HashMap<String, String>,
96        secrets: Option<HashMap<String, SecretValue>>,
97    },
98    /// An already-initialized connection
99    Conn(ConnectionManager),
100}
101
102#[derive(Clone, PartialEq, Eq, Hash)]
103struct WatchedKeyInfo {
104    event_type: WatchEventType,
105    target: String,
106}
107
108#[derive(Clone, PartialEq, Eq, Hash)]
109enum WatchEventType {
110    Set,
111    Delete,
112}
113/// Represents a unique identifier for a link (target_id, link_name)
114#[derive(Eq, Hash, PartialEq)]
115struct LinkId {
116    pub target_id: String,
117    pub link_name: String,
118}
119
120/// Type for storing watch tasks associated with links
121type WatchTaskMap = HashMap<LinkId, JoinHandle<()>>;
122
123/// Shared connection keys are keys that identify shared connections
124///
125/// Normally, this would be the URL of a connection, hashed.
126///
127/// Note this key should *not* be the URL of the connection directly,
128/// to avoid printing it inadvertently.
129type SharedConnectionKey = String;
130
131/// URL of a redis connection
132#[derive(Clone)]
133enum RedisConnection {
134    /// Direct connection
135    Direct(ConnectionManager),
136    /// Shared connection, identified by the hash of the connection URL
137    Shared(String),
138}
139
140/// Redis `wrpc:keyvalue` provider implementation.
141#[derive(Clone)]
142pub struct KvRedisProvider {
143    /// Store redis connections per source ID & link name
144    sources: Arc<RwLock<HashMap<(String, String), RedisConnection>>>,
145
146    /// Redis connections indexed by URL
147    shared_connections: Arc<RwLock<HashMap<SharedConnectionKey, ConnectionManager>>>,
148
149    /// Default connection, which may be uninitialized
150    default_connection: Option<Arc<RwLock<DefaultConnection>>>,
151    /// Stores information about watched keys for keyspace notifications
152    /// The outer HashMap uses the key as its key, and the HashSet contains
153    /// WatchedKeyInfo structs for each watcher of that key, allowing multiple
154    /// components to watch the same key for different event types.
155    watched_keys: Arc<RwLock<HashMap<String, HashSet<WatchedKeyInfo>>>>,
156    /// Stores background tasks that handle keyspace notifications for each link
157    watch_tasks: Arc<RwLock<WatchTaskMap>>,
158}
159
160pub async fn run() -> anyhow::Result<()> {
161    KvRedisProvider::run().await
162}
163
164impl KvRedisProvider {
165    pub fn name() -> &'static str {
166        "keyvalue-redis-provider"
167    }
168
169    pub async fn run() -> anyhow::Result<()> {
170        let host_data = load_host_data().context("failed to load host data")?;
171        let flamegraph_path = host_data
172            .config
173            .get("FLAMEGRAPH_PATH")
174            .map(String::from)
175            .or_else(|| std::env::var("PROVIDER_KEYVALUE_REDIS_FLAMEGRAPH_PATH").ok());
176        initialize_observability!(Self::name(), flamegraph_path);
177        let provider = KvRedisProvider::from_host_data(host_data);
178        let shutdown = run_provider(provider.clone(), KvRedisProvider::name())
179            .await
180            .context("failed to run provider")?;
181        let connection = get_connection();
182        let wrpc = connection
183            .get_wrpc_client(connection.provider_key())
184            .await?;
185        serve_provider_exports(&wrpc, provider, shutdown, bindings::serve)
186            .await
187            .context("failed to serve provider exports")
188    }
189
190    #[must_use]
191    pub fn from_config(config: HashMap<String, String>) -> Self {
192        let default_connection_disabled = config
193            .keys()
194            .any(|k| k.eq_ignore_ascii_case(CONFIG_DISABLE_DEFAULT_CONNECTION_KEY));
195
196        KvRedisProvider {
197            sources: Arc::default(),
198            default_connection: if default_connection_disabled {
199                None
200            } else {
201                Some(Arc::new(RwLock::new(DefaultConnection::ClientConfig {
202                    config,
203                    secrets: None,
204                })))
205            },
206            shared_connections: Arc::new(RwLock::new(HashMap::new())),
207            watched_keys: Arc::new(RwLock::new(HashMap::new())),
208            watch_tasks: Arc::new(RwLock::new(HashMap::new())),
209        }
210    }
211
212    #[must_use]
213    pub fn from_host_data(host_data: &HostData) -> Self {
214        let default_connection_disabled = host_data
215            .config
216            .keys()
217            .any(|k| k.eq_ignore_ascii_case(CONFIG_DISABLE_DEFAULT_CONNECTION_KEY));
218
219        KvRedisProvider {
220            sources: Arc::default(),
221            default_connection: if default_connection_disabled {
222                None
223            } else {
224                Some(Arc::new(RwLock::new(DefaultConnection::ClientConfig {
225                    config: host_data.config.clone(),
226                    secrets: Some(host_data.secrets.clone()),
227                })))
228            },
229            shared_connections: Arc::new(RwLock::new(HashMap::new())),
230            watched_keys: Arc::new(RwLock::new(HashMap::new())),
231            watch_tasks: Arc::new(RwLock::new(HashMap::new())),
232        }
233    }
234
235    #[instrument(level = "trace", skip_all)]
236    async fn get_default_connection(&self) -> anyhow::Result<ConnectionManager> {
237        let Some(ref default_connection) = self.default_connection else {
238            bail!("default connection is disabled via config, please provide valid configuration");
239        };
240
241        // NOTE: The read lock is only held for the duration of the `if let` block so we can acquire
242        // the write lock to update the default connection if needed.
243        if let DefaultConnection::Conn(conn) = &*default_connection.read().await {
244            return Ok(conn.clone());
245        }
246
247        // Build the default connection
248        let mut default_conn = default_connection.write().await;
249        match &mut *default_conn {
250            DefaultConnection::Conn(conn) => Ok(conn.clone()),
251            DefaultConnection::ClientConfig { config, secrets } => {
252                let conn = redis::Client::open(retrieve_default_url(config, secrets))
253                    .context("failed to construct default Redis client")?
254                    .get_connection_manager()
255                    .await
256                    .context("failed to construct Redis connection manager")?;
257                *default_conn = DefaultConnection::Conn(conn.clone());
258                Ok(conn)
259            }
260        }
261    }
262
263    #[instrument(level = "debug", skip(self))]
264    async fn invocation_conn(&self, context: Option<Context>) -> anyhow::Result<ConnectionManager> {
265        let ctx = context.context("unexpectedly missing context")?;
266
267        let Some(ref source_id) = ctx.component else {
268            return self.get_default_connection().await.map_err(|err| {
269                error!(error = ?err, "failed to get default connection for invocation");
270                err
271            });
272        };
273
274        let sources = self.sources.read().await;
275        let Some(conn) = sources.get(&(source_id.into(), ctx.link_name().into())) else {
276            error!(source_id, "no Redis connection found for component");
277            bail!("No Redis connection found for component [{source_id}]. Please ensure the URL supplied in the link definition is a valid Redis URL")
278        };
279
280        // Resolve the connection as a direct or shared one
281        match conn {
282            RedisConnection::Direct(c) => Ok(c.clone()),
283            RedisConnection::Shared(key) => {
284                let shared = self.shared_connections.read().await;
285                match shared.get(key) {
286                    Some(c) => Ok(c.clone()),
287                    None => {
288                        error!(key, "no shared Redis connection found with given key");
289                        bail!("No shared Redis connection found with key [{key}]");
290                    }
291                }
292            }
293        }
294    }
295
296    /// Execute Redis async command
297    #[instrument(level = "debug", skip(self, context, cmd))]
298    async fn exec_cmd<T: FromRedisValue>(
299        &self,
300        context: Option<Context>,
301        cmd: &mut Cmd,
302    ) -> Result<T, keyvalue::store::Error> {
303        let mut conn = self
304            .invocation_conn(context)
305            .await
306            .map_err(|err| keyvalue::store::Error::Other(format!("{err:#}")))?;
307        match cmd.query_async(&mut conn).await {
308            Ok(v) => Ok(v),
309            Err(e) => {
310                error!("failed to execute Redis command: {e}");
311                Err(keyvalue::store::Error::Other(format!(
312                    "failed to execute Redis command: {e}"
313                )))
314            }
315        }
316    }
317}
318#[instrument(level = "info", skip(wrpc))]
319async fn invoke_on_set(wrpc: &WrpcClient, bucket: &str, key: &str, value: &Bytes) {
320    let mut cx: async_nats::HeaderMap = async_nats::HeaderMap::new();
321    for (k, v) in
322        wasmcloud_provider_sdk::wasmcloud_tracing::context::TraceContextInjector::default_with_span(
323        )
324        .iter()
325    {
326        cx.insert(k.as_str(), v.as_str())
327    }
328    match bindings::wrpc::keyvalue::watcher::on_set(wrpc, Some(cx), bucket, key, value).await {
329        Ok(_) => {
330            debug!("successfully invoked on_set");
331        }
332        Err(err) => {
333            error!(?err, "failed to invoke on_set");
334        }
335    }
336    debug!("key set");
337}
338#[instrument(level = "info", skip(wrpc))]
339async fn invoke_on_delete(wrpc: &WrpcClient, bucket: &str, key: &str) {
340    let mut cx: async_nats::HeaderMap = async_nats::HeaderMap::new();
341    for (k, v) in
342        wasmcloud_provider_sdk::wasmcloud_tracing::context::TraceContextInjector::default_with_span(
343        )
344        .iter()
345    {
346        cx.insert(k.as_str(), v.as_str())
347    }
348    match bindings::wrpc::keyvalue::watcher::on_delete(wrpc, Some(cx), bucket, key).await {
349        Ok(_) => {
350            debug!("successfully invoked on_delete");
351        }
352        Err(err) => {
353            error!(?err, "failed to invoke on_delete");
354        }
355    }
356    debug!("key deleted");
357}
358
359impl keyvalue::store::Handler<Option<Context>> for KvRedisProvider {
360    #[instrument(level = "debug", skip(self))]
361    async fn delete(
362        &self,
363        context: Option<Context>,
364        bucket: String,
365        key: String,
366    ) -> anyhow::Result<Result<()>> {
367        propagate_trace_for_ctx!(context);
368        check_bucket_name(&bucket);
369        Ok(self.exec_cmd(context, &mut Cmd::del(key)).await)
370    }
371
372    #[instrument(level = "debug", skip(self))]
373    async fn exists(
374        &self,
375        context: Option<Context>,
376        bucket: String,
377        key: String,
378    ) -> anyhow::Result<Result<bool>> {
379        propagate_trace_for_ctx!(context);
380        check_bucket_name(&bucket);
381        Ok(self.exec_cmd(context, &mut Cmd::exists(key)).await)
382    }
383
384    #[instrument(level = "debug", skip(self))]
385    async fn get(
386        &self,
387        context: Option<Context>,
388        bucket: String,
389        key: String,
390    ) -> anyhow::Result<Result<Option<Bytes>>> {
391        propagate_trace_for_ctx!(context);
392        check_bucket_name(&bucket);
393        match self
394            .exec_cmd::<redis::Value>(context, &mut Cmd::get(key))
395            .await
396        {
397            Ok(redis::Value::Nil) => Ok(Ok(None)),
398            Ok(redis::Value::BulkString(buf)) => Ok(Ok(Some(buf.into()))),
399            Ok(_) => Ok(Err(keyvalue::store::Error::Other(
400                "invalid data type returned by Redis".into(),
401            ))),
402            Err(err) => Ok(Err(err)),
403        }
404    }
405
406    #[instrument(level = "debug", skip(self))]
407    async fn set(
408        &self,
409        context: Option<Context>,
410        bucket: String,
411        key: String,
412        value: Bytes,
413    ) -> anyhow::Result<Result<()>> {
414        propagate_trace_for_ctx!(context);
415        check_bucket_name(&bucket);
416        Ok(self
417            .exec_cmd(context, &mut Cmd::set(key, value.to_vec()))
418            .await)
419    }
420
421    #[instrument(level = "debug", skip(self))]
422    async fn list_keys(
423        &self,
424        context: Option<Context>,
425        bucket: String,
426        cursor: Option<u64>,
427    ) -> anyhow::Result<Result<keyvalue::store::KeyResponse>> {
428        propagate_trace_for_ctx!(context);
429        check_bucket_name(&bucket);
430        match self
431            .exec_cmd(
432                context,
433                redis::cmd("SCAN").cursor_arg(cursor.unwrap_or_default()),
434            )
435            .await
436        {
437            Ok((cursor, keys)) => Ok(Ok(keyvalue::store::KeyResponse {
438                keys,
439                cursor: NonZeroU64::new(cursor).map(Into::into),
440            })),
441            Err(err) => Ok(Err(err)),
442        }
443    }
444}
445
446impl keyvalue::atomics::Handler<Option<Context>> for KvRedisProvider {
447    /// Increments a numeric value, returning the new value
448    #[instrument(level = "debug", skip(self))]
449    async fn increment(
450        &self,
451        context: Option<Context>,
452        bucket: String,
453        key: String,
454        delta: u64,
455    ) -> anyhow::Result<Result<u64, keyvalue::store::Error>> {
456        propagate_trace_for_ctx!(context);
457        check_bucket_name(&bucket);
458        Ok(self
459            .exec_cmd::<u64>(context, &mut Cmd::incr(key, delta))
460            .await)
461    }
462}
463
464impl keyvalue::batch::Handler<Option<Context>> for KvRedisProvider {
465    async fn get_many(
466        &self,
467        ctx: Option<Context>,
468        bucket: String,
469        keys: Vec<String>,
470    ) -> anyhow::Result<Result<Vec<Option<(String, Bytes)>>>> {
471        check_bucket_name(&bucket);
472        let data = match self
473            .exec_cmd::<Vec<Option<Bytes>>>(ctx, &mut Cmd::mget(&keys))
474            .await
475        {
476            Ok(v) => v
477                .into_iter()
478                .zip(keys.into_iter())
479                .map(|(val, key)| val.map(|b| (key, b)))
480                .collect::<Vec<_>>(),
481            Err(err) => {
482                return Ok(Err(err));
483            }
484        };
485        Ok(Ok(data))
486    }
487
488    async fn set_many(
489        &self,
490        ctx: Option<Context>,
491        bucket: String,
492        items: Vec<(String, Bytes)>,
493    ) -> anyhow::Result<Result<()>> {
494        check_bucket_name(&bucket);
495        let items = items
496            .into_iter()
497            .map(|(name, buf)| (name, buf.to_vec()))
498            .collect::<Vec<_>>();
499        Ok(self.exec_cmd(ctx, &mut Cmd::mset(&items)).await)
500    }
501
502    async fn delete_many(
503        &self,
504        ctx: Option<Context>,
505        bucket: String,
506        keys: Vec<String>,
507    ) -> anyhow::Result<Result<()>> {
508        check_bucket_name(&bucket);
509        Ok(self.exec_cmd(ctx, &mut Cmd::del(keys)).await)
510    }
511}
512
513/// Handle provider control commands
514impl Provider for KvRedisProvider {
515    /// Provider should perform any operations needed for a new link,
516    /// including setting up per-component resources, and checking authorization.
517    /// If the link is allowed, return true, otherwise return false to deny the link.
518    #[instrument(level = "debug", skip(self, config))]
519    async fn receive_link_config_as_target(
520        &self,
521        LinkConfig {
522            source_id,
523            config,
524            secrets,
525            link_name,
526            ..
527        }: LinkConfig<'_>,
528    ) -> anyhow::Result<()> {
529        let url = secrets
530            .keys()
531            .find(|k| k.eq_ignore_ascii_case(CONFIG_REDIS_URL_KEY))
532            .and_then(|url_key| config.get(url_key))
533            .or_else(|| {
534                warn!("redis connection URLs can be sensitive. Please consider using secrets to pass this value");
535                config
536                    .keys()
537                    .find(|k| k.eq_ignore_ascii_case(CONFIG_REDIS_URL_KEY))
538                    .and_then(|url_key| config.get(url_key))
539            });
540
541        let default_connection_disabled = secrets
542            .keys()
543            .any(|k| k.eq_ignore_ascii_case(CONFIG_DISABLE_DEFAULT_CONNECTION_KEY))
544            || config
545                .keys()
546                .any(|k| k.eq_ignore_ascii_case(CONFIG_DISABLE_DEFAULT_CONNECTION_KEY));
547
548        let share_connections_by_url = secrets
549            .keys()
550            .any(|k| k.eq_ignore_ascii_case(CONFIG_SHARE_CONNECTIONS_BY_URL_KEY))
551            || config
552                .keys()
553                .any(|k| k.eq_ignore_ascii_case(CONFIG_SHARE_CONNECTIONS_BY_URL_KEY));
554
555        let key = (source_id.to_string(), link_name.to_string());
556
557        // If the shared connection is already present with the given URL (hashed)
558        // make the association and exit early.
559        {
560            if let (Some(url), true) = (url, share_connections_by_url) {
561                let shared_connections = self.shared_connections.read().await;
562                let shared_key = format!("{:X}", Sha256::digest(url));
563                if shared_connections.contains_key(&shared_key) {
564                    // SAFETY: shared_connections should always be locked first
565                    let mut sources = self.sources.write().await;
566                    sources.insert(key, RedisConnection::Shared(shared_key));
567                    return Ok(());
568                }
569            }
570        }
571
572        // Create initial configuration for the connection that is intended to fail fast
573        let cfg = build_connection_mgr_config(config);
574        let conn = if let Some(url) = url {
575            match redis::Client::open(url.to_string()) {
576                Ok(client) => match ConnectionManager::new_with_config(client, cfg).await {
577                    Ok(conn) => {
578                        info!(url, "established link");
579                        conn
580                    }
581                    Err(err) => {
582                        warn!(
583                            url,
584                            ?err,
585                        "Could not create Redis connection manager for source [{source_id}], keyvalue operations will fail",
586                    );
587                        bail!("failed to create redis connection manager");
588                    }
589                },
590                Err(err) => {
591                    warn!(
592                        ?err,
593                        "Could not create Redis client for source [{source_id}], keyvalue operations will fail",
594                    );
595                    bail!("failed to create redis client");
596                }
597            }
598        } else {
599            // Disallow default connections if disabled via link config
600            if default_connection_disabled {
601                error!(
602                    component = source_id,
603                    "using the default connection is disabled via link configuration"
604                );
605                bail!(
606                    "using the default connection is disabled via link configuration for component [{source_id}]"
607                );
608            }
609
610            self.get_default_connection().await.map_err(|err| {
611                error!(error = ?err, "failed to get default connection for link");
612                err
613            })?
614        };
615
616        match (url, share_connections_by_url) {
617            // If there was a URL (non-default connection) and connections should be shared by URL,
618            // update both shared connections and sources
619            (Some(url), true) => {
620                let shared_key = format!("{:X}", Sha256::digest(url));
621
622                // SAFETY: shared_connections should always be locked first
623                let mut shared_connections = self.shared_connections.write().await;
624                shared_connections.insert(shared_key.clone(), conn);
625                drop(shared_connections);
626
627                let mut sources = self.sources.write().await;
628                sources.insert(key, RedisConnection::Shared(shared_key));
629                drop(sources);
630            }
631            // In the case of a default connection in use (implicitly shared) or if share connections is turned off,
632            // save the direct connection.
633            _ => {
634                let mut sources = self.sources.write().await;
635                sources.insert(key, RedisConnection::Direct(conn));
636                drop(sources);
637            }
638        }
639
640        Ok(())
641    }
642
643    async fn receive_link_config_as_source(
644        &self,
645        LinkConfig {
646            target_id,
647            config,
648            secrets,
649            link_name,
650            wit_metadata: (_, _, interfaces),
651            ..
652        }: LinkConfig<'_>,
653    ) -> anyhow::Result<()> {
654        let url = secrets
655            .keys()
656            .find(|k| k.eq_ignore_ascii_case(CONFIG_REDIS_URL_KEY))
657            .and_then(|url_key| config.get(url_key))
658            .or_else(|| {
659                warn!("Redis connection URLs can be sensitive. Consider using secrets to pass this value.");
660                config.keys()
661                    .find(|k| k.eq_ignore_ascii_case(CONFIG_REDIS_URL_KEY))
662                    .and_then(|url_key| config.get(url_key))
663            })
664            .map_or(DEFAULT_CONNECT_URL, |v| v);
665
666        let client = match redis::Client::open(url.to_string()) {
667            Ok(client) => {
668                info!(url, "Established link at receive_link_config_as_source");
669                client
670            }
671            Err(err) => {
672                warn!(target_id = %target_id, err = ?err, "Failed to create Redis client");
673                bail!("Failed to create Redis client");
674            }
675        };
676        let mut conn = client.get_connection_manager().await.map_err(|e| {
677            error!(err = ?e, "Failed to get async connection");
678            anyhow::anyhow!("Failed to get async connection: {}", e)
679        })?;
680
681        let component_id: Arc<str> = target_id.into();
682        let wrpc = get_connection()
683            .get_wrpc_client(&component_id)
684            .await
685            .context("failed to construct wRPC client")?;
686        if interfaces.contains(&"watcher".to_string()) {
687            let config_response: Vec<String> = redis::cmd("CONFIG")
688                .arg("GET")
689                .arg("notify-keyspace-events")
690                .query_async(&mut conn)
691                .await
692                .map_err(|e| {
693                    error!(err = %e, "Failed to get keyspace notifications config");
694                    anyhow::anyhow!("Failed to get keyspace notifications config: {}", e)
695                })?;
696
697            let current_config = config_response.get(1).ok_or_else(|| {
698                error!("Unexpected response format from Redis CONFIG GET");
699                anyhow::anyhow!("Unexpected response format from Redis CONFIG GET")
700            })?;
701
702            if !current_config.contains('K')
703                || !current_config.contains('$')
704                || !current_config.contains('g')
705            {
706                error!(
707                    current_config = %current_config,
708                    "Redis keyspace-notifications not properly configured"
709                );
710                return Err(anyhow::anyhow!(
711                    "Redis keyspace-notifications not properly configured! \
712                        Expected 'K$g' in settings, but got '{}'. \
713                        Please run: CONFIG SET notify-keyspace-events K$g",
714                    current_config
715                ));
716            }
717
718            let wrpc = Arc::new(wrpc);
719            let wrpc_for_task = wrpc.clone();
720
721            let config_watch_entries = parse_watch_config(config, target_id);
722
723            // Update watched keys
724            let mut watched_keys = self.watched_keys.write().await;
725            for (key, key_info_set) in config_watch_entries {
726                watched_keys
727                    .entry(key)
728                    .or_insert_with(HashSet::new)
729                    .extend(key_info_set);
730            }
731
732            let client_clone = client.clone();
733            let self_clone = self.clone();
734            let mut conn_clone = conn.clone();
735            let task = tokio::spawn(async move {
736                let mut pubsub = match client_clone.get_async_pubsub().await {
737                    Ok(pubsub) => pubsub,
738                    Err(e) => {
739                        error!(err = %e, "Failed to get pubsub connection");
740                        return;
741                    }
742                };
743                let watched_keys = self_clone.watched_keys.read().await;
744                for key in watched_keys.keys() {
745                    let channel = format!("__keyspace@0__:{key}");
746                    let _ = pubsub
747                        .psubscribe(&channel)
748                        .await
749                        .context("Failed to subscribe to SET/DEL events for key");
750                }
751                let stream = pubsub.on_message();
752                tokio::pin!(stream);
753                while let Some(msg) = stream.next().await {
754                    let channel: String = msg.get_channel_name().to_string();
755                    let event: String = match msg.get_payload() {
756                        Ok(event) => event,
757                        Err(e) => {
758                            error!(err = %e, "Failed to get payload");
759                            continue;
760                        }
761                    };
762                    // The Channel is in the format __keyspace@0__:key
763                    // While the payload is the event (ie set | del)
764                    let mkey = match channel.split(':').next_back() {
765                        Some(key) => key,
766                        None => {
767                            error!(channel = %channel, "Malformed Redis channel name: expected '__keyspace@0__:key' format");
768                            continue;
769                        }
770                    };
771                    // Check if the key is being watched by any component
772                    let watched_keys = self_clone.watched_keys.read().await;
773                    if let Some(key_info_set) = watched_keys.get(mkey) {
774                        if event == "set" || event == "SET" {
775                            // Perform a GET operation to retrieve the current value of the key since redis doesn't have a
776                            // native way to get the value of the key from the notification
777                            let value: wit_bindgen_wrpc::bytes::Bytes = match redis::cmd("GET")
778                                .arg(mkey)
779                                .query_async::<Option<Vec<u8>>>(&mut conn_clone)
780                                .await
781                            {
782                                Ok(Some(v)) => v.into(),
783                                Ok(None) => {
784                                    debug!(key = %mkey, "Key not found or was deleted");
785                                    continue;
786                                }
787                                Err(e) => {
788                                    error!(key = %mkey, err = %e, "Failed to get value for key");
789                                    continue;
790                                }
791                            };
792                            for key_info in key_info_set {
793                                if key_info.event_type == WatchEventType::Set {
794                                    invoke_on_set(&wrpc_for_task, "0", mkey, &value).await;
795                                }
796                            }
797                        } else if event == "del" || event == "DEL" {
798                            for key_info in key_info_set {
799                                if key_info.event_type == WatchEventType::Delete {
800                                    invoke_on_delete(&wrpc_for_task, "0", mkey).await;
801                                }
802                            }
803                        }
804                    }
805                }
806            });
807            let mut tasks = self.watch_tasks.write().await;
808            tasks.insert(
809                LinkId {
810                    target_id: target_id.to_string(),
811                    link_name: link_name.to_string(),
812                },
813                task,
814            );
815        }
816
817        let mut sources = self.sources.write().await;
818        sources.insert(
819            (target_id.to_string(), link_name.to_string()),
820            RedisConnection::Direct(conn),
821        );
822
823        Ok(())
824    }
825
826    /// Handle notification that a link is dropped - close the connection
827    #[instrument(level = "info", skip_all, fields(source_id = info.get_source_id()))]
828    async fn delete_link_as_target(&self, info: impl LinkDeleteInfo) -> anyhow::Result<()> {
829        let component_id = info.get_source_id();
830        let mut aw = self.sources.write().await;
831        // NOTE: ideally we should *not* get rid of all links for a given source here,
832        // but delete_link actually does not tell us enough about the link to know whether
833        // we're dealing with one link or the other.
834        aw.retain(|(src_id, _link_name), _| src_id != component_id);
835        debug!(component_id, "closing all redis connections for component");
836        Ok(())
837    }
838
839    #[instrument(level = "info", skip_all, fields(target_id = info.get_target_id()))]
840    async fn delete_link_as_source(&self, info: impl LinkDeleteInfo) -> anyhow::Result<()> {
841        let component_id = info.get_target_id();
842        let link_name = info.get_link_name();
843
844        let mut sources = self.sources.write().await;
845        sources.remove(&(component_id.to_string(), link_name.to_string()));
846
847        let mut watch_tasks = self.watch_tasks.write().await;
848
849        // If there's a watch task for this link, abort it and remove from map
850        if let Some(task) = watch_tasks.remove(&LinkId {
851            target_id: component_id.to_string(),
852            link_name: link_name.to_string(),
853        }) {
854            task.abort();
855            let _ = task.await;
856        }
857
858        // Clean up watched keys for this target
859        let mut watched_keys = self.watched_keys.write().await;
860        for key_watchers in watched_keys.values_mut() {
861            key_watchers.retain(|key_info| key_info.target != component_id);
862        }
863
864        // Remove any empty watch sets
865        watched_keys.retain(|_, watchers| !watchers.is_empty());
866
867        debug!(
868            component_id,
869            link_name, "cleaned up redis connection and watch tasks for link"
870        );
871        Ok(())
872    }
873
874    /// Handle shutdown request by closing all connections
875    async fn shutdown(&self) -> anyhow::Result<()> {
876        info!("shutting down");
877        let mut aw = self.sources.write().await;
878        // empty the component link data and stop all servers
879        for (_, conn) in aw.drain() {
880            drop(conn);
881        }
882        Ok(())
883    }
884}
885
886/// Fetch the default URL to use for connecting to Redis from the configuration, defaulting
887/// to `DEFAULT_CONNECT_URL` if no URL is found in the configuration.
888fn retrieve_default_url(
889    config: &HashMap<String, String>,
890    secrets: &Option<HashMap<String, SecretValue>>,
891) -> String {
892    // Use connect URL provided by secrets first, if present
893    if let Some(secrets) = secrets {
894        if let Some(url) = secrets
895            .keys()
896            .find(|sk| sk.eq_ignore_ascii_case(CONFIG_REDIS_URL_KEY))
897            .and_then(|k| secrets.get(k))
898        {
899            if let Some(s) = url.as_string() {
900                debug!(
901                    url = ?url, // NOTE: this is the SecretValue redacted output
902                    "using Redis URL from secrets"
903                );
904                return s.into();
905            } else {
906                warn!("invalid secret value for URL (expected string, found bytes). Falling back to config");
907            }
908        }
909    }
910
911    // To aid in user experience, find the URL key in the config that matches "URL" in a case-insensitive manner
912    let config_supplied_url = config
913        .keys()
914        .find(|k| k.eq_ignore_ascii_case(CONFIG_REDIS_URL_KEY))
915        .and_then(|url_key| config.get(url_key));
916
917    if let Some(url) = config_supplied_url {
918        debug!(url, "using Redis URL from config");
919        url.to_string()
920    } else {
921        debug!(DEFAULT_CONNECT_URL, "using default Redis URL");
922        DEFAULT_CONNECT_URL.to_string()
923    }
924}
925/// Parse watch configuration from the link configuration and return watch entries
926///
927/// Watch configuration is expected in the format "SET@key,DEL@key" where:
928/// - SET: Watch for set operations on the specified key
929/// - DEL: Watch for delete operations on the specified key
930///
931/// Returns a map of keys to sets of WatchedKeyInfo indicating which operations to watch for each key
932#[instrument(level = "debug", skip(config))]
933fn parse_watch_config(
934    config: &HashMap<String, String>,
935    target_id: &str,
936) -> HashMap<String, HashSet<WatchedKeyInfo>> {
937    let mut watched_keys = HashMap::new();
938
939    // Convert config keys to case-insensitive map
940    let config_map: HashMap<UniCase<&str>, &String> = config
941        .iter()
942        .map(|(k, v)| (UniCase::new(k.as_str()), v))
943        .collect();
944
945    // Look for watch configuration in the format "watch: SET@key,DEL@key"
946    if let Some(watch_config) = config_map.get(&UniCase::new("watch")) {
947        for watch_entry in watch_config.split(',') {
948            let watch_entry = watch_entry.trim();
949            if watch_entry.is_empty() {
950                continue;
951            }
952
953            let parts: Vec<&str> = watch_entry.split('@').collect();
954            if parts.len() != 2 {
955                error!(watch_entry = %watch_entry, "Invalid watch entry format. Expected FORMAT@KEY");
956                continue;
957            }
958
959            let operation = parts[0].trim().to_uppercase();
960            let key_value = parts[1].trim();
961
962            if key_value.contains(':') {
963                error!(key = %key_value, "Invalid SET watch format. SET expects only KEY");
964                continue;
965            }
966            if key_value.is_empty() {
967                error!(watch_entry = %watch_entry, "Invalid watch entry: Missing key.");
968                continue;
969            }
970
971            match operation.as_str() {
972                "SET" => {
973                    watched_keys
974                        .entry(key_value.to_string())
975                        .or_insert_with(HashSet::new)
976                        .insert(WatchedKeyInfo {
977                            event_type: WatchEventType::Set,
978                            target: target_id.to_string(),
979                        });
980                }
981                "DEL" => {
982                    watched_keys
983                        .entry(key_value.to_string())
984                        .or_insert_with(HashSet::new)
985                        .insert(WatchedKeyInfo {
986                            event_type: WatchEventType::Delete,
987                            target: target_id.to_string(),
988                        });
989                }
990                _ => {
991                    error!(operation = %operation, "Unsupported watch operation. Expected SET or DEL");
992                }
993            }
994        }
995    }
996
997    watched_keys
998}
999
1000/// Check for unsupported bucket names,
1001/// primarily warning on non-empty bucket names, since this provider does not yet properly support named buckets
1002fn check_bucket_name(bucket: &str) {
1003    if !bucket.is_empty() {
1004        warn!(bucket, "non-empty bucket names are not yet supported; ignoring non-empty bucket name (using a non-empty bucket name may become an error in the future).")
1005    }
1006}
1007
1008/// Build configuration for a backend redis connection from existing config
1009fn build_connection_mgr_config(config: &HashMap<String, String>) -> ConnectionManagerConfig {
1010    let mut cfg = ConnectionManagerConfig::new();
1011
1012    // Set default values for the connection manager configuration
1013    cfg = cfg
1014        .set_number_of_retries(DEFAULT_REDIS_BACKEND_RECONNECT_NUM_RETRIES)
1015        .set_max_delay(DEFAULT_REDIS_BACKEND_RECONNECT_MAX_DELAY_MS)
1016        .set_connection_timeout(Duration::from_millis(
1017            DEFAULT_REDIS_BACKEND_CONNECTION_TIMEOUT_MS,
1018        ))
1019        .set_response_timeout(Duration::from_millis(
1020            DEFAULT_REDIS_BACKEND_RESPONSE_TIMEOUT_MS,
1021        ));
1022
1023    // Override defaults with values from the config if they are present
1024    for (k, v) in config.iter() {
1025        if k.eq_ignore_ascii_case(CONFIG_REDIS_BACKEND_RECONNECT_NUM_RETRIES_KEY) {
1026            if let Ok(val) = v.parse::<usize>() {
1027                cfg = cfg.set_number_of_retries(val);
1028            } else {
1029                warn!(
1030                    key = %CONFIG_REDIS_BACKEND_RECONNECT_NUM_RETRIES_KEY,
1031                    value = %v,
1032                    "Invalid value for number of retries, using default"
1033                );
1034            }
1035        }
1036
1037        if let Some(max_delay) = if k
1038            .eq_ignore_ascii_case(CONFIG_REDIS_BACKEND_RECONNECT_MAX_DELAY_MS_KEY)
1039        {
1040            match v.parse() {
1041                Ok(val) => Some(val),
1042                Err(_) => {
1043                    warn!(key = %CONFIG_REDIS_BACKEND_RECONNECT_MAX_DELAY_MS_KEY, value = %v, "Invalid value for max delay, using default");
1044                    Some(DEFAULT_REDIS_BACKEND_RECONNECT_MAX_DELAY_MS)
1045                }
1046            }
1047        } else {
1048            None
1049        } {
1050            cfg = cfg.set_max_delay(max_delay);
1051        }
1052
1053        if let Some(timeout) = if k
1054            .eq_ignore_ascii_case(CONFIG_REDIS_BACKEND_CONNECTION_TIMEOUT_MS_KEY)
1055        {
1056            match v.parse() {
1057                Ok(val) => Some(val),
1058                Err(_) => {
1059                    warn!(key = %CONFIG_REDIS_BACKEND_CONNECTION_TIMEOUT_MS_KEY,value = %v,"Invalid value for connection timeout, using default");
1060                    Some(DEFAULT_REDIS_BACKEND_CONNECTION_TIMEOUT_MS)
1061                }
1062            }
1063        } else {
1064            None
1065        } {
1066            cfg = cfg.set_connection_timeout(Duration::from_millis(timeout));
1067        }
1068
1069        if let Some(timeout) = if k
1070            .eq_ignore_ascii_case(CONFIG_REDIS_BACKEND_RESPONSE_TIMEOUT_MS_KEY)
1071        {
1072            match v.parse() {
1073                Ok(val) => Some(val),
1074                Err(_) => {
1075                    warn!(key = %CONFIG_REDIS_BACKEND_RESPONSE_TIMEOUT_MS_KEY,value = %v,"Invalid value for response timeout, using default");
1076                    Some(DEFAULT_REDIS_BACKEND_RESPONSE_TIMEOUT_MS)
1077                }
1078            }
1079        } else {
1080            None
1081        } {
1082            cfg = cfg.set_response_timeout(Duration::from_millis(timeout));
1083        }
1084    }
1085
1086    cfg
1087}
1088
1089#[cfg(test)]
1090mod test {
1091    use super::*;
1092    use std::collections::HashMap;
1093
1094    use crate::retrieve_default_url;
1095
1096    const PROPER_URL: &str = "redis://127.0.0.1:6379";
1097
1098    #[test]
1099    fn can_deserialize_config_case_insensitive() {
1100        let lowercase_config = HashMap::from_iter([("url".to_string(), PROPER_URL.to_string())]);
1101        let uppercase_config = HashMap::from_iter([("URL".to_string(), PROPER_URL.to_string())]);
1102        let initial_caps_config = HashMap::from_iter([("Url".to_string(), PROPER_URL.to_string())]);
1103
1104        assert_eq!(PROPER_URL, retrieve_default_url(&lowercase_config, &None));
1105        assert_eq!(PROPER_URL, retrieve_default_url(&uppercase_config, &None));
1106        assert_eq!(
1107            PROPER_URL,
1108            retrieve_default_url(&initial_caps_config, &None)
1109        );
1110    }
1111
1112    #[test]
1113    fn test_parse_watch_config_valid_entries() {
1114        let mut config = HashMap::new();
1115        config.insert(
1116            "watch".to_string(),
1117            "SET@key1,DEL@key2,SET@key2".to_string(),
1118        );
1119        let target_id = "target_1";
1120
1121        let result = parse_watch_config(&config, target_id);
1122
1123        assert_eq!(result.len(), 2);
1124        assert!(result.contains_key("key1"));
1125        assert!(result.contains_key("key2"));
1126
1127        assert!(result["key1"].contains(&WatchedKeyInfo {
1128            event_type: WatchEventType::Set,
1129            target: target_id.to_string()
1130        }));
1131        assert!(result["key2"].contains(&WatchedKeyInfo {
1132            event_type: WatchEventType::Delete,
1133            target: target_id.to_string()
1134        }));
1135        assert!(result["key2"].contains(&WatchedKeyInfo {
1136            event_type: WatchEventType::Set,
1137            target: target_id.to_string()
1138        }));
1139    }
1140
1141    #[test]
1142    fn test_parse_watch_config_invalid_entries() {
1143        let mut config = HashMap::new();
1144        config.insert(
1145            "watch".to_string(),
1146            "INVALID@key1,SET@key2,DEL@key3,SET@key4:extra".to_string(),
1147        );
1148        let target_id = "target_2";
1149
1150        let result = parse_watch_config(&config, target_id);
1151
1152        assert_eq!(result.len(), 2);
1153        assert!(result.contains_key("key2"));
1154        assert!(result.contains_key("key3"));
1155
1156        assert!(result["key2"].contains(&WatchedKeyInfo {
1157            event_type: WatchEventType::Set,
1158            target: target_id.to_string()
1159        }));
1160        assert!(result["key3"].contains(&WatchedKeyInfo {
1161            event_type: WatchEventType::Delete,
1162            target: target_id.to_string()
1163        }));
1164    }
1165
1166    #[test]
1167    fn test_parse_watch_config_empty_or_malformed() {
1168        let mut config = HashMap::new();
1169        config.insert("watch".to_string(), "SET@,DEL@ , @key5".to_string());
1170        let target_id = "target_3";
1171
1172        let result = parse_watch_config(&config, target_id);
1173
1174        assert!(result.is_empty());
1175    }
1176
1177    #[test]
1178    fn test_parse_watch_config_case_insensitivity() {
1179        let mut config = HashMap::new();
1180        config.insert("WATCH".to_string(), "set@key1,del@key2".to_string());
1181        let target_id = "target_4";
1182
1183        let result = parse_watch_config(&config, target_id);
1184
1185        assert_eq!(result.len(), 2);
1186        assert!(result.contains_key("key1"));
1187        assert!(result.contains_key("key2"));
1188
1189        assert!(result["key1"].contains(&WatchedKeyInfo {
1190            event_type: WatchEventType::Set,
1191            target: target_id.to_string()
1192        }));
1193        assert!(result["key2"].contains(&WatchedKeyInfo {
1194            event_type: WatchEventType::Delete,
1195            target: target_id.to_string()
1196        }));
1197    }
1198
1199    #[test]
1200    fn test_parse_watch_config_no_watch_key() {
1201        let config = HashMap::new();
1202        let target_id = "target_5";
1203
1204        let result = parse_watch_config(&config, target_id);
1205
1206        assert!(result.is_empty());
1207    }
1208}