use crate::io::TokioIo;
use crate::{
bindings::http::types::{self, Method, Scheme},
body::{HostIncomingBody, HyperIncomingBody, HyperOutgoingBody},
error::dns_error,
hyper_request_error,
};
use anyhow::bail;
use bytes::Bytes;
use http_body_util::BodyExt;
use hyper::body::Body;
use hyper::header::HeaderName;
use std::any::Any;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::time::timeout;
use wasmtime::component::{Resource, ResourceTable};
use wasmtime_wasi::{runtime::AbortOnDropJoinHandle, Subscribe};
#[derive(Debug)]
pub struct WasiHttpCtx {
_priv: (),
}
impl WasiHttpCtx {
pub fn new() -> Self {
Self { _priv: () }
}
}
pub trait WasiHttpView: Send {
fn ctx(&mut self) -> &mut WasiHttpCtx;
fn table(&mut self) -> &mut ResourceTable;
fn new_incoming_request<B>(
&mut self,
scheme: Scheme,
req: hyper::Request<B>,
) -> wasmtime::Result<Resource<HostIncomingRequest>>
where
B: Body<Data = Bytes, Error = hyper::Error> + Send + Sync + 'static,
Self: Sized,
{
let (parts, body) = req.into_parts();
let body = body.map_err(crate::hyper_response_error).boxed();
let body = HostIncomingBody::new(
body,
std::time::Duration::from_millis(600 * 1000),
);
let incoming_req = HostIncomingRequest::new(self, parts, scheme, Some(body))?;
Ok(self.table().push(incoming_req)?)
}
fn new_response_outparam(
&mut self,
result: tokio::sync::oneshot::Sender<
Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
>,
) -> wasmtime::Result<Resource<HostResponseOutparam>> {
let id = self.table().push(HostResponseOutparam { result })?;
Ok(id)
}
fn send_request(
&mut self,
request: hyper::Request<HyperOutgoingBody>,
config: OutgoingRequestConfig,
) -> crate::HttpResult<HostFutureIncomingResponse> {
Ok(default_send_request(request, config))
}
fn is_forbidden_header(&mut self, _name: &HeaderName) -> bool {
false
}
}
impl<T: ?Sized + WasiHttpView> WasiHttpView for &mut T {
fn ctx(&mut self) -> &mut WasiHttpCtx {
T::ctx(self)
}
fn table(&mut self) -> &mut ResourceTable {
T::table(self)
}
fn new_response_outparam(
&mut self,
result: tokio::sync::oneshot::Sender<
Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
>,
) -> wasmtime::Result<Resource<HostResponseOutparam>> {
T::new_response_outparam(self, result)
}
fn send_request(
&mut self,
request: hyper::Request<HyperOutgoingBody>,
config: OutgoingRequestConfig,
) -> crate::HttpResult<HostFutureIncomingResponse> {
T::send_request(self, request, config)
}
fn is_forbidden_header(&mut self, name: &HeaderName) -> bool {
T::is_forbidden_header(self, name)
}
}
impl<T: ?Sized + WasiHttpView> WasiHttpView for Box<T> {
fn ctx(&mut self) -> &mut WasiHttpCtx {
T::ctx(self)
}
fn table(&mut self) -> &mut ResourceTable {
T::table(self)
}
fn new_response_outparam(
&mut self,
result: tokio::sync::oneshot::Sender<
Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
>,
) -> wasmtime::Result<Resource<HostResponseOutparam>> {
T::new_response_outparam(self, result)
}
fn send_request(
&mut self,
request: hyper::Request<HyperOutgoingBody>,
config: OutgoingRequestConfig,
) -> crate::HttpResult<HostFutureIncomingResponse> {
T::send_request(self, request, config)
}
fn is_forbidden_header(&mut self, name: &HeaderName) -> bool {
T::is_forbidden_header(self, name)
}
}
#[repr(transparent)]
pub struct WasiHttpImpl<T>(pub T);
impl<T: WasiHttpView> WasiHttpView for WasiHttpImpl<T> {
fn ctx(&mut self) -> &mut WasiHttpCtx {
self.0.ctx()
}
fn table(&mut self) -> &mut ResourceTable {
self.0.table()
}
fn new_response_outparam(
&mut self,
result: tokio::sync::oneshot::Sender<
Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
>,
) -> wasmtime::Result<Resource<HostResponseOutparam>> {
self.0.new_response_outparam(result)
}
fn send_request(
&mut self,
request: hyper::Request<HyperOutgoingBody>,
config: OutgoingRequestConfig,
) -> crate::HttpResult<HostFutureIncomingResponse> {
self.0.send_request(request, config)
}
fn is_forbidden_header(&mut self, name: &HeaderName) -> bool {
self.0.is_forbidden_header(name)
}
}
pub(crate) fn is_forbidden_header(view: &mut dyn WasiHttpView, name: &HeaderName) -> bool {
static FORBIDDEN_HEADERS: [HeaderName; 10] = [
hyper::header::CONNECTION,
HeaderName::from_static("keep-alive"),
hyper::header::PROXY_AUTHENTICATE,
hyper::header::PROXY_AUTHORIZATION,
HeaderName::from_static("proxy-connection"),
hyper::header::TE,
hyper::header::TRANSFER_ENCODING,
hyper::header::UPGRADE,
hyper::header::HOST,
HeaderName::from_static("http2-settings"),
];
FORBIDDEN_HEADERS.contains(name) || view.is_forbidden_header(name)
}
pub(crate) fn remove_forbidden_headers(
view: &mut dyn WasiHttpView,
headers: &mut hyper::HeaderMap,
) {
let forbidden_keys = Vec::from_iter(headers.keys().filter_map(|name| {
if is_forbidden_header(view, name) {
Some(name.clone())
} else {
None
}
}));
for name in forbidden_keys {
headers.remove(name);
}
}
pub struct OutgoingRequestConfig {
pub use_tls: bool,
pub connect_timeout: Duration,
pub first_byte_timeout: Duration,
pub between_bytes_timeout: Duration,
}
pub fn default_send_request(
request: hyper::Request<HyperOutgoingBody>,
config: OutgoingRequestConfig,
) -> HostFutureIncomingResponse {
let handle = wasmtime_wasi::runtime::spawn(async move {
Ok(default_send_request_handler(request, config).await)
});
HostFutureIncomingResponse::pending(handle)
}
pub async fn default_send_request_handler(
mut request: hyper::Request<HyperOutgoingBody>,
OutgoingRequestConfig {
use_tls,
connect_timeout,
first_byte_timeout,
between_bytes_timeout,
}: OutgoingRequestConfig,
) -> Result<IncomingResponse, types::ErrorCode> {
let authority = if let Some(authority) = request.uri().authority() {
if authority.port().is_some() {
authority.to_string()
} else {
let port = if use_tls { 443 } else { 80 };
format!("{}:{port}", authority.to_string())
}
} else {
return Err(types::ErrorCode::HttpRequestUriInvalid);
};
let tcp_stream = timeout(connect_timeout, TcpStream::connect(&authority))
.await
.map_err(|_| types::ErrorCode::ConnectionTimeout)?
.map_err(|e| match e.kind() {
std::io::ErrorKind::AddrNotAvailable => {
dns_error("address not available".to_string(), 0)
}
_ => {
if e.to_string()
.starts_with("failed to lookup address information")
{
dns_error("address not available".to_string(), 0)
} else {
types::ErrorCode::ConnectionRefused
}
}
})?;
let (mut sender, worker) = if use_tls {
#[cfg(any(target_arch = "riscv64", target_arch = "s390x"))]
{
return Err(crate::bindings::http::types::ErrorCode::InternalError(
Some("unsupported architecture for SSL".to_string()),
));
}
#[cfg(not(any(target_arch = "riscv64", target_arch = "s390x")))]
{
use rustls::pki_types::ServerName;
let root_cert_store = rustls::RootCertStore {
roots: webpki_roots::TLS_SERVER_ROOTS.into(),
};
let config = rustls::ClientConfig::builder()
.with_root_certificates(root_cert_store)
.with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
let mut parts = authority.split(":");
let host = parts.next().unwrap_or(&authority);
let domain = ServerName::try_from(host)
.map_err(|e| {
tracing::warn!("dns lookup error: {e:?}");
dns_error("invalid dns name".to_string(), 0)
})?
.to_owned();
let stream = connector.connect(domain, tcp_stream).await.map_err(|e| {
tracing::warn!("tls protocol error: {e:?}");
types::ErrorCode::TlsProtocolError
})?;
let stream = TokioIo::new(stream);
let (sender, conn) = timeout(
connect_timeout,
hyper::client::conn::http1::handshake(stream),
)
.await
.map_err(|_| types::ErrorCode::ConnectionTimeout)?
.map_err(hyper_request_error)?;
let worker = wasmtime_wasi::runtime::spawn(async move {
match conn.await {
Ok(()) => {}
Err(e) => tracing::warn!("dropping error {e}"),
}
});
(sender, worker)
}
} else {
let tcp_stream = TokioIo::new(tcp_stream);
let (sender, conn) = timeout(
connect_timeout,
hyper::client::conn::http1::handshake(tcp_stream),
)
.await
.map_err(|_| types::ErrorCode::ConnectionTimeout)?
.map_err(hyper_request_error)?;
let worker = wasmtime_wasi::runtime::spawn(async move {
match conn.await {
Ok(()) => {}
Err(e) => tracing::warn!("dropping error {e}"),
}
});
(sender, worker)
};
*request.uri_mut() = http::Uri::builder()
.path_and_query(
request
.uri()
.path_and_query()
.map(|p| p.as_str())
.unwrap_or("/"),
)
.build()
.expect("comes from valid request");
let resp = timeout(first_byte_timeout, sender.send_request(request))
.await
.map_err(|_| types::ErrorCode::ConnectionReadTimeout)?
.map_err(hyper_request_error)?
.map(|body| body.map_err(hyper_request_error).boxed());
Ok(IncomingResponse {
resp,
worker: Some(worker),
between_bytes_timeout,
})
}
impl From<http::Method> for types::Method {
fn from(method: http::Method) -> Self {
if method == http::Method::GET {
types::Method::Get
} else if method == hyper::Method::HEAD {
types::Method::Head
} else if method == hyper::Method::POST {
types::Method::Post
} else if method == hyper::Method::PUT {
types::Method::Put
} else if method == hyper::Method::DELETE {
types::Method::Delete
} else if method == hyper::Method::CONNECT {
types::Method::Connect
} else if method == hyper::Method::OPTIONS {
types::Method::Options
} else if method == hyper::Method::TRACE {
types::Method::Trace
} else if method == hyper::Method::PATCH {
types::Method::Patch
} else {
types::Method::Other(method.to_string())
}
}
}
impl TryInto<http::Method> for types::Method {
type Error = http::method::InvalidMethod;
fn try_into(self) -> Result<http::Method, Self::Error> {
match self {
Method::Get => Ok(http::Method::GET),
Method::Head => Ok(http::Method::HEAD),
Method::Post => Ok(http::Method::POST),
Method::Put => Ok(http::Method::PUT),
Method::Delete => Ok(http::Method::DELETE),
Method::Connect => Ok(http::Method::CONNECT),
Method::Options => Ok(http::Method::OPTIONS),
Method::Trace => Ok(http::Method::TRACE),
Method::Patch => Ok(http::Method::PATCH),
Method::Other(s) => http::Method::from_bytes(s.as_bytes()),
}
}
}
#[derive(Debug)]
pub struct HostIncomingRequest {
pub(crate) parts: http::request::Parts,
pub(crate) scheme: Scheme,
pub(crate) authority: String,
pub body: Option<HostIncomingBody>,
}
impl HostIncomingRequest {
pub fn new(
view: &mut dyn WasiHttpView,
mut parts: http::request::Parts,
scheme: Scheme,
body: Option<HostIncomingBody>,
) -> anyhow::Result<Self> {
let authority = match parts.uri.authority() {
Some(authority) => authority.to_string(),
None => match parts.headers.get(http::header::HOST) {
Some(host) => host.to_str()?.to_string(),
None => bail!("invalid HTTP request missing authority in URI and host header"),
},
};
remove_forbidden_headers(view, &mut parts.headers);
Ok(Self {
parts,
authority,
scheme,
body,
})
}
}
pub struct HostResponseOutparam {
pub result:
tokio::sync::oneshot::Sender<Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>>,
}
pub struct HostOutgoingResponse {
pub status: http::StatusCode,
pub headers: FieldMap,
pub body: Option<HyperOutgoingBody>,
}
impl TryFrom<HostOutgoingResponse> for hyper::Response<HyperOutgoingBody> {
type Error = http::Error;
fn try_from(
resp: HostOutgoingResponse,
) -> Result<hyper::Response<HyperOutgoingBody>, Self::Error> {
use http_body_util::Empty;
let mut builder = hyper::Response::builder().status(resp.status);
*builder.headers_mut().unwrap() = resp.headers;
match resp.body {
Some(body) => builder.body(body),
None => builder.body(
Empty::<bytes::Bytes>::new()
.map_err(|_| unreachable!("Infallible error"))
.boxed(),
),
}
}
}
#[derive(Debug)]
pub struct HostOutgoingRequest {
pub method: Method,
pub scheme: Option<Scheme>,
pub authority: Option<String>,
pub path_with_query: Option<String>,
pub headers: FieldMap,
pub body: Option<HyperOutgoingBody>,
}
#[derive(Debug, Default)]
pub struct HostRequestOptions {
pub connect_timeout: Option<std::time::Duration>,
pub first_byte_timeout: Option<std::time::Duration>,
pub between_bytes_timeout: Option<std::time::Duration>,
}
#[derive(Debug)]
pub struct HostIncomingResponse {
pub status: u16,
pub headers: FieldMap,
pub body: Option<HostIncomingBody>,
}
#[derive(Debug)]
pub enum HostFields {
Ref {
parent: u32,
get_fields: for<'a> fn(elem: &'a mut (dyn Any + 'static)) -> &'a mut FieldMap,
},
Owned {
fields: FieldMap,
},
}
pub type FieldMap = hyper::HeaderMap;
pub type FutureIncomingResponseHandle =
AbortOnDropJoinHandle<anyhow::Result<Result<IncomingResponse, types::ErrorCode>>>;
#[derive(Debug)]
pub struct IncomingResponse {
pub resp: hyper::Response<HyperIncomingBody>,
pub worker: Option<AbortOnDropJoinHandle<()>>,
pub between_bytes_timeout: std::time::Duration,
}
#[derive(Debug)]
pub enum HostFutureIncomingResponse {
Pending(FutureIncomingResponseHandle),
Ready(anyhow::Result<Result<IncomingResponse, types::ErrorCode>>),
Consumed,
}
impl HostFutureIncomingResponse {
pub fn pending(handle: FutureIncomingResponseHandle) -> Self {
Self::Pending(handle)
}
pub fn ready(result: anyhow::Result<Result<IncomingResponse, types::ErrorCode>>) -> Self {
Self::Ready(result)
}
pub fn is_ready(&self) -> bool {
matches!(self, Self::Ready(_))
}
pub fn unwrap_ready(self) -> anyhow::Result<Result<IncomingResponse, types::ErrorCode>> {
match self {
Self::Ready(res) => res,
Self::Pending(_) | Self::Consumed => {
panic!("unwrap_ready called on a pending HostFutureIncomingResponse")
}
}
}
}
#[async_trait::async_trait]
impl Subscribe for HostFutureIncomingResponse {
async fn ready(&mut self) {
if let Self::Pending(handle) = self {
*self = Self::Ready(handle.await);
}
}
}