use std::future::Future;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::time::timeout;
use tokio_io_timeout::TimeoutStream;
use hyper::client::connect::{Connected, Connection};
use hyper::{service::Service, Uri};
mod stream;
use stream::TimeoutConnectorStream;
type BoxError = Box<dyn std::error::Error + Send + Sync>;
#[derive(Debug, Clone)]
pub struct TimeoutConnector<T> {
connector: T,
connect_timeout: Option<Duration>,
read_timeout: Option<Duration>,
write_timeout: Option<Duration>,
}
impl<T> TimeoutConnector<T>
where
T: Service<Uri> + Send,
T::Response: AsyncRead + AsyncWrite + Send + Unpin,
T::Future: Send + 'static,
T::Error: Into<BoxError>,
{
pub fn new(connector: T) -> Self {
TimeoutConnector {
connector,
connect_timeout: None,
read_timeout: None,
write_timeout: None,
}
}
}
impl<T> Service<Uri> for TimeoutConnector<T>
where
T: Service<Uri> + Send,
T::Response: AsyncRead + AsyncWrite + Connection + Send + Unpin,
T::Future: Send + 'static,
T::Error: Into<BoxError>,
{
type Response = Pin<Box<TimeoutConnectorStream<T::Response>>>;
type Error = BoxError;
#[allow(clippy::type_complexity)]
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.connector.poll_ready(cx).map_err(Into::into)
}
fn call(&mut self, dst: Uri) -> Self::Future {
let connect_timeout = self.connect_timeout;
let read_timeout = self.read_timeout;
let write_timeout = self.write_timeout;
let connecting = self.connector.call(dst);
let fut = async move {
let stream = match connect_timeout {
None => {
let io = connecting.await.map_err(Into::into)?;
TimeoutStream::new(io)
}
Some(connect_timeout) => {
let timeout = timeout(connect_timeout, connecting);
let connecting = timeout
.await
.map_err(|e| io::Error::new(io::ErrorKind::TimedOut, e))?;
let io = connecting.map_err(Into::into)?;
TimeoutStream::new(io)
}
};
let mut tm = TimeoutConnectorStream::new(stream);
tm.set_read_timeout(read_timeout);
tm.set_write_timeout(write_timeout);
Ok(Box::pin(tm))
};
Box::pin(fut)
}
}
impl<T> TimeoutConnector<T> {
#[inline]
pub fn set_connect_timeout(&mut self, val: Option<Duration>) {
self.connect_timeout = val;
}
#[inline]
pub fn set_read_timeout(&mut self, val: Option<Duration>) {
self.read_timeout = val;
}
#[inline]
pub fn set_write_timeout(&mut self, val: Option<Duration>) {
self.write_timeout = val;
}
}
impl<T> Connection for TimeoutConnector<T>
where
T: AsyncRead + AsyncWrite + Connection + Service<Uri> + Send + Unpin,
T::Response: AsyncRead + AsyncWrite + Send + Unpin,
T::Future: Send + 'static,
T::Error: Into<BoxError>,
{
fn connected(&self) -> Connected {
self.connector.connected()
}
}
#[cfg(test)]
mod tests {
use std::error::Error;
use std::io;
use std::time::Duration;
use hyper::client::HttpConnector;
use hyper::Client;
use super::TimeoutConnector;
#[tokio::test]
async fn test_timeout_connector() {
let url = "http://10.255.255.1".parse().unwrap();
let http = HttpConnector::new();
let mut connector = TimeoutConnector::new(http);
connector.set_connect_timeout(Some(Duration::from_millis(1)));
let client = Client::builder().build::<_, hyper::Body>(connector);
let res = client.get(url).await;
match res {
Ok(_) => panic!("Expected a timeout"),
Err(e) => {
if let Some(io_e) = e.source().unwrap().downcast_ref::<io::Error>() {
assert_eq!(io_e.kind(), io::ErrorKind::TimedOut);
} else {
panic!("Expected timeout error");
}
}
}
}
#[tokio::test]
async fn test_read_timeout() {
let url = "http://example.com".parse().unwrap();
let http = HttpConnector::new();
let mut connector = TimeoutConnector::new(http);
connector.set_read_timeout(Some(Duration::from_millis(1)));
let client = Client::builder().build::<_, hyper::Body>(connector);
let res = client.get(url).await;
match res {
Ok(_) => panic!("Expected a timeout"),
Err(e) => {
if let Some(io_e) = e.source().unwrap().downcast_ref::<io::Error>() {
assert_eq!(io_e.kind(), io::ErrorKind::TimedOut);
} else {
panic!("Expected timeout error");
}
}
}
}
}