1use std::collections::HashMap;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use anyhow::{anyhow, bail, Context as _};
6use async_nats::HeaderMap;
7use futures::stream::{AbortHandle, Abortable};
8use futures::StreamExt;
9use opentelemetry_nats::NatsHeaderInjector;
10use tokio::sync::{OwnedSemaphorePermit, RwLock, Semaphore};
11use tracing::{debug, error, instrument, warn};
12use tracing_futures::Instrument as _;
13use wadm_client::{Client, ClientConnectOptions};
14use wadm_types::wasmcloud::wadm::handler::StatusUpdate;
15use wasmcloud_provider_sdk::{
16 core::HostData, get_connection, load_host_data, provider::WrpcClient, run_provider, Context,
17 LinkConfig, Provider,
18};
19use wasmcloud_provider_sdk::{initialize_observability, serve_provider_exports, LinkDeleteInfo};
20
21use crate::exports::wasmcloud::wadm::client::{ModelSummary, OamManifest, Status, VersionInfo};
22
23mod config;
24
25use config::{extract_wadm_config, ClientConfig};
26
27wit_bindgen_wrpc::generate!({
28 additional_derives: [
29 serde::Serialize,
30 serde::Deserialize,
31 ],
32 with: {
33 "wasmcloud:wadm/types@0.2.0": wadm_types::wasmcloud::wadm::types,
34 "wasmcloud:wadm/handler@0.2.0": generate,
35 "wasmcloud:wadm/client@0.2.0": generate
36 }
37});
38
39pub async fn run() -> anyhow::Result<()> {
40 WadmProvider::run().await
41}
42
43struct WadmClientBundle {
44 pub client: Client,
45 pub sub_handles: Vec<(String, AbortHandle)>,
46}
47
48impl Drop for WadmClientBundle {
49 fn drop(&mut self) {
50 for (_topic, handle) in &self.sub_handles {
51 handle.abort();
52 }
53 }
54}
55
56#[derive(Clone)]
57pub struct WadmProvider {
58 default_config: ClientConfig,
59 handler_components: Arc<RwLock<HashMap<String, WadmClientBundle>>>,
60 consumer_components: Arc<RwLock<HashMap<String, WadmClientBundle>>>,
61}
62
63impl Default for WadmProvider {
64 fn default() -> Self {
65 WadmProvider {
66 handler_components: Arc::new(RwLock::new(HashMap::new())),
67 consumer_components: Arc::new(RwLock::new(HashMap::new())),
68 default_config: Default::default(),
69 }
70 }
71}
72
73impl WadmProvider {
74 fn name() -> &'static str {
75 "wadm-provider"
76 }
77
78 pub async fn run() -> anyhow::Result<()> {
79 initialize_observability!(
80 WadmProvider::name(),
81 std::env::var_os("PROVIDER_SQLDB_POSTGRES_FLAMEGRAPH_PATH")
82 );
83
84 let host_data = load_host_data().context("failed to load host data")?;
85 let provider = Self::from_host_data(host_data);
86 let shutdown = run_provider(provider.clone(), WadmProvider::name())
87 .await
88 .context("failed to run provider")?;
89 let connection = get_connection();
90 let wrpc = connection
91 .get_wrpc_client(connection.provider_key())
92 .await?;
93 serve_provider_exports(&wrpc, provider, shutdown, serve)
94 .await
95 .context("failed to serve provider exports")
96 }
97
98 pub fn from_host_data(host_data: &HostData) -> WadmProvider {
100 let config = ClientConfig::try_from(host_data.config.clone());
101 if let Ok(config) = config {
102 WadmProvider {
103 default_config: config,
104 ..Default::default()
105 }
106 } else {
107 warn!("Failed to build connection configuration, falling back to default");
108 WadmProvider::default()
109 }
110 }
111
112 async fn connect(
116 &self,
117 cfg: ClientConfig,
118 component_id: &str,
119 make_status_sub: bool,
120 ) -> anyhow::Result<WadmClientBundle> {
121 let ca_path: Option<PathBuf> = cfg.ctl_tls_ca_file.as_ref().map(PathBuf::from);
122
123 let url = format!("{}:{}", cfg.ctl_host, cfg.ctl_port);
124 let client_opts = ClientConnectOptions {
125 url: Some(url),
126 seed: cfg.ctl_seed,
127 jwt: cfg.ctl_jwt,
128 creds_path: cfg.ctl_credsfile.as_ref().map(PathBuf::from),
129 ca_path,
130 };
131
132 let client = Client::new(&cfg.lattice, None, client_opts).await?;
133
134 let mut sub_handles = Vec::new();
135 if make_status_sub {
136 if let Some(app_name) = &cfg.app_name {
137 let handle = self.handle_status(&client, component_id, app_name).await?;
138 sub_handles.push(("wadm.status".into(), handle));
139 } else {
140 bail!("app_name is required for status subscription");
141 }
142 }
143
144 Ok(WadmClientBundle {
145 client,
146 sub_handles,
147 })
148 }
149
150 #[instrument(level = "debug", skip(self, client))]
152 async fn handle_status(
153 &self,
154 client: &Client,
155 component_id: &str,
156 app_name: &str,
157 ) -> anyhow::Result<AbortHandle> {
158 debug!(
159 ?component_id,
160 ?app_name,
161 "spawning listener for component and app"
162 );
163
164 let mut subscriber = client
165 .subscribe_to_status(app_name)
166 .await
167 .map_err(|e| anyhow::anyhow!("Failed to subscribe to status: {}", e))?;
168
169 let component_id = Arc::new(component_id.to_string());
170 let app_name = Arc::new(app_name.to_string());
171
172 let (abort_handle, abort_registration) = AbortHandle::new_pair();
173 tokio::task::spawn(Abortable::new(
174 {
175 let semaphore = Arc::new(Semaphore::new(75));
176 let wrpc = match get_connection().get_wrpc_client(&component_id).await {
177 Ok(wrpc) => Arc::new(wrpc),
178 Err(err) => {
179 error!(?err, "failed to construct wRPC client");
180 return Err(anyhow!("Failed to construct wRPC client: {:?}", err));
181 }
182 };
183 async move {
184 while let Some(msg) = subscriber.next().await {
186 match serde_json::from_slice::<wadm_types::api::Status>(&msg.payload) {
188 Ok(status) => {
189 debug!(?status, ?component_id, "received status");
190
191 let span = tracing::debug_span!("handle_message", ?component_id);
192 let permit = match semaphore.clone().acquire_owned().await {
193 Ok(p) => p,
194 Err(_) => {
195 warn!("Work pool has been closed, exiting queue subscribe");
196 break;
197 }
198 };
199
200 let component_id = Arc::clone(&component_id);
201 let wrpc = Arc::clone(&wrpc);
202 let app_name = Arc::clone(&app_name);
203 tokio::spawn(async move {
204 dispatch_status_update(
205 &wrpc,
206 component_id.as_str(),
207 &app_name,
208 status.into(),
209 permit,
210 )
211 .instrument(span)
212 .await;
213 });
214 }
215 Err(e) => {
216 warn!("Failed to deserialize message: {}", e);
217 }
218 };
219 }
220 }
221 },
222 abort_registration,
223 ));
224
225 Ok(abort_handle)
226 }
227
228 async fn get_client(&self, ctx: Option<Context>) -> anyhow::Result<Client> {
230 if let Some(ref source_id) = ctx
231 .as_ref()
232 .and_then(|Context { component, .. }| component.clone())
233 {
234 let components = self.consumer_components.read().await;
235 let wadm_bundle = match components.get(source_id) {
236 Some(wadm_bundle) => wadm_bundle,
237 None => {
238 error!("component not linked: {source_id}");
239 bail!("component not linked: {source_id}")
240 }
241 };
242 Ok(wadm_bundle.client.clone())
243 } else {
244 error!("no component in request");
245 bail!("no component in request")
246 }
247 }
248}
249
250#[instrument(level = "debug", skip_all, fields(component_id = %component_id, app_name = %app))]
251async fn dispatch_status_update(
252 wrpc: &WrpcClient,
253 component_id: &str,
254 app: &str,
255 status: Status,
256 _permit: OwnedSemaphorePermit,
257) {
258 let update = StatusUpdate {
259 app: app.to_string(),
260 status,
261 };
262 debug!(
263 app = app,
264 component_id = component_id,
265 "sending status to component",
266 );
267
268 let cx: HeaderMap = NatsHeaderInjector::default_with_span().into();
269
270 if let Err(e) = wasmcloud::wadm::handler::handle_status_update(wrpc, Some(cx), &update).await {
271 error!(
272 error = %e,
273 "Unable to send message"
274 );
275 }
276}
277
278impl Provider for WadmProvider {
279 #[instrument(level = "debug", skip_all, fields(source_id))]
280 async fn receive_link_config_as_target(
281 &self,
282 link_config @ LinkConfig { source_id, .. }: LinkConfig<'_>,
283 ) -> anyhow::Result<()> {
284 let config = extract_wadm_config(&link_config, false)
285 .ok_or_else(|| anyhow!("Failed to extract WADM configuration"))?;
286
287 let merged_config = self.default_config.merge(&config);
288
289 let mut update_map = self.consumer_components.write().await;
290 let bundle = self
291 .connect(merged_config, source_id, false)
292 .await
293 .context("Failed to connect to NATS")?;
294
295 update_map.insert(source_id.into(), bundle);
296 Ok(())
297 }
298
299 #[instrument(level = "debug", skip_all, fields(target_id))]
300 async fn receive_link_config_as_source(
301 &self,
302 link_config @ LinkConfig { target_id, .. }: LinkConfig<'_>,
303 ) -> anyhow::Result<()> {
304 let config = extract_wadm_config(&link_config, true)
305 .ok_or_else(|| anyhow!("Failed to extract WADM configuration"))?;
306
307 let merged_config = self.default_config.merge(&config);
308
309 let mut update_map = self.handler_components.write().await;
310 let bundle = self
311 .connect(merged_config, target_id, true)
312 .await
313 .context("Failed to connect to NATS")?;
314
315 update_map.insert(target_id.into(), bundle);
316 Ok(())
317 }
318
319 #[instrument(level = "info", skip_all, fields(target_id = info.get_target_id()))]
320 async fn delete_link_as_source(&self, info: impl LinkDeleteInfo) -> anyhow::Result<()> {
321 let component_id = info.get_target_id();
322 self.handler_components.write().await.remove(component_id);
323 Ok(())
324 }
325
326 #[instrument(level = "info", skip_all, fields(source_id = info.get_source_id()))]
327 async fn delete_link_as_target(&self, info: impl LinkDeleteInfo) -> anyhow::Result<()> {
328 let component_id = info.get_source_id();
329 self.consumer_components.write().await.remove(component_id);
330 Ok(())
331 }
332
333 async fn shutdown(&self) -> anyhow::Result<()> {
335 let mut handlers = self.handler_components.write().await;
336 handlers.clear();
337
338 let mut consumers = self.consumer_components.write().await;
339 consumers.clear();
340
341 Ok(())
344 }
345}
346
347impl exports::wasmcloud::wadm::client::Handler<Option<Context>> for WadmProvider {
348 #[instrument(level = "debug", skip(self, ctx), fields(model_name = %model_name))]
349 async fn deploy_model(
350 &self,
351 ctx: Option<Context>,
352 model_name: String,
353 version: Option<String>,
354 lattice: Option<String>,
355 ) -> anyhow::Result<Result<String, String>> {
356 let client = self.get_client(ctx).await?;
357 match client
358 .deploy_manifest(&model_name, version.as_deref())
359 .await
360 {
361 Ok((name, _version)) => Ok(Ok(name)),
362 Err(err) => {
363 error!("Deployment failed: {err}");
364 Ok(Err(format!("Deployment failed: {err}")))
365 }
366 }
367 }
368
369 #[instrument(level = "debug", skip(self, ctx), fields(model_name = %model_name))]
370 async fn undeploy_model(
371 &self,
372 ctx: Option<Context>,
373 model_name: String,
374 lattice: Option<String>,
375 non_destructive: bool,
376 ) -> anyhow::Result<Result<(), String>> {
377 let client = self.get_client(ctx).await?;
378 match client.undeploy_manifest(&model_name).await {
379 Ok(_) => Ok(Ok(())),
380 Err(err) => {
381 error!("Undeployment failed: {err}");
382 Ok(Err(format!("Undeployment failed: {err}")))
383 }
384 }
385 }
386
387 #[instrument(level = "debug", skip(self, ctx), fields(model = %model))]
388 async fn put_model(
389 &self,
390 ctx: Option<Context>,
391 model: String,
392 lattice: Option<String>,
393 ) -> anyhow::Result<Result<(String, String), String>> {
394 let client = self.get_client(ctx).await?;
395 match client.put_manifest(&model).await {
396 Ok(response) => Ok(Ok(response)),
397 Err(err) => {
398 error!("Failed to store model: {err}");
399 Ok(Err(format!("Failed to store model: {err}")))
400 }
401 }
402 }
403
404 #[instrument(level = "debug", skip(self, ctx), fields(manifest = ?manifest))]
405 async fn put_manifest(
406 &self,
407 ctx: Option<Context>,
408 manifest: OamManifest,
409 lattice: Option<String>,
410 ) -> anyhow::Result<Result<(String, String), String>> {
411 let client = self.get_client(ctx).await?;
412
413 let manifest = wadm_types::Manifest::from(manifest);
414
415 match client.put_manifest(manifest).await {
416 Ok(response) => Ok(Ok(response)),
417 Err(err) => {
418 error!("Failed to store manifest: {err}");
419 Ok(Err(format!("Failed to store manifest: {err}")))
420 }
421 }
422 }
423
424 #[instrument(level = "debug", skip(self, ctx), fields(model_name = %model_name))]
425 async fn get_model_history(
426 &self,
427 ctx: Option<Context>,
428 model_name: String,
429 lattice: Option<String>,
430 ) -> anyhow::Result<Result<Vec<VersionInfo>, String>> {
431 let client = self.get_client(ctx).await?;
432 match client.list_versions(&model_name).await {
433 Ok(history) => {
434 let converted_history: Vec<_> =
435 history.into_iter().map(|item| item.into()).collect();
436 Ok(Ok(converted_history))
437 }
438 Err(err) => {
439 error!("Failed to retrieve model history: {err}");
440 Ok(Err(format!("Failed to retrieve model history: {err}")))
441 }
442 }
443 }
444
445 #[instrument(level = "debug", skip(self, ctx), fields(model_name = %model_name))]
446 async fn get_model_status(
447 &self,
448 ctx: Option<Context>,
449 model_name: String,
450 lattice: Option<String>,
451 ) -> anyhow::Result<Result<Status, String>> {
452 let client = self.get_client(ctx).await?;
453 match client.get_manifest_status(&model_name).await {
454 Ok(status) => Ok(Ok(status.into())),
455 Err(err) => {
456 error!("Failed to retrieve model status: {err}");
457 Ok(Err(format!("Failed to retrieve model status: {err}")))
458 }
459 }
460 }
461
462 #[instrument(level = "debug", skip(self, ctx), fields(model_name = %model_name))]
463 async fn get_model_details(
464 &self,
465 ctx: Option<Context>,
466 model_name: String,
467 version: Option<String>,
468 lattice: Option<String>,
469 ) -> anyhow::Result<Result<OamManifest, String>> {
470 let client = self.get_client(ctx).await?;
471 match client.get_manifest(&model_name, version.as_deref()).await {
472 Ok(details) => Ok(Ok(details.into())),
473 Err(err) => {
474 error!("Failed to retrieve model details: {err}");
475 Ok(Err(format!("Failed to retrieve model details: {err}")))
476 }
477 }
478 }
479
480 #[instrument(level = "debug", skip(self, ctx), fields(model_name = %model_name))]
481 async fn delete_model_version(
482 &self,
483 ctx: Option<Context>,
484 model_name: String,
485 version: Option<String>,
486 lattice: Option<String>,
487 ) -> anyhow::Result<Result<bool, String>> {
488 let client = self.get_client(ctx).await?;
489 match client
490 .delete_manifest(&model_name, version.as_deref())
491 .await
492 {
493 Ok(response) => Ok(Ok(response)),
494 Err(err) => {
495 error!("Failed to delete model version: {err}");
496 Ok(Err(format!("Failed to delete model version: {err}")))
497 }
498 }
499 }
500
501 #[instrument(level = "debug", skip(self, ctx))]
502 async fn get_models(
503 &self,
504 ctx: Option<Context>,
505 lattice: Option<String>,
506 ) -> anyhow::Result<Result<Vec<ModelSummary>, String>> {
507 let client = self.get_client(ctx).await?;
508 match client.list_manifests().await {
509 Ok(models) => Ok(Ok(models.into_iter().map(|model| model.into()).collect())),
510 Err(err) => {
511 error!("Failed to retrieve models: {err}");
512 Ok(Err(format!("Failed to retrieve models: {err}")))
513 }
514 }
515 }
516}