redis/aio/
runtime.rs

1use std::{io, sync::Arc, time::Duration};
2
3use futures_util::Future;
4
5#[cfg(feature = "async-std-comp")]
6use super::async_std as crate_async_std;
7#[cfg(feature = "tokio-comp")]
8use super::tokio as crate_tokio;
9use super::RedisRuntime;
10use crate::types::RedisError;
11
12#[derive(Clone, Debug)]
13pub(crate) enum Runtime {
14    #[cfg(feature = "tokio-comp")]
15    Tokio,
16    #[cfg(feature = "async-std-comp")]
17    AsyncStd,
18}
19
20pub(crate) enum TaskHandle {
21    #[cfg(feature = "tokio-comp")]
22    Tokio(tokio::task::JoinHandle<()>),
23    #[cfg(feature = "async-std-comp")]
24    AsyncStd(async_std::task::JoinHandle<()>),
25}
26
27pub(crate) struct HandleContainer(Option<TaskHandle>);
28
29impl HandleContainer {
30    pub(crate) fn new(handle: TaskHandle) -> Self {
31        Self(Some(handle))
32    }
33}
34
35impl Drop for HandleContainer {
36    fn drop(&mut self) {
37        match self.0.take() {
38            None => {}
39            #[cfg(feature = "tokio-comp")]
40            Some(TaskHandle::Tokio(handle)) => handle.abort(),
41            #[cfg(feature = "async-std-comp")]
42            Some(TaskHandle::AsyncStd(handle)) => {
43                // schedule for cancellation without waiting for result.
44                Runtime::locate().spawn(async move { handle.cancel().await.unwrap_or_default() });
45            }
46        }
47    }
48}
49
50#[derive(Clone)]
51// we allow dead code here because the container isn't used directly, only in the derived drop.
52#[allow(dead_code)]
53pub(crate) struct SharedHandleContainer(Arc<HandleContainer>);
54
55impl SharedHandleContainer {
56    pub(crate) fn new(handle: TaskHandle) -> Self {
57        Self(Arc::new(HandleContainer::new(handle)))
58    }
59}
60
61impl Runtime {
62    pub(crate) fn locate() -> Self {
63        #[cfg(all(feature = "tokio-comp", not(feature = "async-std-comp")))]
64        {
65            Runtime::Tokio
66        }
67
68        #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))]
69        {
70            Runtime::AsyncStd
71        }
72
73        #[cfg(all(feature = "tokio-comp", feature = "async-std-comp"))]
74        {
75            if ::tokio::runtime::Handle::try_current().is_ok() {
76                Runtime::Tokio
77            } else {
78                Runtime::AsyncStd
79            }
80        }
81
82        #[cfg(all(not(feature = "tokio-comp"), not(feature = "async-std-comp")))]
83        {
84            compile_error!("tokio-comp or async-std-comp features required for aio feature")
85        }
86    }
87
88    #[allow(dead_code)]
89    pub(crate) fn spawn(&self, f: impl Future<Output = ()> + Send + 'static) -> TaskHandle {
90        match self {
91            #[cfg(feature = "tokio-comp")]
92            Runtime::Tokio => crate_tokio::Tokio::spawn(f),
93            #[cfg(feature = "async-std-comp")]
94            Runtime::AsyncStd => crate_async_std::AsyncStd::spawn(f),
95        }
96    }
97
98    pub(crate) async fn timeout<F: Future>(
99        &self,
100        duration: Duration,
101        future: F,
102    ) -> Result<F::Output, Elapsed> {
103        match self {
104            #[cfg(feature = "tokio-comp")]
105            Runtime::Tokio => tokio::time::timeout(duration, future)
106                .await
107                .map_err(|_| Elapsed(())),
108            #[cfg(feature = "async-std-comp")]
109            Runtime::AsyncStd => async_std::future::timeout(duration, future)
110                .await
111                .map_err(|_| Elapsed(())),
112        }
113    }
114
115    #[cfg(any(feature = "connection-manager", feature = "cluster-async"))]
116    pub(crate) async fn sleep(&self, duration: Duration) {
117        match self {
118            #[cfg(feature = "tokio-comp")]
119            Runtime::Tokio => {
120                tokio::time::sleep(duration).await;
121            }
122            #[cfg(feature = "async-std-comp")]
123            Runtime::AsyncStd => {
124                async_std::task::sleep(duration).await;
125            }
126        }
127    }
128
129    #[cfg(feature = "cluster-async")]
130    pub(crate) async fn locate_and_sleep(duration: Duration) {
131        Self::locate().sleep(duration).await
132    }
133}
134
135#[derive(Debug)]
136pub(crate) struct Elapsed(());
137
138impl From<Elapsed> for RedisError {
139    fn from(_: Elapsed) -> Self {
140        io::Error::from(io::ErrorKind::TimedOut).into()
141    }
142}