async_nats/
connector.rs

1// Copyright 2020-2022 The NATS Authors
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14use crate::auth::Auth;
15use crate::client::Statistics;
16use crate::connection::Connection;
17use crate::connection::State;
18#[cfg(feature = "websockets")]
19use crate::connection::WebSocketAdapter;
20use crate::options::CallbackArg1;
21use crate::tls;
22use crate::AuthError;
23use crate::ClientError;
24use crate::ClientOp;
25use crate::ConnectError;
26use crate::ConnectErrorKind;
27use crate::ConnectInfo;
28use crate::Event;
29use crate::Protocol;
30use crate::ServerAddr;
31use crate::ServerError;
32use crate::ServerInfo;
33use crate::ServerOp;
34use crate::SocketAddr;
35use crate::ToServerAddrs;
36use crate::LANG;
37use crate::VERSION;
38use base64::engine::general_purpose::URL_SAFE_NO_PAD;
39use base64::engine::Engine;
40use rand::seq::SliceRandom;
41use rand::thread_rng;
42use std::cmp;
43use std::io;
44use std::path::PathBuf;
45use std::sync::atomic::AtomicUsize;
46use std::sync::atomic::Ordering;
47use std::sync::Arc;
48use std::time::Duration;
49use tokio::net::TcpStream;
50use tokio::time::sleep;
51use tokio_rustls::rustls;
52
53pub(crate) struct ConnectorOptions {
54    pub(crate) tls_required: bool,
55    pub(crate) certificates: Vec<PathBuf>,
56    pub(crate) client_cert: Option<PathBuf>,
57    pub(crate) client_key: Option<PathBuf>,
58    pub(crate) tls_client_config: Option<rustls::ClientConfig>,
59    pub(crate) tls_first: bool,
60    pub(crate) auth: Auth,
61    pub(crate) no_echo: bool,
62    pub(crate) connection_timeout: Duration,
63    pub(crate) name: Option<String>,
64    pub(crate) ignore_discovered_servers: bool,
65    pub(crate) retain_servers_order: bool,
66    pub(crate) read_buffer_capacity: u16,
67    pub(crate) reconnect_delay_callback: Box<dyn Fn(usize) -> Duration + Send + Sync + 'static>,
68    pub(crate) auth_callback: Option<CallbackArg1<Vec<u8>, Result<Auth, AuthError>>>,
69    pub(crate) max_reconnects: Option<usize>,
70}
71
72/// Maintains a list of servers and establishes connections.
73pub(crate) struct Connector {
74    /// A map of servers and number of connect attempts.
75    servers: Vec<(ServerAddr, usize)>,
76    options: ConnectorOptions,
77    pub(crate) connect_stats: Arc<Statistics>,
78    attempts: usize,
79    pub(crate) events_tx: tokio::sync::mpsc::Sender<Event>,
80    pub(crate) state_tx: tokio::sync::watch::Sender<State>,
81    pub(crate) max_payload: Arc<AtomicUsize>,
82}
83
84pub(crate) fn reconnect_delay_callback_default(attempts: usize) -> Duration {
85    if attempts <= 1 {
86        Duration::from_millis(0)
87    } else {
88        let exp: u32 = (attempts - 1).try_into().unwrap_or(u32::MAX);
89        let max = Duration::from_secs(4);
90        cmp::min(Duration::from_millis(2_u64.saturating_pow(exp)), max)
91    }
92}
93
94impl Connector {
95    pub(crate) fn new<A: ToServerAddrs>(
96        addrs: A,
97        options: ConnectorOptions,
98        events_tx: tokio::sync::mpsc::Sender<Event>,
99        state_tx: tokio::sync::watch::Sender<State>,
100        max_payload: Arc<AtomicUsize>,
101        connect_stats: Arc<Statistics>,
102    ) -> Result<Connector, io::Error> {
103        let servers = addrs.to_server_addrs()?.map(|addr| (addr, 0)).collect();
104
105        Ok(Connector {
106            attempts: 0,
107            servers,
108            options,
109            events_tx,
110            state_tx,
111            max_payload,
112            connect_stats,
113        })
114    }
115
116    pub(crate) async fn connect(&mut self) -> Result<(ServerInfo, Connection), ConnectError> {
117        loop {
118            match self.try_connect().await {
119                Ok(inner) => {
120                    return Ok(inner);
121                }
122                Err(error) => match error.kind() {
123                    ConnectErrorKind::MaxReconnects => {
124                        return Err(ConnectError::with_source(
125                            crate::ConnectErrorKind::MaxReconnects,
126                            error,
127                        ))
128                    }
129                    other => {
130                        self.events_tx
131                            .send(Event::ClientError(ClientError::Other(other.to_string())))
132                            .await
133                            .ok();
134                    }
135                },
136            }
137        }
138    }
139
140    pub(crate) async fn try_connect(&mut self) -> Result<(ServerInfo, Connection), ConnectError> {
141        tracing::debug!("connecting");
142        let mut error = None;
143
144        let mut servers = self.servers.clone();
145        if !self.options.retain_servers_order {
146            servers.shuffle(&mut thread_rng());
147            // sort_by is stable, meaning it will retain the order for equal elements.
148            servers.sort_by(|a, b| a.1.cmp(&b.1));
149        }
150
151        for (server_addr, _) in servers {
152            self.attempts += 1;
153            if let Some(max_reconnects) = self.options.max_reconnects {
154                if self.attempts > max_reconnects {
155                    self.events_tx
156                        .send(Event::ClientError(ClientError::MaxReconnects))
157                        .await
158                        .ok();
159                    return Err(ConnectError::new(crate::ConnectErrorKind::MaxReconnects));
160                }
161            }
162
163            let duration = (self.options.reconnect_delay_callback)(self.attempts);
164
165            sleep(duration).await;
166
167            let socket_addrs = server_addr
168                .socket_addrs()
169                .await
170                .map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Dns, err))?;
171            for socket_addr in socket_addrs {
172                match self
173                    .try_connect_to(
174                        &socket_addr,
175                        server_addr.tls_required(),
176                        server_addr.clone(),
177                    )
178                    .await
179                {
180                    Ok((server_info, mut connection)) => {
181                        if !self.options.ignore_discovered_servers {
182                            for url in &server_info.connect_urls {
183                                let server_addr = url.parse::<ServerAddr>().map_err(|err| {
184                                    ConnectError::with_source(
185                                        crate::ConnectErrorKind::ServerParse,
186                                        err,
187                                    )
188                                })?;
189                                if !self.servers.iter().any(|(addr, _)| addr == &server_addr) {
190                                    self.servers.push((server_addr, 0));
191                                }
192                            }
193                        }
194
195                        let tls_required = self.options.tls_required || server_addr.tls_required();
196                        let mut connect_info = ConnectInfo {
197                            tls_required,
198                            name: self.options.name.clone(),
199                            pedantic: false,
200                            verbose: false,
201                            lang: LANG.to_string(),
202                            version: VERSION.to_string(),
203                            protocol: Protocol::Dynamic,
204                            user: self.options.auth.username.to_owned(),
205                            pass: self.options.auth.password.to_owned(),
206                            auth_token: self.options.auth.token.to_owned(),
207                            user_jwt: None,
208                            nkey: None,
209                            signature: None,
210                            echo: !self.options.no_echo,
211                            headers: true,
212                            no_responders: true,
213                        };
214
215                        if let Some(nkey) = self.options.auth.nkey.as_ref() {
216                            match nkeys::KeyPair::from_seed(nkey.as_str()) {
217                                Ok(key_pair) => {
218                                    let nonce = server_info.nonce.clone();
219                                    match key_pair.sign(nonce.as_bytes()) {
220                                        Ok(signed) => {
221                                            connect_info.nkey = Some(key_pair.public_key());
222                                            connect_info.signature =
223                                                Some(URL_SAFE_NO_PAD.encode(signed));
224                                        }
225                                        Err(_) => {
226                                            return Err(ConnectError::new(
227                                                crate::ConnectErrorKind::Authentication,
228                                            ))
229                                        }
230                                    };
231                                }
232                                Err(_) => {
233                                    return Err(ConnectError::new(
234                                        crate::ConnectErrorKind::Authentication,
235                                    ))
236                                }
237                            }
238                        }
239
240                        if let Some(jwt) = self.options.auth.jwt.as_ref() {
241                            if let Some(sign_fn) = self.options.auth.signature_callback.as_ref() {
242                                match sign_fn.call(server_info.nonce.clone()).await {
243                                    Ok(sig) => {
244                                        connect_info.user_jwt = Some(jwt.clone());
245                                        connect_info.signature = Some(sig);
246                                    }
247                                    Err(_) => {
248                                        return Err(ConnectError::new(
249                                            crate::ConnectErrorKind::Authentication,
250                                        ))
251                                    }
252                                }
253                            }
254                        }
255
256                        if let Some(callback) = self.options.auth_callback.as_ref() {
257                            let auth = callback
258                                .call(server_info.nonce.as_bytes().to_vec())
259                                .await
260                                .map_err(|err| {
261                                ConnectError::with_source(
262                                    crate::ConnectErrorKind::Authentication,
263                                    err,
264                                )
265                            })?;
266                            connect_info.user = auth.username;
267                            connect_info.pass = auth.password;
268                            connect_info.user_jwt = auth.jwt;
269                            connect_info.signature = auth
270                                .signature
271                                .map(|signature| URL_SAFE_NO_PAD.encode(signature));
272                            connect_info.auth_token = auth.token;
273                            connect_info.nkey = auth.nkey;
274                        }
275
276                        connection
277                            .easy_write_and_flush(
278                                [ClientOp::Connect(connect_info), ClientOp::Ping].iter(),
279                            )
280                            .await?;
281
282                        match connection.read_op().await? {
283                            Some(ServerOp::Error(err)) => match err {
284                                ServerError::AuthorizationViolation => {
285                                    return Err(ConnectError::with_source(
286                                        crate::ConnectErrorKind::AuthorizationViolation,
287                                        err,
288                                    ));
289                                }
290                                err => {
291                                    return Err(ConnectError::with_source(
292                                        crate::ConnectErrorKind::Io,
293                                        err,
294                                    ));
295                                }
296                            },
297                            Some(_) => {
298                                tracing::debug!("connected to {}", server_info.port);
299                                self.attempts = 0;
300                                self.connect_stats.connects.add(1, Ordering::Relaxed);
301                                self.events_tx.send(Event::Connected).await.ok();
302                                self.state_tx.send(State::Connected).ok();
303                                self.max_payload.store(
304                                    server_info.max_payload,
305                                    std::sync::atomic::Ordering::Relaxed,
306                                );
307                                return Ok((server_info, connection));
308                            }
309                            None => {
310                                return Err(ConnectError::with_source(
311                                    crate::ConnectErrorKind::Io,
312                                    "broken pipe",
313                                ))
314                            }
315                        }
316                    }
317
318                    Err(inner) => error.replace(inner),
319                };
320            }
321        }
322
323        Err(error.unwrap())
324    }
325
326    pub(crate) async fn try_connect_to(
327        &self,
328        socket_addr: &SocketAddr,
329        tls_required: bool,
330        server_addr: ServerAddr,
331    ) -> Result<(ServerInfo, Connection), ConnectError> {
332        let mut connection = match server_addr.scheme() {
333            #[cfg(feature = "websockets")]
334            "ws" => {
335                let ws = tokio::time::timeout(
336                    self.options.connection_timeout,
337                    tokio_websockets::client::Builder::new()
338                        .uri(format!("{}://{}", server_addr.scheme(), socket_addr).as_str())
339                        .map_err(|err| {
340                            ConnectError::with_source(crate::ConnectErrorKind::ServerParse, err)
341                        })?
342                        .connect(),
343                )
344                .await
345                .map_err(|_| ConnectError::new(crate::ConnectErrorKind::TimedOut))?
346                .map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Io, err))?;
347
348                let con = WebSocketAdapter::new(ws.0);
349                Connection::new(Box::new(con), 0, self.connect_stats.clone())
350            }
351            #[cfg(feature = "websockets")]
352            "wss" => {
353                let domain = rustls_webpki::types::ServerName::try_from(server_addr.host())
354                    .map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Tls, err))?;
355                let tls_config =
356                    Arc::new(tls::config_tls(&self.options).await.map_err(|err| {
357                        ConnectError::with_source(crate::ConnectErrorKind::Tls, err)
358                    })?);
359                let tls_connector = tokio_rustls::TlsConnector::from(tls_config);
360                let ws = tokio::time::timeout(
361                    self.options.connection_timeout,
362                    tokio_websockets::client::Builder::new()
363                        .connector(&tokio_websockets::Connector::Rustls(tls_connector))
364                        .uri(
365                            format!(
366                                "{}://{}:{}",
367                                server_addr.scheme(),
368                                domain.to_str(),
369                                server_addr.port()
370                            )
371                            .as_str(),
372                        )
373                        .map_err(|err| {
374                            ConnectError::with_source(crate::ConnectErrorKind::ServerParse, err)
375                        })?
376                        .connect(),
377                )
378                .await
379                .map_err(|_| ConnectError::new(crate::ConnectErrorKind::TimedOut))?
380                .map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Io, err))?;
381                let con = WebSocketAdapter::new(ws.0);
382                Connection::new(Box::new(con), 0, self.connect_stats.clone())
383            }
384            _ => {
385                let tcp_stream = tokio::time::timeout(
386                    self.options.connection_timeout,
387                    TcpStream::connect(socket_addr),
388                )
389                .await
390                .map_err(|_| ConnectError::new(crate::ConnectErrorKind::TimedOut))??;
391                tcp_stream.set_nodelay(true)?;
392
393                Connection::new(
394                    Box::new(tcp_stream),
395                    self.options.read_buffer_capacity.into(),
396                    self.connect_stats.clone(),
397                )
398            }
399        };
400
401        let tls_connection = |connection: Connection| async {
402            let tls_config = Arc::new(
403                tls::config_tls(&self.options)
404                    .await
405                    .map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Tls, err))?,
406            );
407            let tls_connector = tokio_rustls::TlsConnector::from(tls_config);
408
409            let domain = rustls_webpki::types::ServerName::try_from(server_addr.host())
410                .map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Tls, err))?;
411
412            let tls_stream = tls_connector
413                .connect(domain.to_owned(), connection.stream)
414                .await?;
415
416            Ok::<Connection, ConnectError>(Connection::new(
417                Box::new(tls_stream),
418                0,
419                self.connect_stats.clone(),
420            ))
421        };
422
423        // If `tls_first` was set, establish TLS connection before getting INFO.
424        // There is no point in  checking if tls is required, because
425        // the connection has to be be upgraded to TLS anyway as it's different flow.
426        if self.options.tls_first && !server_addr.is_websocket() {
427            connection = tls_connection(connection).await?;
428        }
429
430        let op = connection.read_op().await?;
431        let info = match op {
432            Some(ServerOp::Info(info)) => info,
433            Some(op) => {
434                return Err(ConnectError::with_source(
435                    crate::ConnectErrorKind::Io,
436                    format!("expected INFO, got {:?}", op),
437                ))
438            }
439            None => {
440                return Err(ConnectError::with_source(
441                    crate::ConnectErrorKind::Io,
442                    "expected INFO, got nothing",
443                ))
444            }
445        };
446
447        // If `tls_first` was not set, establish TLS connection if it is required.
448        if !self.options.tls_first
449            && !server_addr.is_websocket()
450            && (self.options.tls_required || info.tls_required || tls_required)
451        {
452            connection = tls_connection(connection).await?;
453        };
454
455        Ok((*info, connection))
456    }
457}
458
459#[cfg(test)]
460mod tests {
461    use super::*;
462
463    #[test]
464    fn reconnect_delay_callback_duration() {
465        let duration = reconnect_delay_callback_default(0);
466        assert_eq!(duration.as_millis(), 0);
467
468        let duration = reconnect_delay_callback_default(1);
469        assert_eq!(duration.as_millis(), 0);
470
471        let duration = reconnect_delay_callback_default(4);
472        assert_eq!(duration.as_millis(), 8);
473
474        let duration = reconnect_delay_callback_default(12);
475        assert_eq!(duration.as_millis(), 2048);
476
477        let duration = reconnect_delay_callback_default(13);
478        assert_eq!(duration.as_millis(), 4000);
479
480        // The max (4s) was reached and we shouldn't exceed it, regardless of the no of attempts
481        let duration = reconnect_delay_callback_default(50);
482        assert_eq!(duration.as_millis(), 4000);
483    }
484}