use std::io::{BufRead, Error, ErrorKind as IOErrorKind};
use rustls::RootCertStore;
use rustls_pki_types::{CertificateDer, PrivateKeyDer};
use crate::{Client, ConnectionAddr, ConnectionInfo, ErrorKind, RedisError, RedisResult};
#[derive(Clone)]
pub struct ClientTlsConfig {
pub client_cert: Vec<u8>,
pub client_key: Vec<u8>,
}
#[derive(Clone)]
pub struct TlsCertificates {
pub client_tls: Option<ClientTlsConfig>,
pub root_cert: Option<Vec<u8>>,
}
pub(crate) fn inner_build_with_tls(
mut connection_info: ConnectionInfo,
certificates: TlsCertificates,
) -> RedisResult<Client> {
let tls_params = retrieve_tls_certificates(certificates)?;
connection_info.addr = if let ConnectionAddr::TcpTls {
host,
port,
insecure,
..
} = connection_info.addr
{
ConnectionAddr::TcpTls {
host,
port,
insecure,
tls_params: Some(tls_params),
}
} else {
return Err(RedisError::from((
ErrorKind::InvalidClientConfig,
"Constructing a TLS client requires a URL with the `rediss://` scheme",
)));
};
Ok(Client { connection_info })
}
pub(crate) fn retrieve_tls_certificates(
certificates: TlsCertificates,
) -> RedisResult<TlsConnParams> {
let TlsCertificates {
client_tls,
root_cert,
} = certificates;
let client_tls_params = if let Some(ClientTlsConfig {
client_cert,
client_key,
}) = client_tls
{
let buf = &mut client_cert.as_slice() as &mut dyn BufRead;
let certs = rustls_pemfile::certs(buf);
let client_cert_chain = certs.collect::<Result<Vec<_>, _>>()?;
let client_key =
rustls_pemfile::private_key(&mut client_key.as_slice() as &mut dyn BufRead)?
.ok_or_else(|| {
Error::new(
IOErrorKind::Other,
"Unable to extract private key from PEM file",
)
})?;
Some(ClientTlsParams {
client_cert_chain,
client_key,
})
} else {
None
};
let root_cert_store = if let Some(root_cert) = root_cert {
let buf = &mut root_cert.as_slice() as &mut dyn BufRead;
let certs = rustls_pemfile::certs(buf);
let mut root_cert_store = RootCertStore::empty();
for result in certs {
if root_cert_store.add(result?.to_owned()).is_err() {
return Err(
Error::new(IOErrorKind::Other, "Unable to parse TLS trust anchors").into(),
);
}
}
Some(root_cert_store)
} else {
None
};
Ok(TlsConnParams {
client_tls_params,
root_cert_store,
})
}
#[derive(Debug)]
pub struct ClientTlsParams {
pub(crate) client_cert_chain: Vec<CertificateDer<'static>>,
pub(crate) client_key: PrivateKeyDer<'static>,
}
impl Clone for ClientTlsParams {
fn clone(&self) -> Self {
use PrivateKeyDer::*;
Self {
client_cert_chain: self.client_cert_chain.clone(),
client_key: match &self.client_key {
Pkcs1(key) => Pkcs1(key.secret_pkcs1_der().to_vec().into()),
Pkcs8(key) => Pkcs8(key.secret_pkcs8_der().to_vec().into()),
Sec1(key) => Sec1(key.secret_sec1_der().to_vec().into()),
_ => unreachable!(),
},
}
}
}
#[derive(Debug, Clone)]
pub struct TlsConnParams {
pub(crate) client_tls_params: Option<ClientTlsParams>,
pub(crate) root_cert_store: Option<RootCertStore>,
}