spiffe/workload_api/
x509_source.rs

1//! Live X.509 SVID and bundle source backed by the SPIFFE Workload API.
2//!
3//! `X509Source` performs an initial sync before becoming usable, then watches the Workload API
4//! for rotations. Transient failures are handled by reconnecting with exponential backoff.
5//!
6//! Use [`X509Source::updated`] to subscribe to change notifications, and [`X509Source::shutdown`]
7//! to stop background tasks.
8//!
9//! # Example
10//!
11//! ```no_run
12//! use spiffe::{BundleSource, TrustDomain, X509Source};
13//!
14//! # async fn example() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
15//! let source = X509Source::new().await?;
16//!
17//! let svid = source.svid()?;
18//! let td = TrustDomain::new("example.org")?;
19//! let bundle = source
20//!     .get_bundle_for_trust_domain(&td)?
21//!     .ok_or("missing bundle")?;
22//!
23//! # Ok(())
24//! # }
25//! ```
26use crate::error::GrpcClientError;
27use crate::{
28    BundleSource, SvidSource, TrustDomain, WorkloadApiClient, X509Bundle, X509BundleSet,
29    X509Context, X509Svid,
30};
31use arc_swap::ArcSwap;
32use log::{debug, error, info, warn};
33use std::error::Error as StdError;
34use std::fmt::Debug;
35use std::future::Future;
36use std::pin::Pin;
37use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
38use std::sync::Arc;
39use thiserror::Error;
40use tokio::sync::{watch, Mutex};
41use tokio::task::JoinHandle;
42use tokio::time::{sleep, Duration};
43use tokio_stream::StreamExt;
44use tokio_util::sync::CancellationToken;
45
46type ClientFuture =
47    Pin<Box<dyn Future<Output = Result<WorkloadApiClient, GrpcClientError>> + Send + 'static>>;
48
49type ClientFactory = Arc<dyn Fn() -> ClientFuture + Send + Sync + 'static>;
50
51/// Strategy for selecting an X.509 SVID when multiple SVIDs are available.
52pub trait SvidPicker: Debug + Send + Sync {
53    /// Selects an SVID from the provided slice.
54    ///
55    /// Returning `None` indicates that no suitable SVID could be selected.
56    fn pick_svid<'a>(&self, svids: &'a [X509Svid]) -> Option<&'a X509Svid>;
57}
58
59/// Reconnect/backoff configuration.
60#[derive(Clone, Copy, Debug)]
61pub struct ReconnectConfig {
62    /// Initial delay before retrying.
63    pub min_backoff: Duration,
64    /// Maximum delay between retries.
65    pub max_backoff: Duration,
66}
67
68impl Default for ReconnectConfig {
69    fn default() -> Self {
70        Self {
71            min_backoff: Duration::from_millis(200),
72            max_backoff: Duration::from_secs(10),
73        }
74    }
75}
76
77/// Errors returned by `X509Source`.
78#[derive(Debug, Error)]
79pub enum X509SourceError {
80    /// Workload API client error.
81    #[error("grpc client error: {0}")]
82    Grpc(#[from] GrpcClientError),
83
84    /// No SVID could be selected from the received context.
85    #[error("no suitable svid found")]
86    NoSuitableSvid,
87
88    /// The source was closed.
89    #[error("source is closed")]
90    Closed,
91
92    /// The workload API stream ended.
93    #[error("workload api stream ended")]
94    StreamEnded,
95}
96
97/// Live source of X.509 SVIDs and bundles from the SPIFFE Workload API.
98///
99/// `X509Source` performs an initial sync before returning from [`X509Source::new`].
100/// Updates are applied atomically and can be observed via [`X509Source::updated`].
101pub struct X509Source {
102    svid: ArcSwap<X509Svid>,
103    bundles: ArcSwap<X509BundleSet>,
104
105    svid_picker: Option<Box<dyn SvidPicker>>,
106    reconnect: ReconnectConfig,
107    make_client: ClientFactory,
108
109    closed: AtomicBool,
110    cancel: CancellationToken,
111
112    update_seq: AtomicU64,
113    update_tx: watch::Sender<u64>,
114    update_rx: watch::Receiver<u64>,
115
116    supervisor: Mutex<Option<JoinHandle<()>>>,
117}
118
119impl Debug for X509Source {
120    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121        f.debug_struct("X509Source")
122            .field("svid", &"<ArcSwap<X509Svid>>")
123            .field("bundles", &"<ArcSwap<X509BundleSet>>")
124            .field(
125                "svid_picker",
126                &self.svid_picker.as_ref().map(|_| "<SvidPicker>"),
127            )
128            .field("reconnect", &self.reconnect)
129            .field("make_client", &"<ClientFactory>")
130            .field("closed", &self.closed.load(Ordering::Relaxed))
131            .field("cancel", &self.cancel)
132            .field("update_seq", &self.update_seq)
133            .field("update_tx", &"<watch::Sender<u64>>")
134            .field("update_rx", &"<watch::Receiver<u64>>")
135            .finish()
136    }
137}
138
139/// Builder for [`X509Source`].
140///
141/// Use this when you need explicit configuration (socket path, picker, backoff).
142pub struct X509SourceBuilder {
143    svid_picker: Option<Box<dyn SvidPicker>>,
144    reconnect: ReconnectConfig,
145    make_client: Option<ClientFactory>,
146}
147
148impl Debug for X509SourceBuilder {
149    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150        f.debug_struct("X509SourceBuilder")
151            .field(
152                "svid_picker",
153                &self.svid_picker.as_ref().map(|_| "<SvidPicker>"),
154            )
155            .field("reconnect", &self.reconnect)
156            .field(
157                "make_client",
158                &self.make_client.as_ref().map(|_| "<ClientFactory>"),
159            )
160            .finish()
161    }
162}
163
164impl Default for X509SourceBuilder {
165    fn default() -> Self {
166        Self::new()
167    }
168}
169
170impl X509SourceBuilder {
171    /// Creates a new `X509SourceBuilder`.
172    pub fn new() -> Self {
173        Self {
174            svid_picker: None,
175            reconnect: ReconnectConfig::default(),
176            make_client: None,
177        }
178    }
179
180    /// Sets the Workload API socket path.
181    ///
182    /// Accepts either a filesystem path (e.g. `/tmp/spire-agent/public/api.sock`)
183    /// or a full URI (e.g. `unix:///tmp/spire-agent/public/api.sock`).
184    pub fn with_socket_path(mut self, socket_path: impl Into<Arc<str>>) -> Self {
185        let socket_path = socket_path.into();
186
187        let factory: ClientFactory = Arc::new(move || {
188            let socket_path = socket_path.clone();
189            Box::pin(async move { WorkloadApiClient::new_from_path(socket_path).await })
190        });
191
192        self.make_client = Some(factory);
193        self
194    }
195
196    /// Sets a custom client factory.
197    pub fn with_client_factory(mut self, factory: ClientFactory) -> Self {
198        self.make_client = Some(factory);
199        self
200    }
201
202    /// Sets a custom SVID selection strategy.
203    pub fn with_picker(mut self, svid_picker: Box<dyn SvidPicker>) -> Self {
204        self.svid_picker = Some(svid_picker);
205        self
206    }
207
208    /// Sets the reconnect backoff range.
209    pub fn with_reconnect_backoff(mut self, min_backoff: Duration, max_backoff: Duration) -> Self {
210        self.reconnect = ReconnectConfig {
211            min_backoff,
212            max_backoff,
213        };
214        self
215    }
216
217    /// Builds a ready-to-use [`X509Source`].
218    pub async fn build(self) -> Result<Arc<X509Source>, X509SourceError> {
219        let make_client = self
220            .make_client
221            .unwrap_or_else(|| Arc::new(|| Box::pin(async { WorkloadApiClient::default().await })));
222
223        X509Source::new_with(make_client, self.svid_picker, self.reconnect).await
224    }
225}
226
227impl X509Source {
228    /// Creates an `X509Source` using the default Workload API endpoint.
229    ///
230    /// The endpoint is resolved from `SPIFFE_ENDPOINT_SOCKET`. The source selects the default
231    /// X.509 SVID when multiple SVIDs are available.
232    ///
233    /// On success, the returned source is already synchronized with the agent and will keep
234    /// updating in the background until it is closed.
235    pub async fn new() -> Result<Arc<Self>, X509SourceError> {
236        X509SourceBuilder::new().build().await
237    }
238
239    /// Cancels background tasks and waits for termination.
240    pub async fn shutdown(&self) -> Result<(), X509SourceError> {
241        if self.closed.swap(true, Ordering::AcqRel) {
242            return Err(X509SourceError::Closed);
243        }
244        self.cancel.cancel();
245
246        if let Some(handle) = self.supervisor.lock().await.take() {
247            let _ = handle.await;
248        }
249
250        Ok(())
251    }
252
253    /// Returns a receiver that is notified on each successful update.
254    ///
255    /// The received value is a monotonically increasing counter.
256    pub fn updated(&self) -> watch::Receiver<u64> {
257        self.update_rx.clone()
258    }
259
260    /// Returns the current X.509 SVID.
261    pub fn svid(&self) -> Result<X509Svid, X509SourceError> {
262        self.assert_open()?;
263        Ok((**self.svid.load()).clone())
264    }
265}
266
267impl Drop for X509Source {
268    fn drop(&mut self) {
269        // best-effort cancellation
270        self.cancel.cancel();
271    }
272}
273
274impl SvidSource for X509Source {
275    type Item = X509Svid;
276
277    fn get_svid(&self) -> Result<Option<Self::Item>, Box<dyn StdError + Send + Sync + 'static>> {
278        self.assert_open().map_err(Box::new)?;
279        Ok(Some((**self.svid.load()).clone()))
280    }
281}
282
283impl BundleSource for X509Source {
284    type Item = X509Bundle;
285
286    fn get_bundle_for_trust_domain(
287        &self,
288        trust_domain: &TrustDomain,
289    ) -> Result<Option<Self::Item>, Box<dyn StdError + Send + Sync + 'static>> {
290        self.assert_open().map_err(Box::new)?;
291        Ok(self.bundles.load().get_bundle(trust_domain).cloned())
292    }
293}
294
295impl X509Source {
296    /// Returns the current X.509 bundle set.
297    pub fn bundle_set(&self) -> Result<X509BundleSet, X509SourceError> {
298        self.assert_open()?;
299        Ok((**self.bundles.load()).clone())
300    }
301
302    /// Returns the current X.509 context (SVID + bundles) as a single value.
303    pub fn x509_context(&self) -> Result<X509Context, X509SourceError> {
304        self.assert_open()?;
305
306        let svid = (**self.svid.load()).clone();
307        let bundles = (**self.bundles.load()).clone();
308
309        Ok(X509Context::new(vec![svid], bundles))
310    }
311}
312
313// private/internal
314impl X509Source {
315    async fn new_with(
316        make_client: ClientFactory,
317        svid_picker: Option<Box<dyn SvidPicker>>,
318        reconnect: ReconnectConfig,
319    ) -> Result<Arc<X509Source>, X509SourceError> {
320        let (update_tx, update_rx) = watch::channel(0u64);
321        let cancel = CancellationToken::new();
322
323        let (initial_svid, initial_bundles) =
324            initial_sync_with_retry(&make_client, svid_picker.as_deref(), &cancel, reconnect)
325                .await?;
326
327        let src = Arc::new(Self {
328            svid: ArcSwap::from_pointee(initial_svid),
329            bundles: ArcSwap::from_pointee(initial_bundles),
330            svid_picker,
331            reconnect,
332            make_client,
333            closed: AtomicBool::new(false),
334            cancel,
335            update_seq: AtomicU64::new(0),
336            update_tx,
337            update_rx,
338            supervisor: Mutex::new(None),
339        });
340
341        let cloned = Arc::clone(&src);
342        let token = cloned.cancel.clone();
343        let handle = tokio::spawn(async move { cloned.run_update_supervisor(token).await });
344        *src.supervisor.lock().await = Some(handle);
345
346        Ok(src)
347    }
348
349    fn assert_open(&self) -> Result<(), X509SourceError> {
350        if self.closed.load(Ordering::Acquire) || self.cancel.is_cancelled() {
351            return Err(X509SourceError::Closed);
352        }
353        Ok(())
354    }
355
356    fn notify_update(&self) {
357        let next = self.update_seq.fetch_add(1, Ordering::Relaxed) + 1;
358        let _ = self.update_tx.send(next);
359    }
360
361    fn set_x509_context(&self, x509_context: X509Context) -> Result<(), X509SourceError> {
362        let picked = if let Some(ref picker) = self.svid_picker {
363            picker
364                .pick_svid(x509_context.svids())
365                .ok_or(X509SourceError::NoSuitableSvid)?
366        } else {
367            x509_context
368                .default_svid()
369                .ok_or(X509SourceError::NoSuitableSvid)?
370        };
371
372        self.svid.store(Arc::new(picked.clone()));
373        self.bundles
374            .store(Arc::new(x509_context.bundle_set().clone()));
375
376        self.notify_update();
377        Ok(())
378    }
379
380    async fn run_update_supervisor(&self, cancellation_token: CancellationToken) {
381        let mut backoff = self.reconnect.min_backoff;
382
383        loop {
384            if cancellation_token.is_cancelled() {
385                debug!("Cancellation signal received; stopping updates.");
386                return;
387            }
388
389            let mut client = match (self.make_client)().await {
390                Ok(c) => {
391                    backoff = self.reconnect.min_backoff;
392                    c
393                }
394                Err(e) => {
395                    warn!("Failed to create WorkloadApiClient: {e}. Retrying in {backoff:?}.");
396                    if sleep_or_cancel(&cancellation_token, backoff).await {
397                        return;
398                    }
399                    backoff = next_backoff(backoff, self.reconnect.max_backoff);
400                    continue;
401                }
402            };
403
404            let mut stream = match client.stream_x509_contexts().await {
405                Ok(s) => {
406                    info!("Connected to Workload API X509 context stream.");
407                    backoff = self.reconnect.min_backoff;
408                    s
409                }
410                Err(e) => {
411                    warn!(
412                        "Failed to connect to Workload API stream: {e}. Retrying in {backoff:?}."
413                    );
414                    if sleep_or_cancel(&cancellation_token, backoff).await {
415                        return;
416                    }
417                    backoff = next_backoff(backoff, self.reconnect.max_backoff);
418                    continue;
419                }
420            };
421
422            loop {
423                if cancellation_token.is_cancelled() {
424                    debug!("Cancellation signal received; stopping update loop.");
425                    return;
426                }
427
428                match stream.next().await {
429                    Some(Ok(ctx)) => match self.set_x509_context(ctx) {
430                        Err(e) => {
431                            error!("Error updating X509 context: {e}");
432                        }
433                        _ => {
434                            debug!("X509 context updated.");
435                        }
436                    },
437                    Some(Err(e)) => {
438                        warn!("Workload API stream error: {e}. Reconnecting...");
439                        break;
440                    }
441                    None => {
442                        warn!("Workload API stream ended. Reconnecting...");
443                        break;
444                    }
445                }
446            }
447
448            if sleep_or_cancel(&cancellation_token, backoff).await {
449                return;
450            }
451            backoff = next_backoff(backoff, self.reconnect.max_backoff);
452        }
453    }
454}
455
456async fn initial_sync_with_retry(
457    make_client: &ClientFactory,
458    picker: Option<&dyn SvidPicker>,
459    cancel: &CancellationToken,
460    reconnect: ReconnectConfig,
461) -> Result<(X509Svid, X509BundleSet), X509SourceError> {
462    let mut backoff = reconnect.min_backoff;
463
464    loop {
465        if cancel.is_cancelled() {
466            return Err(X509SourceError::Closed);
467        }
468
469        match try_sync_once(make_client, picker).await {
470            Ok(v) => return Ok(v),
471            Err(e) => {
472                warn!("Initial sync failed: {e}. Retrying in {backoff:?}.");
473                if sleep_or_cancel(cancel, backoff).await {
474                    return Err(X509SourceError::Closed);
475                }
476                backoff = next_backoff(backoff, reconnect.max_backoff);
477            }
478        }
479    }
480}
481
482async fn try_sync_once(
483    make_client: &ClientFactory,
484    picker: Option<&dyn SvidPicker>,
485) -> Result<(X509Svid, X509BundleSet), X509SourceError> {
486    let mut client = (make_client)().await.map_err(X509SourceError::Grpc)?;
487    let mut stream = client
488        .stream_x509_contexts()
489        .await
490        .map_err(X509SourceError::Grpc)?;
491
492    match stream.next().await {
493        Some(Ok(ctx)) => {
494            let picked = if let Some(p) = picker {
495                p.pick_svid(ctx.svids())
496                    .ok_or(X509SourceError::NoSuitableSvid)?
497            } else {
498                ctx.default_svid().ok_or(X509SourceError::NoSuitableSvid)?
499            };
500            Ok((picked.clone(), ctx.bundle_set().clone()))
501        }
502        Some(Err(e)) => Err(X509SourceError::Grpc(e)),
503        None => Err(X509SourceError::StreamEnded),
504    }
505}
506
507async fn sleep_or_cancel(token: &CancellationToken, dur: Duration) -> bool {
508    tokio::select! {
509        _ = token.cancelled() => true,
510        _ = sleep(dur) => false,
511    }
512}
513
514fn next_backoff(current: Duration, max: Duration) -> Duration {
515    let doubled = current.saturating_mul(2);
516    if doubled > max {
517        max
518    } else {
519        doubled
520    }
521}