1pub mod auth;
2pub mod aws;
3pub mod database;
4pub mod identity;
5pub mod kv1;
6pub mod kv2;
7pub mod pki;
8pub mod ssh;
9pub mod sys;
10pub mod token;
11pub mod transit;
12
13use std::collections::HashMap;
14use std::str::FromStr;
15
16use async_trait::async_trait;
17use rustify::endpoint::{Endpoint, MiddleWare};
18use rustify::errors::ClientError as RestClientError;
19use serde::{de::DeserializeOwned, Deserialize};
20
21use crate::sys::wrapping;
22use crate::{client::Client, error::ClientError};
23
24use self::sys::responses::WrappingLookupResponse;
25
26#[derive(Deserialize, Debug)]
35pub struct EndpointResult<T> {
36 pub data: Option<T>,
37 pub auth: Option<AuthInfo>,
38 pub lease_id: String,
39 pub lease_duration: u32,
40 pub renewable: bool,
41 pub request_id: String,
42 pub warnings: Option<Vec<String>>,
43 pub wrap_info: Option<WrapInfo>,
44}
45
46impl<T: DeserializeOwned + Send + Sync> rustify::endpoint::Wrapper for EndpointResult<T> {
47 type Value = T;
48}
49
50#[derive(Deserialize, Debug)]
52pub struct WrapInfo {
53 pub token: String,
54 pub accessor: String,
55 pub ttl: u64,
56 pub creation_time: String,
57 pub creation_path: String,
58}
59
60#[derive(Deserialize, Debug)]
62pub struct AuthInfo {
63 pub client_token: String,
64 pub accessor: String,
65 pub policies: Vec<String>,
66 pub token_policies: Vec<String>,
67 pub metadata: Option<HashMap<String, String>>,
68 pub lease_duration: u64,
69 pub renewable: bool,
70 pub entity_id: String,
71 pub token_type: String,
72 pub orphan: bool,
73}
74
75pub struct WrappedResponse<E: Endpoint> {
83 pub info: WrapInfo,
84 pub endpoint: rustify::endpoint::EndpointResult<E::Response>,
85}
86
87impl<E: Endpoint> WrappedResponse<E> {
88 pub async fn lookup(
90 &self,
91 client: &impl Client,
92 ) -> Result<WrappingLookupResponse, ClientError> {
93 debug!("Looking up wrapped response information");
94 wrapping::lookup(client, self.info.token.as_str())
95 .await
96 .map_err(|e| match &e {
97 ClientError::APIError {
98 code: 400,
99 errors: _,
100 } => ClientError::WrapInvalidError,
101 _ => e,
102 })
103 }
104
105 pub async fn unwrap(&self, client: &impl Client) -> Result<E::Response, ClientError> {
107 wrapping::unwrap(client, Some(self.info.token.as_str())).await
108 }
109}
110
111#[async_trait]
113pub trait ResponseWrapper: Endpoint {
114 async fn wrap(self, client: &impl Client) -> Result<WrappedResponse<Self>, ClientError> {
115 wrap(client, self).await
116 }
117}
118
119impl<E: Endpoint> ResponseWrapper for E {}
120
121#[derive(Deserialize, Debug)]
126pub struct EndpointError {
127 pub errors: Vec<String>,
128}
129
130#[derive(Debug, Clone)]
137pub struct EndpointMiddleware {
138 pub token: String,
139 pub version: String,
140 pub wrap: Option<String>,
141 pub namespace: Option<String>,
142}
143impl MiddleWare for EndpointMiddleware {
144 fn request<E: Endpoint>(
145 &self,
146 _: &E,
147 req: &mut http::Request<Vec<u8>>,
148 ) -> Result<(), rustify::errors::ClientError> {
149 trace!(
151 "Middleware: prepending {} version to URL",
152 self.version.as_str()
153 );
154 let url = url::Url::parse(req.uri().to_string().as_str()).unwrap();
155 let mut url_c = url.clone();
156 let mut segs: Vec<&str> = url.path_segments().unwrap().collect();
157 segs.insert(0, self.version.as_str());
158 url_c.set_path(format!("{}{}", self.version, url_c.path()).as_str());
159 *req.uri_mut() = http::Uri::from_str(url_c.as_str()).unwrap();
160 trace!("Middleware: final URL is {}", url_c.as_str());
161
162 req.headers_mut().append(
164 "X-Vault-Request",
165 http::HeaderValue::from_str("true").unwrap(),
166 );
167
168 if !self.token.is_empty() {
170 trace!("Middleware: adding token to header");
171 req.headers_mut().append(
172 "X-Vault-Token",
173 http::HeaderValue::from_str(self.token.as_str()).unwrap(),
174 );
175 }
176
177 if let Some(wrap) = &self.wrap {
179 trace!("Middleware: adding wrap header with {} ttl", wrap);
180 req.headers_mut().append(
181 "X-Vault-Wrap-TTL",
182 http::HeaderValue::from_str(wrap.as_str()).unwrap(),
183 );
184 }
185
186 if let Some(namespace) = &self.namespace {
188 trace!("Middleware: adding namespace header {}", namespace);
189 req.headers_mut().append(
190 "X-Vault-Namespace",
191 http::HeaderValue::from_str(namespace.as_str()).unwrap(),
192 );
193 }
194
195 Ok(())
196 }
197
198 fn response<E: Endpoint>(
199 &self,
200 _: &E,
201 _: &mut http::Response<Vec<u8>>,
202 ) -> Result<(), rustify::errors::ClientError> {
203 Ok(())
204 }
205}
206
207#[instrument(name = "request", skip_all, fields(method = ?endpoint.method(), path = %endpoint.path()), err)]
212pub async fn exec_with_empty<E>(client: &impl Client, endpoint: E) -> Result<(), ClientError>
213where
214 E: Endpoint,
215{
216 trace!("start request");
217 endpoint
218 .with_middleware(client.middle())
219 .exec(client.http())
220 .await
221 .map_err(parse_err)
222 .map(|_| ())
223}
224
225#[instrument(name = "request", skip_all, fields(method = ?endpoint.method(), path = %endpoint.path()), err)]
230pub async fn exec_with_empty_result<E>(client: &impl Client, endpoint: E) -> Result<(), ClientError>
231where
232 E: Endpoint,
233{
234 trace!("start request");
235 endpoint
236 .with_middleware(client.middle())
237 .exec(client.http())
238 .await
239 .map_err(ClientError::from)?
240 .wrap::<EndpointResult<_>>()
241 .map_err(parse_err)
242 .map(strip)
243 .map(|_| ())
244}
245
246#[instrument(name = "request", skip_all, fields(method = ?endpoint.method(), path = %endpoint.path()), err)]
251pub async fn exec_with_no_result<E>(
252 client: &impl Client,
253 endpoint: E,
254) -> Result<E::Response, ClientError>
255where
256 E: Endpoint,
257{
258 trace!("start request");
259 endpoint
260 .with_middleware(client.middle())
261 .exec(client.http())
262 .await
263 .map_err(parse_err)?
264 .parse()
265 .map_err(ClientError::from)
266}
267
268#[instrument(name = "request", skip_all, fields(method = ?endpoint.method(), path = %endpoint.path()), err)]
285pub async fn exec_with_result<E>(
286 client: &impl Client,
287 endpoint: E,
288) -> Result<E::Response, ClientError>
289where
290 E: Endpoint,
291{
292 trace!("start request");
293 endpoint
294 .with_middleware(client.middle())
295 .exec(client.http())
296 .await
297 .map_err(parse_err)?
298 .wrap::<EndpointResult<_>>()
299 .map_err(ClientError::from)
300 .map(strip)?
301 .ok_or(ClientError::ResponseDataEmptyError)
302}
303
304pub async fn wrap<E>(client: &impl Client, endpoint: E) -> Result<WrappedResponse<E>, ClientError>
310where
311 E: Endpoint,
312{
313 trace!(
314 "Executing {} and returning a wrapped response",
315 endpoint.path()
316 );
317 let mut m = client.middle().clone();
318 m.wrap = Some("10m".to_string());
319 let resp = endpoint
320 .with_middleware(&m)
321 .exec(client.http())
322 .await
323 .map_err(parse_err)?;
324 let info = resp
325 .wrap::<EndpointResult<_>>()
326 .map_err(ClientError::from)
327 .map(strip_wrap)??;
328 Ok(WrappedResponse {
329 info,
330 endpoint: resp,
331 })
332}
333
334pub async fn auth<E>(client: &impl Client, endpoint: E) -> Result<AuthInfo, ClientError>
335where
336 E: Endpoint<Response = ()>,
337{
338 trace!(
339 "Executing {} and returning authentication info",
340 endpoint.path()
341 );
342 let r: EndpointResult<()> = endpoint
343 .with_middleware(client.middle())
344 .exec(client.http())
345 .await
346 .map_err(parse_err)?
347 .wrap::<EndpointResult<_>>()
348 .map_err(ClientError::from)?;
349 r.auth.ok_or(ClientError::ResponseEmptyError)
350}
351
352fn strip_wrap<T>(result: EndpointResult<T>) -> Result<WrapInfo, ClientError> {
355 trace!("Stripping wrap info from API response");
356 if let Some(w) = &result.warnings {
357 if !w.is_empty() {
358 warn!("Server returned warnings with response: {:#?}", w);
359 }
360 }
361 result.wrap_info.ok_or(ClientError::ResponseWrapError {})
362}
363
364fn strip<T>(result: EndpointResult<T>) -> Option<T>
366where
367 T: DeserializeOwned,
368{
369 trace!("Stripping response wrapper from API response");
370 if let Some(w) = &result.warnings {
371 if !w.is_empty() {
372 warn!("Detected warnings in API response: {:#?}", w);
373 }
374 }
375 result.data
376}
377
378fn parse_err(e: RestClientError) -> ClientError {
383 if let RestClientError::ServerResponseError { code, content } = &e {
384 match content {
385 Some(c) => {
386 let errs: Result<EndpointError, _> = serde_json::from_str(c.as_str());
387 match errs {
388 Ok(err) => {
389 if !err.errors.is_empty() {
390 error!("Detected errors in API response: {:#?}", err.errors);
391 }
392 ClientError::APIError {
393 code: *code,
394 errors: err.errors,
395 }
396 }
397 Err(_) => ClientError::from(e),
398 }
399 }
400 None => ClientError::from(e),
401 }
402 } else {
403 ClientError::from(e)
404 }
405}