use super::encode::BUFFER_SIZE;
use crate::{metadata::MetadataValue, Status};
use bytes::{Buf, BytesMut};
#[cfg(feature = "gzip")]
use flate2::read::{GzDecoder, GzEncoder};
use std::fmt;
#[cfg(feature = "zstd")]
use zstd::stream::read::{Decoder, Encoder};
pub(crate) const ENCODING_HEADER: &str = "grpc-encoding";
pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding";
#[derive(Debug, Default, Clone, Copy)]
pub struct EnabledCompressionEncodings {
#[cfg(feature = "gzip")]
pub(crate) gzip: bool,
#[cfg(feature = "zstd")]
pub(crate) zstd: bool,
}
impl EnabledCompressionEncodings {
pub fn is_enabled(&self, encoding: CompressionEncoding) -> bool {
match encoding {
#[cfg(feature = "gzip")]
CompressionEncoding::Gzip => self.gzip,
#[cfg(feature = "zstd")]
CompressionEncoding::Zstd => self.zstd,
}
}
pub fn enable(&mut self, encoding: CompressionEncoding) {
match encoding {
#[cfg(feature = "gzip")]
CompressionEncoding::Gzip => self.gzip = true,
#[cfg(feature = "zstd")]
CompressionEncoding::Zstd => self.zstd = true,
}
}
pub(crate) fn into_accept_encoding_header_value(self) -> Option<http::HeaderValue> {
match (self.is_gzip_enabled(), self.is_zstd_enabled()) {
(true, false) => Some(http::HeaderValue::from_static("gzip,identity")),
(false, true) => Some(http::HeaderValue::from_static("zstd,identity")),
(true, true) => Some(http::HeaderValue::from_static("gzip,zstd,identity")),
(false, false) => None,
}
}
#[cfg(feature = "gzip")]
const fn is_gzip_enabled(&self) -> bool {
self.gzip
}
#[cfg(not(feature = "gzip"))]
const fn is_gzip_enabled(&self) -> bool {
false
}
#[cfg(feature = "zstd")]
const fn is_zstd_enabled(&self) -> bool {
self.zstd
}
#[cfg(not(feature = "zstd"))]
const fn is_zstd_enabled(&self) -> bool {
false
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum CompressionEncoding {
#[allow(missing_docs)]
#[cfg(feature = "gzip")]
#[cfg_attr(docsrs, doc(cfg(feature = "gzip")))]
Gzip,
#[allow(missing_docs)]
#[cfg(feature = "zstd")]
#[cfg_attr(docsrs, doc(cfg(feature = "zstd")))]
Zstd,
}
impl CompressionEncoding {
pub(crate) fn from_accept_encoding_header(
map: &http::HeaderMap,
enabled_encodings: EnabledCompressionEncodings,
) -> Option<Self> {
if !enabled_encodings.is_gzip_enabled() && !enabled_encodings.is_zstd_enabled() {
return None;
}
let header_value = map.get(ACCEPT_ENCODING_HEADER)?;
let header_value_str = header_value.to_str().ok()?;
split_by_comma(header_value_str).find_map(|value| match value {
#[cfg(feature = "gzip")]
"gzip" => Some(CompressionEncoding::Gzip),
#[cfg(feature = "zstd")]
"zstd" => Some(CompressionEncoding::Zstd),
_ => None,
})
}
pub(crate) fn from_encoding_header(
map: &http::HeaderMap,
enabled_encodings: EnabledCompressionEncodings,
) -> Result<Option<Self>, Status> {
let header_value = if let Some(value) = map.get(ENCODING_HEADER) {
value
} else {
return Ok(None);
};
let header_value_str = if let Ok(value) = header_value.to_str() {
value
} else {
return Ok(None);
};
match header_value_str {
#[cfg(feature = "gzip")]
"gzip" if enabled_encodings.is_enabled(CompressionEncoding::Gzip) => {
Ok(Some(CompressionEncoding::Gzip))
}
#[cfg(feature = "zstd")]
"zstd" if enabled_encodings.is_enabled(CompressionEncoding::Zstd) => {
Ok(Some(CompressionEncoding::Zstd))
}
"identity" => Ok(None),
other => {
let mut status = Status::unimplemented(format!(
"Content is compressed with `{}` which isn't supported",
other
));
let header_value = enabled_encodings
.into_accept_encoding_header_value()
.map(MetadataValue::unchecked_from_header_value)
.unwrap_or_else(|| MetadataValue::from_static("identity"));
status
.metadata_mut()
.insert(ACCEPT_ENCODING_HEADER, header_value);
Err(status)
}
}
}
#[allow(missing_docs)]
#[cfg(any(feature = "gzip", feature = "zstd"))]
pub(crate) fn as_str(&self) -> &'static str {
match self {
#[cfg(feature = "gzip")]
CompressionEncoding::Gzip => "gzip",
#[cfg(feature = "zstd")]
CompressionEncoding::Zstd => "zstd",
}
}
#[cfg(any(feature = "gzip", feature = "zstd"))]
pub(crate) fn into_header_value(self) -> http::HeaderValue {
http::HeaderValue::from_static(self.as_str())
}
pub(crate) fn encodings() -> &'static [Self] {
&[
#[cfg(feature = "gzip")]
CompressionEncoding::Gzip,
#[cfg(feature = "zstd")]
CompressionEncoding::Zstd,
]
}
}
impl fmt::Display for CompressionEncoding {
#[allow(unused_variables)]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
#[cfg(feature = "gzip")]
CompressionEncoding::Gzip => write!(f, "gzip"),
#[cfg(feature = "zstd")]
CompressionEncoding::Zstd => write!(f, "zstd"),
}
}
}
fn split_by_comma(s: &str) -> impl Iterator<Item = &str> {
s.trim().split(',').map(|s| s.trim())
}
#[allow(unused_variables, unreachable_code)]
pub(crate) fn compress(
encoding: CompressionEncoding,
decompressed_buf: &mut BytesMut,
out_buf: &mut BytesMut,
len: usize,
) -> Result<(), std::io::Error> {
let capacity = ((len / BUFFER_SIZE) + 1) * BUFFER_SIZE;
out_buf.reserve(capacity);
#[cfg(any(feature = "gzip", feature = "zstd"))]
let mut out_writer = bytes::BufMut::writer(out_buf);
match encoding {
#[cfg(feature = "gzip")]
CompressionEncoding::Gzip => {
let mut gzip_encoder = GzEncoder::new(
&decompressed_buf[0..len],
flate2::Compression::new(6),
);
std::io::copy(&mut gzip_encoder, &mut out_writer)?;
}
#[cfg(feature = "zstd")]
CompressionEncoding::Zstd => {
let mut zstd_encoder = Encoder::new(
&decompressed_buf[0..len],
zstd::DEFAULT_COMPRESSION_LEVEL,
)?;
std::io::copy(&mut zstd_encoder, &mut out_writer)?;
}
}
decompressed_buf.advance(len);
Ok(())
}
#[allow(unused_variables, unreachable_code)]
pub(crate) fn decompress(
encoding: CompressionEncoding,
compressed_buf: &mut BytesMut,
out_buf: &mut BytesMut,
len: usize,
) -> Result<(), std::io::Error> {
let estimate_decompressed_len = len * 2;
let capacity = ((estimate_decompressed_len / BUFFER_SIZE) + 1) * BUFFER_SIZE;
out_buf.reserve(capacity);
#[cfg(any(feature = "gzip", feature = "zstd"))]
let mut out_writer = bytes::BufMut::writer(out_buf);
match encoding {
#[cfg(feature = "gzip")]
CompressionEncoding::Gzip => {
let mut gzip_decoder = GzDecoder::new(&compressed_buf[0..len]);
std::io::copy(&mut gzip_decoder, &mut out_writer)?;
}
#[cfg(feature = "zstd")]
CompressionEncoding::Zstd => {
let mut zstd_decoder = Decoder::new(&compressed_buf[0..len])?;
std::io::copy(&mut zstd_decoder, &mut out_writer)?;
}
}
compressed_buf.advance(len);
Ok(())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum SingleMessageCompressionOverride {
Inherit,
Disable,
}
impl Default for SingleMessageCompressionOverride {
fn default() -> Self {
Self::Inherit
}
}