1use super::{AsyncPushSender, HandleContainer, RedisFuture};
2#[cfg(feature = "cache-aio")]
3use crate::caching::CacheManager;
4use crate::{
5 aio::{check_resp3, ConnectionLike, MultiplexedConnection, Runtime},
6 cmd,
7 subscription_tracker::{SubscriptionAction, SubscriptionTracker},
8 types::{RedisError, RedisResult, Value},
9 AsyncConnectionConfig, Client, Cmd, Pipeline, PushInfo, PushKind, ToRedisArgs,
10};
11use arc_swap::ArcSwap;
12use backon::{ExponentialBuilder, Retryable};
13use futures_channel::oneshot;
14use futures_util::future::{self, BoxFuture, FutureExt, Shared};
15use std::sync::Arc;
16use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver};
17use tokio::sync::Mutex;
18
19#[derive(Clone)]
21pub struct ConnectionManagerConfig {
22 exponent_base: u64,
25 factor: u64,
29 number_of_retries: usize,
31 max_delay: Option<u64>,
33 response_timeout: Option<std::time::Duration>,
35 connection_timeout: Option<std::time::Duration>,
37 push_sender: Option<Arc<dyn AsyncPushSender>>,
39 resubscribe_automatically: bool,
41 tcp_settings: crate::io::tcp::TcpSettings,
42 #[cfg(feature = "cache-aio")]
43 pub(crate) cache_config: Option<crate::caching::CacheConfig>,
44}
45
46impl std::fmt::Debug for ConnectionManagerConfig {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
48 let &Self {
49 exponent_base,
50 factor,
51 number_of_retries,
52 max_delay,
53 response_timeout,
54 connection_timeout,
55 push_sender,
56 resubscribe_automatically,
57 tcp_settings,
58 #[cfg(feature = "cache-aio")]
59 cache_config,
60 } = &self;
61 let mut str = f.debug_struct("ConnectionManagerConfig");
62 str.field("exponent_base", &exponent_base)
63 .field("factor", &factor)
64 .field("number_of_retries", &number_of_retries)
65 .field("max_delay", &max_delay)
66 .field("response_timeout", &response_timeout)
67 .field("connection_timeout", &connection_timeout)
68 .field("resubscribe_automatically", &resubscribe_automatically)
69 .field(
70 "push_sender",
71 if push_sender.is_some() {
72 &"set"
73 } else {
74 &"not set"
75 },
76 )
77 .field("tcp_settings", &tcp_settings);
78
79 #[cfg(feature = "cache-aio")]
80 str.field("cache_config", &cache_config);
81
82 str.finish()
83 }
84}
85
86impl ConnectionManagerConfig {
87 const DEFAULT_CONNECTION_RETRY_EXPONENT_BASE: u64 = 2;
88 const DEFAULT_CONNECTION_RETRY_FACTOR: u64 = 100;
89 const DEFAULT_NUMBER_OF_CONNECTION_RETRIES: usize = 6;
90 const DEFAULT_RESPONSE_TIMEOUT: Option<std::time::Duration> = None;
91 const DEFAULT_CONNECTION_TIMEOUT: Option<std::time::Duration> = None;
92
93 pub fn new() -> Self {
95 Self::default()
96 }
97
98 pub fn set_factor(mut self, factor: u64) -> ConnectionManagerConfig {
102 self.factor = factor;
103 self
104 }
105
106 pub fn set_max_delay(mut self, time: u64) -> ConnectionManagerConfig {
108 self.max_delay = Some(time);
109 self
110 }
111
112 pub fn set_exponent_base(mut self, base: u64) -> ConnectionManagerConfig {
115 self.exponent_base = base;
116 self
117 }
118
119 pub fn set_number_of_retries(mut self, amount: usize) -> ConnectionManagerConfig {
121 self.number_of_retries = amount;
122 self
123 }
124
125 pub fn set_response_timeout(
127 mut self,
128 duration: std::time::Duration,
129 ) -> ConnectionManagerConfig {
130 self.response_timeout = Some(duration);
131 self
132 }
133
134 pub fn set_connection_timeout(
136 mut self,
137 duration: std::time::Duration,
138 ) -> ConnectionManagerConfig {
139 self.connection_timeout = Some(duration);
140 self
141 }
142
143 pub fn set_push_sender(mut self, sender: impl AsyncPushSender) -> Self {
172 self.push_sender = Some(Arc::new(sender));
173 self
174 }
175
176 pub fn set_automatic_resubscription(mut self) -> Self {
178 self.resubscribe_automatically = true;
179 self
180 }
181
182 pub fn set_tcp_settings(self, tcp_settings: crate::io::tcp::TcpSettings) -> Self {
184 Self {
185 tcp_settings,
186 ..self
187 }
188 }
189
190 #[cfg(feature = "cache-aio")]
192 pub fn set_cache_config(self, cache_config: crate::caching::CacheConfig) -> Self {
193 Self {
194 cache_config: Some(cache_config),
195 ..self
196 }
197 }
198}
199
200impl Default for ConnectionManagerConfig {
201 fn default() -> Self {
202 Self {
203 exponent_base: Self::DEFAULT_CONNECTION_RETRY_EXPONENT_BASE,
204 factor: Self::DEFAULT_CONNECTION_RETRY_FACTOR,
205 number_of_retries: Self::DEFAULT_NUMBER_OF_CONNECTION_RETRIES,
206 response_timeout: Self::DEFAULT_RESPONSE_TIMEOUT,
207 connection_timeout: Self::DEFAULT_CONNECTION_TIMEOUT,
208 max_delay: None,
209 push_sender: None,
210 resubscribe_automatically: false,
211 tcp_settings: Default::default(),
212 #[cfg(feature = "cache-aio")]
213 cache_config: None,
214 }
215 }
216}
217
218struct Internals {
219 client: Client,
221 connection: ArcSwap<SharedRedisFuture<MultiplexedConnection>>,
226
227 runtime: Runtime,
228 retry_strategy: ExponentialBuilder,
229 connection_config: AsyncConnectionConfig,
230 subscription_tracker: Option<Mutex<SubscriptionTracker>>,
231 #[cfg(feature = "cache-aio")]
232 cache_manager: Option<CacheManager>,
233 _task_handle: HandleContainer,
234}
235
236#[derive(Clone)]
262pub struct ConnectionManager(Arc<Internals>);
263
264type CloneableRedisResult<T> = Result<T, Arc<RedisError>>;
266
267type SharedRedisFuture<T> = Shared<BoxFuture<'static, CloneableRedisResult<T>>>;
269
270macro_rules! reconnect_if_dropped {
272 ($self:expr, $result:expr, $current:expr) => {
273 if let Err(ref e) = $result {
274 if e.is_unrecoverable_error() {
275 $self.reconnect($current);
276 }
277 }
278 };
279}
280
281macro_rules! reconnect_if_io_error {
284 ($self:expr, $result:expr, $current:expr) => {
285 if let Err(e) = $result {
286 if e.is_io_error() {
287 $self.reconnect($current);
288 }
289 return Err(e);
290 }
291 };
292}
293
294impl ConnectionManager {
295 pub async fn new(client: Client) -> RedisResult<Self> {
300 let config = ConnectionManagerConfig::new();
301
302 Self::new_with_config(client, config).await
303 }
304
305 #[deprecated(note = "Use `new_with_config`")]
314 pub async fn new_with_backoff(
315 client: Client,
316 exponent_base: u64,
317 factor: u64,
318 number_of_retries: usize,
319 ) -> RedisResult<Self> {
320 let config = ConnectionManagerConfig::new()
321 .set_exponent_base(exponent_base)
322 .set_factor(factor)
323 .set_number_of_retries(number_of_retries);
324 Self::new_with_config(client, config).await
325 }
326
327 #[deprecated(note = "Use `new_with_config`")]
339 pub async fn new_with_backoff_and_timeouts(
340 client: Client,
341 exponent_base: u64,
342 factor: u64,
343 number_of_retries: usize,
344 response_timeout: std::time::Duration,
345 connection_timeout: std::time::Duration,
346 ) -> RedisResult<Self> {
347 let config = ConnectionManagerConfig::new()
348 .set_exponent_base(exponent_base)
349 .set_factor(factor)
350 .set_number_of_retries(number_of_retries)
351 .set_response_timeout(response_timeout)
352 .set_connection_timeout(connection_timeout);
353
354 Self::new_with_config(client, config).await
355 }
356
357 pub async fn new_with_config(
371 client: Client,
372 config: ConnectionManagerConfig,
373 ) -> RedisResult<Self> {
374 let runtime = Runtime::locate();
376
377 if config.resubscribe_automatically && config.push_sender.is_none() {
378 return Err((crate::ErrorKind::ClientError, "Cannot set resubscribe_automatically without setting a push sender to receive messages.").into());
379 }
380
381 let mut retry_strategy = ExponentialBuilder::default()
382 .with_factor(config.factor as f32)
383 .with_max_times(config.number_of_retries)
384 .with_jitter();
385 if let Some(max_delay) = config.max_delay {
386 retry_strategy =
387 retry_strategy.with_max_delay(std::time::Duration::from_millis(max_delay));
388 }
389
390 let mut connection_config = AsyncConnectionConfig::new();
391 if let Some(connection_timeout) = config.connection_timeout {
392 connection_config = connection_config.set_connection_timeout(connection_timeout);
393 }
394 if let Some(response_timeout) = config.response_timeout {
395 connection_config = connection_config.set_response_timeout(response_timeout);
396 }
397 connection_config = connection_config.set_tcp_settings(config.tcp_settings);
398 #[cfg(feature = "cache-aio")]
399 let cache_manager = config
400 .cache_config
401 .as_ref()
402 .map(|cache_config| CacheManager::new(*cache_config));
403 #[cfg(feature = "cache-aio")]
404 if let Some(cache_manager) = cache_manager.as_ref() {
405 connection_config = connection_config.set_cache_manager(cache_manager.clone());
406 }
407
408 let (oneshot_sender, oneshot_receiver) = oneshot::channel();
409 let _task_handle = HandleContainer::new(
410 runtime.spawn(Self::check_for_disconnect_pushes(oneshot_receiver)),
411 );
412
413 let mut components_for_reconnection_on_push = None;
414 if let Some(push_sender) = config.push_sender.clone() {
415 check_resp3!(
416 client.connection_info.redis.protocol,
417 "Can only pass push sender to a connection using RESP3"
418 );
419
420 let (internal_sender, internal_receiver) = unbounded_channel();
421 components_for_reconnection_on_push = Some((internal_receiver, push_sender));
422
423 connection_config =
424 connection_config.set_push_sender_internal(Arc::new(internal_sender));
425 }
426
427 let connection =
428 Self::new_connection(&client, retry_strategy, &connection_config, None).await?;
429 let subscription_tracker = if config.resubscribe_automatically {
430 Some(Mutex::new(SubscriptionTracker::default()))
431 } else {
432 None
433 };
434
435 let new_self = Self(Arc::new(Internals {
436 client,
437 connection: ArcSwap::from_pointee(future::ok(connection).boxed().shared()),
438 runtime,
439 retry_strategy,
440 connection_config,
441 subscription_tracker,
442 #[cfg(feature = "cache-aio")]
443 cache_manager,
444 _task_handle,
445 }));
446
447 if let Some((internal_receiver, external_sender)) = components_for_reconnection_on_push {
448 oneshot_sender
449 .send((new_self.clone(), internal_receiver, external_sender))
450 .map_err(|_| {
451 crate::RedisError::from((
452 crate::ErrorKind::ClientError,
453 "Failed to set automatic resubscription",
454 ))
455 })?;
456 }
457
458 Ok(new_self)
459 }
460
461 async fn new_connection(
462 client: &Client,
463 exponential_backoff: ExponentialBuilder,
464 connection_config: &AsyncConnectionConfig,
465 additional_commands: Option<Pipeline>,
466 ) -> RedisResult<MultiplexedConnection> {
467 let connection_config = connection_config.clone();
468 let get_conn = || async {
469 client
470 .get_multiplexed_async_connection_with_config(&connection_config)
471 .await
472 };
473 let mut conn = get_conn
474 .retry(exponential_backoff)
475 .sleep(|duration| async move { Runtime::locate().sleep(duration).await })
476 .await?;
477 if let Some(pipeline) = additional_commands {
478 let _ = pipeline.exec_async(&mut conn).await;
480 }
481 Ok(conn)
482 }
483
484 fn reconnect(&self, current: arc_swap::Guard<Arc<SharedRedisFuture<MultiplexedConnection>>>) {
489 #[cfg(feature = "cache-aio")]
490 if let Some(manager) = self.0.cache_manager.as_ref() {
491 manager.invalidate_all();
492 }
493 let self_clone = self.clone();
494 let new_connection: SharedRedisFuture<MultiplexedConnection> = async move {
495 let additional_commands = match &self_clone.0.subscription_tracker {
496 Some(subscription_tracker) => Some(
497 subscription_tracker
498 .lock()
499 .await
500 .get_subscription_pipeline(),
501 ),
502 None => None,
503 };
504 let con = Self::new_connection(
505 &self_clone.0.client,
506 self_clone.0.retry_strategy,
507 &self_clone.0.connection_config,
508 additional_commands,
509 )
510 .await?;
511 Ok(con)
512 }
513 .boxed()
514 .shared();
515
516 let new_connection_arc = Arc::new(new_connection.clone());
518 let prev = self
519 .0
520 .connection
521 .compare_and_swap(¤t, new_connection_arc);
522
523 if Arc::ptr_eq(&prev, ¤t) {
525 self.0.runtime.spawn(new_connection.map(|_| ()));
527 }
528 }
529
530 async fn check_for_disconnect_pushes(
531 receiver: oneshot::Receiver<(
532 ConnectionManager,
533 UnboundedReceiver<PushInfo>,
534 Arc<dyn AsyncPushSender>,
535 )>,
536 ) {
537 let Ok((this, mut internal_receiver, external_sender)) = receiver.await else {
538 return;
539 };
540 while let Some(push_info) = internal_receiver.recv().await {
541 if push_info.kind == PushKind::Disconnection {
542 this.reconnect(this.0.connection.load());
543 }
544 if external_sender.send(push_info).is_err() {
545 return;
546 }
547 }
548 }
549
550 pub async fn send_packed_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
553 let guard = self.0.connection.load();
555 let connection_result = (**guard)
556 .clone()
557 .await
558 .map_err(|e| e.clone_mostly("Reconnecting failed"));
559 reconnect_if_io_error!(self, connection_result, guard);
560 let result = connection_result?.send_packed_command(cmd).await;
561 reconnect_if_dropped!(self, &result, guard);
562 result
563 }
564
565 pub async fn send_packed_commands(
569 &mut self,
570 cmd: &crate::Pipeline,
571 offset: usize,
572 count: usize,
573 ) -> RedisResult<Vec<Value>> {
574 let guard = self.0.connection.load();
576 let connection_result = (**guard)
577 .clone()
578 .await
579 .map_err(|e| e.clone_mostly("Reconnecting failed"));
580 reconnect_if_io_error!(self, connection_result, guard);
581 let result = connection_result?
582 .send_packed_commands(cmd, offset, count)
583 .await;
584 reconnect_if_dropped!(self, &result, guard);
585 result
586 }
587
588 async fn update_subscription_tracker(
589 &self,
590 action: SubscriptionAction,
591 args: impl ToRedisArgs,
592 ) {
593 let Some(subscription_tracker) = &self.0.subscription_tracker else {
594 return;
595 };
596 let mut guard = subscription_tracker.lock().await;
597 guard.update_with_request(action, args.to_redis_args().into_iter());
598 }
599
600 pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
609 check_resp3!(self.0.client.connection_info.redis.protocol);
610 let mut cmd = cmd("SUBSCRIBE");
611 cmd.arg(&channel_name);
612 cmd.exec_async(self).await?;
613 self.update_subscription_tracker(SubscriptionAction::Subscribe, channel_name)
614 .await;
615
616 Ok(())
617 }
618
619 pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
623 check_resp3!(self.0.client.connection_info.redis.protocol);
624 let mut cmd = cmd("UNSUBSCRIBE");
625 cmd.arg(&channel_name);
626 cmd.exec_async(self).await?;
627 self.update_subscription_tracker(SubscriptionAction::Unsubscribe, channel_name)
628 .await;
629 Ok(())
630 }
631
632 pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
641 check_resp3!(self.0.client.connection_info.redis.protocol);
642 let mut cmd = cmd("PSUBSCRIBE");
643 cmd.arg(&channel_pattern);
644 cmd.exec_async(self).await?;
645 self.update_subscription_tracker(SubscriptionAction::PSubscribe, channel_pattern)
646 .await;
647 Ok(())
648 }
649
650 pub async fn punsubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
654 check_resp3!(self.0.client.connection_info.redis.protocol);
655 let mut cmd = cmd("PUNSUBSCRIBE");
656 cmd.arg(&channel_pattern);
657 cmd.exec_async(self).await?;
658 self.update_subscription_tracker(SubscriptionAction::PUnsubscribe, channel_pattern)
659 .await;
660 Ok(())
661 }
662
663 #[cfg(feature = "cache-aio")]
665 #[cfg_attr(docsrs, doc(cfg(feature = "cache-aio")))]
666 pub fn get_cache_statistics(&self) -> Option<crate::caching::CacheStatistics> {
667 self.0.cache_manager.as_ref().map(|cm| cm.statistics())
668 }
669}
670
671impl ConnectionLike for ConnectionManager {
672 fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
673 (async move { self.send_packed_command(cmd).await }).boxed()
674 }
675
676 fn req_packed_commands<'a>(
677 &'a mut self,
678 cmd: &'a crate::Pipeline,
679 offset: usize,
680 count: usize,
681 ) -> RedisFuture<'a, Vec<Value>> {
682 (async move { self.send_packed_commands(cmd, offset, count).await }).boxed()
683 }
684
685 fn get_db(&self) -> i64 {
686 self.0.client.connection_info().redis.db
687 }
688}