1use std::str::FromStr;
36
37use crate::bundle::jwt::{JwtBundle, JwtBundleSet};
38use crate::bundle::x509::{X509Bundle, X509BundleSet};
39use crate::endpoint::{get_default_socket_path, validate_socket_path};
40use crate::spiffe_id::{SpiffeId, TrustDomain};
41use crate::svid::jwt::JwtSvid;
42use crate::svid::x509::X509Svid;
43use crate::workload_api::x509_context::X509Context;
44use hyper_util::rt::TokioIo;
45use std::convert::TryFrom;
46use std::sync::Arc;
47use tokio::net::UnixStream;
48use tokio_stream::{Stream, StreamExt};
49
50use crate::constants::DEFAULT_SVID;
51use crate::error::GrpcClientError;
52use crate::workload_api::pb::spiffe_workload_api_client::SpiffeWorkloadApiClient;
53use crate::workload_api::pb::{
54 JwtBundlesRequest, JwtBundlesResponse, JwtsvidRequest, JwtsvidResponse, ValidateJwtsvidRequest,
55 ValidateJwtsvidResponse, X509BundlesRequest, X509BundlesResponse, X509svidRequest,
56 X509svidResponse,
57};
58use tonic::transport::{Endpoint, Uri};
59use tower::service_fn;
60
61const SPIFFE_HEADER_KEY: &str = "workload.spiffe.io";
62const SPIFFE_HEADER_VALUE: &str = "true";
63
64#[derive(Debug, Clone)]
70pub struct WorkloadApiClient {
71 socket_path: Arc<str>,
72 client: SpiffeWorkloadApiClient<
73 tonic::service::interceptor::InterceptedService<tonic::transport::Channel, MetadataAdder>,
74 >,
75}
76
77#[derive(Clone)]
79struct MetadataAdder;
80
81impl tonic::service::Interceptor for MetadataAdder {
82 fn call(
83 &mut self,
84 mut request: tonic::Request<()>,
85 ) -> Result<tonic::Request<()>, tonic::Status> {
86 let parsed_header = SPIFFE_HEADER_VALUE
87 .parse()
88 .map_err(|e| tonic::Status::internal(format!("Failed to parse header: {e}")))?;
89 request
90 .metadata_mut()
91 .insert(SPIFFE_HEADER_KEY, parsed_header);
92 Ok(request)
93 }
94}
95
96impl WorkloadApiClient {
97 const UNIX_PREFIX: &'static str = "unix:";
98 const TONIC_DEFAULT_URI: &'static str = "http://[::]:50051";
99
100 pub fn socket_path(&self) -> &str {
102 &self.socket_path
103 }
104
105 async fn connect_channel(
106 socket_path: &str,
107 ) -> Result<tonic::transport::Channel, GrpcClientError> {
108 validate_socket_path(socket_path)?;
109
110 let stripped = socket_path
112 .strip_prefix(Self::UNIX_PREFIX)
113 .unwrap_or(socket_path)
114 .to_string();
115
116 let channel = Endpoint::try_from(Self::TONIC_DEFAULT_URI)?
117 .connect_with_connector(service_fn(move |_: Uri| {
118 let stripped = stripped.clone();
119 async { UnixStream::connect(stripped).await.map(TokioIo::new) }
120 }))
121 .await?;
122
123 Ok(channel)
124 }
125
126 pub async fn new_from_path(path: impl AsRef<str>) -> Result<Self, GrpcClientError> {
130 let path = path.as_ref();
131 validate_socket_path(path)?;
132
133 let socket_path: Arc<str> = Arc::from(path);
134 let channel = Self::connect_channel(path).await?;
135
136 Ok(WorkloadApiClient {
137 socket_path,
138 client: SpiffeWorkloadApiClient::with_interceptor(channel, MetadataAdder {}),
139 })
140 }
141
142 pub async fn reconnect(&mut self) -> Result<(), GrpcClientError> {
147 let channel = Self::connect_channel(&self.socket_path).await?;
148 self.client = SpiffeWorkloadApiClient::with_interceptor(channel, MetadataAdder {});
149 Ok(())
150 }
151
152 pub async fn default() -> Result<Self, GrpcClientError> {
154 let socket_path =
155 get_default_socket_path().ok_or(GrpcClientError::MissingEndpointSocketPath)?;
156 Self::new_from_path(socket_path.as_str()).await
157 }
158
159 pub fn new(
163 socket_path: impl AsRef<str>,
164 conn: tonic::transport::Channel,
165 ) -> Result<Self, GrpcClientError> {
166 Ok(WorkloadApiClient {
167 socket_path: Arc::from(socket_path.as_ref()),
168 client: SpiffeWorkloadApiClient::with_interceptor(conn, MetadataAdder {}),
169 })
170 }
171
172 pub async fn fetch_x509_svid(&mut self) -> Result<X509Svid, GrpcClientError> {
174 let request = X509svidRequest::default();
175
176 let grpc_stream_response: tonic::Response<tonic::Streaming<X509svidResponse>> =
177 self.client.fetch_x509svid(request).await?;
178
179 let response = grpc_stream_response
180 .into_inner()
181 .message()
182 .await?
183 .ok_or(GrpcClientError::EmptyResponse)?;
184 WorkloadApiClient::parse_x509_svid_from_grpc_response(response)
185 }
186
187 pub async fn fetch_all_x509_svids(&mut self) -> Result<Vec<X509Svid>, GrpcClientError> {
189 let request = X509svidRequest::default();
190
191 let grpc_stream_response: tonic::Response<tonic::Streaming<X509svidResponse>> =
192 self.client.fetch_x509svid(request).await?;
193
194 let response = grpc_stream_response
195 .into_inner()
196 .message()
197 .await?
198 .ok_or(GrpcClientError::EmptyResponse)?;
199 WorkloadApiClient::parse_x509_svids_from_grpc_response(response)
200 }
201
202 pub async fn fetch_x509_bundles(&mut self) -> Result<X509BundleSet, GrpcClientError> {
204 let request = X509BundlesRequest::default();
205
206 let grpc_stream_response: tonic::Response<tonic::Streaming<X509BundlesResponse>> =
207 self.client.fetch_x509_bundles(request).await?;
208
209 let response = grpc_stream_response
210 .into_inner()
211 .message()
212 .await?
213 .ok_or(GrpcClientError::EmptyResponse)?;
214 WorkloadApiClient::parse_x509_bundle_set_from_grpc_response(response)
215 }
216
217 pub async fn fetch_jwt_bundles(&mut self) -> Result<JwtBundleSet, GrpcClientError> {
219 let request = JwtBundlesRequest::default();
220
221 let grpc_stream_response: tonic::Response<tonic::Streaming<JwtBundlesResponse>> =
222 self.client.fetch_jwt_bundles(request).await?;
223
224 let response = grpc_stream_response
225 .into_inner()
226 .message()
227 .await?
228 .ok_or(GrpcClientError::EmptyResponse)?;
229 WorkloadApiClient::parse_jwt_bundle_set_from_grpc_response(response)
230 }
231
232 pub async fn fetch_x509_context(&mut self) -> Result<X509Context, GrpcClientError> {
234 let request = X509svidRequest::default();
235
236 let grpc_stream_response: tonic::Response<tonic::Streaming<X509svidResponse>> =
237 self.client.fetch_x509svid(request).await?;
238
239 let response = grpc_stream_response
240 .into_inner()
241 .message()
242 .await?
243 .ok_or(GrpcClientError::EmptyResponse)?;
244 WorkloadApiClient::parse_x509_context_from_grpc_response(response)
245 }
246
247 pub async fn fetch_jwt_svid<T: AsRef<str> + ToString>(
251 &mut self,
252 audience: &[T],
253 spiffe_id: Option<&SpiffeId>,
254 ) -> Result<JwtSvid, GrpcClientError> {
255 let response = self.fetch_jwt(audience, spiffe_id).await?;
256 response
257 .svids
258 .get(DEFAULT_SVID)
259 .ok_or(GrpcClientError::EmptyResponse)
260 .and_then(|r| JwtSvid::from_str(&r.svid).map_err(GrpcClientError::JwtSvid))
261 }
262
263 pub async fn fetch_jwt_token<T: AsRef<str> + ToString>(
267 &mut self,
268 audience: &[T],
269 spiffe_id: Option<&SpiffeId>,
270 ) -> Result<String, GrpcClientError> {
271 let response = self.fetch_jwt(audience, spiffe_id).await?;
272 response
273 .svids
274 .get(DEFAULT_SVID)
275 .map(|r| r.svid.to_string())
276 .ok_or(GrpcClientError::EmptyResponse)
277 }
278
279 pub async fn validate_jwt_token<T: AsRef<str> + ToString>(
281 &mut self,
282 audience: T,
283 jwt_token: &str,
284 ) -> Result<JwtSvid, GrpcClientError> {
285 let _ = self.validate_jwt(audience, jwt_token).await?;
289 let jwt_svid = JwtSvid::parse_insecure(jwt_token)?;
290 Ok(jwt_svid)
291 }
292
293 pub async fn stream_x509_contexts(
295 &mut self,
296 ) -> Result<impl Stream<Item = Result<X509Context, GrpcClientError>> + use<>, GrpcClientError>
297 {
298 let request = X509svidRequest::default();
299 let response = self.client.fetch_x509svid(request).await?;
300 let stream = response.into_inner().map(|message| {
301 message
302 .map_err(GrpcClientError::from)
303 .and_then(WorkloadApiClient::parse_x509_context_from_grpc_response)
304 });
305 Ok(stream)
306 }
307
308 pub async fn stream_x509_svids(
310 &mut self,
311 ) -> Result<impl Stream<Item = Result<X509Svid, GrpcClientError>> + use<>, GrpcClientError>
312 {
313 let request = X509svidRequest::default();
314 let response = self.client.fetch_x509svid(request).await?;
315 let stream = response.into_inner().map(|message| {
316 message
317 .map_err(GrpcClientError::from)
318 .and_then(WorkloadApiClient::parse_x509_svid_from_grpc_response)
319 });
320 Ok(stream)
321 }
322
323 pub async fn stream_x509_bundles(
325 &mut self,
326 ) -> Result<impl Stream<Item = Result<X509BundleSet, GrpcClientError>> + use<>, GrpcClientError>
327 {
328 let request = X509BundlesRequest::default();
329 let response = self.client.fetch_x509_bundles(request).await?;
330 let stream = response.into_inner().map(|message| {
331 message
332 .map_err(GrpcClientError::from)
333 .and_then(WorkloadApiClient::parse_x509_bundle_set_from_grpc_response)
334 });
335 Ok(stream)
336 }
337
338 pub async fn stream_jwt_bundles(
340 &mut self,
341 ) -> Result<impl Stream<Item = Result<JwtBundleSet, GrpcClientError>> + use<>, GrpcClientError>
342 {
343 let request = JwtBundlesRequest::default();
344 let response = self.client.fetch_jwt_bundles(request).await?;
345 let stream = response.into_inner().map(|message| {
346 message
347 .map_err(GrpcClientError::from)
348 .and_then(WorkloadApiClient::parse_jwt_bundle_set_from_grpc_response)
349 });
350 Ok(stream)
351 }
352}
353
354impl WorkloadApiClient {
356 async fn fetch_jwt<T: AsRef<str> + ToString>(
357 &mut self,
358 audience: &[T],
359 spiffe_id: Option<&SpiffeId>,
360 ) -> Result<JwtsvidResponse, GrpcClientError> {
361 let request = JwtsvidRequest {
362 spiffe_id: spiffe_id.map(ToString::to_string).unwrap_or_default(),
363 audience: audience.iter().map(|s| s.to_string()).collect(),
364 };
365
366 Ok(self.client.fetch_jwtsvid(request).await?.into_inner())
367 }
368
369 async fn validate_jwt<T: AsRef<str>>(
370 &mut self,
371 audience: T,
372 jwt_svid: &str,
373 ) -> Result<ValidateJwtsvidResponse, GrpcClientError> {
374 let request = ValidateJwtsvidRequest {
375 audience: audience.as_ref().into(),
376 svid: jwt_svid.into(),
377 };
378
379 Ok(self.client.validate_jwtsvid(request).await?.into_inner())
380 }
381
382 fn parse_x509_svid_from_grpc_response(
383 response: X509svidResponse,
384 ) -> Result<X509Svid, GrpcClientError> {
385 let svid = response
386 .svids
387 .get(DEFAULT_SVID)
388 .ok_or(GrpcClientError::EmptyResponse)?;
389
390 X509Svid::parse_from_der(svid.x509_svid.as_ref(), svid.x509_svid_key.as_ref())
391 .map_err(GrpcClientError::from)
392 }
393
394 fn parse_x509_svids_from_grpc_response(
395 response: X509svidResponse,
396 ) -> Result<Vec<X509Svid>, GrpcClientError> {
397 let mut svids_vec = Vec::new();
398
399 for svid in response.svids.iter() {
400 let parsed_svid =
401 X509Svid::parse_from_der(svid.x509_svid.as_ref(), svid.x509_svid_key.as_ref())
402 .map_err(GrpcClientError::from)?;
403
404 svids_vec.push(parsed_svid);
405 }
406
407 Ok(svids_vec)
408 }
409
410 fn parse_x509_bundle_set_from_grpc_response(
411 response: X509BundlesResponse,
412 ) -> Result<X509BundleSet, GrpcClientError> {
413 let bundles: Result<Vec<_>, _> = response
414 .bundles
415 .into_iter()
416 .map(|(td, bundle_data)| {
417 let trust_domain = TrustDomain::try_from(td)?;
418 X509Bundle::parse_from_der(trust_domain, &bundle_data)
419 .map_err(GrpcClientError::from)
420 })
421 .collect();
422
423 let mut bundle_set = X509BundleSet::new();
424 for bundle in bundles? {
425 bundle_set.add_bundle(bundle);
426 }
427
428 Ok(bundle_set)
429 }
430
431 fn parse_jwt_bundle_set_from_grpc_response(
432 response: JwtBundlesResponse,
433 ) -> Result<JwtBundleSet, GrpcClientError> {
434 let mut bundle_set = JwtBundleSet::new();
435
436 for (td, bundle_data) in response.bundles.into_iter() {
437 let trust_domain = TrustDomain::try_from(td)?;
438 let bundle = JwtBundle::from_jwt_authorities(trust_domain, &bundle_data)
439 .map_err(GrpcClientError::from)?;
440
441 bundle_set.add_bundle(bundle);
442 }
443
444 Ok(bundle_set)
445 }
446
447 fn parse_x509_context_from_grpc_response(
448 response: X509svidResponse,
449 ) -> Result<X509Context, GrpcClientError> {
450 let mut svids = Vec::new();
451 let mut bundle_set = X509BundleSet::new();
452
453 for svid in response.svids.into_iter() {
454 let x509_svid =
455 X509Svid::parse_from_der(svid.x509_svid.as_ref(), svid.x509_svid_key.as_ref())
456 .map_err(GrpcClientError::from)?;
457
458 let trust_domain = x509_svid.spiffe_id().trust_domain().clone();
459 svids.push(x509_svid);
460
461 let bundle = X509Bundle::parse_from_der(trust_domain, svid.bundle.as_ref())
462 .map_err(GrpcClientError::from)?;
463 bundle_set.add_bundle(bundle);
464 }
465
466 for (trust_domain, bundle) in response.federated_bundles.into_iter() {
467 let trust_domain = TrustDomain::try_from(trust_domain)?;
468 let x509_bundle = X509Bundle::parse_from_der(trust_domain, bundle.as_ref())
469 .map_err(GrpcClientError::from)?;
470 bundle_set.add_bundle(x509_bundle);
471 }
472
473 Ok(X509Context::new(svids, bundle_set))
474 }
475}