1#![deny(unsafe_code)]
136#![warn(
137 clippy::all,
138 clippy::await_holding_lock,
139 clippy::dbg_macro,
140 clippy::debug_assert_with_mut_call,
141 clippy::doc_markdown,
142 clippy::empty_enum,
143 clippy::enum_glob_use,
144 clippy::exit,
145 clippy::explicit_into_iter_loop,
146 clippy::filter_map_next,
147 clippy::fn_params_excessive_bools,
148 clippy::if_let_mutex,
149 clippy::imprecise_flops,
150 clippy::inefficient_to_string,
151 clippy::large_types_passed_by_value,
152 clippy::let_unit_value,
153 clippy::linkedlist,
154 clippy::lossy_float_literal,
155 clippy::macro_use_imports,
156 clippy::map_err_ignore,
157 clippy::map_flatten,
158 clippy::map_unwrap_or,
159 clippy::match_on_vec_items,
160 clippy::match_same_arms,
161 clippy::match_wildcard_for_single_variants,
162 clippy::mem_forget,
163 clippy::needless_borrow,
164 clippy::needless_continue,
165 clippy::option_option,
166 clippy::ref_option_ref,
167 clippy::rest_pat_in_fully_bound_structs,
168 clippy::string_add_assign,
169 clippy::string_add,
170 clippy::string_to_string,
171 clippy::suboptimal_flops,
172 clippy::todo,
173 clippy::unimplemented,
174 clippy::unnested_or_patterns,
175 clippy::unused_self,
176 clippy::verbose_file_reads,
177 unexpected_cfgs,
178 future_incompatible,
179 nonstandard_style,
180 rust_2018_idioms
181)]
182#![warn(missing_docs)]
184#![deny(rustdoc::broken_intra_doc_links)]
185
186use backoff_strategies::{
187 BackoffStrategy, ExponentialBackoff, FixedBackoff, LinearBackoff, NoBackoff,
188};
189use pin_project_lite::pin_project;
190use std::time::Duration;
191use std::{
192 fmt,
193 future::Future,
194 pin::Pin,
195 task::{Context, Poll},
196};
197
198mod on_retry;
199
200pub mod backoff_strategies;
201
202pub use on_retry::{NoOnRetry, OnRetry};
203
204pub fn retry_fn<F>(f: F) -> RetryFn<F> {
206 RetryFn { f }
207}
208
209#[derive(Debug)]
211pub struct RetryFn<F> {
212 f: F,
213}
214
215impl<F, Fut, T, E> RetryFn<F>
216where
217 F: FnMut() -> Fut,
218 Fut: Future<Output = Result<T, E>>,
219{
220 pub fn retries(self, max_retries: u32) -> RetryFuture<F, Fut, NoBackoff, NoOnRetry> {
222 self.with_config(RetryFutureConfig::new(max_retries))
223 }
224
225 pub fn with_config<BackoffT, OnRetryT>(
227 self,
228 config: RetryFutureConfig<BackoffT, OnRetryT>,
229 ) -> RetryFuture<F, Fut, BackoffT, OnRetryT> {
230 RetryFuture {
231 make_future: self.f,
232 attempts_remaining: config.max_retries,
233 state: RetryState::NotStarted,
234 attempt: 0,
235 config,
236 }
237 }
238}
239
240pin_project! {
241 pub struct RetryFuture<MakeFutureT, FutureT, BackoffT, OnRetryT> {
245 make_future: MakeFutureT,
246 attempts_remaining: u32,
247 #[pin]
248 state: RetryState<FutureT>,
249 attempt: u32,
250 config: RetryFutureConfig<BackoffT, OnRetryT>,
251 }
252}
253
254impl<MakeFutureT, FutureT, BackoffT, T, E, OnRetryT>
255 RetryFuture<MakeFutureT, FutureT, BackoffT, OnRetryT>
256where
257 MakeFutureT: FnMut() -> FutureT,
258 FutureT: Future<Output = Result<T, E>>,
259{
260 #[inline]
262 pub fn max_delay(mut self, delay: Duration) -> Self {
263 self.config = self.config.max_delay(delay);
264 self
265 }
266
267 #[inline]
271 pub fn no_backoff(self) -> RetryFuture<MakeFutureT, FutureT, NoBackoff, OnRetryT> {
272 self.custom_backoff(NoBackoff)
273 }
274
275 #[inline]
279 pub fn exponential_backoff(
280 self,
281 initial_delay: Duration,
282 ) -> RetryFuture<MakeFutureT, FutureT, ExponentialBackoff, OnRetryT> {
283 self.custom_backoff(ExponentialBackoff {
284 delay: initial_delay,
285 })
286 }
287
288 #[inline]
292 pub fn fixed_backoff(
293 self,
294 delay: Duration,
295 ) -> RetryFuture<MakeFutureT, FutureT, FixedBackoff, OnRetryT> {
296 self.custom_backoff(FixedBackoff { delay })
297 }
298
299 #[inline]
303 pub fn linear_backoff(
304 self,
305 delay: Duration,
306 ) -> RetryFuture<MakeFutureT, FutureT, LinearBackoff, OnRetryT> {
307 self.custom_backoff(LinearBackoff { delay })
308 }
309
310 #[inline]
361 pub fn custom_backoff<B>(
362 self,
363 backoff_strategy: B,
364 ) -> RetryFuture<MakeFutureT, FutureT, B, OnRetryT>
365 where
366 for<'a> B: BackoffStrategy<'a, E>,
367 {
368 RetryFuture {
369 make_future: self.make_future,
370 attempts_remaining: self.attempts_remaining,
371 state: self.state,
372 attempt: self.attempt,
373 config: self.config.custom_backoff(backoff_strategy),
374 }
375 }
376
377 #[inline]
416 pub fn on_retry<F, OnRetryFut>(self, f: F) -> RetryFuture<MakeFutureT, FutureT, BackoffT, F>
417 where
418 F: Fn(u32, Option<Duration>, &E) -> OnRetryFut,
419 {
420 RetryFuture {
421 make_future: self.make_future,
422 attempts_remaining: self.attempts_remaining,
423 state: self.state,
424 attempt: self.attempt,
425 config: self.config.on_retry(f),
426 }
427 }
428}
429
430#[derive(Clone, Copy, PartialEq, Eq)]
434pub struct RetryFutureConfig<BackoffT, OnRetryT> {
435 backoff_strategy: BackoffT,
436 max_delay: Option<Duration>,
437 on_retry: Option<OnRetryT>,
438 max_retries: u32,
439}
440
441impl RetryFutureConfig<NoBackoff, NoOnRetry> {
442 pub fn new(max_retries: u32) -> Self {
444 Self {
445 backoff_strategy: NoBackoff,
446 max_delay: None,
447 on_retry: None::<NoOnRetry>,
448 max_retries,
449 }
450 }
451}
452
453impl<BackoffT, OnRetryT> RetryFutureConfig<BackoffT, OnRetryT> {
454 #[inline]
456 pub fn max_delay(mut self, delay: Duration) -> Self {
457 self.max_delay = Some(delay);
458 self
459 }
460
461 #[inline]
465 pub fn no_backoff(self) -> RetryFutureConfig<NoBackoff, OnRetryT> {
466 self.custom_backoff(NoBackoff)
467 }
468
469 #[inline]
473 pub fn exponential_backoff(
474 self,
475 initial_delay: Duration,
476 ) -> RetryFutureConfig<ExponentialBackoff, OnRetryT> {
477 self.custom_backoff(ExponentialBackoff {
478 delay: initial_delay,
479 })
480 }
481
482 #[inline]
486 pub fn fixed_backoff(self, delay: Duration) -> RetryFutureConfig<FixedBackoff, OnRetryT> {
487 self.custom_backoff(FixedBackoff { delay })
488 }
489
490 #[inline]
494 pub fn linear_backoff(self, delay: Duration) -> RetryFutureConfig<LinearBackoff, OnRetryT> {
495 self.custom_backoff(LinearBackoff { delay })
496 }
497
498 #[inline]
502 pub fn custom_backoff<B>(self, backoff_strategy: B) -> RetryFutureConfig<B, OnRetryT> {
503 RetryFutureConfig {
504 backoff_strategy,
505 max_delay: self.max_delay,
506 max_retries: self.max_retries,
507 on_retry: self.on_retry,
508 }
509 }
510
511 #[inline]
515 pub fn on_retry<F>(self, f: F) -> RetryFutureConfig<BackoffT, F> {
516 RetryFutureConfig {
517 backoff_strategy: self.backoff_strategy,
518 max_delay: self.max_delay,
519 max_retries: self.max_retries,
520 on_retry: Some(f),
521 }
522 }
523}
524
525impl<BackoffT, OnRetryT> fmt::Debug for RetryFutureConfig<BackoffT, OnRetryT>
526where
527 BackoffT: fmt::Debug,
528{
529 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
530 f.debug_struct("RetryFutureConfig")
531 .field("backoff_strategy", &self.backoff_strategy)
532 .field("max_delay", &self.max_delay)
533 .field("max_retries", &self.max_retries)
534 .field(
535 "on_retry",
536 &format_args!("<{}>", std::any::type_name::<OnRetryT>()),
537 )
538 .finish()
539 }
540}
541
542pin_project! {
543 #[project = RetryStateProj]
544 #[allow(clippy::large_enum_variant)]
545 enum RetryState<F> {
546 NotStarted,
547 WaitingForFuture { #[pin] future: F },
548 TimerActive { #[pin] sleep: tokio::time::Sleep },
549 }
550}
551
552impl<F, Fut, B, T, E, OnRetryT> Future for RetryFuture<F, Fut, B, OnRetryT>
553where
554 F: FnMut() -> Fut,
555 Fut: Future<Output = Result<T, E>>,
556 for<'a> B: BackoffStrategy<'a, E>,
557 for<'a> <B as BackoffStrategy<'a, E>>::Output: Into<RetryPolicy>,
558 OnRetryT: OnRetry<E>,
559{
560 type Output = Result<T, E>;
561
562 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
563 loop {
564 let this = self.as_mut().project();
565
566 let new_state = match this.state.project() {
567 RetryStateProj::NotStarted => RetryState::WaitingForFuture {
568 future: (this.make_future)(),
569 },
570
571 RetryStateProj::TimerActive { sleep } => match sleep.poll(cx) {
572 Poll::Ready(()) => RetryState::WaitingForFuture {
573 future: (this.make_future)(),
574 },
575 Poll::Pending => return Poll::Pending,
576 },
577
578 RetryStateProj::WaitingForFuture { future } => match future.poll(cx) {
579 Poll::Pending => return Poll::Pending,
580 Poll::Ready(Ok(value)) => {
581 return Poll::Ready(Ok(value));
582 }
583 Poll::Ready(Err(error)) => {
584 if *this.attempts_remaining == 0 {
585 if let Some(on_retry) = &mut this.config.on_retry {
586 tokio::spawn(on_retry.on_retry(*this.attempt, None, &error));
587 }
588
589 return Poll::Ready(Err(error));
590 } else {
591 *this.attempt += 1;
592 *this.attempts_remaining -= 1;
593
594 let delay: RetryPolicy = this
595 .config
596 .backoff_strategy
597 .delay(*this.attempt, &error)
598 .into();
599 let mut delay_duration = match delay {
600 RetryPolicy::Delay(duration) => duration,
601 RetryPolicy::Break => {
602 if let Some(on_retry) = &mut this.config.on_retry {
603 tokio::spawn(on_retry.on_retry(
604 *this.attempt,
605 None,
606 &error,
607 ));
608 }
609
610 return Poll::Ready(Err(error));
611 }
612 };
613
614 if let Some(max_delay) = this.config.max_delay {
615 delay_duration = delay_duration.min(max_delay);
616 }
617
618 if let Some(on_retry) = &mut this.config.on_retry {
619 tokio::spawn(on_retry.on_retry(
620 *this.attempt,
621 Some(delay_duration),
622 &error,
623 ));
624 }
625
626 let sleep = tokio::time::sleep(delay_duration);
627
628 RetryState::TimerActive { sleep }
629 }
630 }
631 },
632 };
633
634 self.as_mut().project().state.set(new_state);
635 }
636 }
637}
638
639#[derive(Debug, Eq, PartialEq, Clone)]
643pub enum RetryPolicy {
644 Delay(Duration),
646
647 Break,
649}
650
651impl From<Duration> for RetryPolicy {
652 fn from(duration: Duration) -> Self {
653 RetryPolicy::Delay(duration)
654 }
655}
656
657#[cfg(test)]
658mod tests {
659 use super::*;
660 use std::sync::atomic::{AtomicUsize, Ordering};
661 use std::sync::Arc;
662 use std::{convert::Infallible, time::Instant};
663
664 #[tokio::test]
665 async fn succeed() {
666 retry_fn(|| async { Ok::<_, Infallible>(true) })
667 .retries(10)
668 .await
669 .unwrap();
670 }
671
672 #[tokio::test]
673 async fn retrying_correct_amount_of_times() {
674 let counter = AtomicUsize::new(0);
675
676 let err = retry_fn(|| async {
677 counter.fetch_add(1, Ordering::SeqCst);
678 Err::<Infallible, _>("error")
679 })
680 .retries(10)
681 .await
682 .unwrap_err();
683
684 assert_eq!(err, "error");
685 assert_eq!(counter.load(Ordering::Relaxed), 11);
686 }
687
688 #[tokio::test]
689 async fn retry_0_times() {
690 let counter = AtomicUsize::new(0);
691
692 retry_fn(|| async {
693 counter.fetch_add(1, Ordering::SeqCst);
694 Err::<Infallible, _>("error")
695 })
696 .retries(0)
697 .await
698 .unwrap_err();
699
700 assert_eq!(counter.load(Ordering::Relaxed), 1);
701 }
702
703 #[tokio::test]
704 async fn the_backoff_strategy_gets_used() {
705 async fn make_future() -> Result<Infallible, &'static str> {
706 Err("foo")
707 }
708
709 let start = Instant::now();
710 retry_fn(make_future)
711 .retries(10)
712 .no_backoff()
713 .await
714 .unwrap_err();
715 let time_with_none = start.elapsed();
716
717 let start = Instant::now();
718 retry_fn(make_future)
719 .retries(10)
720 .fixed_backoff(Duration::from_millis(10))
721 .await
722 .unwrap_err();
723 let time_with_fixed = start.elapsed();
724
725 assert!(time_with_fixed >= time_with_none);
728 }
729
730 #[test]
733 fn is_send() {
734 fn assert_send<T: Send>(_: T) {}
735 async fn some_future() -> Result<(), Infallible> {
736 Ok(())
737 }
738 assert_send(retry_fn(some_future).retries(10));
739 }
740
741 #[tokio::test]
742 async fn stop_retrying() {
743 let mut n = 0;
744 let make_future = || {
745 n += 1;
746 if n == 8 {
747 panic!("retried too many times");
748 }
749 async { Err::<Infallible, _>("foo") }
750 };
751
752 let error = retry_fn(make_future)
753 .retries(10)
754 .custom_backoff(|n, _: &&'static str| {
755 if n >= 3 {
756 RetryPolicy::Break
757 } else {
758 RetryPolicy::Delay(Duration::from_nanos(10))
759 }
760 })
761 .await
762 .unwrap_err();
763
764 assert_eq!(error, "foo");
765 }
766
767 #[tokio::test]
768 async fn custom_returning_duration() {
769 retry_fn(|| async { Ok::<_, Infallible>(true) })
770 .retries(10)
771 .custom_backoff(|_, _: &Infallible| Duration::from_nanos(10))
772 .await
773 .unwrap();
774 }
775
776 #[tokio::test]
777 async fn retry_hook_succeed() {
778 use std::sync::Arc;
779 use tokio::sync::Mutex;
780
781 let errors = Arc::new(Mutex::new(Vec::new()));
782
783 retry_fn(|| async { Err::<Infallible, String>("error".to_string()) })
784 .retries(10)
785 .on_retry(|attempt, next_delay, error: &String| {
786 let errors = Arc::clone(&errors);
787 let error = error.clone();
788 async move {
789 errors.lock().await.push((attempt, next_delay, error));
790 }
791 })
792 .await
793 .unwrap_err();
794
795 let errors = errors.lock().await;
796 assert_eq!(errors.len(), 10);
797 for n in 1_u32..=10 {
798 assert_eq!(
799 &errors[(n - 1) as usize],
800 &(n, Some(Duration::new(0, 0)), "error".to_string())
801 );
802 }
803 }
804
805 #[tokio::test]
806 async fn reusing_the_config() {
807 let counter = Arc::new(AtomicUsize::new(0));
808
809 let config = RetryFutureConfig::new(10)
810 .linear_backoff(Duration::from_millis(10))
811 .on_retry(|_, _, _: &&'static str| {
812 let counter = Arc::clone(&counter);
813 async move {
814 counter.fetch_add(1, Ordering::SeqCst);
815 }
816 });
817
818 let ok_value = retry_fn(|| async { Ok::<_, &str>(true) })
819 .with_config(config)
820 .await
821 .unwrap();
822 assert!(ok_value);
823 assert_eq!(counter.load(Ordering::SeqCst), 0);
824
825 let err_value = retry_fn(|| async { Err::<(), _>("foo") })
826 .with_config(config)
827 .await
828 .unwrap_err();
829 assert_eq!(err_value, "foo");
830 assert_eq!(counter.load(Ordering::SeqCst), 10);
831 }
832
833 #[tokio::test]
834 async fn custom_backoff_wrapping_another_strategy() {
835 #[derive(Clone)]
836 struct MyBackoffStrategy {
837 inner: ExponentialBackoff,
838 }
839
840 impl<'a> BackoffStrategy<'a, std::io::Error> for MyBackoffStrategy {
841 type Output = RetryPolicy;
842
843 fn delay(&mut self, attempt: u32, error: &'a std::io::Error) -> Self::Output {
844 if error.kind() == std::io::ErrorKind::NotFound {
845 RetryPolicy::Break
846 } else {
847 RetryPolicy::Delay(self.inner.delay(attempt, error))
848 }
849 }
850 }
851
852 #[derive(Clone)]
853 struct MyOnRetry;
854
855 impl OnRetry<std::io::Error> for MyOnRetry {
856 type Future = futures::future::BoxFuture<'static, ()>;
857
858 fn on_retry(
859 &mut self,
860 attempt: u32,
861 next_delay: Option<Duration>,
862 previous_error: &std::io::Error,
863 ) -> Self::Future {
864 let previous_error = previous_error.to_string();
865 Box::pin(async move {
866 println!("{} {:?} {}", attempt, next_delay, previous_error);
867 })
868 }
869 }
870
871 let config: RetryFutureConfig<MyBackoffStrategy, MyOnRetry> = RetryFutureConfig::new(10)
872 .custom_backoff(MyBackoffStrategy {
873 inner: ExponentialBackoff::new(Duration::from_millis(10)),
874 })
875 .on_retry(MyOnRetry);
876
877 retry_fn(|| async { Ok::<_, std::io::Error>(true) })
878 .with_config(config.clone())
879 .await
880 .unwrap();
881
882 retry_fn(|| async { Ok::<_, std::io::Error>(true) })
883 .with_config(config)
884 .await
885 .unwrap();
886 }
887
888 #[tokio::test]
889 async fn inference_works() {
890 std::mem::drop(async {
892 let _ = retry_fn(|| async { Result::<_, Infallible>::Ok(()) })
893 .retries(0)
894 .on_retry(|_, _, _| async {})
895 .await;
896 });
897 }
898}