spiffe/workload_api/
x509_source.rs

1//! # X509Source Module
2//!
3//! This module provides a source of X.509 SVIDs and X.509 bundles, backed by a workload API client
4//! that continuously fetches the X.509 context (SVIDs and bundles) behind the scenes.
5//! This ensures that the `X509Source` is always up to date.
6//!
7//! It allows for fetching and managing X.509 SVIDs and bundles, and includes functionality for updating
8//! the context and closing the source. Users can utilize the `X509Source` to obtain SVIDs and bundles,
9//! listen for updates, and manage the lifecycle of the source.
10//!
11//! ## Usage
12//!
13//! The `X509Source` can be created and configured to fetch SVIDs and bundles, respond to updates, and
14//! handle closing. It provides a seamless interface for working with X.509 SVIDs and bundles.
15//!
16//! ### Example
17//!
18//! ```no_run
19//! use spiffe::{BundleSource, SvidSource, TrustDomain, X509Source};
20//!
21//! # async fn example() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
22//! let source = X509Source::default().await?;
23//! let svid = source.get_svid()?;
24//! let trust_domain = TrustDomain::new("example.org").unwrap();
25//! let bundle = source
26//!     .get_bundle_for_trust_domain(&trust_domain)
27//!     .map_err(|e| {
28//!         format!(
29//!             "Failed to get bundle for trust domain {}: {}",
30//!             trust_domain, e
31//!         )
32//!     })?;
33//!
34//! # Ok(())
35//! # }
36//! ```
37//!
38//! ## Error Handling
39//!
40//! The `X509SourceError` enum provides detailed error information, including errors related to GRPC client failures,
41//! lock issues, and other non-specific errors.
42//!
43//! ## Update Handling
44//!
45//! The `X509Source` provides a method to listen for updates, allowing parts of your system to respond to changes.
46//! The `updated` method returns a `watch::Receiver<()>` that can be used to listen for notifications when the `X509Source` is updated.
47//!
48//! ## Closing the Source
49//!
50//! The `close` method can be used to close the `X509Source`, canceling all spawned tasks and stopping updates.
51use crate::error::GrpcClientError;
52use crate::{
53    BundleSource, SvidSource, TrustDomain, WorkloadApiClient, X509Bundle, X509BundleSet,
54    X509Context, X509Svid,
55};
56use log::{debug, error, info};
57use std::error::Error;
58use std::fmt::Debug;
59use std::sync::{Arc, PoisonError, RwLock};
60use thiserror::Error;
61use tokio::sync::watch;
62use tokio_stream::StreamExt;
63use tokio_util::sync::CancellationToken;
64
65/// `SvidPicker` is a trait defining the behavior for selecting an `X509Svid`.
66///
67/// Implementors of this trait must provide a concrete implementation of the `pick_svid` method, which
68/// takes a reference to a slice of `X509Svid` and returns an `Option<&X509Svid>`.
69///
70/// The trait requires that implementing types are both `Send` and `Sync`, ensuring that they can be
71/// sent between threads and accessed concurrently.
72///
73/// # Example
74///
75/// ```
76/// use spiffe::workload_api::x509_source::SvidPicker;
77/// use spiffe::X509Svid;
78///
79/// #[derive(Debug)]
80/// struct SecondSvidPicker;
81///
82/// impl SvidPicker for SecondSvidPicker {
83///     fn pick_svid<'a>(&self, svids: &'a [X509Svid]) -> Option<&'a X509Svid> {
84///         svids.get(1) // return second svid
85///     }
86/// }
87/// ```
88pub trait SvidPicker: Debug + Send + Sync {
89    /// Selects an `X509Svid` from the provided slice of `X509Svid`.
90    ///
91    /// # Parameters
92    /// * `svids`: A reference to a slice of `X509Svid` from which the `X509Svid` should be selected.
93    ///
94    /// # Returns
95    /// * An `Option<&X509Svid>`, where a `None` value indicates that no suitable `X509Svid` was found.
96    fn pick_svid<'a>(&self, svids: &'a [X509Svid]) -> Option<&'a X509Svid>;
97}
98
99/// Enumerates errors that can occur within the X509Source.
100#[derive(Debug, Error)]
101pub enum X509SourceError {
102    /// Error when a GRPC client fails to fetch an X509Context.
103    #[error("GRPC client error: {0}")]
104    GrpcError(GrpcClientError),
105
106    /// Error when no suitable SVID is found by the picker.
107    #[error("No suitable SVID found by picker")]
108    NoSuitableSvid,
109
110    /// Error related to internal operations.
111    #[error("Internal error while {0}: {1}")]
112    InternalError(String, String),
113
114    /// Other non-specific error.
115    #[error("{0}")]
116    Other(String),
117}
118
119impl X509SourceError {
120    fn from_lock_err<T>(err: PoisonError<T>, action: &str) -> Self {
121        X509SourceError::InternalError(action.to_string(), err.to_string())
122    }
123}
124
125/// Represents a source of X.509 SVIDs and X.509 bundles.
126///
127///
128/// `X509Source` implements the [`BundleSource`] and [`SvidSource`] traits.
129///
130/// The methods return cloned instances of the underlying objects.
131#[derive(Debug)]
132pub struct X509Source {
133    svid: RwLock<Option<X509Svid>>,
134    bundles: RwLock<Option<X509BundleSet>>,
135    svid_picker: Option<Box<dyn SvidPicker>>,
136    workload_api_client: WorkloadApiClient,
137    closed: RwLock<bool>,
138    cancellation_token: CancellationToken,
139    update_notifier: watch::Sender<()>,
140    updated: watch::Receiver<()>,
141}
142
143/// Builder for `X509Source`.
144#[derive(Debug)]
145pub struct X509SourceBuilder {
146    client: Option<WorkloadApiClient>,
147    svid_picker: Option<Box<dyn SvidPicker>>,
148}
149
150/// A builder for creating a new `X509Source` with optional client and svid_picker configurations.
151///
152/// Allows for customization by accepting a client and/or svid_picker.
153///
154/// # Example
155///
156/// ```no_run
157/// use spiffe::workload_api::x509_source::{SvidPicker, X509SourceBuilder};
158/// use spiffe::{WorkloadApiClient, X509Svid};
159/// use std::error::Error;
160///
161/// #[derive(Debug)]
162/// struct SecondSvidPicker;
163///
164/// impl SvidPicker for SecondSvidPicker {
165///     fn pick_svid<'a>(&self, svids: &'a [X509Svid]) -> Option<&'a X509Svid> {
166///         svids.get(1) // return second svid
167///     }
168/// }
169///
170/// # async fn example() -> Result<(), Box< dyn Error>> {
171/// let client = WorkloadApiClient::default().await?;
172/// let source = X509SourceBuilder::new()
173///     .with_client(client)
174///     .with_picker(Box::new(SecondSvidPicker))
175///     .build()
176///     .await?;
177///
178/// # Ok(())
179/// # }
180/// ```
181///
182/// # Returns
183/// A `Result` containing an `Arc<X509Source>` or an `X509SourceError` if an error occurs.
184impl X509SourceBuilder {
185    /// Creates a new `X509SourceBuilder`.
186    pub fn new() -> Self {
187        Self {
188            client: None,
189            svid_picker: None,
190        }
191    }
192
193    /// Sets the Workload API client to be used by the X509Source.
194    pub fn with_client(mut self, client: WorkloadApiClient) -> Self {
195        self.client = Some(client);
196        self
197    }
198
199    /// Sets the svid_picker to be used by the X509Source.
200    pub fn with_picker(mut self, svid_picker: Box<dyn SvidPicker>) -> Self {
201        self.svid_picker = Some(svid_picker);
202        self
203    }
204
205    /// Builds an `X509Source` using the provided configuration.
206    pub async fn build(self) -> Result<Arc<X509Source>, X509SourceError> {
207        let client = match self.client {
208            Some(client) => client,
209            None => WorkloadApiClient::default()
210                .await
211                .map_err(X509SourceError::GrpcError)?,
212        };
213
214        X509Source::new(client, self.svid_picker).await
215    }
216}
217
218impl Default for X509SourceBuilder {
219    fn default() -> Self {
220        Self::new()
221    }
222}
223
224impl SvidSource for X509Source {
225    type Item = X509Svid;
226
227    /// Retrieves the X.509 SVID from the source.
228    ///
229    /// # Returns
230    ///
231    /// An `Result<Option<X509Svid>, Box<dyn Error + Send + Sync + 'static>>` containing the X.509 SVID if available.
232    /// Returns `Ok(None)` if no SVID is found.
233    /// Returns an error if the source is closed or if there's an issue fetching the SVID.
234    fn get_svid(&self) -> Result<Option<Self::Item>, Box<dyn Error + Send + Sync + 'static>> {
235        self.assert_not_closed().map_err(Box::new)?;
236
237        let svid_option = self
238            .svid
239            .read()
240            .map_err(|e| X509SourceError::from_lock_err(e, "reading SVIDs from source"))?;
241
242        Ok(svid_option.clone())
243    }
244}
245
246impl BundleSource for X509Source {
247    type Item = X509Bundle;
248
249    /// Retrieves the X.509 bundle for the given trust domain.
250    ///
251    /// # Arguments
252    /// * `trust_domain` - The trust domain for which the X.509 bundle is to be retrieved.
253    ///
254    /// # Returns
255    /// A `Result` containing an `Option<X509Bundle>` for the given trust domain. If the bundle is not found, returns `Ok(None)`.
256    ///
257    /// # Errors
258    /// Returns a boxed error if the source is closed or if there is an issue accessing the bundle.
259    fn get_bundle_for_trust_domain(
260        &self,
261        trust_domain: &TrustDomain,
262    ) -> Result<Option<Self::Item>, Box<dyn Error + Send + Sync + 'static>> {
263        self.assert_not_closed().map_err(Box::new)?;
264
265        // Read the bundles
266        let bundles_option = self
267            .bundles
268            .read()
269            .map_err(|e| X509SourceError::from_lock_err(e, "reading bundles from source"))?;
270        let bundle_set = match bundles_option.as_ref() {
271            Some(set) => set,
272            None => return Ok(None),
273        };
274
275        // Get the bundle for the trust domain
276        let bundle = bundle_set.get_bundle(trust_domain);
277
278        // Return the bundle if found, or Ok(None) if not found
279        Ok(bundle.cloned())
280    }
281}
282
283// public methods
284impl X509Source {
285    /// Builds a new `X509Source` using a default [`WorkloadApiClient`] and no SVID picker.
286    /// Since no SVID picker is provided, the `get_svid` method will return the default SVID.
287    ///
288    /// This method is asynchronous and may return an error if the initialization fails.
289    pub async fn default() -> Result<Arc<Self>, X509SourceError> {
290        X509SourceBuilder::new().build().await
291    }
292
293    /// Returns a `watch::Receiver<()>` that can be used to listen for notifications when the X509Source is updated.
294    ///
295    /// # Example
296    ///
297    /// ``no_run
298    /// let mut update_channel = source.updated(); // Get the watch receiver for the source
299    ///
300    /// // Asynchronously handle updates in a loop
301    /// tokio::spawn(async move {
302    ///     loop {
303    ///         match update_channel.changed().await {
304    ///             Ok(_) => {
305    ///                 println!("X509Source was updated!");
306    ///             },
307    ///             Err(_) => {
308    ///                 println!("Watch channel closed; exiting update loop");
309    ///                 break;
310    ///             }
311    ///         }
312    ///     }
313    /// });
314    /// ```
315    pub fn updated(&self) -> watch::Receiver<()> {
316        self.updated.clone()
317    }
318
319    /// Closes the X509Source cancelling all spawned tasks.
320    pub fn close(&self) -> Result<(), X509SourceError> {
321        self.assert_not_closed()?;
322
323        let mut closed = self
324            .closed
325            .write()
326            .map_err(|e| X509SourceError::from_lock_err(e, "closing source"))?;
327        *closed = true;
328
329        self.cancellation_token.cancel();
330
331        info!("X509Source has been closed.");
332        Ok(())
333    }
334}
335
336// private methods
337impl X509Source {
338    async fn new(
339        client: WorkloadApiClient,
340        svid_picker: Option<Box<dyn SvidPicker>>,
341    ) -> Result<Arc<X509Source>, X509SourceError> {
342        let (update_notifier, updated) = watch::channel(());
343        let cancellation_token = CancellationToken::new();
344        let cancellation_token_clone = cancellation_token.clone();
345
346        let source = Arc::new(X509Source {
347            svid: RwLock::new(None),
348            bundles: RwLock::new(None),
349            workload_api_client: client,
350            closed: RwLock::new(false),
351            svid_picker,
352            cancellation_token,
353            updated,
354            update_notifier,
355        });
356
357        let source_clone = Arc::clone(&source);
358        let mut client_clone = source_clone.workload_api_client.clone();
359        let mut stream = client_clone
360            .stream_x509_contexts()
361            .await
362            .map_err(X509SourceError::GrpcError)?;
363
364        // Block until the first X509Context is fetched.
365        if let Some(update) = stream.next().await {
366            match update {
367                Ok(x509_context) => source_clone.set_x509_context(x509_context).map_err(|e| {
368                    X509SourceError::Other(format!("Failed to set X509Context: {e}"))
369                })?,
370                Err(e) => return Err(X509SourceError::GrpcError(e)),
371            }
372        } else {
373            return Err(X509SourceError::Other(
374                "Stream ended without an update".to_string(),
375            ));
376        }
377
378        // Spawn a task to handle subsequent updates
379        tokio::spawn(async move {
380            loop {
381                if cancellation_token_clone.is_cancelled() {
382                    debug!("Cancellation signal received; stopping updates.");
383                    break;
384                }
385
386                match stream.next().await {
387                    Some(update) => match update {
388                        Ok(x509_context) => {
389                            if let Err(e) = source_clone.set_x509_context(x509_context) {
390                                error!("Error updating X509 context: {e}");
391                            } else {
392                                info!("X509 context updated successfully.");
393                            }
394                        }
395                        Err(e) => error!("GRPC client error: {e}"),
396                    },
397                    None => {
398                        error!("Stream ended; no more updates will be received.");
399                        break;
400                    }
401                }
402            }
403        });
404
405        Ok(source)
406    }
407
408    fn set_x509_context(&self, x509_context: X509Context) -> Result<(), X509SourceError> {
409        let svid = if let Some(ref svid_picker) = self.svid_picker {
410            svid_picker
411                .pick_svid(x509_context.svids())
412                .ok_or(X509SourceError::NoSuitableSvid)?
413        } else {
414            x509_context
415                .default_svid()
416                .ok_or(X509SourceError::NoSuitableSvid)?
417        };
418
419        self.set_svid(svid)?;
420
421        self.bundles
422            .write()
423            .map_err(|e| {
424                X509SourceError::InternalError(
425                    "writing bundles to source".to_string(),
426                    e.to_string(),
427                )
428            })?
429            .replace(x509_context.bundle_set().clone());
430
431        self.notify_update();
432        Ok(())
433    }
434
435    fn set_svid(&self, svid: &X509Svid) -> Result<(), X509SourceError> {
436        self.svid
437            .write()
438            .map_err(|e| {
439                X509SourceError::InternalError("writing SVID to source".to_string(), e.to_string())
440            })?
441            .replace(svid.clone());
442        Ok(())
443    }
444
445    fn notify_update(&self) {
446        let _ = self.update_notifier.send(());
447    }
448
449    fn assert_not_closed(&self) -> Result<(), X509SourceError> {
450        let closed = self.closed.read().map_err(|e| {
451            X509SourceError::InternalError(
452                "reading closed state from source".to_string(),
453                e.to_string(),
454            )
455        })?;
456        if *closed {
457            return Err(X509SourceError::Other("X509Source is closed".into()));
458        }
459        Ok(())
460    }
461}