1use std::io::{self, Error};
2
3use rustls::pki_types::pem::PemObject;
4use rustls::pki_types::{CertificateDer, PrivateKeyDer};
5use rustls::RootCertStore;
6
7use crate::connection::TlsConnParams;
8use crate::{Client, ConnectionAddr, ConnectionInfo, ErrorKind, RedisError, RedisResult};
9
10#[derive(Clone)]
13pub struct ClientTlsConfig {
14 pub client_cert: Vec<u8>,
16 pub client_key: Vec<u8>,
18}
19
20#[derive(Clone)]
25pub struct TlsCertificates {
26 pub client_tls: Option<ClientTlsConfig>,
28 pub root_cert: Option<Vec<u8>>,
30}
31
32pub(crate) fn inner_build_with_tls(
33 mut connection_info: ConnectionInfo,
34 certificates: &TlsCertificates,
35) -> RedisResult<Client> {
36 let tls_params = retrieve_tls_certificates(certificates)?;
37
38 connection_info.addr = if let ConnectionAddr::TcpTls {
39 host,
40 port,
41 insecure,
42 ..
43 } = connection_info.addr
44 {
45 ConnectionAddr::TcpTls {
46 host,
47 port,
48 insecure,
49 tls_params: Some(tls_params),
50 }
51 } else {
52 return Err(RedisError::from((
53 ErrorKind::InvalidClientConfig,
54 "Constructing a TLS client requires a URL with the `rediss://` scheme",
55 )));
56 };
57
58 Ok(Client { connection_info })
59}
60
61pub(crate) fn retrieve_tls_certificates(
62 certificates: &TlsCertificates,
63) -> RedisResult<TlsConnParams> {
64 let TlsCertificates {
65 client_tls,
66 root_cert,
67 } = certificates;
68
69 let client_tls_params = if let Some(ClientTlsConfig {
70 client_cert,
71 client_key,
72 }) = client_tls
73 {
74 let client_cert_chain = CertificateDer::pem_slice_iter(client_cert)
75 .collect::<Result<Vec<_>, _>>()
76 .map_err(|err| {
77 Error::new(
78 io::ErrorKind::Other,
79 format!("Unable to parse client certificate chain PEM: {err}"),
80 )
81 })?;
82
83 let client_key = PrivateKeyDer::from_pem_slice(client_key).map_err(|err| {
84 Error::new(
85 io::ErrorKind::Other,
86 format!("Unable to extract private key from PEM file: {err}"),
87 )
88 })?;
89
90 Some(ClientTlsParams {
91 client_cert_chain,
92 client_key,
93 })
94 } else {
95 None
96 };
97
98 let root_cert_store = if let Some(root_cert) = root_cert {
99 let mut root_cert_store = RootCertStore::empty();
100 for result in CertificateDer::pem_slice_iter(root_cert) {
101 let cert = result.map_err(|err| {
102 Error::new(
103 io::ErrorKind::Other,
104 format!("Unable to parse root certificate PEM: {err}"),
105 )
106 })?;
107
108 if root_cert_store.add(cert).is_err() {
109 return Err(
110 Error::new(io::ErrorKind::Other, "Unable to parse TLS trust anchors").into(),
111 );
112 }
113 }
114
115 Some(root_cert_store)
116 } else {
117 None
118 };
119
120 Ok(TlsConnParams {
121 client_tls_params,
122 root_cert_store,
123 #[cfg(any(feature = "tls-rustls-insecure", feature = "tls-native-tls"))]
124 danger_accept_invalid_hostnames: false,
125 })
126}
127
128#[derive(Debug)]
129pub struct ClientTlsParams {
130 pub(crate) client_cert_chain: Vec<CertificateDer<'static>>,
131 pub(crate) client_key: PrivateKeyDer<'static>,
132}
133
134impl Clone for ClientTlsParams {
136 fn clone(&self) -> Self {
137 use PrivateKeyDer::*;
138 Self {
139 client_cert_chain: self.client_cert_chain.clone(),
140 client_key: match &self.client_key {
141 Pkcs1(key) => Pkcs1(key.secret_pkcs1_der().to_vec().into()),
142 Pkcs8(key) => Pkcs8(key.secret_pkcs8_der().to_vec().into()),
143 Sec1(key) => Sec1(key.secret_sec1_der().to_vec().into()),
144 _ => unreachable!(),
145 },
146 }
147 }
148}