tokio_postgres/
connect_tls.rs1use crate::config::{SslMode, SslNegotiation};
2use crate::maybe_tls_stream::MaybeTlsStream;
3use crate::tls::private::ForcePrivateApi;
4use crate::tls::TlsConnect;
5use crate::Error;
6use bytes::BytesMut;
7use postgres_protocol::message::frontend;
8use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
9
10pub async fn connect_tls<S, T>(
11 mut stream: S,
12 mode: SslMode,
13 negotiation: SslNegotiation,
14 tls: T,
15 has_hostname: bool,
16) -> Result<MaybeTlsStream<S, T::Stream>, Error>
17where
18 S: AsyncRead + AsyncWrite + Unpin,
19 T: TlsConnect<S>,
20{
21 match mode {
22 SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)),
23 SslMode::Prefer if !tls.can_connect(ForcePrivateApi) => {
24 return Ok(MaybeTlsStream::Raw(stream))
25 }
26 SslMode::Prefer if negotiation == SslNegotiation::Direct => return Err(Error::tls(
27 "weak sslmode \"prefer\" may not be used with sslnegotiation=direct (use \"require\")"
28 .into(),
29 )),
30 SslMode::Prefer | SslMode::Require => {}
31 }
32
33 if negotiation == SslNegotiation::Postgres {
34 let mut buf = BytesMut::new();
35 frontend::ssl_request(&mut buf);
36 stream.write_all(&buf).await.map_err(Error::io)?;
37
38 let mut buf = [0];
39 stream.read_exact(&mut buf).await.map_err(Error::io)?;
40
41 if buf[0] != b'S' {
42 if SslMode::Require == mode {
43 return Err(Error::tls("server does not support TLS".into()));
44 } else {
45 return Ok(MaybeTlsStream::Raw(stream));
46 }
47 }
48 }
49
50 if !has_hostname {
51 return Err(Error::tls("no hostname provided for TLS handshake".into()));
52 }
53
54 let stream = tls
55 .connect(stream)
56 .await
57 .map_err(|e| Error::tls(e.into()))?;
58
59 Ok(MaybeTlsStream::Tls(stream))
60}