1use std::collections::HashMap;
8use std::fmt;
9use std::io::{Read, Write};
10use std::mem;
11use std::net::{Shutdown, TcpStream};
12use std::time::{Duration, Instant};
13
14#[cfg(feature = "security")]
15use openssl::ssl::SslConnector;
16
17use crate::error::Result;
18
19#[cfg(feature = "security")]
24pub struct SecurityConfig {
25 connector: SslConnector,
26 verify_hostname: bool,
27}
28
29#[cfg(feature = "security")]
30impl SecurityConfig {
31 pub fn new(connector: SslConnector) -> Self {
33 SecurityConfig {
34 connector,
35 verify_hostname: true,
36 }
37 }
38
39 pub fn with_hostname_verification(self, verify_hostname: bool) -> SecurityConfig {
41 SecurityConfig {
42 verify_hostname,
43 ..self
44 }
45 }
46}
47
48#[cfg(feature = "security")]
49impl fmt::Debug for SecurityConfig {
50 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51 write!(
52 f,
53 "SecurityConfig {{ verify_hostname: {} }}",
54 self.verify_hostname
55 )
56 }
57}
58
59struct Pooled<T> {
62 last_checkout: Instant,
63 item: T,
64}
65
66impl<T> Pooled<T> {
67 fn new(last_checkout: Instant, item: T) -> Self {
68 Pooled {
69 last_checkout,
70 item,
71 }
72 }
73}
74
75impl<T: fmt::Debug> fmt::Debug for Pooled<T> {
76 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77 write!(
78 f,
79 "Pooled {{ last_checkout: {:?}, item: {:?} }}",
80 self.last_checkout, self.item
81 )
82 }
83}
84
85#[derive(Debug)]
86pub struct Config {
87 rw_timeout: Option<Duration>,
88 idle_timeout: Duration,
89 #[cfg(feature = "security")]
90 security_config: Option<SecurityConfig>,
91}
92
93impl Config {
94 #[cfg(not(feature = "security"))]
95 fn new_conn(&self, id: u32, host: &str) -> Result<KafkaConnection> {
96 KafkaConnection::new(id, host, self.rw_timeout).map(|c| {
97 debug!("Established: {:?}", c);
98 c
99 })
100 }
101
102 #[cfg(feature = "security")]
103 fn new_conn(&self, id: u32, host: &str) -> Result<KafkaConnection> {
104 KafkaConnection::new(
105 id,
106 host,
107 self.rw_timeout,
108 self.security_config
109 .as_ref()
110 .map(|c| (c.connector.clone(), c.verify_hostname)),
111 )
112 .map(|c| {
113 debug!("Established: {:?}", c);
114 c
115 })
116 }
117}
118
119#[derive(Debug)]
120struct State {
121 num_conns: u32,
122}
123
124impl State {
125 fn new() -> State {
126 State { num_conns: 0 }
127 }
128
129 fn next_conn_id(&mut self) -> u32 {
130 let c = self.num_conns;
131 self.num_conns = self.num_conns.wrapping_add(1);
132 c
133 }
134}
135
136#[derive(Debug)]
137pub struct Connections {
138 conns: HashMap<String, Pooled<KafkaConnection>>,
139 state: State,
140 config: Config,
141}
142
143impl Connections {
144 #[cfg(not(feature = "security"))]
145 pub fn new(rw_timeout: Option<Duration>, idle_timeout: Duration) -> Connections {
146 Connections {
147 conns: HashMap::new(),
148 state: State::new(),
149 config: Config {
150 rw_timeout,
151 idle_timeout,
152 },
153 }
154 }
155
156 #[cfg(feature = "security")]
157 pub fn new(rw_timeout: Option<Duration>, idle_timeout: Duration) -> Connections {
158 Self::new_with_security(rw_timeout, idle_timeout, None)
159 }
160
161 #[cfg(feature = "security")]
162 pub fn new_with_security(
163 rw_timeout: Option<Duration>,
164 idle_timeout: Duration,
165 security: Option<SecurityConfig>,
166 ) -> Connections {
167 Connections {
168 conns: HashMap::new(),
169 state: State::new(),
170 config: Config {
171 rw_timeout,
172 idle_timeout,
173 security_config: security,
174 },
175 }
176 }
177
178 pub fn set_idle_timeout(&mut self, idle_timeout: Duration) {
179 self.config.idle_timeout = idle_timeout;
180 }
181
182 pub fn idle_timeout(&self) -> Duration {
183 self.config.idle_timeout
184 }
185
186 pub fn get_conn<'a>(&'a mut self, host: &str, now: Instant) -> Result<&'a mut KafkaConnection> {
187 if let Some(conn) = self.conns.get_mut(host) {
188 if now.duration_since(conn.last_checkout) >= self.config.idle_timeout {
189 debug!("Idle timeout reached: {:?}", conn.item);
190 let new_conn = self.config.new_conn(self.state.next_conn_id(), host)?;
191 let _ = conn.item.shutdown();
192 conn.item = new_conn;
193 }
194 conn.last_checkout = now;
195 let kconn: &mut KafkaConnection = &mut conn.item;
196 return Ok(unsafe { mem::transmute(kconn) });
201 }
202 let cid = self.state.next_conn_id();
203 self.conns.insert(
204 host.to_owned(),
205 Pooled::new(now, self.config.new_conn(cid, host)?),
206 );
207 Ok(&mut self.conns.get_mut(host).unwrap().item)
208 }
209
210 pub fn get_conn_any(&mut self, now: Instant) -> Option<&mut KafkaConnection> {
211 for (host, conn) in &mut self.conns {
212 if now.duration_since(conn.last_checkout) >= self.config.idle_timeout {
213 debug!("Idle timeout reached: {:?}", conn.item);
214 let new_conn_id = self.state.next_conn_id();
215 let new_conn = match self.config.new_conn(new_conn_id, host.as_str()) {
216 Ok(new_conn) => {
217 let _ = conn.item.shutdown();
218 new_conn
219 }
220 Err(e) => {
221 warn!("Failed to establish connection to {}: {:?}", host, e);
222 continue;
223 }
224 };
225 conn.item = new_conn;
226 }
227 conn.last_checkout = now;
228 let kconn: &mut KafkaConnection = &mut conn.item;
229 return Some(kconn);
230 }
231 None
232 }
233}
234
235trait IsSecured {
238 fn is_secured(&self) -> bool;
239}
240
241#[cfg(not(feature = "security"))]
242type KafkaStream = TcpStream;
243
244#[cfg(not(feature = "security"))]
245impl IsSecured for KafkaStream {
246 fn is_secured(&self) -> bool {
247 false
248 }
249}
250
251#[cfg(feature = "security")]
252use self::openssled::KafkaStream;
253
254#[cfg(feature = "security")]
255mod openssled {
256 use std::io::{self, Read, Write};
257 use std::net::{Shutdown, TcpStream};
258 use std::time::Duration;
259
260 use openssl::ssl::SslStream;
261
262 use super::IsSecured;
263
264 pub enum KafkaStream {
265 Plain(TcpStream),
266 Ssl(SslStream<TcpStream>),
267 }
268
269 impl IsSecured for KafkaStream {
270 fn is_secured(&self) -> bool {
271 matches!(self, KafkaStream::Ssl(_))
272 }
273 }
274
275 impl KafkaStream {
276 fn get_ref(&self) -> &TcpStream {
277 match *self {
278 KafkaStream::Plain(ref s) => s,
279 KafkaStream::Ssl(ref s) => s.get_ref(),
280 }
281 }
282
283 pub fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
284 self.get_ref().set_read_timeout(dur)
285 }
286
287 pub fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
288 self.get_ref().set_write_timeout(dur)
289 }
290
291 pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
292 self.get_ref().shutdown(how)
293 }
294 }
295
296 impl Read for KafkaStream {
297 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
298 match *self {
299 KafkaStream::Plain(ref mut s) => s.read(buf),
300 KafkaStream::Ssl(ref mut s) => s.read(buf),
301 }
302 }
303 }
304
305 impl Write for KafkaStream {
306 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
307 match *self {
308 KafkaStream::Plain(ref mut s) => s.write(buf),
309 KafkaStream::Ssl(ref mut s) => s.write(buf),
310 }
311 }
312 fn flush(&mut self) -> io::Result<()> {
313 match *self {
314 KafkaStream::Plain(ref mut s) => s.flush(),
315 KafkaStream::Ssl(ref mut s) => s.flush(),
316 }
317 }
318 }
319}
320
321pub struct KafkaConnection {
323 id: u32,
326 host: String,
328 stream: KafkaStream,
330}
331
332impl fmt::Debug for KafkaConnection {
333 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
334 write!(
335 f,
336 "KafkaConnection {{ id: {}, secured: {}, host: \"{}\" }}",
337 self.id,
338 self.stream.is_secured(),
339 self.host
340 )
341 }
342}
343
344impl KafkaConnection {
345 pub fn send(&mut self, msg: &[u8]) -> Result<usize> {
346 let r = self.stream.write(msg).map_err(From::from);
347 trace!("Sent {} bytes to: {:?} => {:?}", msg.len(), self, r);
348 r
349 }
350
351 pub fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
352 let r = (&mut self.stream).read_exact(buf).map_err(From::from);
353 trace!("Read {} bytes from: {:?} => {:?}", buf.len(), self, r);
354 r
355 }
356
357 pub fn read_exact_alloc(&mut self, size: u64) -> Result<Vec<u8>> {
358 let mut buffer = vec![0; size as usize];
359 self.read_exact(buffer.as_mut_slice())?;
360 Ok(buffer)
361 }
362
363 fn shutdown(&mut self) -> Result<()> {
364 let r = self.stream.shutdown(Shutdown::Both);
365 debug!("Shut down: {:?} => {:?}", self, r);
366 r.map_err(From::from)
367 }
368
369 fn from_stream(
370 stream: KafkaStream,
371 id: u32,
372 host: &str,
373 rw_timeout: Option<Duration>,
374 ) -> Result<KafkaConnection> {
375 stream.set_read_timeout(rw_timeout)?;
376 stream.set_write_timeout(rw_timeout)?;
377 Ok(KafkaConnection {
378 id,
379 host: host.to_owned(),
380 stream,
381 })
382 }
383
384 #[cfg(not(feature = "security"))]
385 fn new(id: u32, host: &str, rw_timeout: Option<Duration>) -> Result<KafkaConnection> {
386 KafkaConnection::from_stream(TcpStream::connect(host)?, id, host, rw_timeout)
387 }
388
389 #[cfg(feature = "security")]
390 fn new(
391 id: u32,
392 host: &str,
393 rw_timeout: Option<Duration>,
394 security: Option<(SslConnector, bool)>,
395 ) -> Result<KafkaConnection> {
396 use crate::Error;
397
398 let stream = TcpStream::connect(host)?;
399 let stream = match security {
400 Some((connector, verify_hostname)) => {
401 if !verify_hostname {
402 connector
403 .configure()
404 .map_err(openssl::ssl::Error::from)?
405 .set_verify_hostname(false);
406 }
407 let domain = match host.rfind(':') {
408 None => host,
409 Some(i) => &host[..i],
410 };
411 let connection = connector.connect(domain, stream).map_err(|err| match err {
412 openssl::ssl::HandshakeError::SetupFailure(err) => {
413 Error::from(openssl::ssl::Error::from(err))
414 }
415 openssl::ssl::HandshakeError::Failure(err) => Error::from(err.into_error()),
416 openssl::ssl::HandshakeError::WouldBlock(err) => Error::from(err.into_error()),
417 })?;
418 KafkaStream::Ssl(connection)
419 }
420 None => KafkaStream::Plain(stream),
421 };
422 KafkaConnection::from_stream(stream, id, host, rw_timeout)
423 }
424}