1use core::num::NonZeroU64;
11
12use std::collections::{HashMap, HashSet};
13use std::sync::Arc;
14use std::time::Duration;
15
16use anyhow::{bail, Context as _};
17use bytes::Bytes;
18use redis::aio::{ConnectionManager, ConnectionManagerConfig};
19use redis::{Cmd, FromRedisValue};
20use sha2::{Digest as _, Sha256};
21use tokio::sync::RwLock;
22use tokio::task::JoinHandle;
23use tracing::{debug, error, info, instrument, warn};
24use unicase::UniCase;
25use wasmcloud_provider_sdk::core::secrets::SecretValue;
26use wasmcloud_provider_sdk::provider::WrpcClient;
27use wasmcloud_provider_sdk::{
28 get_connection, load_host_data, propagate_trace_for_ctx, run_provider, Context, HostData,
29 LinkConfig, LinkDeleteInfo, Provider,
30};
31use wasmcloud_provider_sdk::{initialize_observability, serve_provider_exports};
32
33mod bindings {
34 wit_bindgen_wrpc::generate!({
35 with: {
36 "wrpc:keyvalue/atomics@0.2.0-draft": generate,
37 "wrpc:keyvalue/batch@0.2.0-draft": generate,
38 "wrpc:keyvalue/store@0.2.0-draft": generate,
39 "wrpc:keyvalue/watcher@0.2.0-draft": generate,
40 }
41 });
42}
43use bindings::exports::wrpc::keyvalue;
44use wit_bindgen_wrpc::futures::StreamExt;
45
46const DEFAULT_CONNECT_URL: &str = "redis://127.0.0.1:6379/";
48
49const CONFIG_REDIS_URL_KEY: &str = "URL";
51
52const CONFIG_REDIS_BACKEND_RECONNECT_NUM_RETRIES_KEY: &str = "BACKEND_RECONNECT_NUM_RETRIES";
54
55const DEFAULT_REDIS_BACKEND_RECONNECT_NUM_RETRIES: usize = 3;
57
58const CONFIG_REDIS_BACKEND_RECONNECT_MAX_DELAY_MS_KEY: &str = "BACKEND_RECONNECT_MAX_DELAY_MS";
60
61const DEFAULT_REDIS_BACKEND_RECONNECT_MAX_DELAY_MS: u64 = 300;
63
64const CONFIG_REDIS_BACKEND_CONNECTION_TIMEOUT_MS_KEY: &str = "BACKEND_CONNECTION_TIMEOUT_MS";
66
67const DEFAULT_REDIS_BACKEND_CONNECTION_TIMEOUT_MS: u64 = 3000;
69
70const CONFIG_REDIS_BACKEND_RESPONSE_TIMEOUT_MS_KEY: &str = "BACKEND_RESPONSE_TIMEOUT_MS";
72
73const DEFAULT_REDIS_BACKEND_RESPONSE_TIMEOUT_MS: u64 = 1000;
75
76const CONFIG_DISABLE_DEFAULT_CONNECTION_KEY: &str = "DISABLE_DEFAULT_CONNECTION";
78
79const CONFIG_SHARE_CONNECTIONS_BY_URL_KEY: &str = "SHARE_CONNECTIONS_BY_URL";
84
85type Result<T, E = keyvalue::store::Error> = core::result::Result<T, E>;
86
87#[derive(Clone)]
92pub enum DefaultConnection {
93 ClientConfig {
95 config: HashMap<String, String>,
96 secrets: Option<HashMap<String, SecretValue>>,
97 },
98 Conn(ConnectionManager),
100}
101
102#[derive(Clone, PartialEq, Eq, Hash)]
103struct WatchedKeyInfo {
104 event_type: WatchEventType,
105 target: String,
106}
107
108#[derive(Clone, PartialEq, Eq, Hash)]
109enum WatchEventType {
110 Set,
111 Delete,
112}
113#[derive(Eq, Hash, PartialEq)]
115struct LinkId {
116 pub target_id: String,
117 pub link_name: String,
118}
119
120type WatchTaskMap = HashMap<LinkId, JoinHandle<()>>;
122
123type SharedConnectionKey = String;
130
131#[derive(Clone)]
133enum RedisConnection {
134 Direct(ConnectionManager),
136 Shared(String),
138}
139
140#[derive(Clone)]
142pub struct KvRedisProvider {
143 sources: Arc<RwLock<HashMap<(String, String), RedisConnection>>>,
145
146 shared_connections: Arc<RwLock<HashMap<SharedConnectionKey, ConnectionManager>>>,
148
149 default_connection: Option<Arc<RwLock<DefaultConnection>>>,
151 watched_keys: Arc<RwLock<HashMap<String, HashSet<WatchedKeyInfo>>>>,
156 watch_tasks: Arc<RwLock<WatchTaskMap>>,
158}
159
160pub async fn run() -> anyhow::Result<()> {
161 KvRedisProvider::run().await
162}
163
164impl KvRedisProvider {
165 pub fn name() -> &'static str {
166 "keyvalue-redis-provider"
167 }
168
169 pub async fn run() -> anyhow::Result<()> {
170 let host_data = load_host_data().context("failed to load host data")?;
171 let flamegraph_path = host_data
172 .config
173 .get("FLAMEGRAPH_PATH")
174 .map(String::from)
175 .or_else(|| std::env::var("PROVIDER_KEYVALUE_REDIS_FLAMEGRAPH_PATH").ok());
176 initialize_observability!(Self::name(), flamegraph_path);
177 let provider = KvRedisProvider::from_host_data(host_data);
178 let shutdown = run_provider(provider.clone(), KvRedisProvider::name())
179 .await
180 .context("failed to run provider")?;
181 let connection = get_connection();
182 let wrpc = connection
183 .get_wrpc_client(connection.provider_key())
184 .await?;
185 serve_provider_exports(&wrpc, provider, shutdown, bindings::serve)
186 .await
187 .context("failed to serve provider exports")
188 }
189
190 #[must_use]
191 pub fn from_config(config: HashMap<String, String>) -> Self {
192 let default_connection_disabled = config
193 .keys()
194 .any(|k| k.eq_ignore_ascii_case(CONFIG_DISABLE_DEFAULT_CONNECTION_KEY));
195
196 KvRedisProvider {
197 sources: Arc::default(),
198 default_connection: if default_connection_disabled {
199 None
200 } else {
201 Some(Arc::new(RwLock::new(DefaultConnection::ClientConfig {
202 config,
203 secrets: None,
204 })))
205 },
206 shared_connections: Arc::new(RwLock::new(HashMap::new())),
207 watched_keys: Arc::new(RwLock::new(HashMap::new())),
208 watch_tasks: Arc::new(RwLock::new(HashMap::new())),
209 }
210 }
211
212 #[must_use]
213 pub fn from_host_data(host_data: &HostData) -> Self {
214 let default_connection_disabled = host_data
215 .config
216 .keys()
217 .any(|k| k.eq_ignore_ascii_case(CONFIG_DISABLE_DEFAULT_CONNECTION_KEY));
218
219 KvRedisProvider {
220 sources: Arc::default(),
221 default_connection: if default_connection_disabled {
222 None
223 } else {
224 Some(Arc::new(RwLock::new(DefaultConnection::ClientConfig {
225 config: host_data.config.clone(),
226 secrets: Some(host_data.secrets.clone()),
227 })))
228 },
229 shared_connections: Arc::new(RwLock::new(HashMap::new())),
230 watched_keys: Arc::new(RwLock::new(HashMap::new())),
231 watch_tasks: Arc::new(RwLock::new(HashMap::new())),
232 }
233 }
234
235 #[instrument(level = "trace", skip_all)]
236 async fn get_default_connection(&self) -> anyhow::Result<ConnectionManager> {
237 let Some(ref default_connection) = self.default_connection else {
238 bail!("default connection is disabled via config, please provide valid configuration");
239 };
240
241 if let DefaultConnection::Conn(conn) = &*default_connection.read().await {
244 return Ok(conn.clone());
245 }
246
247 let mut default_conn = default_connection.write().await;
249 match &mut *default_conn {
250 DefaultConnection::Conn(conn) => Ok(conn.clone()),
251 DefaultConnection::ClientConfig { config, secrets } => {
252 let conn = redis::Client::open(retrieve_default_url(config, secrets))
253 .context("failed to construct default Redis client")?
254 .get_connection_manager()
255 .await
256 .context("failed to construct Redis connection manager")?;
257 *default_conn = DefaultConnection::Conn(conn.clone());
258 Ok(conn)
259 }
260 }
261 }
262
263 #[instrument(level = "debug", skip(self))]
264 async fn invocation_conn(&self, context: Option<Context>) -> anyhow::Result<ConnectionManager> {
265 let ctx = context.context("unexpectedly missing context")?;
266
267 let Some(ref source_id) = ctx.component else {
268 return self.get_default_connection().await.map_err(|err| {
269 error!(error = ?err, "failed to get default connection for invocation");
270 err
271 });
272 };
273
274 let sources = self.sources.read().await;
275 let Some(conn) = sources.get(&(source_id.into(), ctx.link_name().into())) else {
276 error!(source_id, "no Redis connection found for component");
277 bail!("No Redis connection found for component [{source_id}]. Please ensure the URL supplied in the link definition is a valid Redis URL")
278 };
279
280 match conn {
282 RedisConnection::Direct(c) => Ok(c.clone()),
283 RedisConnection::Shared(key) => {
284 let shared = self.shared_connections.read().await;
285 match shared.get(key) {
286 Some(c) => Ok(c.clone()),
287 None => {
288 error!(key, "no shared Redis connection found with given key");
289 bail!("No shared Redis connection found with key [{key}]");
290 }
291 }
292 }
293 }
294 }
295
296 #[instrument(level = "debug", skip(self, context, cmd))]
298 async fn exec_cmd<T: FromRedisValue>(
299 &self,
300 context: Option<Context>,
301 cmd: &mut Cmd,
302 ) -> Result<T, keyvalue::store::Error> {
303 let mut conn = self
304 .invocation_conn(context)
305 .await
306 .map_err(|err| keyvalue::store::Error::Other(format!("{err:#}")))?;
307 match cmd.query_async(&mut conn).await {
308 Ok(v) => Ok(v),
309 Err(e) => {
310 error!("failed to execute Redis command: {e}");
311 Err(keyvalue::store::Error::Other(format!(
312 "failed to execute Redis command: {e}"
313 )))
314 }
315 }
316 }
317}
318#[instrument(level = "info", skip(wrpc))]
319async fn invoke_on_set(wrpc: &WrpcClient, bucket: &str, key: &str, value: &Bytes) {
320 let mut cx: async_nats::HeaderMap = async_nats::HeaderMap::new();
321 for (k, v) in
322 wasmcloud_provider_sdk::wasmcloud_tracing::context::TraceContextInjector::default_with_span(
323 )
324 .iter()
325 {
326 cx.insert(k.as_str(), v.as_str())
327 }
328 match bindings::wrpc::keyvalue::watcher::on_set(wrpc, Some(cx), bucket, key, value).await {
329 Ok(_) => {
330 debug!("successfully invoked on_set");
331 }
332 Err(err) => {
333 error!(?err, "failed to invoke on_set");
334 }
335 }
336 debug!("key set");
337}
338#[instrument(level = "info", skip(wrpc))]
339async fn invoke_on_delete(wrpc: &WrpcClient, bucket: &str, key: &str) {
340 let mut cx: async_nats::HeaderMap = async_nats::HeaderMap::new();
341 for (k, v) in
342 wasmcloud_provider_sdk::wasmcloud_tracing::context::TraceContextInjector::default_with_span(
343 )
344 .iter()
345 {
346 cx.insert(k.as_str(), v.as_str())
347 }
348 match bindings::wrpc::keyvalue::watcher::on_delete(wrpc, Some(cx), bucket, key).await {
349 Ok(_) => {
350 debug!("successfully invoked on_delete");
351 }
352 Err(err) => {
353 error!(?err, "failed to invoke on_delete");
354 }
355 }
356 debug!("key deleted");
357}
358
359impl keyvalue::store::Handler<Option<Context>> for KvRedisProvider {
360 #[instrument(level = "debug", skip(self))]
361 async fn delete(
362 &self,
363 context: Option<Context>,
364 bucket: String,
365 key: String,
366 ) -> anyhow::Result<Result<()>> {
367 propagate_trace_for_ctx!(context);
368 check_bucket_name(&bucket);
369 Ok(self.exec_cmd(context, &mut Cmd::del(key)).await)
370 }
371
372 #[instrument(level = "debug", skip(self))]
373 async fn exists(
374 &self,
375 context: Option<Context>,
376 bucket: String,
377 key: String,
378 ) -> anyhow::Result<Result<bool>> {
379 propagate_trace_for_ctx!(context);
380 check_bucket_name(&bucket);
381 Ok(self.exec_cmd(context, &mut Cmd::exists(key)).await)
382 }
383
384 #[instrument(level = "debug", skip(self))]
385 async fn get(
386 &self,
387 context: Option<Context>,
388 bucket: String,
389 key: String,
390 ) -> anyhow::Result<Result<Option<Bytes>>> {
391 propagate_trace_for_ctx!(context);
392 check_bucket_name(&bucket);
393 match self
394 .exec_cmd::<redis::Value>(context, &mut Cmd::get(key))
395 .await
396 {
397 Ok(redis::Value::Nil) => Ok(Ok(None)),
398 Ok(redis::Value::BulkString(buf)) => Ok(Ok(Some(buf.into()))),
399 Ok(_) => Ok(Err(keyvalue::store::Error::Other(
400 "invalid data type returned by Redis".into(),
401 ))),
402 Err(err) => Ok(Err(err)),
403 }
404 }
405
406 #[instrument(level = "debug", skip(self))]
407 async fn set(
408 &self,
409 context: Option<Context>,
410 bucket: String,
411 key: String,
412 value: Bytes,
413 ) -> anyhow::Result<Result<()>> {
414 propagate_trace_for_ctx!(context);
415 check_bucket_name(&bucket);
416 Ok(self
417 .exec_cmd(context, &mut Cmd::set(key, value.to_vec()))
418 .await)
419 }
420
421 #[instrument(level = "debug", skip(self))]
422 async fn list_keys(
423 &self,
424 context: Option<Context>,
425 bucket: String,
426 cursor: Option<u64>,
427 ) -> anyhow::Result<Result<keyvalue::store::KeyResponse>> {
428 propagate_trace_for_ctx!(context);
429 check_bucket_name(&bucket);
430 match self
431 .exec_cmd(
432 context,
433 redis::cmd("SCAN").cursor_arg(cursor.unwrap_or_default()),
434 )
435 .await
436 {
437 Ok((cursor, keys)) => Ok(Ok(keyvalue::store::KeyResponse {
438 keys,
439 cursor: NonZeroU64::new(cursor).map(Into::into),
440 })),
441 Err(err) => Ok(Err(err)),
442 }
443 }
444}
445
446impl keyvalue::atomics::Handler<Option<Context>> for KvRedisProvider {
447 #[instrument(level = "debug", skip(self))]
449 async fn increment(
450 &self,
451 context: Option<Context>,
452 bucket: String,
453 key: String,
454 delta: u64,
455 ) -> anyhow::Result<Result<u64, keyvalue::store::Error>> {
456 propagate_trace_for_ctx!(context);
457 check_bucket_name(&bucket);
458 Ok(self
459 .exec_cmd::<u64>(context, &mut Cmd::incr(key, delta))
460 .await)
461 }
462}
463
464impl keyvalue::batch::Handler<Option<Context>> for KvRedisProvider {
465 async fn get_many(
466 &self,
467 ctx: Option<Context>,
468 bucket: String,
469 keys: Vec<String>,
470 ) -> anyhow::Result<Result<Vec<Option<(String, Bytes)>>>> {
471 check_bucket_name(&bucket);
472 let data = match self
473 .exec_cmd::<Vec<Option<Bytes>>>(ctx, &mut Cmd::mget(&keys))
474 .await
475 {
476 Ok(v) => v
477 .into_iter()
478 .zip(keys.into_iter())
479 .map(|(val, key)| val.map(|b| (key, b)))
480 .collect::<Vec<_>>(),
481 Err(err) => {
482 return Ok(Err(err));
483 }
484 };
485 Ok(Ok(data))
486 }
487
488 async fn set_many(
489 &self,
490 ctx: Option<Context>,
491 bucket: String,
492 items: Vec<(String, Bytes)>,
493 ) -> anyhow::Result<Result<()>> {
494 check_bucket_name(&bucket);
495 let items = items
496 .into_iter()
497 .map(|(name, buf)| (name, buf.to_vec()))
498 .collect::<Vec<_>>();
499 Ok(self.exec_cmd(ctx, &mut Cmd::mset(&items)).await)
500 }
501
502 async fn delete_many(
503 &self,
504 ctx: Option<Context>,
505 bucket: String,
506 keys: Vec<String>,
507 ) -> anyhow::Result<Result<()>> {
508 check_bucket_name(&bucket);
509 Ok(self.exec_cmd(ctx, &mut Cmd::del(keys)).await)
510 }
511}
512
513impl Provider for KvRedisProvider {
515 #[instrument(level = "debug", skip(self, config))]
519 async fn receive_link_config_as_target(
520 &self,
521 LinkConfig {
522 source_id,
523 config,
524 secrets,
525 link_name,
526 ..
527 }: LinkConfig<'_>,
528 ) -> anyhow::Result<()> {
529 let url = secrets
530 .keys()
531 .find(|k| k.eq_ignore_ascii_case(CONFIG_REDIS_URL_KEY))
532 .and_then(|url_key| config.get(url_key))
533 .or_else(|| {
534 warn!("redis connection URLs can be sensitive. Please consider using secrets to pass this value");
535 config
536 .keys()
537 .find(|k| k.eq_ignore_ascii_case(CONFIG_REDIS_URL_KEY))
538 .and_then(|url_key| config.get(url_key))
539 });
540
541 let default_connection_disabled = secrets
542 .keys()
543 .any(|k| k.eq_ignore_ascii_case(CONFIG_DISABLE_DEFAULT_CONNECTION_KEY))
544 || config
545 .keys()
546 .any(|k| k.eq_ignore_ascii_case(CONFIG_DISABLE_DEFAULT_CONNECTION_KEY));
547
548 let share_connections_by_url = secrets
549 .keys()
550 .any(|k| k.eq_ignore_ascii_case(CONFIG_SHARE_CONNECTIONS_BY_URL_KEY))
551 || config
552 .keys()
553 .any(|k| k.eq_ignore_ascii_case(CONFIG_SHARE_CONNECTIONS_BY_URL_KEY));
554
555 let key = (source_id.to_string(), link_name.to_string());
556
557 {
560 if let (Some(url), true) = (url, share_connections_by_url) {
561 let shared_connections = self.shared_connections.read().await;
562 let shared_key = format!("{:X}", Sha256::digest(url));
563 if shared_connections.contains_key(&shared_key) {
564 let mut sources = self.sources.write().await;
566 sources.insert(key, RedisConnection::Shared(shared_key));
567 return Ok(());
568 }
569 }
570 }
571
572 let cfg = build_connection_mgr_config(config);
574 let conn = if let Some(url) = url {
575 match redis::Client::open(url.to_string()) {
576 Ok(client) => match ConnectionManager::new_with_config(client, cfg).await {
577 Ok(conn) => {
578 info!(url, "established link");
579 conn
580 }
581 Err(err) => {
582 warn!(
583 url,
584 ?err,
585 "Could not create Redis connection manager for source [{source_id}], keyvalue operations will fail",
586 );
587 bail!("failed to create redis connection manager");
588 }
589 },
590 Err(err) => {
591 warn!(
592 ?err,
593 "Could not create Redis client for source [{source_id}], keyvalue operations will fail",
594 );
595 bail!("failed to create redis client");
596 }
597 }
598 } else {
599 if default_connection_disabled {
601 error!(
602 component = source_id,
603 "using the default connection is disabled via link configuration"
604 );
605 bail!(
606 "using the default connection is disabled via link configuration for component [{source_id}]"
607 );
608 }
609
610 self.get_default_connection().await.map_err(|err| {
611 error!(error = ?err, "failed to get default connection for link");
612 err
613 })?
614 };
615
616 match (url, share_connections_by_url) {
617 (Some(url), true) => {
620 let shared_key = format!("{:X}", Sha256::digest(url));
621
622 let mut shared_connections = self.shared_connections.write().await;
624 shared_connections.insert(shared_key.clone(), conn);
625 drop(shared_connections);
626
627 let mut sources = self.sources.write().await;
628 sources.insert(key, RedisConnection::Shared(shared_key));
629 drop(sources);
630 }
631 _ => {
634 let mut sources = self.sources.write().await;
635 sources.insert(key, RedisConnection::Direct(conn));
636 drop(sources);
637 }
638 }
639
640 Ok(())
641 }
642
643 async fn receive_link_config_as_source(
644 &self,
645 LinkConfig {
646 target_id,
647 config,
648 secrets,
649 link_name,
650 wit_metadata: (_, _, interfaces),
651 ..
652 }: LinkConfig<'_>,
653 ) -> anyhow::Result<()> {
654 let url = secrets
655 .keys()
656 .find(|k| k.eq_ignore_ascii_case(CONFIG_REDIS_URL_KEY))
657 .and_then(|url_key| config.get(url_key))
658 .or_else(|| {
659 warn!("Redis connection URLs can be sensitive. Consider using secrets to pass this value.");
660 config.keys()
661 .find(|k| k.eq_ignore_ascii_case(CONFIG_REDIS_URL_KEY))
662 .and_then(|url_key| config.get(url_key))
663 })
664 .map_or(DEFAULT_CONNECT_URL, |v| v);
665
666 let client = match redis::Client::open(url.to_string()) {
667 Ok(client) => {
668 info!(url, "Established link at receive_link_config_as_source");
669 client
670 }
671 Err(err) => {
672 warn!(target_id = %target_id, err = ?err, "Failed to create Redis client");
673 bail!("Failed to create Redis client");
674 }
675 };
676 let mut conn = client.get_connection_manager().await.map_err(|e| {
677 error!(err = ?e, "Failed to get async connection");
678 anyhow::anyhow!("Failed to get async connection: {}", e)
679 })?;
680
681 let component_id: Arc<str> = target_id.into();
682 let wrpc = get_connection()
683 .get_wrpc_client(&component_id)
684 .await
685 .context("failed to construct wRPC client")?;
686 if interfaces.contains(&"watcher".to_string()) {
687 let config_response: Vec<String> = redis::cmd("CONFIG")
688 .arg("GET")
689 .arg("notify-keyspace-events")
690 .query_async(&mut conn)
691 .await
692 .map_err(|e| {
693 error!(err = %e, "Failed to get keyspace notifications config");
694 anyhow::anyhow!("Failed to get keyspace notifications config: {}", e)
695 })?;
696
697 let current_config = config_response.get(1).ok_or_else(|| {
698 error!("Unexpected response format from Redis CONFIG GET");
699 anyhow::anyhow!("Unexpected response format from Redis CONFIG GET")
700 })?;
701
702 if !current_config.contains('K')
703 || !current_config.contains('$')
704 || !current_config.contains('g')
705 {
706 error!(
707 current_config = %current_config,
708 "Redis keyspace-notifications not properly configured"
709 );
710 return Err(anyhow::anyhow!(
711 "Redis keyspace-notifications not properly configured! \
712 Expected 'K$g' in settings, but got '{}'. \
713 Please run: CONFIG SET notify-keyspace-events K$g",
714 current_config
715 ));
716 }
717
718 let wrpc = Arc::new(wrpc);
719 let wrpc_for_task = wrpc.clone();
720
721 let config_watch_entries = parse_watch_config(config, target_id);
722
723 let mut watched_keys = self.watched_keys.write().await;
725 for (key, key_info_set) in config_watch_entries {
726 watched_keys
727 .entry(key)
728 .or_insert_with(HashSet::new)
729 .extend(key_info_set);
730 }
731
732 let client_clone = client.clone();
733 let self_clone = self.clone();
734 let mut conn_clone = conn.clone();
735 let task = tokio::spawn(async move {
736 let mut pubsub = match client_clone.get_async_pubsub().await {
737 Ok(pubsub) => pubsub,
738 Err(e) => {
739 error!(err = %e, "Failed to get pubsub connection");
740 return;
741 }
742 };
743 let watched_keys = self_clone.watched_keys.read().await;
744 for key in watched_keys.keys() {
745 let channel = format!("__keyspace@0__:{key}");
746 let _ = pubsub
747 .psubscribe(&channel)
748 .await
749 .context("Failed to subscribe to SET/DEL events for key");
750 }
751 let stream = pubsub.on_message();
752 tokio::pin!(stream);
753 while let Some(msg) = stream.next().await {
754 let channel: String = msg.get_channel_name().to_string();
755 let event: String = match msg.get_payload() {
756 Ok(event) => event,
757 Err(e) => {
758 error!(err = %e, "Failed to get payload");
759 continue;
760 }
761 };
762 let mkey = match channel.split(':').next_back() {
765 Some(key) => key,
766 None => {
767 error!(channel = %channel, "Malformed Redis channel name: expected '__keyspace@0__:key' format");
768 continue;
769 }
770 };
771 let watched_keys = self_clone.watched_keys.read().await;
773 if let Some(key_info_set) = watched_keys.get(mkey) {
774 if event == "set" || event == "SET" {
775 let value: wit_bindgen_wrpc::bytes::Bytes = match redis::cmd("GET")
778 .arg(mkey)
779 .query_async::<Option<Vec<u8>>>(&mut conn_clone)
780 .await
781 {
782 Ok(Some(v)) => v.into(),
783 Ok(None) => {
784 debug!(key = %mkey, "Key not found or was deleted");
785 continue;
786 }
787 Err(e) => {
788 error!(key = %mkey, err = %e, "Failed to get value for key");
789 continue;
790 }
791 };
792 for key_info in key_info_set {
793 if key_info.event_type == WatchEventType::Set {
794 invoke_on_set(&wrpc_for_task, "0", mkey, &value).await;
795 }
796 }
797 } else if event == "del" || event == "DEL" {
798 for key_info in key_info_set {
799 if key_info.event_type == WatchEventType::Delete {
800 invoke_on_delete(&wrpc_for_task, "0", mkey).await;
801 }
802 }
803 }
804 }
805 }
806 });
807 let mut tasks = self.watch_tasks.write().await;
808 tasks.insert(
809 LinkId {
810 target_id: target_id.to_string(),
811 link_name: link_name.to_string(),
812 },
813 task,
814 );
815 }
816
817 let mut sources = self.sources.write().await;
818 sources.insert(
819 (target_id.to_string(), link_name.to_string()),
820 RedisConnection::Direct(conn),
821 );
822
823 Ok(())
824 }
825
826 #[instrument(level = "info", skip_all, fields(source_id = info.get_source_id()))]
828 async fn delete_link_as_target(&self, info: impl LinkDeleteInfo) -> anyhow::Result<()> {
829 let component_id = info.get_source_id();
830 let mut aw = self.sources.write().await;
831 aw.retain(|(src_id, _link_name), _| src_id != component_id);
835 debug!(component_id, "closing all redis connections for component");
836 Ok(())
837 }
838
839 #[instrument(level = "info", skip_all, fields(target_id = info.get_target_id()))]
840 async fn delete_link_as_source(&self, info: impl LinkDeleteInfo) -> anyhow::Result<()> {
841 let component_id = info.get_target_id();
842 let link_name = info.get_link_name();
843
844 let mut sources = self.sources.write().await;
845 sources.remove(&(component_id.to_string(), link_name.to_string()));
846
847 let mut watch_tasks = self.watch_tasks.write().await;
848
849 if let Some(task) = watch_tasks.remove(&LinkId {
851 target_id: component_id.to_string(),
852 link_name: link_name.to_string(),
853 }) {
854 task.abort();
855 let _ = task.await;
856 }
857
858 let mut watched_keys = self.watched_keys.write().await;
860 for key_watchers in watched_keys.values_mut() {
861 key_watchers.retain(|key_info| key_info.target != component_id);
862 }
863
864 watched_keys.retain(|_, watchers| !watchers.is_empty());
866
867 debug!(
868 component_id,
869 link_name, "cleaned up redis connection and watch tasks for link"
870 );
871 Ok(())
872 }
873
874 async fn shutdown(&self) -> anyhow::Result<()> {
876 info!("shutting down");
877 let mut aw = self.sources.write().await;
878 for (_, conn) in aw.drain() {
880 drop(conn);
881 }
882 Ok(())
883 }
884}
885
886fn retrieve_default_url(
889 config: &HashMap<String, String>,
890 secrets: &Option<HashMap<String, SecretValue>>,
891) -> String {
892 if let Some(secrets) = secrets {
894 if let Some(url) = secrets
895 .keys()
896 .find(|sk| sk.eq_ignore_ascii_case(CONFIG_REDIS_URL_KEY))
897 .and_then(|k| secrets.get(k))
898 {
899 if let Some(s) = url.as_string() {
900 debug!(
901 url = ?url, "using Redis URL from secrets"
903 );
904 return s.into();
905 } else {
906 warn!("invalid secret value for URL (expected string, found bytes). Falling back to config");
907 }
908 }
909 }
910
911 let config_supplied_url = config
913 .keys()
914 .find(|k| k.eq_ignore_ascii_case(CONFIG_REDIS_URL_KEY))
915 .and_then(|url_key| config.get(url_key));
916
917 if let Some(url) = config_supplied_url {
918 debug!(url, "using Redis URL from config");
919 url.to_string()
920 } else {
921 debug!(DEFAULT_CONNECT_URL, "using default Redis URL");
922 DEFAULT_CONNECT_URL.to_string()
923 }
924}
925#[instrument(level = "debug", skip(config))]
933fn parse_watch_config(
934 config: &HashMap<String, String>,
935 target_id: &str,
936) -> HashMap<String, HashSet<WatchedKeyInfo>> {
937 let mut watched_keys = HashMap::new();
938
939 let config_map: HashMap<UniCase<&str>, &String> = config
941 .iter()
942 .map(|(k, v)| (UniCase::new(k.as_str()), v))
943 .collect();
944
945 if let Some(watch_config) = config_map.get(&UniCase::new("watch")) {
947 for watch_entry in watch_config.split(',') {
948 let watch_entry = watch_entry.trim();
949 if watch_entry.is_empty() {
950 continue;
951 }
952
953 let parts: Vec<&str> = watch_entry.split('@').collect();
954 if parts.len() != 2 {
955 error!(watch_entry = %watch_entry, "Invalid watch entry format. Expected FORMAT@KEY");
956 continue;
957 }
958
959 let operation = parts[0].trim().to_uppercase();
960 let key_value = parts[1].trim();
961
962 if key_value.contains(':') {
963 error!(key = %key_value, "Invalid SET watch format. SET expects only KEY");
964 continue;
965 }
966 if key_value.is_empty() {
967 error!(watch_entry = %watch_entry, "Invalid watch entry: Missing key.");
968 continue;
969 }
970
971 match operation.as_str() {
972 "SET" => {
973 watched_keys
974 .entry(key_value.to_string())
975 .or_insert_with(HashSet::new)
976 .insert(WatchedKeyInfo {
977 event_type: WatchEventType::Set,
978 target: target_id.to_string(),
979 });
980 }
981 "DEL" => {
982 watched_keys
983 .entry(key_value.to_string())
984 .or_insert_with(HashSet::new)
985 .insert(WatchedKeyInfo {
986 event_type: WatchEventType::Delete,
987 target: target_id.to_string(),
988 });
989 }
990 _ => {
991 error!(operation = %operation, "Unsupported watch operation. Expected SET or DEL");
992 }
993 }
994 }
995 }
996
997 watched_keys
998}
999
1000fn check_bucket_name(bucket: &str) {
1003 if !bucket.is_empty() {
1004 warn!(bucket, "non-empty bucket names are not yet supported; ignoring non-empty bucket name (using a non-empty bucket name may become an error in the future).")
1005 }
1006}
1007
1008fn build_connection_mgr_config(config: &HashMap<String, String>) -> ConnectionManagerConfig {
1010 let mut cfg = ConnectionManagerConfig::new();
1011
1012 cfg = cfg
1014 .set_number_of_retries(DEFAULT_REDIS_BACKEND_RECONNECT_NUM_RETRIES)
1015 .set_max_delay(DEFAULT_REDIS_BACKEND_RECONNECT_MAX_DELAY_MS)
1016 .set_connection_timeout(Duration::from_millis(
1017 DEFAULT_REDIS_BACKEND_CONNECTION_TIMEOUT_MS,
1018 ))
1019 .set_response_timeout(Duration::from_millis(
1020 DEFAULT_REDIS_BACKEND_RESPONSE_TIMEOUT_MS,
1021 ));
1022
1023 for (k, v) in config.iter() {
1025 if k.eq_ignore_ascii_case(CONFIG_REDIS_BACKEND_RECONNECT_NUM_RETRIES_KEY) {
1026 if let Ok(val) = v.parse::<usize>() {
1027 cfg = cfg.set_number_of_retries(val);
1028 } else {
1029 warn!(
1030 key = %CONFIG_REDIS_BACKEND_RECONNECT_NUM_RETRIES_KEY,
1031 value = %v,
1032 "Invalid value for number of retries, using default"
1033 );
1034 }
1035 }
1036
1037 if let Some(max_delay) = if k
1038 .eq_ignore_ascii_case(CONFIG_REDIS_BACKEND_RECONNECT_MAX_DELAY_MS_KEY)
1039 {
1040 match v.parse() {
1041 Ok(val) => Some(val),
1042 Err(_) => {
1043 warn!(key = %CONFIG_REDIS_BACKEND_RECONNECT_MAX_DELAY_MS_KEY, value = %v, "Invalid value for max delay, using default");
1044 Some(DEFAULT_REDIS_BACKEND_RECONNECT_MAX_DELAY_MS)
1045 }
1046 }
1047 } else {
1048 None
1049 } {
1050 cfg = cfg.set_max_delay(max_delay);
1051 }
1052
1053 if let Some(timeout) = if k
1054 .eq_ignore_ascii_case(CONFIG_REDIS_BACKEND_CONNECTION_TIMEOUT_MS_KEY)
1055 {
1056 match v.parse() {
1057 Ok(val) => Some(val),
1058 Err(_) => {
1059 warn!(key = %CONFIG_REDIS_BACKEND_CONNECTION_TIMEOUT_MS_KEY,value = %v,"Invalid value for connection timeout, using default");
1060 Some(DEFAULT_REDIS_BACKEND_CONNECTION_TIMEOUT_MS)
1061 }
1062 }
1063 } else {
1064 None
1065 } {
1066 cfg = cfg.set_connection_timeout(Duration::from_millis(timeout));
1067 }
1068
1069 if let Some(timeout) = if k
1070 .eq_ignore_ascii_case(CONFIG_REDIS_BACKEND_RESPONSE_TIMEOUT_MS_KEY)
1071 {
1072 match v.parse() {
1073 Ok(val) => Some(val),
1074 Err(_) => {
1075 warn!(key = %CONFIG_REDIS_BACKEND_RESPONSE_TIMEOUT_MS_KEY,value = %v,"Invalid value for response timeout, using default");
1076 Some(DEFAULT_REDIS_BACKEND_RESPONSE_TIMEOUT_MS)
1077 }
1078 }
1079 } else {
1080 None
1081 } {
1082 cfg = cfg.set_response_timeout(Duration::from_millis(timeout));
1083 }
1084 }
1085
1086 cfg
1087}
1088
1089#[cfg(test)]
1090mod test {
1091 use super::*;
1092 use std::collections::HashMap;
1093
1094 use crate::retrieve_default_url;
1095
1096 const PROPER_URL: &str = "redis://127.0.0.1:6379";
1097
1098 #[test]
1099 fn can_deserialize_config_case_insensitive() {
1100 let lowercase_config = HashMap::from_iter([("url".to_string(), PROPER_URL.to_string())]);
1101 let uppercase_config = HashMap::from_iter([("URL".to_string(), PROPER_URL.to_string())]);
1102 let initial_caps_config = HashMap::from_iter([("Url".to_string(), PROPER_URL.to_string())]);
1103
1104 assert_eq!(PROPER_URL, retrieve_default_url(&lowercase_config, &None));
1105 assert_eq!(PROPER_URL, retrieve_default_url(&uppercase_config, &None));
1106 assert_eq!(
1107 PROPER_URL,
1108 retrieve_default_url(&initial_caps_config, &None)
1109 );
1110 }
1111
1112 #[test]
1113 fn test_parse_watch_config_valid_entries() {
1114 let mut config = HashMap::new();
1115 config.insert(
1116 "watch".to_string(),
1117 "SET@key1,DEL@key2,SET@key2".to_string(),
1118 );
1119 let target_id = "target_1";
1120
1121 let result = parse_watch_config(&config, target_id);
1122
1123 assert_eq!(result.len(), 2);
1124 assert!(result.contains_key("key1"));
1125 assert!(result.contains_key("key2"));
1126
1127 assert!(result["key1"].contains(&WatchedKeyInfo {
1128 event_type: WatchEventType::Set,
1129 target: target_id.to_string()
1130 }));
1131 assert!(result["key2"].contains(&WatchedKeyInfo {
1132 event_type: WatchEventType::Delete,
1133 target: target_id.to_string()
1134 }));
1135 assert!(result["key2"].contains(&WatchedKeyInfo {
1136 event_type: WatchEventType::Set,
1137 target: target_id.to_string()
1138 }));
1139 }
1140
1141 #[test]
1142 fn test_parse_watch_config_invalid_entries() {
1143 let mut config = HashMap::new();
1144 config.insert(
1145 "watch".to_string(),
1146 "INVALID@key1,SET@key2,DEL@key3,SET@key4:extra".to_string(),
1147 );
1148 let target_id = "target_2";
1149
1150 let result = parse_watch_config(&config, target_id);
1151
1152 assert_eq!(result.len(), 2);
1153 assert!(result.contains_key("key2"));
1154 assert!(result.contains_key("key3"));
1155
1156 assert!(result["key2"].contains(&WatchedKeyInfo {
1157 event_type: WatchEventType::Set,
1158 target: target_id.to_string()
1159 }));
1160 assert!(result["key3"].contains(&WatchedKeyInfo {
1161 event_type: WatchEventType::Delete,
1162 target: target_id.to_string()
1163 }));
1164 }
1165
1166 #[test]
1167 fn test_parse_watch_config_empty_or_malformed() {
1168 let mut config = HashMap::new();
1169 config.insert("watch".to_string(), "SET@,DEL@ , @key5".to_string());
1170 let target_id = "target_3";
1171
1172 let result = parse_watch_config(&config, target_id);
1173
1174 assert!(result.is_empty());
1175 }
1176
1177 #[test]
1178 fn test_parse_watch_config_case_insensitivity() {
1179 let mut config = HashMap::new();
1180 config.insert("WATCH".to_string(), "set@key1,del@key2".to_string());
1181 let target_id = "target_4";
1182
1183 let result = parse_watch_config(&config, target_id);
1184
1185 assert_eq!(result.len(), 2);
1186 assert!(result.contains_key("key1"));
1187 assert!(result.contains_key("key2"));
1188
1189 assert!(result["key1"].contains(&WatchedKeyInfo {
1190 event_type: WatchEventType::Set,
1191 target: target_id.to_string()
1192 }));
1193 assert!(result["key2"].contains(&WatchedKeyInfo {
1194 event_type: WatchEventType::Delete,
1195 target: target_id.to_string()
1196 }));
1197 }
1198
1199 #[test]
1200 fn test_parse_watch_config_no_watch_key() {
1201 let config = HashMap::new();
1202 let target_id = "target_5";
1203
1204 let result = parse_watch_config(&config, target_id);
1205
1206 assert!(result.is_empty());
1207 }
1208}