1use crate::auth::Auth;
15use crate::client::Statistics;
16use crate::connection::Connection;
17use crate::connection::State;
18#[cfg(feature = "websockets")]
19use crate::connection::WebSocketAdapter;
20use crate::options::CallbackArg1;
21use crate::tls;
22use crate::AuthError;
23use crate::ClientError;
24use crate::ClientOp;
25use crate::ConnectError;
26use crate::ConnectErrorKind;
27use crate::ConnectInfo;
28use crate::Event;
29use crate::Protocol;
30use crate::ServerAddr;
31use crate::ServerError;
32use crate::ServerInfo;
33use crate::ServerOp;
34use crate::SocketAddr;
35use crate::ToServerAddrs;
36use crate::LANG;
37use crate::VERSION;
38use base64::engine::general_purpose::URL_SAFE_NO_PAD;
39use base64::engine::Engine;
40use rand::seq::SliceRandom;
41use rand::thread_rng;
42use std::cmp;
43use std::io;
44use std::path::PathBuf;
45use std::sync::atomic::AtomicUsize;
46use std::sync::atomic::Ordering;
47use std::sync::Arc;
48use std::time::Duration;
49use tokio::net::TcpStream;
50use tokio::time::sleep;
51use tokio_rustls::rustls;
52
53pub(crate) struct ConnectorOptions {
54 pub(crate) tls_required: bool,
55 pub(crate) certificates: Vec<PathBuf>,
56 pub(crate) client_cert: Option<PathBuf>,
57 pub(crate) client_key: Option<PathBuf>,
58 pub(crate) tls_client_config: Option<rustls::ClientConfig>,
59 pub(crate) tls_first: bool,
60 pub(crate) auth: Auth,
61 pub(crate) no_echo: bool,
62 pub(crate) connection_timeout: Duration,
63 pub(crate) name: Option<String>,
64 pub(crate) ignore_discovered_servers: bool,
65 pub(crate) retain_servers_order: bool,
66 pub(crate) read_buffer_capacity: u16,
67 pub(crate) reconnect_delay_callback: Box<dyn Fn(usize) -> Duration + Send + Sync + 'static>,
68 pub(crate) auth_callback: Option<CallbackArg1<Vec<u8>, Result<Auth, AuthError>>>,
69 pub(crate) max_reconnects: Option<usize>,
70}
71
72pub(crate) struct Connector {
74 servers: Vec<(ServerAddr, usize)>,
76 options: ConnectorOptions,
77 pub(crate) connect_stats: Arc<Statistics>,
78 attempts: usize,
79 pub(crate) events_tx: tokio::sync::mpsc::Sender<Event>,
80 pub(crate) state_tx: tokio::sync::watch::Sender<State>,
81 pub(crate) max_payload: Arc<AtomicUsize>,
82}
83
84pub(crate) fn reconnect_delay_callback_default(attempts: usize) -> Duration {
85 if attempts <= 1 {
86 Duration::from_millis(0)
87 } else {
88 let exp: u32 = (attempts - 1).try_into().unwrap_or(u32::MAX);
89 let max = Duration::from_secs(4);
90 cmp::min(Duration::from_millis(2_u64.saturating_pow(exp)), max)
91 }
92}
93
94impl Connector {
95 pub(crate) fn new<A: ToServerAddrs>(
96 addrs: A,
97 options: ConnectorOptions,
98 events_tx: tokio::sync::mpsc::Sender<Event>,
99 state_tx: tokio::sync::watch::Sender<State>,
100 max_payload: Arc<AtomicUsize>,
101 connect_stats: Arc<Statistics>,
102 ) -> Result<Connector, io::Error> {
103 let servers = addrs.to_server_addrs()?.map(|addr| (addr, 0)).collect();
104
105 Ok(Connector {
106 attempts: 0,
107 servers,
108 options,
109 events_tx,
110 state_tx,
111 max_payload,
112 connect_stats,
113 })
114 }
115
116 pub(crate) async fn connect(&mut self) -> Result<(ServerInfo, Connection), ConnectError> {
117 loop {
118 match self.try_connect().await {
119 Ok(inner) => {
120 return Ok(inner);
121 }
122 Err(error) => match error.kind() {
123 ConnectErrorKind::MaxReconnects => {
124 return Err(ConnectError::with_source(
125 crate::ConnectErrorKind::MaxReconnects,
126 error,
127 ))
128 }
129 other => {
130 self.events_tx
131 .send(Event::ClientError(ClientError::Other(other.to_string())))
132 .await
133 .ok();
134 }
135 },
136 }
137 }
138 }
139
140 pub(crate) async fn try_connect(&mut self) -> Result<(ServerInfo, Connection), ConnectError> {
141 tracing::debug!("connecting");
142 let mut error = None;
143
144 let mut servers = self.servers.clone();
145 if !self.options.retain_servers_order {
146 servers.shuffle(&mut thread_rng());
147 servers.sort_by(|a, b| a.1.cmp(&b.1));
149 }
150
151 for (server_addr, _) in servers {
152 self.attempts += 1;
153 if let Some(max_reconnects) = self.options.max_reconnects {
154 if self.attempts > max_reconnects {
155 self.events_tx
156 .send(Event::ClientError(ClientError::MaxReconnects))
157 .await
158 .ok();
159 return Err(ConnectError::new(crate::ConnectErrorKind::MaxReconnects));
160 }
161 }
162
163 let duration = (self.options.reconnect_delay_callback)(self.attempts);
164
165 sleep(duration).await;
166
167 let socket_addrs = server_addr
168 .socket_addrs()
169 .await
170 .map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Dns, err))?;
171 for socket_addr in socket_addrs {
172 match self
173 .try_connect_to(
174 &socket_addr,
175 server_addr.tls_required(),
176 server_addr.clone(),
177 )
178 .await
179 {
180 Ok((server_info, mut connection)) => {
181 if !self.options.ignore_discovered_servers {
182 for url in &server_info.connect_urls {
183 let server_addr = url.parse::<ServerAddr>().map_err(|err| {
184 ConnectError::with_source(
185 crate::ConnectErrorKind::ServerParse,
186 err,
187 )
188 })?;
189 if !self.servers.iter().any(|(addr, _)| addr == &server_addr) {
190 self.servers.push((server_addr, 0));
191 }
192 }
193 }
194
195 let tls_required = self.options.tls_required || server_addr.tls_required();
196 let mut connect_info = ConnectInfo {
197 tls_required,
198 name: self.options.name.clone(),
199 pedantic: false,
200 verbose: false,
201 lang: LANG.to_string(),
202 version: VERSION.to_string(),
203 protocol: Protocol::Dynamic,
204 user: self.options.auth.username.to_owned(),
205 pass: self.options.auth.password.to_owned(),
206 auth_token: self.options.auth.token.to_owned(),
207 user_jwt: None,
208 nkey: None,
209 signature: None,
210 echo: !self.options.no_echo,
211 headers: true,
212 no_responders: true,
213 };
214
215 if let Some(nkey) = self.options.auth.nkey.as_ref() {
216 match nkeys::KeyPair::from_seed(nkey.as_str()) {
217 Ok(key_pair) => {
218 let nonce = server_info.nonce.clone();
219 match key_pair.sign(nonce.as_bytes()) {
220 Ok(signed) => {
221 connect_info.nkey = Some(key_pair.public_key());
222 connect_info.signature =
223 Some(URL_SAFE_NO_PAD.encode(signed));
224 }
225 Err(_) => {
226 return Err(ConnectError::new(
227 crate::ConnectErrorKind::Authentication,
228 ))
229 }
230 };
231 }
232 Err(_) => {
233 return Err(ConnectError::new(
234 crate::ConnectErrorKind::Authentication,
235 ))
236 }
237 }
238 }
239
240 if let Some(jwt) = self.options.auth.jwt.as_ref() {
241 if let Some(sign_fn) = self.options.auth.signature_callback.as_ref() {
242 match sign_fn.call(server_info.nonce.clone()).await {
243 Ok(sig) => {
244 connect_info.user_jwt = Some(jwt.clone());
245 connect_info.signature = Some(sig);
246 }
247 Err(_) => {
248 return Err(ConnectError::new(
249 crate::ConnectErrorKind::Authentication,
250 ))
251 }
252 }
253 }
254 }
255
256 if let Some(callback) = self.options.auth_callback.as_ref() {
257 let auth = callback
258 .call(server_info.nonce.as_bytes().to_vec())
259 .await
260 .map_err(|err| {
261 ConnectError::with_source(
262 crate::ConnectErrorKind::Authentication,
263 err,
264 )
265 })?;
266 connect_info.user = auth.username;
267 connect_info.pass = auth.password;
268 connect_info.user_jwt = auth.jwt;
269 connect_info.signature = auth
270 .signature
271 .map(|signature| URL_SAFE_NO_PAD.encode(signature));
272 connect_info.auth_token = auth.token;
273 connect_info.nkey = auth.nkey;
274 }
275
276 connection
277 .easy_write_and_flush(
278 [ClientOp::Connect(connect_info), ClientOp::Ping].iter(),
279 )
280 .await?;
281
282 match connection.read_op().await? {
283 Some(ServerOp::Error(err)) => match err {
284 ServerError::AuthorizationViolation => {
285 return Err(ConnectError::with_source(
286 crate::ConnectErrorKind::AuthorizationViolation,
287 err,
288 ));
289 }
290 err => {
291 return Err(ConnectError::with_source(
292 crate::ConnectErrorKind::Io,
293 err,
294 ));
295 }
296 },
297 Some(_) => {
298 tracing::debug!("connected to {}", server_info.port);
299 self.attempts = 0;
300 self.connect_stats.connects.add(1, Ordering::Relaxed);
301 self.events_tx.send(Event::Connected).await.ok();
302 self.state_tx.send(State::Connected).ok();
303 self.max_payload.store(
304 server_info.max_payload,
305 std::sync::atomic::Ordering::Relaxed,
306 );
307 return Ok((server_info, connection));
308 }
309 None => {
310 return Err(ConnectError::with_source(
311 crate::ConnectErrorKind::Io,
312 "broken pipe",
313 ))
314 }
315 }
316 }
317
318 Err(inner) => error.replace(inner),
319 };
320 }
321 }
322
323 Err(error.unwrap())
324 }
325
326 pub(crate) async fn try_connect_to(
327 &self,
328 socket_addr: &SocketAddr,
329 tls_required: bool,
330 server_addr: ServerAddr,
331 ) -> Result<(ServerInfo, Connection), ConnectError> {
332 let mut connection = match server_addr.scheme() {
333 #[cfg(feature = "websockets")]
334 "ws" => {
335 let ws = tokio::time::timeout(
336 self.options.connection_timeout,
337 tokio_websockets::client::Builder::new()
338 .uri(format!("{}://{}", server_addr.scheme(), socket_addr).as_str())
339 .map_err(|err| {
340 ConnectError::with_source(crate::ConnectErrorKind::ServerParse, err)
341 })?
342 .connect(),
343 )
344 .await
345 .map_err(|_| ConnectError::new(crate::ConnectErrorKind::TimedOut))?
346 .map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Io, err))?;
347
348 let con = WebSocketAdapter::new(ws.0);
349 Connection::new(Box::new(con), 0, self.connect_stats.clone())
350 }
351 #[cfg(feature = "websockets")]
352 "wss" => {
353 let domain = rustls_webpki::types::ServerName::try_from(server_addr.host())
354 .map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Tls, err))?;
355 let tls_config =
356 Arc::new(tls::config_tls(&self.options).await.map_err(|err| {
357 ConnectError::with_source(crate::ConnectErrorKind::Tls, err)
358 })?);
359 let tls_connector = tokio_rustls::TlsConnector::from(tls_config);
360 let ws = tokio::time::timeout(
361 self.options.connection_timeout,
362 tokio_websockets::client::Builder::new()
363 .connector(&tokio_websockets::Connector::Rustls(tls_connector))
364 .uri(
365 format!(
366 "{}://{}:{}",
367 server_addr.scheme(),
368 domain.to_str(),
369 server_addr.port()
370 )
371 .as_str(),
372 )
373 .map_err(|err| {
374 ConnectError::with_source(crate::ConnectErrorKind::ServerParse, err)
375 })?
376 .connect(),
377 )
378 .await
379 .map_err(|_| ConnectError::new(crate::ConnectErrorKind::TimedOut))?
380 .map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Io, err))?;
381 let con = WebSocketAdapter::new(ws.0);
382 Connection::new(Box::new(con), 0, self.connect_stats.clone())
383 }
384 _ => {
385 let tcp_stream = tokio::time::timeout(
386 self.options.connection_timeout,
387 TcpStream::connect(socket_addr),
388 )
389 .await
390 .map_err(|_| ConnectError::new(crate::ConnectErrorKind::TimedOut))??;
391 tcp_stream.set_nodelay(true)?;
392
393 Connection::new(
394 Box::new(tcp_stream),
395 self.options.read_buffer_capacity.into(),
396 self.connect_stats.clone(),
397 )
398 }
399 };
400
401 let tls_connection = |connection: Connection| async {
402 let tls_config = Arc::new(
403 tls::config_tls(&self.options)
404 .await
405 .map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Tls, err))?,
406 );
407 let tls_connector = tokio_rustls::TlsConnector::from(tls_config);
408
409 let domain = rustls_webpki::types::ServerName::try_from(server_addr.host())
410 .map_err(|err| ConnectError::with_source(crate::ConnectErrorKind::Tls, err))?;
411
412 let tls_stream = tls_connector
413 .connect(domain.to_owned(), connection.stream)
414 .await?;
415
416 Ok::<Connection, ConnectError>(Connection::new(
417 Box::new(tls_stream),
418 0,
419 self.connect_stats.clone(),
420 ))
421 };
422
423 if self.options.tls_first && !server_addr.is_websocket() {
427 connection = tls_connection(connection).await?;
428 }
429
430 let op = connection.read_op().await?;
431 let info = match op {
432 Some(ServerOp::Info(info)) => info,
433 Some(op) => {
434 return Err(ConnectError::with_source(
435 crate::ConnectErrorKind::Io,
436 format!("expected INFO, got {:?}", op),
437 ))
438 }
439 None => {
440 return Err(ConnectError::with_source(
441 crate::ConnectErrorKind::Io,
442 "expected INFO, got nothing",
443 ))
444 }
445 };
446
447 if !self.options.tls_first
449 && !server_addr.is_websocket()
450 && (self.options.tls_required || info.tls_required || tls_required)
451 {
452 connection = tls_connection(connection).await?;
453 };
454
455 Ok((*info, connection))
456 }
457}
458
459#[cfg(test)]
460mod tests {
461 use super::*;
462
463 #[test]
464 fn reconnect_delay_callback_duration() {
465 let duration = reconnect_delay_callback_default(0);
466 assert_eq!(duration.as_millis(), 0);
467
468 let duration = reconnect_delay_callback_default(1);
469 assert_eq!(duration.as_millis(), 0);
470
471 let duration = reconnect_delay_callback_default(4);
472 assert_eq!(duration.as_millis(), 8);
473
474 let duration = reconnect_delay_callback_default(12);
475 assert_eq!(duration.as_millis(), 2048);
476
477 let duration = reconnect_delay_callback_default(13);
478 assert_eq!(duration.as_millis(), 4000);
479
480 let duration = reconnect_delay_callback_default(50);
482 assert_eq!(duration.as_millis(), 4000);
483 }
484}