use super::{ConnectionLike, Runtime};
use crate::aio::setup_connection;
use crate::cmd::Cmd;
use crate::connection::RedisConnectionInfo;
#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))]
use crate::parser::ValueCodec;
use crate::types::{RedisError, RedisFuture, RedisResult, Value};
use ::tokio::{
io::{AsyncRead, AsyncWrite},
sync::{mpsc, oneshot},
};
use futures_util::{
future::{Future, FutureExt},
ready,
sink::Sink,
stream::{self, Stream, StreamExt, TryStreamExt as _},
};
use pin_project_lite::pin_project;
use std::collections::VecDeque;
use std::fmt;
use std::fmt::Debug;
use std::io;
use std::pin::Pin;
use std::task::{self, Poll};
use std::time::Duration;
#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))]
use tokio_util::codec::Decoder;
type PipelineOutput = oneshot::Sender<RedisResult<Value>>;
enum ResponseAggregate {
SingleCommand,
Pipeline {
expected_response_count: usize,
current_response_count: usize,
buffer: Vec<Value>,
first_err: Option<RedisError>,
},
}
impl ResponseAggregate {
fn new(pipeline_response_count: Option<usize>) -> Self {
match pipeline_response_count {
Some(response_count) => ResponseAggregate::Pipeline {
expected_response_count: response_count,
current_response_count: 0,
buffer: Vec::new(),
first_err: None,
},
None => ResponseAggregate::SingleCommand,
}
}
}
struct InFlight {
output: PipelineOutput,
response_aggregate: ResponseAggregate,
}
struct PipelineMessage<S> {
input: S,
output: PipelineOutput,
pipeline_response_count: Option<usize>,
}
struct Pipeline<SinkItem>(mpsc::Sender<PipelineMessage<SinkItem>>);
impl<SinkItem> Clone for Pipeline<SinkItem> {
fn clone(&self) -> Self {
Pipeline(self.0.clone())
}
}
impl<SinkItem> Debug for Pipeline<SinkItem>
where
SinkItem: Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("Pipeline").field(&self.0).finish()
}
}
pin_project! {
struct PipelineSink<T> {
#[pin]
sink_stream: T,
in_flight: VecDeque<InFlight>,
error: Option<RedisError>,
}
}
impl<T> PipelineSink<T>
where
T: Stream<Item = RedisResult<Value>> + 'static,
{
fn new<SinkItem>(sink_stream: T) -> Self
where
T: Sink<SinkItem, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
{
PipelineSink {
sink_stream,
in_flight: VecDeque::new(),
error: None,
}
}
fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Result<(), ()>> {
loop {
if self.in_flight.is_empty() {
return Poll::Ready(Ok(()));
}
let item = match ready!(self.as_mut().project().sink_stream.poll_next(cx)) {
Some(result) => result,
None => return Poll::Ready(Err(())),
};
self.as_mut().send_result(item);
}
}
fn send_result(self: Pin<&mut Self>, result: RedisResult<Value>) {
let self_ = self.project();
{
let mut entry = match self_.in_flight.pop_front() {
Some(entry) => entry,
None => return,
};
match &mut entry.response_aggregate {
ResponseAggregate::SingleCommand => {
entry.output.send(result).ok();
}
ResponseAggregate::Pipeline {
expected_response_count,
current_response_count,
buffer,
first_err,
} => {
match result {
Ok(item) => {
buffer.push(item);
}
Err(err) => {
if first_err.is_none() {
*first_err = Some(err);
}
}
}
*current_response_count += 1;
if current_response_count < expected_response_count {
self_.in_flight.push_front(entry);
return;
}
let response = match first_err.take() {
Some(err) => Err(err),
None => Ok(Value::Bulk(std::mem::take(buffer))),
};
entry.output.send(response).ok();
}
}
}
}
}
impl<SinkItem, T> Sink<PipelineMessage<SinkItem>> for PipelineSink<T>
where
T: Sink<SinkItem, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
{
type Error = ();
fn poll_ready(
mut self: Pin<&mut Self>,
cx: &mut task::Context,
) -> Poll<Result<(), Self::Error>> {
match ready!(self.as_mut().project().sink_stream.poll_ready(cx)) {
Ok(()) => Ok(()).into(),
Err(err) => {
*self.project().error = Some(err);
Ok(()).into()
}
}
}
fn start_send(
mut self: Pin<&mut Self>,
PipelineMessage {
input,
output,
pipeline_response_count,
}: PipelineMessage<SinkItem>,
) -> Result<(), Self::Error> {
if output.is_closed() {
return Ok(());
}
let self_ = self.as_mut().project();
if let Some(err) = self_.error.take() {
let _ = output.send(Err(err));
return Err(());
}
match self_.sink_stream.start_send(input) {
Ok(()) => {
let response_aggregate = ResponseAggregate::new(pipeline_response_count);
let entry = InFlight {
output,
response_aggregate,
};
self_.in_flight.push_back(entry);
Ok(())
}
Err(err) => {
let _ = output.send(Err(err));
Err(())
}
}
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut task::Context,
) -> Poll<Result<(), Self::Error>> {
ready!(self
.as_mut()
.project()
.sink_stream
.poll_flush(cx)
.map_err(|err| {
self.as_mut().send_result(Err(err));
}))?;
self.poll_read(cx)
}
fn poll_close(
mut self: Pin<&mut Self>,
cx: &mut task::Context,
) -> Poll<Result<(), Self::Error>> {
if !self.in_flight.is_empty() {
ready!(self.as_mut().poll_flush(cx))?;
}
let this = self.as_mut().project();
this.sink_stream.poll_close(cx).map_err(|err| {
self.send_result(Err(err));
})
}
}
impl<SinkItem> Pipeline<SinkItem>
where
SinkItem: Send + 'static,
{
fn new<T>(sink_stream: T) -> (Self, impl Future<Output = ()>)
where
T: Sink<SinkItem, Error = RedisError> + Stream<Item = RedisResult<Value>> + 'static,
T: Send + 'static,
T::Item: Send,
T::Error: Send,
T::Error: ::std::fmt::Debug,
{
const BUFFER_SIZE: usize = 50;
let (sender, mut receiver) = mpsc::channel(BUFFER_SIZE);
let f = stream::poll_fn(move |cx| receiver.poll_recv(cx))
.map(Ok)
.forward(PipelineSink::new::<SinkItem>(sink_stream))
.map(|_| ());
(Pipeline(sender), f)
}
async fn send_single(
&mut self,
item: SinkItem,
timeout: Duration,
) -> Result<Value, Option<RedisError>> {
self.send_recv(item, None, timeout).await
}
async fn send_recv(
&mut self,
input: SinkItem,
pipeline_response_count: Option<usize>,
timeout: Duration,
) -> Result<Value, Option<RedisError>> {
let (sender, receiver) = oneshot::channel();
self.0
.send(PipelineMessage {
input,
pipeline_response_count,
output: sender,
})
.await
.map_err(|_| None)?;
match Runtime::locate().timeout(timeout, receiver).await {
Ok(Ok(result)) => result.map_err(Some),
Ok(Err(_)) => {
Err(None)
}
Err(elapsed) => Err(Some(elapsed.into())),
}
}
}
#[derive(Clone)]
pub struct MultiplexedConnection {
pipeline: Pipeline<Vec<u8>>,
db: i64,
response_timeout: Duration,
}
impl Debug for MultiplexedConnection {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MultiplexedConnection")
.field("pipeline", &self.pipeline)
.field("db", &self.db)
.finish()
}
}
impl MultiplexedConnection {
pub async fn new<C>(
connection_info: &RedisConnectionInfo,
stream: C,
) -> RedisResult<(Self, impl Future<Output = ()>)>
where
C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
{
Self::new_with_response_timeout(connection_info, stream, std::time::Duration::MAX).await
}
pub async fn new_with_response_timeout<C>(
connection_info: &RedisConnectionInfo,
stream: C,
response_timeout: std::time::Duration,
) -> RedisResult<(Self, impl Future<Output = ()>)>
where
C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
{
fn boxed(
f: impl Future<Output = ()> + Send + 'static,
) -> Pin<Box<dyn Future<Output = ()> + Send>> {
Box::pin(f)
}
#[cfg(all(not(feature = "tokio-comp"), not(feature = "async-std-comp")))]
compile_error!("tokio-comp or async-std-comp features required for aio feature");
let codec = ValueCodec::default()
.framed(stream)
.and_then(|msg| async move { msg });
let (pipeline, driver) = Pipeline::new(codec);
let driver = boxed(driver);
let mut con = MultiplexedConnection {
pipeline,
db: connection_info.db,
response_timeout,
};
let driver = {
let auth = setup_connection(connection_info, &mut con);
futures_util::pin_mut!(auth);
match futures_util::future::select(auth, driver).await {
futures_util::future::Either::Left((result, driver)) => {
result?;
driver
}
futures_util::future::Either::Right(((), _)) => {
return Err(RedisError::from((
crate::ErrorKind::IoError,
"Multiplexed connection driver unexpectedly terminated",
)));
}
}
};
Ok((con, driver))
}
pub fn set_response_timeout(&mut self, timeout: std::time::Duration) {
self.response_timeout = timeout;
}
pub async fn send_packed_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
self.pipeline
.send_single(cmd.get_packed_command(), self.response_timeout)
.await
.map_err(|err| {
err.unwrap_or_else(|| RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe)))
})
}
pub async fn send_packed_commands(
&mut self,
cmd: &crate::Pipeline,
offset: usize,
count: usize,
) -> RedisResult<Vec<Value>> {
let value = self
.pipeline
.send_recv(
cmd.get_packed_pipeline(),
Some(offset + count),
self.response_timeout,
)
.await
.map_err(|err| {
err.unwrap_or_else(|| RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe)))
})?;
match value {
Value::Bulk(mut values) => {
values.drain(..offset);
Ok(values)
}
_ => Ok(vec![value]),
}
}
}
impl ConnectionLike for MultiplexedConnection {
fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
(async move { self.send_packed_command(cmd).await }).boxed()
}
fn req_packed_commands<'a>(
&'a mut self,
cmd: &'a crate::Pipeline,
offset: usize,
count: usize,
) -> RedisFuture<'a, Vec<Value>> {
(async move { self.send_packed_commands(cmd, offset, count).await }).boxed()
}
fn get_db(&self) -> i64 {
self.db
}
}