1use crate::error::GrpcClientError;
27use crate::{
28 BundleSource, SvidSource, TrustDomain, WorkloadApiClient, X509Bundle, X509BundleSet,
29 X509Context, X509Svid,
30};
31use arc_swap::ArcSwap;
32use log::{debug, error, info, warn};
33use std::error::Error as StdError;
34use std::fmt::Debug;
35use std::future::Future;
36use std::pin::Pin;
37use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
38use std::sync::Arc;
39use thiserror::Error;
40use tokio::sync::{watch, Mutex};
41use tokio::task::JoinHandle;
42use tokio::time::{sleep, Duration};
43use tokio_stream::StreamExt;
44use tokio_util::sync::CancellationToken;
45
46type ClientFuture =
47 Pin<Box<dyn Future<Output = Result<WorkloadApiClient, GrpcClientError>> + Send + 'static>>;
48
49type ClientFactory = Arc<dyn Fn() -> ClientFuture + Send + Sync + 'static>;
50
51pub trait SvidPicker: Debug + Send + Sync {
53 fn pick_svid<'a>(&self, svids: &'a [X509Svid]) -> Option<&'a X509Svid>;
57}
58
59#[derive(Clone, Copy, Debug)]
61pub struct ReconnectConfig {
62 pub min_backoff: Duration,
64 pub max_backoff: Duration,
66}
67
68impl Default for ReconnectConfig {
69 fn default() -> Self {
70 Self {
71 min_backoff: Duration::from_millis(200),
72 max_backoff: Duration::from_secs(10),
73 }
74 }
75}
76
77#[derive(Debug, Error)]
79pub enum X509SourceError {
80 #[error("grpc client error: {0}")]
82 Grpc(#[from] GrpcClientError),
83
84 #[error("no suitable svid found")]
86 NoSuitableSvid,
87
88 #[error("source is closed")]
90 Closed,
91
92 #[error("workload api stream ended")]
94 StreamEnded,
95}
96
97pub struct X509Source {
102 svid: ArcSwap<X509Svid>,
103 bundles: ArcSwap<X509BundleSet>,
104
105 svid_picker: Option<Box<dyn SvidPicker>>,
106 reconnect: ReconnectConfig,
107 make_client: ClientFactory,
108
109 closed: AtomicBool,
110 cancel: CancellationToken,
111
112 update_seq: AtomicU64,
113 update_tx: watch::Sender<u64>,
114 update_rx: watch::Receiver<u64>,
115
116 supervisor: Mutex<Option<JoinHandle<()>>>,
117}
118
119impl Debug for X509Source {
120 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121 f.debug_struct("X509Source")
122 .field("svid", &"<ArcSwap<X509Svid>>")
123 .field("bundles", &"<ArcSwap<X509BundleSet>>")
124 .field(
125 "svid_picker",
126 &self.svid_picker.as_ref().map(|_| "<SvidPicker>"),
127 )
128 .field("reconnect", &self.reconnect)
129 .field("make_client", &"<ClientFactory>")
130 .field("closed", &self.closed.load(Ordering::Relaxed))
131 .field("cancel", &self.cancel)
132 .field("update_seq", &self.update_seq)
133 .field("update_tx", &"<watch::Sender<u64>>")
134 .field("update_rx", &"<watch::Receiver<u64>>")
135 .finish()
136 }
137}
138
139pub struct X509SourceBuilder {
143 svid_picker: Option<Box<dyn SvidPicker>>,
144 reconnect: ReconnectConfig,
145 make_client: Option<ClientFactory>,
146}
147
148impl Debug for X509SourceBuilder {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 f.debug_struct("X509SourceBuilder")
151 .field(
152 "svid_picker",
153 &self.svid_picker.as_ref().map(|_| "<SvidPicker>"),
154 )
155 .field("reconnect", &self.reconnect)
156 .field(
157 "make_client",
158 &self.make_client.as_ref().map(|_| "<ClientFactory>"),
159 )
160 .finish()
161 }
162}
163
164impl Default for X509SourceBuilder {
165 fn default() -> Self {
166 Self::new()
167 }
168}
169
170impl X509SourceBuilder {
171 pub fn new() -> Self {
173 Self {
174 svid_picker: None,
175 reconnect: ReconnectConfig::default(),
176 make_client: None,
177 }
178 }
179
180 pub fn with_socket_path(mut self, socket_path: impl Into<Arc<str>>) -> Self {
185 let socket_path = socket_path.into();
186
187 let factory: ClientFactory = Arc::new(move || {
188 let socket_path = socket_path.clone();
189 Box::pin(async move { WorkloadApiClient::new_from_path(socket_path).await })
190 });
191
192 self.make_client = Some(factory);
193 self
194 }
195
196 pub fn with_client_factory(mut self, factory: ClientFactory) -> Self {
198 self.make_client = Some(factory);
199 self
200 }
201
202 pub fn with_picker(mut self, svid_picker: Box<dyn SvidPicker>) -> Self {
204 self.svid_picker = Some(svid_picker);
205 self
206 }
207
208 pub fn with_reconnect_backoff(mut self, min_backoff: Duration, max_backoff: Duration) -> Self {
210 self.reconnect = ReconnectConfig {
211 min_backoff,
212 max_backoff,
213 };
214 self
215 }
216
217 pub async fn build(self) -> Result<Arc<X509Source>, X509SourceError> {
219 let make_client = self
220 .make_client
221 .unwrap_or_else(|| Arc::new(|| Box::pin(async { WorkloadApiClient::default().await })));
222
223 X509Source::new_with(make_client, self.svid_picker, self.reconnect).await
224 }
225}
226
227impl X509Source {
228 pub async fn new() -> Result<Arc<Self>, X509SourceError> {
236 X509SourceBuilder::new().build().await
237 }
238
239 pub async fn shutdown(&self) -> Result<(), X509SourceError> {
241 if self.closed.swap(true, Ordering::AcqRel) {
242 return Err(X509SourceError::Closed);
243 }
244 self.cancel.cancel();
245
246 if let Some(handle) = self.supervisor.lock().await.take() {
247 let _ = handle.await;
248 }
249
250 Ok(())
251 }
252
253 pub fn updated(&self) -> watch::Receiver<u64> {
257 self.update_rx.clone()
258 }
259
260 pub fn svid(&self) -> Result<X509Svid, X509SourceError> {
262 self.assert_open()?;
263 Ok((**self.svid.load()).clone())
264 }
265}
266
267impl Drop for X509Source {
268 fn drop(&mut self) {
269 self.cancel.cancel();
271 }
272}
273
274impl SvidSource for X509Source {
275 type Item = X509Svid;
276
277 fn get_svid(&self) -> Result<Option<Self::Item>, Box<dyn StdError + Send + Sync + 'static>> {
278 self.assert_open().map_err(Box::new)?;
279 Ok(Some((**self.svid.load()).clone()))
280 }
281}
282
283impl BundleSource for X509Source {
284 type Item = X509Bundle;
285
286 fn get_bundle_for_trust_domain(
287 &self,
288 trust_domain: &TrustDomain,
289 ) -> Result<Option<Self::Item>, Box<dyn StdError + Send + Sync + 'static>> {
290 self.assert_open().map_err(Box::new)?;
291 Ok(self.bundles.load().get_bundle(trust_domain).cloned())
292 }
293}
294
295impl X509Source {
296 pub fn bundle_set(&self) -> Result<X509BundleSet, X509SourceError> {
298 self.assert_open()?;
299 Ok((**self.bundles.load()).clone())
300 }
301
302 pub fn x509_context(&self) -> Result<X509Context, X509SourceError> {
304 self.assert_open()?;
305
306 let svid = (**self.svid.load()).clone();
307 let bundles = (**self.bundles.load()).clone();
308
309 Ok(X509Context::new(vec![svid], bundles))
310 }
311}
312
313impl X509Source {
315 async fn new_with(
316 make_client: ClientFactory,
317 svid_picker: Option<Box<dyn SvidPicker>>,
318 reconnect: ReconnectConfig,
319 ) -> Result<Arc<X509Source>, X509SourceError> {
320 let (update_tx, update_rx) = watch::channel(0u64);
321 let cancel = CancellationToken::new();
322
323 let (initial_svid, initial_bundles) =
324 initial_sync_with_retry(&make_client, svid_picker.as_deref(), &cancel, reconnect)
325 .await?;
326
327 let src = Arc::new(Self {
328 svid: ArcSwap::from_pointee(initial_svid),
329 bundles: ArcSwap::from_pointee(initial_bundles),
330 svid_picker,
331 reconnect,
332 make_client,
333 closed: AtomicBool::new(false),
334 cancel,
335 update_seq: AtomicU64::new(0),
336 update_tx,
337 update_rx,
338 supervisor: Mutex::new(None),
339 });
340
341 let cloned = Arc::clone(&src);
342 let token = cloned.cancel.clone();
343 let handle = tokio::spawn(async move { cloned.run_update_supervisor(token).await });
344 *src.supervisor.lock().await = Some(handle);
345
346 Ok(src)
347 }
348
349 fn assert_open(&self) -> Result<(), X509SourceError> {
350 if self.closed.load(Ordering::Acquire) || self.cancel.is_cancelled() {
351 return Err(X509SourceError::Closed);
352 }
353 Ok(())
354 }
355
356 fn notify_update(&self) {
357 let next = self.update_seq.fetch_add(1, Ordering::Relaxed) + 1;
358 let _ = self.update_tx.send(next);
359 }
360
361 fn set_x509_context(&self, x509_context: X509Context) -> Result<(), X509SourceError> {
362 let picked = if let Some(ref picker) = self.svid_picker {
363 picker
364 .pick_svid(x509_context.svids())
365 .ok_or(X509SourceError::NoSuitableSvid)?
366 } else {
367 x509_context
368 .default_svid()
369 .ok_or(X509SourceError::NoSuitableSvid)?
370 };
371
372 self.svid.store(Arc::new(picked.clone()));
373 self.bundles
374 .store(Arc::new(x509_context.bundle_set().clone()));
375
376 self.notify_update();
377 Ok(())
378 }
379
380 async fn run_update_supervisor(&self, cancellation_token: CancellationToken) {
381 let mut backoff = self.reconnect.min_backoff;
382
383 loop {
384 if cancellation_token.is_cancelled() {
385 debug!("Cancellation signal received; stopping updates.");
386 return;
387 }
388
389 let mut client = match (self.make_client)().await {
390 Ok(c) => {
391 backoff = self.reconnect.min_backoff;
392 c
393 }
394 Err(e) => {
395 warn!("Failed to create WorkloadApiClient: {e}. Retrying in {backoff:?}.");
396 if sleep_or_cancel(&cancellation_token, backoff).await {
397 return;
398 }
399 backoff = next_backoff(backoff, self.reconnect.max_backoff);
400 continue;
401 }
402 };
403
404 let mut stream = match client.stream_x509_contexts().await {
405 Ok(s) => {
406 info!("Connected to Workload API X509 context stream.");
407 backoff = self.reconnect.min_backoff;
408 s
409 }
410 Err(e) => {
411 warn!(
412 "Failed to connect to Workload API stream: {e}. Retrying in {backoff:?}."
413 );
414 if sleep_or_cancel(&cancellation_token, backoff).await {
415 return;
416 }
417 backoff = next_backoff(backoff, self.reconnect.max_backoff);
418 continue;
419 }
420 };
421
422 loop {
423 if cancellation_token.is_cancelled() {
424 debug!("Cancellation signal received; stopping update loop.");
425 return;
426 }
427
428 match stream.next().await {
429 Some(Ok(ctx)) => match self.set_x509_context(ctx) {
430 Err(e) => {
431 error!("Error updating X509 context: {e}");
432 }
433 _ => {
434 debug!("X509 context updated.");
435 }
436 },
437 Some(Err(e)) => {
438 warn!("Workload API stream error: {e}. Reconnecting...");
439 break;
440 }
441 None => {
442 warn!("Workload API stream ended. Reconnecting...");
443 break;
444 }
445 }
446 }
447
448 if sleep_or_cancel(&cancellation_token, backoff).await {
449 return;
450 }
451 backoff = next_backoff(backoff, self.reconnect.max_backoff);
452 }
453 }
454}
455
456async fn initial_sync_with_retry(
457 make_client: &ClientFactory,
458 picker: Option<&dyn SvidPicker>,
459 cancel: &CancellationToken,
460 reconnect: ReconnectConfig,
461) -> Result<(X509Svid, X509BundleSet), X509SourceError> {
462 let mut backoff = reconnect.min_backoff;
463
464 loop {
465 if cancel.is_cancelled() {
466 return Err(X509SourceError::Closed);
467 }
468
469 match try_sync_once(make_client, picker).await {
470 Ok(v) => return Ok(v),
471 Err(e) => {
472 warn!("Initial sync failed: {e}. Retrying in {backoff:?}.");
473 if sleep_or_cancel(cancel, backoff).await {
474 return Err(X509SourceError::Closed);
475 }
476 backoff = next_backoff(backoff, reconnect.max_backoff);
477 }
478 }
479 }
480}
481
482async fn try_sync_once(
483 make_client: &ClientFactory,
484 picker: Option<&dyn SvidPicker>,
485) -> Result<(X509Svid, X509BundleSet), X509SourceError> {
486 let mut client = (make_client)().await.map_err(X509SourceError::Grpc)?;
487 let mut stream = client
488 .stream_x509_contexts()
489 .await
490 .map_err(X509SourceError::Grpc)?;
491
492 match stream.next().await {
493 Some(Ok(ctx)) => {
494 let picked = if let Some(p) = picker {
495 p.pick_svid(ctx.svids())
496 .ok_or(X509SourceError::NoSuitableSvid)?
497 } else {
498 ctx.default_svid().ok_or(X509SourceError::NoSuitableSvid)?
499 };
500 Ok((picked.clone(), ctx.bundle_set().clone()))
501 }
502 Some(Err(e)) => Err(X509SourceError::Grpc(e)),
503 None => Err(X509SourceError::StreamEnded),
504 }
505}
506
507async fn sleep_or_cancel(token: &CancellationToken, dur: Duration) -> bool {
508 tokio::select! {
509 _ = token.cancelled() => true,
510 _ = sleep(dur) => false,
511 }
512}
513
514fn next_backoff(current: Duration, max: Duration) -> Duration {
515 let doubled = current.saturating_mul(2);
516 if doubled > max {
517 max
518 } else {
519 doubled
520 }
521}