use std::{
io::{self, Read},
str,
};
use crate::types::{
ErrorKind, InternalValue, RedisError, RedisResult, ServerError, ServerErrorKind, Value,
};
use combine::{
any,
error::StreamError,
opaque,
parser::{
byte::{crlf, take_until_bytes},
combinator::{any_send_sync_partial_state, AnySendSyncPartialState},
range::{recognize, take},
},
stream::{PointerOffset, RangeStream, StreamErrorFor},
ParseError, Parser as _,
};
const MAX_RECURSE_DEPTH: usize = 100;
fn value<'a, I>(
count: Option<usize>,
) -> impl combine::Parser<I, Output = InternalValue, PartialState = AnySendSyncPartialState>
where
I: RangeStream<Token = u8, Range = &'a [u8]>,
I::Error: combine::ParseError<u8, &'a [u8], I::Position>,
{
let count = count.unwrap_or(1);
opaque!(any_send_sync_partial_state(
any()
.then_partial(move |&mut b| {
if b == b'*' && count > MAX_RECURSE_DEPTH {
combine::unexpected_any("Maximum recursion depth exceeded").left()
} else {
combine::value(b).right()
}
})
.then_partial(move |&mut b| {
let line = || {
recognize(take_until_bytes(&b"\r\n"[..]).with(take(2).map(|_| ()))).and_then(
|line: &[u8]| {
str::from_utf8(&line[..line.len() - 2])
.map_err(StreamErrorFor::<I>::other)
},
)
};
let status = || {
line().map(|line| {
if line == "OK" {
InternalValue::Okay
} else {
InternalValue::Status(line.into())
}
})
};
let int = || {
line().and_then(|line| match line.trim().parse::<i64>() {
Err(_) => Err(StreamErrorFor::<I>::message_static_message(
"Expected integer, got garbage",
)),
Ok(value) => Ok(value),
})
};
let data = || {
int().then_partial(move |size| {
if *size < 0 {
combine::produce(|| InternalValue::Nil).left()
} else {
take(*size as usize)
.map(|bs: &[u8]| InternalValue::Data(bs.to_vec()))
.skip(crlf())
.right()
}
})
};
let bulk = || {
int().then_partial(move |&mut length| {
if length < 0 {
combine::produce(|| InternalValue::Nil).left()
} else {
let length = length as usize;
combine::count_min_max(length, length, value(Some(count + 1)))
.map(InternalValue::Bulk)
.right()
}
})
};
let error = || {
line().map(|line: &str| {
let mut pieces = line.splitn(2, ' ');
let kind = match pieces.next().unwrap() {
"ERR" => ServerErrorKind::ResponseError,
"EXECABORT" => ServerErrorKind::ExecAbortError,
"LOADING" => ServerErrorKind::BusyLoadingError,
"NOSCRIPT" => ServerErrorKind::NoScriptError,
"MOVED" => ServerErrorKind::Moved,
"ASK" => ServerErrorKind::Ask,
"TRYAGAIN" => ServerErrorKind::TryAgain,
"CLUSTERDOWN" => ServerErrorKind::ClusterDown,
"CROSSSLOT" => ServerErrorKind::CrossSlot,
"MASTERDOWN" => ServerErrorKind::MasterDown,
"READONLY" => ServerErrorKind::ReadOnly,
"NOTBUSY" => ServerErrorKind::NotBusy,
code => {
return ServerError::ExtensionError {
code: code.to_string(),
detail: pieces.next().map(|str| str.to_string()),
}
}
};
let detail = pieces.next().map(|str| str.to_string());
ServerError::KnownError { kind, detail }
})
};
combine::dispatch!(b;
b'+' => status(),
b':' => int().map(InternalValue::Int),
b'$' => data(),
b'*' => bulk(),
b'-' => error().map(InternalValue::ServerError),
b => combine::unexpected_any(combine::error::Token(b))
)
})
))
}
#[cfg(feature = "aio")]
mod aio_support {
use super::*;
use bytes::{Buf, BytesMut};
use tokio::io::AsyncRead;
use tokio_util::codec::{Decoder, Encoder};
#[derive(Default)]
pub struct ValueCodec {
state: AnySendSyncPartialState,
}
impl ValueCodec {
fn decode_stream(
&mut self,
bytes: &mut BytesMut,
eof: bool,
) -> RedisResult<Option<RedisResult<Value>>> {
let (opt, removed_len) = {
let buffer = &bytes[..];
let mut stream =
combine::easy::Stream(combine::stream::MaybePartialStream(buffer, !eof));
match combine::stream::decode_tokio(value(None), &mut stream, &mut self.state) {
Ok(x) => x,
Err(err) => {
let err = err
.map_position(|pos| pos.translate_position(buffer))
.map_range(|range| format!("{range:?}"))
.to_string();
return Err(RedisError::from((
ErrorKind::ParseError,
"parse error",
err,
)));
}
}
};
bytes.advance(removed_len);
match opt {
Some(result) => Ok(Some(result.into())),
None => Ok(None),
}
}
}
impl Encoder<Vec<u8>> for ValueCodec {
type Error = RedisError;
fn encode(&mut self, item: Vec<u8>, dst: &mut BytesMut) -> Result<(), Self::Error> {
dst.extend_from_slice(item.as_ref());
Ok(())
}
}
impl Decoder for ValueCodec {
type Item = RedisResult<Value>;
type Error = RedisError;
fn decode(&mut self, bytes: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
self.decode_stream(bytes, false)
}
fn decode_eof(&mut self, bytes: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
self.decode_stream(bytes, true)
}
}
pub async fn parse_redis_value_async<R>(
decoder: &mut combine::stream::Decoder<AnySendSyncPartialState, PointerOffset<[u8]>>,
read: &mut R,
) -> RedisResult<Value>
where
R: AsyncRead + std::marker::Unpin,
{
let result = combine::decode_tokio!(*decoder, *read, value(None), |input, _| {
combine::stream::easy::Stream::from(input)
});
match result {
Err(err) => Err(match err {
combine::stream::decoder::Error::Io { error, .. } => error.into(),
combine::stream::decoder::Error::Parse(err) => {
if err.is_unexpected_end_of_input() {
RedisError::from(io::Error::from(io::ErrorKind::UnexpectedEof))
} else {
let err = err
.map_range(|range| format!("{range:?}"))
.map_position(|pos| pos.translate_position(decoder.buffer()))
.to_string();
RedisError::from((ErrorKind::ParseError, "parse error", err))
}
}
}),
Ok(result) => result.into(),
}
}
}
#[cfg(feature = "aio")]
#[cfg_attr(docsrs, doc(cfg(feature = "aio")))]
pub use self::aio_support::*;
pub struct Parser {
decoder: combine::stream::decoder::Decoder<AnySendSyncPartialState, PointerOffset<[u8]>>,
}
impl Default for Parser {
fn default() -> Self {
Parser::new()
}
}
impl Parser {
pub fn new() -> Parser {
Parser {
decoder: combine::stream::decoder::Decoder::new(),
}
}
pub fn parse_value<T: Read>(&mut self, mut reader: T) -> RedisResult<Value> {
let mut decoder = &mut self.decoder;
let result = combine::decode!(decoder, reader, value(None), |input, _| {
combine::stream::easy::Stream::from(input)
});
match result {
Err(err) => Err(match err {
combine::stream::decoder::Error::Io { error, .. } => error.into(),
combine::stream::decoder::Error::Parse(err) => {
if err.is_unexpected_end_of_input() {
RedisError::from(io::Error::from(io::ErrorKind::UnexpectedEof))
} else {
let err = err
.map_range(|range| format!("{range:?}"))
.map_position(|pos| pos.translate_position(decoder.buffer()))
.to_string();
RedisError::from((ErrorKind::ParseError, "parse error", err))
}
}
}),
Ok(result) => result.into(),
}
}
}
pub fn parse_redis_value(bytes: &[u8]) -> RedisResult<Value> {
let mut parser = Parser::new();
parser.parse_value(bytes)
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "aio")]
#[test]
fn decode_eof_returns_none_at_eof() {
use tokio_util::codec::Decoder;
let mut codec = ValueCodec::default();
let mut bytes = bytes::BytesMut::from(&b"+GET 123\r\n"[..]);
assert_eq!(
codec.decode_eof(&mut bytes),
Ok(Some(Ok(parse_redis_value(b"+GET 123\r\n").unwrap())))
);
assert_eq!(codec.decode_eof(&mut bytes), Ok(None));
assert_eq!(codec.decode_eof(&mut bytes), Ok(None));
}
#[cfg(feature = "aio")]
#[test]
fn decode_eof_returns_error_inside_array_and_can_parse_more_inputs() {
use tokio_util::codec::Decoder;
let mut codec = ValueCodec::default();
let mut bytes =
bytes::BytesMut::from(b"*3\r\n+OK\r\n-LOADING server is loading\r\n+OK\r\n".as_slice());
let result = codec.decode_eof(&mut bytes).unwrap().unwrap();
assert_eq!(
result,
Err(RedisError::from((
ErrorKind::BusyLoadingError,
"An error was signalled by the server",
"server is loading".to_string()
)))
);
let mut bytes = bytes::BytesMut::from(b"+OK\r\n".as_slice());
let result = codec.decode_eof(&mut bytes).unwrap().unwrap();
assert_eq!(result, Ok(Value::Okay));
}
#[test]
fn parse_nested_error_and_handle_more_inputs() {
let bytes = b"*3\r\n+OK\r\n-LOADING server is loading\r\n+OK\r\n";
let result = parse_redis_value(bytes);
assert_eq!(
result,
Err(RedisError::from((
ErrorKind::BusyLoadingError,
"An error was signalled by the server",
"server is loading".to_string()
)))
);
let result = parse_redis_value(b"+OK\r\n").unwrap();
assert_eq!(result, Value::Okay);
}
#[test]
fn test_max_recursion_depth() {
let bytes = b"*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n";
match parse_redis_value(bytes) {
Ok(_) => panic!("Expected Err"),
Err(e) => assert!(matches!(e.kind(), ErrorKind::ParseError)),
}
}
}