use crate::api::AuthInfo;
use crate::api::{token::responses::LookupTokenResponse, EndpointMiddleware};
use crate::error::ClientError;
use async_trait::async_trait;
pub use reqwest::Identity;
use rustify::clients::reqwest::Client as HTTPClient;
use std::time::Duration;
use std::{env, fs};
use url::Url;
const VALID_SCHEMES: [&str; 2] = ["http", "https"];
#[async_trait]
pub trait Client: Send + Sync + Sized {
fn http(&self) -> &HTTPClient;
fn middle(&self) -> &EndpointMiddleware;
fn settings(&self) -> &VaultClientSettings;
fn set_token(&mut self, token: &str);
async fn lookup(&self) -> Result<LookupTokenResponse, ClientError> {
crate::token::lookup_self(self).await
}
async fn renew(&self, increment: Option<&str>) -> Result<AuthInfo, ClientError> {
crate::token::renew_self(self, increment).await
}
async fn revoke(&self) -> Result<(), ClientError> {
crate::token::revoke_self(self).await
}
async fn status(&self) -> Result<crate::sys::ServerStatus, ClientError> {
crate::sys::status(self).await
}
}
pub struct VaultClient {
pub http: HTTPClient,
pub middle: EndpointMiddleware,
pub settings: VaultClientSettings,
}
#[async_trait]
impl Client for VaultClient {
fn http(&self) -> &HTTPClient {
&self.http
}
fn middle(&self) -> &EndpointMiddleware {
&self.middle
}
fn settings(&self) -> &VaultClientSettings {
&self.settings
}
fn set_token(&mut self, token: &str) {
self.settings.token = token.to_string();
self.middle.token = token.to_string();
}
}
impl VaultClient {
#[instrument(skip(settings), err)]
pub fn new(settings: VaultClientSettings) -> Result<VaultClient, ClientError> {
let mut http_client = reqwest::ClientBuilder::new();
http_client = if let Some(timeout) = settings.timeout {
http_client.timeout(timeout)
} else {
http_client
};
if !settings.verify {
event!(tracing::Level::WARN, "Disabling TLS verification");
}
http_client = http_client.danger_accept_invalid_certs(!settings.verify);
for path in &settings.ca_certs {
let content = std::fs::read(path).map_err(|e| ClientError::FileReadError {
source: e,
path: path.clone(),
})?;
let cert = reqwest::Certificate::from_pem(&content).map_err(|e| {
ClientError::ParseCertificateError {
source: e,
path: path.clone(),
}
})?;
info!("Importing CA certificate from {}", path);
http_client = http_client.add_root_certificate(cert);
}
if let Some(identity) = &settings.identity {
http_client = http_client.identity(identity.clone());
}
debug!("Using API version {}", settings.version);
let version_str = format!("v{}", settings.version);
let middle = EndpointMiddleware {
token: settings.token.clone(),
version: version_str,
wrap: None,
namespace: settings.namespace.clone(),
};
let http_client = http_client
.build()
.map_err(|e| ClientError::RestClientBuildError { source: e })?;
let http = HTTPClient::new(settings.address.as_str(), http_client);
Ok(VaultClient {
settings,
middle,
http,
})
}
}
#[derive(Builder, Clone, Debug)]
#[builder(build_fn(validate = "Self::validate"))]
pub struct VaultClientSettings {
#[builder(setter(custom), default = "self.default_address()?")]
pub address: Url,
#[builder(default = "self.default_ca_certs()")]
pub ca_certs: Vec<String>,
#[builder(default = "self.default_identity()")]
pub identity: Option<Identity>,
#[builder(default)]
pub timeout: Option<Duration>,
#[builder(setter(into), default = "self.default_token()")]
pub token: String,
#[builder(default = "self.default_verify()")]
pub verify: bool,
#[builder(setter(into, strip_option), default = "1")]
pub version: u8,
#[builder(default = "false")]
pub wrapping: bool,
#[builder(default)]
pub namespace: Option<String>,
}
impl VaultClientSettingsBuilder {
pub fn address<T>(&mut self, address: T) -> &mut Self
where
T: AsRef<str>,
{
let url = Url::parse(address.as_ref())
.map_err(|_| format!("Invalid URL format: {}", address.as_ref()))
.unwrap();
self.address = Some(url);
self
}
pub fn set_namespace(&mut self, str: String) -> &mut Self {
self.namespace = Some(Some(str));
self
}
fn default_address(&self) -> Result<Url, String> {
let address = if let Ok(address) = env::var("VAULT_ADDR") {
info!("Using vault address from $VAULT_ADDR: {address}");
address
} else {
info!("Using default vault address http://127.0.0.1:8200");
String::from("http://127.0.0.1:8200")
};
let url = Url::parse(&address);
let url = url.map_err(|_| format!("Invalid URL format: {}", &address))?;
self.validate_url(&url)?;
Ok(url)
}
fn default_token(&self) -> String {
match env::var("VAULT_TOKEN") {
Ok(s) => {
info!("Using vault token from $VAULT_TOKEN");
s
}
Err(_) => {
info!("Using default empty vault token");
String::from("")
}
}
}
fn default_verify(&self) -> bool {
info!("Checking TLS verification using $VAULT_SKIP_VERIFY");
match env::var("VAULT_SKIP_VERIFY") {
Ok(value) => !matches!(value.to_lowercase().as_str(), "0" | "f" | "false"),
Err(_) => true,
}
}
fn default_ca_certs(&self) -> Vec<String> {
let mut paths: Vec<String> = Vec::new();
if let Ok(s) = env::var("VAULT_CACERT") {
info!("Found CA certificate in $VAULT_CACERT");
paths.push(s);
}
if let Ok(s) = env::var("VAULT_CAPATH") {
info!("Found CA certificate path in $VAULT_CAPATH");
if let Ok(p) = fs::read_dir(s) {
for path in p {
paths.push(path.unwrap().path().to_str().unwrap().to_string())
}
}
}
paths
}
fn default_identity(&self) -> Option<reqwest::Identity> {
let env_client_cert = env::var("VAULT_CLIENT_CERT").unwrap_or_default();
let env_client_key = env::var("VAULT_CLIENT_KEY").unwrap_or_default();
if env_client_cert.is_empty() || env_client_key.is_empty() {
debug!("No client certificate (env VAULT_CLIENT_CERT & VAULT_CLIENT_KEY are not set)");
return None;
}
#[cfg(feature = "rustls")]
{
let mut client_cert = match fs::read(&env_client_cert) {
Ok(content) => content,
Err(err) => {
error!("error reading client cert '{}': {}", env_client_cert, err);
return None;
}
};
let mut client_key = match fs::read(&env_client_key) {
Ok(content) => content,
Err(err) => {
error!("error reading client key '{}': {}", env_client_key, err);
return None;
}
};
client_cert.append(&mut client_key);
match reqwest::Identity::from_pem(&client_cert) {
Ok(pkcs8) => return Some(pkcs8),
Err(err) => error!("error creating identity: {}", err),
};
}
#[cfg(feature = "native-tls")]
{
error!("Client certificates not implemented for native-tls");
}
None
}
fn validate(&self) -> Result<(), String> {
if let Some(url) = &self.address {
self.validate_url(url)
} else {
Ok(())
}
}
fn validate_url(&self, url: &Url) -> Result<(), String> {
if !VALID_SCHEMES.contains(&url.scheme()) {
Err(format!("Invalid scheme for HTTP URL: {}", url.scheme()))
} else {
Ok(())
}
}
}