wadm_client/
lib.rs

1//! A client for interacting with Wadm.
2use std::path::PathBuf;
3use std::sync::{Arc, OnceLock};
4
5use async_nats::{HeaderMap, Message};
6use error::{ClientError, SerializationError};
7use futures::Stream;
8use topics::TopicGenerator;
9use wadm_types::{
10    api::{
11        DeleteModelRequest, DeleteModelResponse, DeleteResult, DeployModelRequest,
12        DeployModelResponse, DeployResult, GetModelRequest, GetModelResponse, GetResult,
13        ModelSummary, PutModelResponse, PutResult, Status, StatusResponse, StatusResult,
14        VersionInfo, VersionResponse,
15    },
16    Manifest,
17};
18
19mod nats;
20
21pub mod error;
22pub use error::Result;
23pub mod loader;
24pub use loader::ManifestLoader;
25pub mod topics;
26
27/// Headers for `Content-Type: application/json`
28static HEADERS_CONTENT_TYPE_JSON: OnceLock<HeaderMap> = OnceLock::new();
29/// Retrieve static content type headers
30fn get_headers_content_type_json() -> &'static HeaderMap {
31    HEADERS_CONTENT_TYPE_JSON.get_or_init(|| {
32        let mut headers = HeaderMap::new();
33        headers.insert("Content-Type", "application/json");
34        headers
35    })
36}
37
38#[derive(Clone)]
39pub struct Client {
40    topics: Arc<TopicGenerator>,
41    client: async_nats::Client,
42}
43
44#[derive(Default, Clone)]
45/// Options for connecting to a NATS server for a Wadm client. Setting none of these options will
46/// default to anonymous authentication with a localhost NATS server running on port 4222
47pub struct ClientConnectOptions {
48    /// The URL of the NATS server to connect to. If not provided, the client will connect to the
49    /// default NATS address of 127.0.0.1:4222
50    pub url: Option<String>,
51    /// An nkey seed to use for authenticating with the NATS server. This can either be the raw seed
52    /// or a path to a file containing the seed. If used, the `jwt` option must be provided
53    pub seed: Option<String>,
54    /// A JWT to use for authenticating with the NATS server. This can either be the raw JWT or a
55    /// path to a file containing the JWT. If used, the `seed` option must be provided
56    pub jwt: Option<String>,
57    /// A path to a file containing the credentials to use for authenticating with the NATS server.
58    /// If used, the `seed` and `jwt` options must not be provided
59    pub creds_path: Option<PathBuf>,
60    /// An optional path to a file containing the root CA certificates to use for authenticating
61    /// with the NATS server.
62    pub ca_path: Option<PathBuf>,
63}
64
65impl Client {
66    /// Creates a new client with the given lattice ID, optional API prefix, and connection options.
67    /// Errors if it is unable to connect to the NATS server
68    pub async fn new(
69        lattice: &str,
70        prefix: Option<&str>,
71        opts: ClientConnectOptions,
72    ) -> anyhow::Result<Self> {
73        let topics = TopicGenerator::new(lattice, prefix);
74        let nats_client =
75            nats::get_client(opts.url, opts.seed, opts.jwt, opts.creds_path, opts.ca_path).await?;
76        Ok(Client {
77            topics: Arc::new(topics),
78            client: nats_client,
79        })
80    }
81
82    /// Creates a new client with the given lattice ID, optional API prefix, and NATS client. This
83    /// is not recommended and is hidden because the async-nats crate is not 1.0 yet. That means it
84    /// is a breaking API change every time we upgrade versions. DO NOT use this function unless you
85    /// are willing to accept this breaking change. This function is explicitly excluded from our
86    /// semver guarantees until async-nats is 1.0.
87    #[doc(hidden)]
88    pub fn from_nats_client(
89        lattice: &str,
90        prefix: Option<&str>,
91        nats_client: async_nats::Client,
92    ) -> Self {
93        let topics = TopicGenerator::new(lattice, prefix);
94        Client {
95            topics: Arc::new(topics),
96            client: nats_client,
97        }
98    }
99
100    /// Puts the given manifest into the lattice. The lattice can be anything that implements the
101    /// [`ManifestLoader`] trait (a path to a file, raw bytes, or an already parsed manifest).
102    ///
103    /// Returns the name and version of the manifest that was put into the lattice
104    pub async fn put_manifest(&self, manifest: impl ManifestLoader) -> Result<(String, String)> {
105        let manifest = manifest.load_manifest().await?;
106        let manifest_bytes = serde_json::to_vec(&manifest).map_err(SerializationError::from)?;
107        let topic = self.topics.model_put_topic();
108        let resp = self
109            .client
110            .request_with_headers(
111                topic,
112                get_headers_content_type_json().clone(),
113                manifest_bytes.into(),
114            )
115            .await?;
116        let body: PutModelResponse =
117            serde_json::from_slice(&resp.payload).map_err(SerializationError::from)?;
118        if matches!(body.result, PutResult::Error) {
119            return Err(ClientError::ApiError(body.message));
120        }
121        Ok((body.name, body.current_version))
122    }
123
124    /// Gets a list of all manifests in the lattice. This does not return the full manifest, just a
125    /// summary of its metadata and status
126    pub async fn list_manifests(&self) -> Result<Vec<ModelSummary>> {
127        let topic = self.topics.model_list_topic();
128        let resp = self
129            .client
130            .request(topic, Vec::with_capacity(0).into())
131            .await?;
132        let body: Vec<ModelSummary> =
133            serde_json::from_slice(&resp.payload).map_err(SerializationError::from)?;
134        Ok(body)
135    }
136
137    /// Gets a manifest from the lattice by name and optionally its version. If no version is set,
138    /// the latest version will be returned
139    pub async fn get_manifest(&self, name: &str, version: Option<&str>) -> Result<Manifest> {
140        let topic = self.topics.model_get_topic(name);
141        let body = if let Some(version) = version {
142            serde_json::to_vec(&GetModelRequest {
143                version: Some(version.to_string()),
144            })
145            .map_err(SerializationError::from)?
146        } else {
147            Vec::with_capacity(0)
148        };
149        let resp = self.client.request(topic, body.into()).await?;
150        let body: GetModelResponse =
151            serde_json::from_slice(&resp.payload).map_err(SerializationError::from)?;
152
153        match body.result {
154            GetResult::Error => Err(ClientError::ApiError(body.message)),
155            GetResult::NotFound => Err(ClientError::NotFound(name.to_string())),
156            GetResult::Success => body.manifest.ok_or_else(|| {
157                ClientError::ApiError("API returned success but didn't set a manifest".to_string())
158            }),
159        }
160    }
161
162    /// Deletes a manifest from the lattice by name and optionally its version. If no version is
163    /// set, all versions will be deleted
164    ///
165    /// Returns true if the manifest was deleted, false if it was a noop (meaning it wasn't found or
166    /// was already deleted)
167    pub async fn delete_manifest(&self, name: &str, version: Option<&str>) -> Result<bool> {
168        let topic = self.topics.model_delete_topic(name);
169        let body = if let Some(version) = version {
170            serde_json::to_vec(&DeleteModelRequest {
171                version: Some(version.to_string()),
172            })
173            .map_err(SerializationError::from)?
174        } else {
175            Vec::with_capacity(0)
176        };
177        let resp = self.client.request(topic, body.into()).await?;
178        let body: DeleteModelResponse =
179            serde_json::from_slice(&resp.payload).map_err(SerializationError::from)?;
180        match body.result {
181            DeleteResult::Error => Err(ClientError::ApiError(body.message)),
182            DeleteResult::Noop => Ok(false),
183            DeleteResult::Deleted => Ok(true),
184        }
185    }
186
187    /// Gets a list of all versions of a manifest in the lattice
188    pub async fn list_versions(&self, name: &str) -> Result<Vec<VersionInfo>> {
189        let topic = self.topics.model_versions_topic(name);
190        let resp = self
191            .client
192            .request(topic, Vec::with_capacity(0).into())
193            .await?;
194        let body: VersionResponse =
195            serde_json::from_slice(&resp.payload).map_err(SerializationError::from)?;
196        match body.result {
197            GetResult::Error => Err(ClientError::ApiError(body.message)),
198            GetResult::NotFound => Err(ClientError::NotFound(name.to_string())),
199            GetResult::Success => Ok(body.versions),
200        }
201    }
202
203    /// Deploys a manifest to the lattice. The optional version parameter can be used to deploy a
204    /// specific version of a manifest. If no version is set, the latest version will be deployed
205    ///
206    /// Please note that an OK response does not necessarily mean that the manifest was deployed
207    /// successfully, just that the server accepted the deployment request.
208    ///
209    /// Returns a tuple of the name and version of the manifest that was deployed
210    pub async fn deploy_manifest(
211        &self,
212        name: &str,
213        version: Option<&str>,
214    ) -> Result<(String, Option<String>)> {
215        let topic = self.topics.model_deploy_topic(name);
216        let body = if let Some(version) = version {
217            serde_json::to_vec(&DeployModelRequest {
218                version: Some(version.to_string()),
219            })
220            .map_err(SerializationError::from)?
221        } else {
222            Vec::with_capacity(0)
223        };
224        let resp = self.client.request(topic, body.into()).await?;
225        let body: DeployModelResponse =
226            serde_json::from_slice(&resp.payload).map_err(SerializationError::from)?;
227        match body.result {
228            DeployResult::Error => Err(ClientError::ApiError(body.message)),
229            DeployResult::NotFound => Err(ClientError::NotFound(name.to_string())),
230            DeployResult::Acknowledged => Ok((body.name, body.version)),
231        }
232    }
233
234    /// A shorthand method that is the equivalent of calling [`put_manifest`](Self::put_manifest)
235    /// and then [`deploy_manifest`](Self::deploy_manifest)
236    ///
237    /// Returns the name and version of the manifest that was deployed. Note that this will always
238    /// deploy the latest version of the manifest (i.e. the one that was just put)
239    pub async fn put_and_deploy_manifest(
240        &self,
241        manifest: impl ManifestLoader,
242    ) -> Result<(String, String)> {
243        let (name, version) = self.put_manifest(manifest).await?;
244        // We don't technically need to put the version since we just deployed, but to make sure we
245        // maintain that behvior we'll put it here just in case
246        self.deploy_manifest(&name, Some(&version)).await?;
247        Ok((name, version))
248    }
249
250    /// Undeploys the given manifest from the lattice
251    ///
252    /// Returns Ok(manifest_name) if the manifest undeploy request was acknowledged
253    pub async fn undeploy_manifest(&self, name: &str) -> Result<String> {
254        let topic = self.topics.model_undeploy_topic(name);
255        let resp = self
256            .client
257            .request(topic, Vec::with_capacity(0).into())
258            .await?;
259        let body: DeployModelResponse =
260            serde_json::from_slice(&resp.payload).map_err(SerializationError::from)?;
261        match body.result {
262            DeployResult::Error => Err(ClientError::ApiError(body.message)),
263            DeployResult::NotFound => Err(ClientError::NotFound(name.to_string())),
264            DeployResult::Acknowledged => Ok(body.name),
265        }
266    }
267
268    /// Gets the status of the given manifest
269    pub async fn get_manifest_status(&self, name: &str) -> Result<Status> {
270        let topic = self.topics.model_status_topic(name);
271        let resp = self
272            .client
273            .request(topic, Vec::with_capacity(0).into())
274            .await?;
275        let body: StatusResponse =
276            serde_json::from_slice(&resp.payload).map_err(SerializationError::from)?;
277        match body.result {
278            StatusResult::Error => Err(ClientError::ApiError(body.message)),
279            StatusResult::NotFound => Err(ClientError::NotFound(name.to_string())),
280            StatusResult::Ok => body.status.ok_or_else(|| {
281                ClientError::ApiError("API returned success but didn't set a status".to_string())
282            }),
283        }
284    }
285
286    /// Subscribes to the status of a given manifest
287    pub async fn subscribe_to_status(&self, name: &str) -> Result<impl Stream<Item = Message>> {
288        let subject = self.topics.wadm_status_topic(name);
289        let subscriber = self
290            .client
291            .subscribe(subject)
292            .await
293            .map_err(|e| ClientError::ApiError(e.to_string()))?;
294
295        Ok(subscriber)
296    }
297}