use std::path::PathBuf;
use std::sync::{Arc, OnceLock};
use async_nats::{HeaderMap, Message};
use error::{ClientError, SerializationError};
use futures::Stream;
use topics::TopicGenerator;
use wadm_types::{
api::{
DeleteModelRequest, DeleteModelResponse, DeleteResult, DeployModelRequest,
DeployModelResponse, DeployResult, GetModelRequest, GetModelResponse, GetResult,
ModelSummary, PutModelResponse, PutResult, Status, StatusResponse, StatusResult,
VersionInfo, VersionResponse,
},
Manifest,
};
mod nats;
pub mod error;
pub use error::Result;
pub mod loader;
pub use loader::ManifestLoader;
pub mod topics;
static HEADERS_CONTENT_TYPE_JSON: OnceLock<HeaderMap> = OnceLock::new();
fn get_headers_content_type_json() -> &'static HeaderMap {
HEADERS_CONTENT_TYPE_JSON.get_or_init(|| {
let mut headers = HeaderMap::new();
headers.insert("Content-Type", "application/json");
headers
})
}
#[derive(Clone)]
pub struct Client {
topics: Arc<TopicGenerator>,
client: async_nats::Client,
}
#[derive(Default, Clone)]
pub struct ClientConnectOptions {
pub url: Option<String>,
pub seed: Option<String>,
pub jwt: Option<String>,
pub creds_path: Option<PathBuf>,
pub ca_path: Option<PathBuf>,
}
impl Client {
pub async fn new(
lattice: &str,
prefix: Option<&str>,
opts: ClientConnectOptions,
) -> anyhow::Result<Self> {
let topics = TopicGenerator::new(lattice, prefix);
let nats_client =
nats::get_client(opts.url, opts.seed, opts.jwt, opts.creds_path, opts.ca_path).await?;
Ok(Client {
topics: Arc::new(topics),
client: nats_client,
})
}
#[doc(hidden)]
pub fn from_nats_client(
lattice: &str,
prefix: Option<&str>,
nats_client: async_nats::Client,
) -> Self {
let topics = TopicGenerator::new(lattice, prefix);
Client {
topics: Arc::new(topics),
client: nats_client,
}
}
pub async fn put_manifest(&self, manifest: impl ManifestLoader) -> Result<(String, String)> {
let manifest = manifest.load_manifest().await?;
let manifest_bytes = serde_json::to_vec(&manifest).map_err(SerializationError::from)?;
let topic = self.topics.model_put_topic();
let resp = self
.client
.request_with_headers(
topic,
get_headers_content_type_json().clone(),
manifest_bytes.into(),
)
.await?;
let body: PutModelResponse =
serde_json::from_slice(&resp.payload).map_err(SerializationError::from)?;
if matches!(body.result, PutResult::Error) {
return Err(ClientError::ApiError(body.message));
}
Ok((body.name, body.current_version))
}
pub async fn list_manifests(&self) -> Result<Vec<ModelSummary>> {
let topic = self.topics.model_list_topic();
let resp = self
.client
.request(topic, Vec::with_capacity(0).into())
.await?;
let body: Vec<ModelSummary> =
serde_json::from_slice(&resp.payload).map_err(SerializationError::from)?;
Ok(body)
}
pub async fn get_manifest(&self, name: &str, version: Option<&str>) -> Result<Manifest> {
let topic = self.topics.model_get_topic(name);
let body = if let Some(version) = version {
serde_json::to_vec(&GetModelRequest {
version: Some(version.to_string()),
})
.map_err(SerializationError::from)?
} else {
Vec::with_capacity(0)
};
let resp = self.client.request(topic, body.into()).await?;
let body: GetModelResponse =
serde_json::from_slice(&resp.payload).map_err(SerializationError::from)?;
match body.result {
GetResult::Error => Err(ClientError::ApiError(body.message)),
GetResult::NotFound => Err(ClientError::NotFound(name.to_string())),
GetResult::Success => body.manifest.ok_or_else(|| {
ClientError::ApiError("API returned success but didn't set a manifest".to_string())
}),
}
}
pub async fn delete_manifest(&self, name: &str, version: Option<&str>) -> Result<bool> {
let topic = self.topics.model_delete_topic(name);
let body = if let Some(version) = version {
serde_json::to_vec(&DeleteModelRequest {
version: Some(version.to_string()),
})
.map_err(SerializationError::from)?
} else {
Vec::with_capacity(0)
};
let resp = self.client.request(topic, body.into()).await?;
let body: DeleteModelResponse =
serde_json::from_slice(&resp.payload).map_err(SerializationError::from)?;
match body.result {
DeleteResult::Error => Err(ClientError::ApiError(body.message)),
DeleteResult::Noop => Ok(false),
DeleteResult::Deleted => Ok(true),
}
}
pub async fn list_versions(&self, name: &str) -> Result<Vec<VersionInfo>> {
let topic = self.topics.model_versions_topic(name);
let resp = self
.client
.request(topic, Vec::with_capacity(0).into())
.await?;
let body: VersionResponse =
serde_json::from_slice(&resp.payload).map_err(SerializationError::from)?;
match body.result {
GetResult::Error => Err(ClientError::ApiError(body.message)),
GetResult::NotFound => Err(ClientError::NotFound(name.to_string())),
GetResult::Success => Ok(body.versions),
}
}
pub async fn deploy_manifest(
&self,
name: &str,
version: Option<&str>,
) -> Result<(String, Option<String>)> {
let topic = self.topics.model_deploy_topic(name);
let body = if let Some(version) = version {
serde_json::to_vec(&DeployModelRequest {
version: Some(version.to_string()),
})
.map_err(SerializationError::from)?
} else {
Vec::with_capacity(0)
};
let resp = self.client.request(topic, body.into()).await?;
let body: DeployModelResponse =
serde_json::from_slice(&resp.payload).map_err(SerializationError::from)?;
match body.result {
DeployResult::Error => Err(ClientError::ApiError(body.message)),
DeployResult::NotFound => Err(ClientError::NotFound(name.to_string())),
DeployResult::Acknowledged => Ok((body.name, body.version)),
}
}
pub async fn put_and_deploy_manifest(
&self,
manifest: impl ManifestLoader,
) -> Result<(String, String)> {
let (name, version) = self.put_manifest(manifest).await?;
self.deploy_manifest(&name, Some(&version)).await?;
Ok((name, version))
}
pub async fn undeploy_manifest(&self, name: &str) -> Result<String> {
let topic = self.topics.model_undeploy_topic(name);
let resp = self
.client
.request(topic, Vec::with_capacity(0).into())
.await?;
let body: DeployModelResponse =
serde_json::from_slice(&resp.payload).map_err(SerializationError::from)?;
match body.result {
DeployResult::Error => Err(ClientError::ApiError(body.message)),
DeployResult::NotFound => Err(ClientError::NotFound(name.to_string())),
DeployResult::Acknowledged => Ok(body.name),
}
}
pub async fn get_manifest_status(&self, name: &str) -> Result<Status> {
let topic = self.topics.model_status_topic(name);
let resp = self
.client
.request(topic, Vec::with_capacity(0).into())
.await?;
let body: StatusResponse =
serde_json::from_slice(&resp.payload).map_err(SerializationError::from)?;
match body.result {
StatusResult::Error => Err(ClientError::ApiError(body.message)),
StatusResult::NotFound => Err(ClientError::NotFound(name.to_string())),
StatusResult::Ok => body.status.ok_or_else(|| {
ClientError::ApiError("API returned success but didn't set a status".to_string())
}),
}
}
pub async fn subscribe_to_status(&self, name: &str) -> Result<impl Stream<Item = Message>> {
let subject = self.topics.wadm_status_topic(name);
let subscriber = self
.client
.subscribe(subject)
.await
.map_err(|e| ClientError::ApiError(e.to_string()))?;
Ok(subscriber)
}
}