spiffe/workload_api/
client.rs

1//! Workload API client for fetching SPIFFE X.509 and JWT material.
2//!
3//! `WorkloadApiClient` provides one-shot RPCs (fetch SVIDs/bundles) and streaming RPCs for
4//! receiving updates as material rotates.
5//!
6//! Most users should prefer higher-level types like `X509Source`, which handle reconnection and
7//! provide an always-up-to-date view of the X.509 context.
8//!
9//! # Example
10//!
11//! ```no_run
12//! use spiffe::{SpiffeId, WorkloadApiClient};
13//! use tokio_stream::StreamExt;
14//!
15//! # async fn example() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
16//! let mut client =
17//!     WorkloadApiClient::new_from_path("unix:/tmp/spire-agent/public/api.sock").await?;
18//!
19//! let jwt = client.fetch_jwt_token(&["service1"], None).await?;
20//! let _jwt_svid = client.fetch_jwt_svid(&["service1"], None).await?;
21//! let _jwt_bundles = client.fetch_jwt_bundles().await?;
22//!
23//! let _x509_svid = client.fetch_x509_svid().await?;
24//! let _x509_bundles = client.fetch_x509_bundles().await?;
25//! let _x509_ctx = client.fetch_x509_context().await?;
26//!
27//! let mut updates = client.stream_x509_contexts().await?;
28//! while let Some(update) = updates.next().await {
29//!     let _ctx = update?;
30//! }
31//! # Ok(())
32//! # }
33//! ```
34
35use std::str::FromStr;
36
37use crate::bundle::jwt::{JwtBundle, JwtBundleSet};
38use crate::bundle::x509::{X509Bundle, X509BundleSet};
39use crate::endpoint::{get_default_socket_path, validate_socket_path};
40use crate::spiffe_id::{SpiffeId, TrustDomain};
41use crate::svid::jwt::JwtSvid;
42use crate::svid::x509::X509Svid;
43use crate::workload_api::x509_context::X509Context;
44use hyper_util::rt::TokioIo;
45use std::convert::TryFrom;
46use std::sync::Arc;
47use tokio::net::UnixStream;
48use tokio_stream::{Stream, StreamExt};
49
50use crate::constants::DEFAULT_SVID;
51use crate::error::GrpcClientError;
52use crate::workload_api::pb::spiffe_workload_api_client::SpiffeWorkloadApiClient;
53use crate::workload_api::pb::{
54    JwtBundlesRequest, JwtBundlesResponse, JwtsvidRequest, JwtsvidResponse, ValidateJwtsvidRequest,
55    ValidateJwtsvidResponse, X509BundlesRequest, X509BundlesResponse, X509svidRequest,
56    X509svidResponse,
57};
58use tonic::transport::{Endpoint, Uri};
59use tower::service_fn;
60
61const SPIFFE_HEADER_KEY: &str = "workload.spiffe.io";
62const SPIFFE_HEADER_VALUE: &str = "true";
63
64/// Client for the SPIFFE Workload API.
65///
66/// Provides one-shot calls and streaming updates for X.509 and JWT SVIDs and bundles.
67/// For an always-up-to-date, shareable source of X.509 material with automatic reconnection,
68/// see [`crate::X509Source`].
69#[derive(Debug, Clone)]
70pub struct WorkloadApiClient {
71    socket_path: Arc<str>,
72    client: SpiffeWorkloadApiClient<
73        tonic::service::interceptor::InterceptedService<tonic::transport::Channel, MetadataAdder>,
74    >,
75}
76
77/// Tonic interceptor that adds the Workload API metadata header required by SPIRE.
78#[derive(Clone)]
79struct MetadataAdder;
80
81impl tonic::service::Interceptor for MetadataAdder {
82    fn call(
83        &mut self,
84        mut request: tonic::Request<()>,
85    ) -> Result<tonic::Request<()>, tonic::Status> {
86        let parsed_header = SPIFFE_HEADER_VALUE
87            .parse()
88            .map_err(|e| tonic::Status::internal(format!("Failed to parse header: {e}")))?;
89        request
90            .metadata_mut()
91            .insert(SPIFFE_HEADER_KEY, parsed_header);
92        Ok(request)
93    }
94}
95
96impl WorkloadApiClient {
97    const UNIX_PREFIX: &'static str = "unix:";
98    const TONIC_DEFAULT_URI: &'static str = "http://[::]:50051";
99
100    /// Returns the configured Workload API socket path.
101    pub fn socket_path(&self) -> &str {
102        &self.socket_path
103    }
104
105    async fn connect_channel(
106        socket_path: &str,
107    ) -> Result<tonic::transport::Channel, GrpcClientError> {
108        validate_socket_path(socket_path)?;
109
110        // Strip the 'unix:' prefix for tonic compatibility.
111        let stripped = socket_path
112            .strip_prefix(Self::UNIX_PREFIX)
113            .unwrap_or(socket_path)
114            .to_string();
115
116        let channel = Endpoint::try_from(Self::TONIC_DEFAULT_URI)?
117            .connect_with_connector(service_fn(move |_: Uri| {
118                let stripped = stripped.clone();
119                async { UnixStream::connect(stripped).await.map(TokioIo::new) }
120            }))
121            .await?;
122
123        Ok(channel)
124    }
125
126    /// Connects to the Workload API using the given UNIX domain socket path.
127    ///
128    /// The path may optionally be prefixed with `unix:` (e.g. `unix:/tmp/spire-agent/public/api.sock`).
129    pub async fn new_from_path(path: impl AsRef<str>) -> Result<Self, GrpcClientError> {
130        let path = path.as_ref();
131        validate_socket_path(path)?;
132
133        let socket_path: Arc<str> = Arc::from(path);
134        let channel = Self::connect_channel(path).await?;
135
136        Ok(WorkloadApiClient {
137            socket_path,
138            client: SpiffeWorkloadApiClient::with_interceptor(channel, MetadataAdder {}),
139        })
140    }
141
142    /// Rebuilds the underlying gRPC channel.
143    ///
144    /// This is intended for manual recovery scenarios. Higher-level abstractions such as `X509Source`
145    /// typically create fresh clients and manage reconnection automatically.
146    pub async fn reconnect(&mut self) -> Result<(), GrpcClientError> {
147        let channel = Self::connect_channel(&self.socket_path).await?;
148        self.client = SpiffeWorkloadApiClient::with_interceptor(channel, MetadataAdder {});
149        Ok(())
150    }
151
152    /// Connects to the Workload API using `SPIFFE_ENDPOINT_SOCKET`.
153    pub async fn default() -> Result<Self, GrpcClientError> {
154        let socket_path =
155            get_default_socket_path().ok_or(GrpcClientError::MissingEndpointSocketPath)?;
156        Self::new_from_path(socket_path.as_str()).await
157    }
158
159    /// Creates a client from an existing gRPC channel.
160    ///
161    /// This is primarily useful for tests or advanced transport customization.
162    pub fn new(
163        socket_path: impl AsRef<str>,
164        conn: tonic::transport::Channel,
165    ) -> Result<Self, GrpcClientError> {
166        Ok(WorkloadApiClient {
167            socket_path: Arc::from(socket_path.as_ref()),
168            client: SpiffeWorkloadApiClient::with_interceptor(conn, MetadataAdder {}),
169        })
170    }
171
172    /// Fetches the default X.509 SVID for the calling workload.
173    pub async fn fetch_x509_svid(&mut self) -> Result<X509Svid, GrpcClientError> {
174        let request = X509svidRequest::default();
175
176        let grpc_stream_response: tonic::Response<tonic::Streaming<X509svidResponse>> =
177            self.client.fetch_x509svid(request).await?;
178
179        let response = grpc_stream_response
180            .into_inner()
181            .message()
182            .await?
183            .ok_or(GrpcClientError::EmptyResponse)?;
184        WorkloadApiClient::parse_x509_svid_from_grpc_response(response)
185    }
186
187    /// Fetches all X.509 SVIDs available to the calling workload.
188    pub async fn fetch_all_x509_svids(&mut self) -> Result<Vec<X509Svid>, GrpcClientError> {
189        let request = X509svidRequest::default();
190
191        let grpc_stream_response: tonic::Response<tonic::Streaming<X509svidResponse>> =
192            self.client.fetch_x509svid(request).await?;
193
194        let response = grpc_stream_response
195            .into_inner()
196            .message()
197            .await?
198            .ok_or(GrpcClientError::EmptyResponse)?;
199        WorkloadApiClient::parse_x509_svids_from_grpc_response(response)
200    }
201
202    /// Fetches the current X.509 bundle set.
203    pub async fn fetch_x509_bundles(&mut self) -> Result<X509BundleSet, GrpcClientError> {
204        let request = X509BundlesRequest::default();
205
206        let grpc_stream_response: tonic::Response<tonic::Streaming<X509BundlesResponse>> =
207            self.client.fetch_x509_bundles(request).await?;
208
209        let response = grpc_stream_response
210            .into_inner()
211            .message()
212            .await?
213            .ok_or(GrpcClientError::EmptyResponse)?;
214        WorkloadApiClient::parse_x509_bundle_set_from_grpc_response(response)
215    }
216
217    /// Fetches the current JWT bundle set.
218    pub async fn fetch_jwt_bundles(&mut self) -> Result<JwtBundleSet, GrpcClientError> {
219        let request = JwtBundlesRequest::default();
220
221        let grpc_stream_response: tonic::Response<tonic::Streaming<JwtBundlesResponse>> =
222            self.client.fetch_jwt_bundles(request).await?;
223
224        let response = grpc_stream_response
225            .into_inner()
226            .message()
227            .await?
228            .ok_or(GrpcClientError::EmptyResponse)?;
229        WorkloadApiClient::parse_jwt_bundle_set_from_grpc_response(response)
230    }
231
232    /// Fetches the current X.509 context (SVIDs and bundles).
233    pub async fn fetch_x509_context(&mut self) -> Result<X509Context, GrpcClientError> {
234        let request = X509svidRequest::default();
235
236        let grpc_stream_response: tonic::Response<tonic::Streaming<X509svidResponse>> =
237            self.client.fetch_x509svid(request).await?;
238
239        let response = grpc_stream_response
240            .into_inner()
241            .message()
242            .await?
243            .ok_or(GrpcClientError::EmptyResponse)?;
244        WorkloadApiClient::parse_x509_context_from_grpc_response(response)
245    }
246
247    /// Fetches a JWT-SVID for the given audience and optional SPIFFE ID.
248    ///
249    /// If `spiffe_id` is `None`, the Workload API returns the default identity.
250    pub async fn fetch_jwt_svid<T: AsRef<str> + ToString>(
251        &mut self,
252        audience: &[T],
253        spiffe_id: Option<&SpiffeId>,
254    ) -> Result<JwtSvid, GrpcClientError> {
255        let response = self.fetch_jwt(audience, spiffe_id).await?;
256        response
257            .svids
258            .get(DEFAULT_SVID)
259            .ok_or(GrpcClientError::EmptyResponse)
260            .and_then(|r| JwtSvid::from_str(&r.svid).map_err(GrpcClientError::JwtSvid))
261    }
262
263    /// Fetches a JWT-SVID token string for the given audience and optional SPIFFE ID.
264    ///
265    /// If `spiffe_id` is `None`, the Workload API returns the default identity.
266    pub async fn fetch_jwt_token<T: AsRef<str> + ToString>(
267        &mut self,
268        audience: &[T],
269        spiffe_id: Option<&SpiffeId>,
270    ) -> Result<String, GrpcClientError> {
271        let response = self.fetch_jwt(audience, spiffe_id).await?;
272        response
273            .svids
274            .get(DEFAULT_SVID)
275            .map(|r| r.svid.to_string())
276            .ok_or(GrpcClientError::EmptyResponse)
277    }
278
279    /// Validates a JWT-SVID token for the given audience and returns the parsed `JwtSvid`.
280    pub async fn validate_jwt_token<T: AsRef<str> + ToString>(
281        &mut self,
282        audience: T,
283        jwt_token: &str,
284    ) -> Result<JwtSvid, GrpcClientError> {
285        // validate token with Workload API, the returned claims and spiffe_id are ignored as
286        // they are parsed from token when the `JwtSvid` object is created, this way we avoid having
287        // to validate that the response from the Workload API contains correct claims.
288        let _ = self.validate_jwt(audience, jwt_token).await?;
289        let jwt_svid = JwtSvid::parse_insecure(jwt_token)?;
290        Ok(jwt_svid)
291    }
292
293    /// Streams X.509 context updates from the Workload API.
294    pub async fn stream_x509_contexts(
295        &mut self,
296    ) -> Result<impl Stream<Item = Result<X509Context, GrpcClientError>> + use<>, GrpcClientError>
297    {
298        let request = X509svidRequest::default();
299        let response = self.client.fetch_x509svid(request).await?;
300        let stream = response.into_inner().map(|message| {
301            message
302                .map_err(GrpcClientError::from)
303                .and_then(WorkloadApiClient::parse_x509_context_from_grpc_response)
304        });
305        Ok(stream)
306    }
307
308    /// Streams X.509 SVID updates from the Workload API.
309    pub async fn stream_x509_svids(
310        &mut self,
311    ) -> Result<impl Stream<Item = Result<X509Svid, GrpcClientError>> + use<>, GrpcClientError>
312    {
313        let request = X509svidRequest::default();
314        let response = self.client.fetch_x509svid(request).await?;
315        let stream = response.into_inner().map(|message| {
316            message
317                .map_err(GrpcClientError::from)
318                .and_then(WorkloadApiClient::parse_x509_svid_from_grpc_response)
319        });
320        Ok(stream)
321    }
322
323    /// Streams X.509 bundle set updates from the Workload API.
324    pub async fn stream_x509_bundles(
325        &mut self,
326    ) -> Result<impl Stream<Item = Result<X509BundleSet, GrpcClientError>> + use<>, GrpcClientError>
327    {
328        let request = X509BundlesRequest::default();
329        let response = self.client.fetch_x509_bundles(request).await?;
330        let stream = response.into_inner().map(|message| {
331            message
332                .map_err(GrpcClientError::from)
333                .and_then(WorkloadApiClient::parse_x509_bundle_set_from_grpc_response)
334        });
335        Ok(stream)
336    }
337
338    /// Streams JWT bundle set updates from the Workload API.
339    pub async fn stream_jwt_bundles(
340        &mut self,
341    ) -> Result<impl Stream<Item = Result<JwtBundleSet, GrpcClientError>> + use<>, GrpcClientError>
342    {
343        let request = JwtBundlesRequest::default();
344        let response = self.client.fetch_jwt_bundles(request).await?;
345        let stream = response.into_inner().map(|message| {
346            message
347                .map_err(GrpcClientError::from)
348                .and_then(WorkloadApiClient::parse_jwt_bundle_set_from_grpc_response)
349        });
350        Ok(stream)
351    }
352}
353
354/// private
355impl WorkloadApiClient {
356    async fn fetch_jwt<T: AsRef<str> + ToString>(
357        &mut self,
358        audience: &[T],
359        spiffe_id: Option<&SpiffeId>,
360    ) -> Result<JwtsvidResponse, GrpcClientError> {
361        let request = JwtsvidRequest {
362            spiffe_id: spiffe_id.map(ToString::to_string).unwrap_or_default(),
363            audience: audience.iter().map(|s| s.to_string()).collect(),
364        };
365
366        Ok(self.client.fetch_jwtsvid(request).await?.into_inner())
367    }
368
369    async fn validate_jwt<T: AsRef<str>>(
370        &mut self,
371        audience: T,
372        jwt_svid: &str,
373    ) -> Result<ValidateJwtsvidResponse, GrpcClientError> {
374        let request = ValidateJwtsvidRequest {
375            audience: audience.as_ref().into(),
376            svid: jwt_svid.into(),
377        };
378
379        Ok(self.client.validate_jwtsvid(request).await?.into_inner())
380    }
381
382    fn parse_x509_svid_from_grpc_response(
383        response: X509svidResponse,
384    ) -> Result<X509Svid, GrpcClientError> {
385        let svid = response
386            .svids
387            .get(DEFAULT_SVID)
388            .ok_or(GrpcClientError::EmptyResponse)?;
389
390        X509Svid::parse_from_der(svid.x509_svid.as_ref(), svid.x509_svid_key.as_ref())
391            .map_err(GrpcClientError::from)
392    }
393
394    fn parse_x509_svids_from_grpc_response(
395        response: X509svidResponse,
396    ) -> Result<Vec<X509Svid>, GrpcClientError> {
397        let mut svids_vec = Vec::new();
398
399        for svid in response.svids.iter() {
400            let parsed_svid =
401                X509Svid::parse_from_der(svid.x509_svid.as_ref(), svid.x509_svid_key.as_ref())
402                    .map_err(GrpcClientError::from)?;
403
404            svids_vec.push(parsed_svid);
405        }
406
407        Ok(svids_vec)
408    }
409
410    fn parse_x509_bundle_set_from_grpc_response(
411        response: X509BundlesResponse,
412    ) -> Result<X509BundleSet, GrpcClientError> {
413        let bundles: Result<Vec<_>, _> = response
414            .bundles
415            .into_iter()
416            .map(|(td, bundle_data)| {
417                let trust_domain = TrustDomain::try_from(td)?;
418                X509Bundle::parse_from_der(trust_domain, &bundle_data)
419                    .map_err(GrpcClientError::from)
420            })
421            .collect();
422
423        let mut bundle_set = X509BundleSet::new();
424        for bundle in bundles? {
425            bundle_set.add_bundle(bundle);
426        }
427
428        Ok(bundle_set)
429    }
430
431    fn parse_jwt_bundle_set_from_grpc_response(
432        response: JwtBundlesResponse,
433    ) -> Result<JwtBundleSet, GrpcClientError> {
434        let mut bundle_set = JwtBundleSet::new();
435
436        for (td, bundle_data) in response.bundles.into_iter() {
437            let trust_domain = TrustDomain::try_from(td)?;
438            let bundle = JwtBundle::from_jwt_authorities(trust_domain, &bundle_data)
439                .map_err(GrpcClientError::from)?;
440
441            bundle_set.add_bundle(bundle);
442        }
443
444        Ok(bundle_set)
445    }
446
447    fn parse_x509_context_from_grpc_response(
448        response: X509svidResponse,
449    ) -> Result<X509Context, GrpcClientError> {
450        let mut svids = Vec::new();
451        let mut bundle_set = X509BundleSet::new();
452
453        for svid in response.svids.into_iter() {
454            let x509_svid =
455                X509Svid::parse_from_der(svid.x509_svid.as_ref(), svid.x509_svid_key.as_ref())
456                    .map_err(GrpcClientError::from)?;
457
458            let trust_domain = x509_svid.spiffe_id().trust_domain().clone();
459            svids.push(x509_svid);
460
461            let bundle = X509Bundle::parse_from_der(trust_domain, svid.bundle.as_ref())
462                .map_err(GrpcClientError::from)?;
463            bundle_set.add_bundle(bundle);
464        }
465
466        for (trust_domain, bundle) in response.federated_bundles.into_iter() {
467            let trust_domain = TrustDomain::try_from(trust_domain)?;
468            let x509_bundle = X509Bundle::parse_from_der(trust_domain, bundle.as_ref())
469                .map_err(GrpcClientError::from)?;
470            bundle_set.add_bundle(x509_bundle);
471        }
472
473        Ok(X509Context::new(svids, bundle_set))
474    }
475}