use crate::StatusCode;
use std::borrow::Cow;
use std::fmt::{Debug, Display};
mod http_error;
mod macros;
use crate::headers::{self, Headers};
pub use http_error::HttpError;
use self::http_error::get_error_code_from_header;
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ErrorKind {
HttpResponse {
status: StatusCode,
error_code: Option<String>,
},
Io,
DataConversion,
Credential,
MockFramework,
Other,
}
impl ErrorKind {
pub fn into_error(self) -> Error {
Error {
context: Context::Simple(self),
}
}
pub fn http_response(status: StatusCode, error_code: Option<String>) -> Self {
Self::HttpResponse { status, error_code }
}
pub fn http_response_from_parts(status: StatusCode, headers: &Headers, body: &[u8]) -> Self {
if let Some(header_err_code) = get_error_code_from_header(headers) {
Self::HttpResponse {
status,
error_code: Some(header_err_code),
}
} else {
let (error_code, _) = http_error::get_error_code_message_from_body(
body,
headers.get_optional_str(&headers::CONTENT_TYPE),
);
Self::HttpResponse { status, error_code }
}
}
}
impl Display for ErrorKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ErrorKind::HttpResponse { status, error_code } => {
write!(
f,
"HttpResponse({},{})",
status,
error_code.as_deref().unwrap_or("unknown")
)
}
ErrorKind::Io => write!(f, "Io"),
ErrorKind::DataConversion => write!(f, "DataConversion"),
ErrorKind::Credential => write!(f, "Credential"),
ErrorKind::MockFramework => write!(f, "MockFramework"),
ErrorKind::Other => write!(f, "Other"),
}
}
}
#[derive(Debug)]
pub struct Error {
context: Context,
}
impl Error {
pub fn new<E>(kind: ErrorKind, error: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync>>,
{
Self {
context: Context::Custom(Custom {
kind,
error: error.into(),
}),
}
}
#[must_use]
pub fn full<E, C>(kind: ErrorKind, error: E, message: C) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync>>,
C: Into<Cow<'static, str>>,
{
Self {
context: Context::Full(
Custom {
kind,
error: error.into(),
},
message.into(),
),
}
}
#[must_use]
pub fn message<C>(kind: ErrorKind, message: C) -> Self
where
C: Into<Cow<'static, str>>,
{
Self {
context: Context::Message {
kind,
message: message.into(),
},
}
}
#[must_use]
pub fn with_message<F, C>(kind: ErrorKind, message: F) -> Self
where
Self: Sized,
F: FnOnce() -> C,
C: Into<Cow<'static, str>>,
{
Self {
context: Context::Message {
kind,
message: message().into(),
},
}
}
#[must_use]
pub fn context<C>(self, message: C) -> Self
where
C: Into<Cow<'static, str>>,
{
Self::full(self.kind().clone(), self, message)
}
#[must_use]
pub fn with_context<F, C>(self, f: F) -> Self
where
F: FnOnce() -> C,
C: Into<Cow<'static, str>>,
{
self.context(f())
}
pub fn kind(&self) -> &ErrorKind {
match &self.context {
Context::Simple(kind)
| Context::Message { kind, .. }
| Context::Custom(Custom { kind, .. })
| Context::Full(Custom { kind, .. }, _) => kind,
}
}
pub fn into_inner(self) -> std::result::Result<Box<dyn std::error::Error + Send + Sync>, Self> {
match self.context {
Context::Custom(Custom { error, .. }) | Context::Full(Custom { error, .. }, _) => {
Ok(error)
}
_ => Err(self),
}
}
pub fn into_downcast<T: std::error::Error + 'static>(self) -> std::result::Result<T, Self> {
if self.downcast_ref::<T>().is_none() {
return Err(self);
}
Ok(*self
.into_inner()?
.downcast()
.expect("failed to unwrap downcast"))
}
pub fn get_ref(&self) -> Option<&(dyn std::error::Error + Send + Sync + 'static)> {
match &self.context {
Context::Custom(Custom { error, .. }) | Context::Full(Custom { error, .. }, _) => {
Some(error.as_ref())
}
_ => None,
}
}
pub fn as_http_error(&self) -> Option<&HttpError> {
let mut error = self.get_ref()? as &(dyn std::error::Error);
loop {
match error.downcast_ref::<HttpError>() {
Some(e) => return Some(e),
None => error = error.source()?,
}
}
}
pub fn downcast_ref<T: std::error::Error + 'static>(&self) -> Option<&T> {
self.get_ref()?.downcast_ref()
}
pub fn get_mut(&mut self) -> Option<&mut (dyn std::error::Error + Send + Sync + 'static)> {
match &mut self.context {
Context::Custom(Custom { error, .. }) | Context::Full(Custom { error, .. }, _) => {
Some(error.as_mut())
}
_ => None,
}
}
pub fn downcast_mut<T: std::error::Error + 'static>(&mut self) -> Option<&mut T> {
self.get_mut()?.downcast_mut()
}
}
impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match &self.context {
Context::Custom(Custom { error, .. }) | Context::Full(Custom { error, .. }, _) => {
Some(&**error)
}
_ => None,
}
}
}
impl From<ErrorKind> for Error {
fn from(kind: ErrorKind) -> Self {
Self {
context: Context::Simple(kind),
}
}
}
impl From<std::io::Error> for Error {
fn from(error: std::io::Error) -> Self {
Self::new(ErrorKind::Io, error)
}
}
impl From<base64::DecodeError> for Error {
fn from(error: base64::DecodeError) -> Self {
Self::new(ErrorKind::DataConversion, error)
}
}
impl From<serde_json::Error> for Error {
fn from(error: serde_json::Error) -> Self {
Self::new(ErrorKind::DataConversion, error)
}
}
impl From<std::string::FromUtf8Error> for Error {
fn from(error: std::string::FromUtf8Error) -> Self {
Self::new(ErrorKind::DataConversion, error)
}
}
impl From<std::str::Utf8Error> for Error {
fn from(error: std::str::Utf8Error) -> Self {
Self::new(ErrorKind::DataConversion, error)
}
}
impl From<url::ParseError> for Error {
fn from(error: url::ParseError) -> Self {
Self::new(ErrorKind::DataConversion, error)
}
}
impl Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.context {
Context::Simple(kind) => write!(f, "{kind}"),
Context::Message { message, .. } => write!(f, "{message}"),
Context::Custom(Custom { error, .. }) => write!(f, "{error}"),
Context::Full(_, message) => {
write!(f, "{message}")
}
}
}
}
pub trait ResultExt<T>: private::Sealed {
fn map_kind(self, kind: ErrorKind) -> Result<T>
where
Self: Sized;
fn context<C>(self, kind: ErrorKind, message: C) -> Result<T>
where
Self: Sized,
C: Into<Cow<'static, str>>;
fn with_context<F, C>(self, kind: ErrorKind, f: F) -> Result<T>
where
Self: Sized,
F: FnOnce() -> C,
C: Into<Cow<'static, str>>;
}
mod private {
pub trait Sealed {}
impl<T, E> Sealed for std::result::Result<T, E> where E: std::error::Error + Send + Sync + 'static {}
}
impl<T, E> ResultExt<T> for std::result::Result<T, E>
where
E: std::error::Error + Send + Sync + 'static,
{
fn map_kind(self, kind: ErrorKind) -> Result<T>
where
Self: Sized,
{
self.map_err(|e| Error::new(kind, e))
}
fn context<C>(self, kind: ErrorKind, message: C) -> Result<T>
where
Self: Sized,
C: Into<Cow<'static, str>>,
{
self.map_err(|e| Error {
context: Context::Full(
Custom {
error: Box::new(e),
kind,
},
message.into(),
),
})
}
fn with_context<F, C>(self, kind: ErrorKind, f: F) -> Result<T>
where
Self: Sized,
F: FnOnce() -> C,
C: Into<Cow<'static, str>>,
{
self.context(kind, f())
}
}
#[derive(Debug)]
enum Context {
Simple(ErrorKind),
Message {
kind: ErrorKind,
message: Cow<'static, str>,
},
Custom(Custom),
Full(Custom, Cow<'static, str>),
}
#[derive(Debug)]
struct Custom {
kind: ErrorKind,
error: Box<dyn std::error::Error + Send + Sync>,
}
#[cfg(test)]
mod tests {
use super::*;
use std::io;
#[allow(
dead_code,
unconditional_recursion,
clippy::extra_unused_type_parameters
)]
fn ensure_send<T: Send>() {
ensure_send::<Error>();
}
#[derive(thiserror::Error, Debug)]
enum IntermediateError {
#[error("second error")]
Io(#[from] std::io::Error),
}
fn create_error() -> Error {
let inner = io::Error::new(io::ErrorKind::BrokenPipe, "third error");
let inner: IntermediateError = inner.into();
let inner = io::Error::new(io::ErrorKind::ConnectionAborted, inner);
Error::new(ErrorKind::Io, inner)
}
#[test]
fn errors_display_properly() {
let error = create_error();
let mut error: &dyn std::error::Error = &error;
let display = format!("{error}");
let mut errors = vec![];
while let Some(cause) = error.source() {
errors.push(format!("{cause}"));
error = cause;
}
assert_eq!(display, "second error");
assert_eq!(errors.join(","), "second error,third error");
let inner = io::Error::new(io::ErrorKind::BrokenPipe, "third error");
let error: Result<()> = std::result::Result::<(), std::io::Error>::Err(inner)
.context(ErrorKind::Io, "oh no broken pipe!");
assert_eq!(format!("{}", error.unwrap_err()), "oh no broken pipe!");
}
#[test]
fn downcasting_works() {
let error = &create_error() as &dyn std::error::Error;
assert!(error.is::<Error>());
let downcasted = error
.source()
.unwrap()
.downcast_ref::<std::io::Error>()
.unwrap();
assert_eq!(format!("{downcasted}"), "second error");
}
#[test]
fn turn_into_inner_error() {
let error = create_error();
let inner = error.into_inner().unwrap();
let inner = inner.downcast_ref::<std::io::Error>().unwrap();
assert_eq!(format!("{inner}"), "second error");
let error = create_error();
let inner = error.get_ref().unwrap();
let inner = inner.downcast_ref::<std::io::Error>().unwrap();
assert_eq!(format!("{inner}"), "second error");
let mut error = create_error();
let inner = error.get_mut().unwrap();
let inner = inner.downcast_ref::<std::io::Error>().unwrap();
assert_eq!(format!("{inner}"), "second error");
}
#[test]
fn matching_against_http_error() {
let kind =
ErrorKind::http_response_from_parts(StatusCode::ImATeapot, &Headers::new(), b"{}");
assert!(matches!(
kind,
ErrorKind::HttpResponse {
status: StatusCode::ImATeapot,
error_code: None
}
));
let kind = ErrorKind::http_response_from_parts(
StatusCode::ImATeapot,
&Headers::new(),
br#"{"error": {"code":"teepot"}}"#,
);
assert!(matches!(
kind,
ErrorKind::HttpResponse {
status: StatusCode::ImATeapot,
error_code
}
if error_code.as_deref() == Some("teepot")
));
let mut headers = Headers::new();
headers.insert(headers::ERROR_CODE, "teapot");
let kind = ErrorKind::http_response_from_parts(StatusCode::ImATeapot, &headers, br#"{}"#);
assert!(matches!(
kind,
ErrorKind::HttpResponse {
status: StatusCode::ImATeapot,
error_code
}
if error_code.as_deref() == Some("teapot")
));
}
#[test]
fn set_result_kind() {
let result = std::result::Result::<(), _>::Err(create_error());
let result = result.map_kind(ErrorKind::Io);
assert_eq!(&ErrorKind::Io, result.unwrap_err().kind());
}
}