wasmcloud_provider_wadm/
lib.rs

1use std::collections::HashMap;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use anyhow::{anyhow, bail, Context as _};
6use async_nats::HeaderMap;
7use futures::stream::{AbortHandle, Abortable};
8use futures::StreamExt;
9use opentelemetry_nats::NatsHeaderInjector;
10use tokio::sync::{OwnedSemaphorePermit, RwLock, Semaphore};
11use tracing::{debug, error, instrument, warn};
12use tracing_futures::Instrument as _;
13use wadm_client::{Client, ClientConnectOptions};
14use wadm_types::wasmcloud::wadm::handler::StatusUpdate;
15use wasmcloud_provider_sdk::{
16    core::HostData, get_connection, load_host_data, provider::WrpcClient, run_provider, Context,
17    LinkConfig, Provider,
18};
19use wasmcloud_provider_sdk::{initialize_observability, serve_provider_exports, LinkDeleteInfo};
20
21use crate::exports::wasmcloud::wadm::client::{ModelSummary, OamManifest, Status, VersionInfo};
22
23mod config;
24
25use config::{extract_wadm_config, ClientConfig};
26
27wit_bindgen_wrpc::generate!({
28    additional_derives: [
29        serde::Serialize,
30        serde::Deserialize,
31    ],
32    with: {
33        "wasmcloud:wadm/types@0.2.0": wadm_types::wasmcloud::wadm::types,
34        "wasmcloud:wadm/handler@0.2.0": generate,
35        "wasmcloud:wadm/client@0.2.0": generate
36    }
37});
38
39pub async fn run() -> anyhow::Result<()> {
40    WadmProvider::run().await
41}
42
43struct WadmClientBundle {
44    pub client: Client,
45    pub sub_handles: Vec<(String, AbortHandle)>,
46}
47
48impl Drop for WadmClientBundle {
49    fn drop(&mut self) {
50        for (_topic, handle) in &self.sub_handles {
51            handle.abort();
52        }
53    }
54}
55
56#[derive(Clone)]
57pub struct WadmProvider {
58    default_config: ClientConfig,
59    handler_components: Arc<RwLock<HashMap<String, WadmClientBundle>>>,
60    consumer_components: Arc<RwLock<HashMap<String, WadmClientBundle>>>,
61}
62
63impl Default for WadmProvider {
64    fn default() -> Self {
65        WadmProvider {
66            handler_components: Arc::new(RwLock::new(HashMap::new())),
67            consumer_components: Arc::new(RwLock::new(HashMap::new())),
68            default_config: Default::default(),
69        }
70    }
71}
72
73impl WadmProvider {
74    fn name() -> &'static str {
75        "wadm-provider"
76    }
77
78    pub async fn run() -> anyhow::Result<()> {
79        initialize_observability!(
80            WadmProvider::name(),
81            std::env::var_os("PROVIDER_SQLDB_POSTGRES_FLAMEGRAPH_PATH")
82        );
83
84        let host_data = load_host_data().context("failed to load host data")?;
85        let provider = Self::from_host_data(host_data);
86        let shutdown = run_provider(provider.clone(), WadmProvider::name())
87            .await
88            .context("failed to run provider")?;
89        let connection = get_connection();
90        let wrpc = connection
91            .get_wrpc_client(connection.provider_key())
92            .await?;
93        serve_provider_exports(&wrpc, provider, shutdown, serve)
94            .await
95            .context("failed to serve provider exports")
96    }
97
98    /// Build a [`WadmProvider`] from [`HostData`]
99    pub fn from_host_data(host_data: &HostData) -> WadmProvider {
100        let config = ClientConfig::try_from(host_data.config.clone());
101        if let Ok(config) = config {
102            WadmProvider {
103                default_config: config,
104                ..Default::default()
105            }
106        } else {
107            warn!("Failed to build connection configuration, falling back to default");
108            WadmProvider::default()
109        }
110    }
111
112    /// Attempt to connect to nats url and create a wadm client
113    /// If 'make_status_sub' is true, the client will subscribe to
114    /// wadm status updates for this component
115    async fn connect(
116        &self,
117        cfg: ClientConfig,
118        component_id: &str,
119        make_status_sub: bool,
120    ) -> anyhow::Result<WadmClientBundle> {
121        let ca_path: Option<PathBuf> = cfg.ctl_tls_ca_file.as_ref().map(PathBuf::from);
122
123        let url = format!("{}:{}", cfg.ctl_host, cfg.ctl_port);
124        let client_opts = ClientConnectOptions {
125            url: Some(url),
126            seed: cfg.ctl_seed,
127            jwt: cfg.ctl_jwt,
128            creds_path: cfg.ctl_credsfile.as_ref().map(PathBuf::from),
129            ca_path,
130        };
131
132        let client = Client::new(&cfg.lattice, None, client_opts).await?;
133
134        let mut sub_handles = Vec::new();
135        if make_status_sub {
136            if let Some(app_name) = &cfg.app_name {
137                let handle = self.handle_status(&client, component_id, app_name).await?;
138                sub_handles.push(("wadm.status".into(), handle));
139            } else {
140                bail!("app_name is required for status subscription");
141            }
142        }
143
144        Ok(WadmClientBundle {
145            client,
146            sub_handles,
147        })
148    }
149
150    /// Add a subscription to status events
151    #[instrument(level = "debug", skip(self, client))]
152    async fn handle_status(
153        &self,
154        client: &Client,
155        component_id: &str,
156        app_name: &str,
157    ) -> anyhow::Result<AbortHandle> {
158        debug!(
159            ?component_id,
160            ?app_name,
161            "spawning listener for component and app"
162        );
163
164        let mut subscriber = client
165            .subscribe_to_status(app_name)
166            .await
167            .map_err(|e| anyhow::anyhow!("Failed to subscribe to status: {}", e))?;
168
169        let component_id = Arc::new(component_id.to_string());
170        let app_name = Arc::new(app_name.to_string());
171
172        let (abort_handle, abort_registration) = AbortHandle::new_pair();
173        tokio::task::spawn(Abortable::new(
174            {
175                let semaphore = Arc::new(Semaphore::new(75));
176                let wrpc = match get_connection().get_wrpc_client(&component_id).await {
177                    Ok(wrpc) => Arc::new(wrpc),
178                    Err(err) => {
179                        error!(?err, "failed to construct wRPC client");
180                        return Err(anyhow!("Failed to construct wRPC client: {:?}", err));
181                    }
182                };
183                async move {
184                    // Listen for NATS message(s)
185                    while let Some(msg) = subscriber.next().await {
186                        // Parse the message into a StatusResponse
187                        match serde_json::from_slice::<wadm_types::api::Status>(&msg.payload) {
188                            Ok(status) => {
189                                debug!(?status, ?component_id, "received status");
190
191                                let span = tracing::debug_span!("handle_message", ?component_id);
192                                let permit = match semaphore.clone().acquire_owned().await {
193                                    Ok(p) => p,
194                                    Err(_) => {
195                                        warn!("Work pool has been closed, exiting queue subscribe");
196                                        break;
197                                    }
198                                };
199
200                                let component_id = Arc::clone(&component_id);
201                                let wrpc = Arc::clone(&wrpc);
202                                let app_name = Arc::clone(&app_name);
203                                tokio::spawn(async move {
204                                    dispatch_status_update(
205                                        &wrpc,
206                                        component_id.as_str(),
207                                        &app_name,
208                                        status.into(),
209                                        permit,
210                                    )
211                                    .instrument(span)
212                                    .await;
213                                });
214                            }
215                            Err(e) => {
216                                warn!("Failed to deserialize message: {}", e);
217                            }
218                        };
219                    }
220                }
221            },
222            abort_registration,
223        ));
224
225        Ok(abort_handle)
226    }
227
228    /// Helper function to get the NATS client from the context
229    async fn get_client(&self, ctx: Option<Context>) -> anyhow::Result<Client> {
230        if let Some(ref source_id) = ctx
231            .as_ref()
232            .and_then(|Context { component, .. }| component.clone())
233        {
234            let components = self.consumer_components.read().await;
235            let wadm_bundle = match components.get(source_id) {
236                Some(wadm_bundle) => wadm_bundle,
237                None => {
238                    error!("component not linked: {source_id}");
239                    bail!("component not linked: {source_id}")
240                }
241            };
242            Ok(wadm_bundle.client.clone())
243        } else {
244            error!("no component in request");
245            bail!("no component in request")
246        }
247    }
248}
249
250#[instrument(level = "debug", skip_all, fields(component_id = %component_id, app_name = %app))]
251async fn dispatch_status_update(
252    wrpc: &WrpcClient,
253    component_id: &str,
254    app: &str,
255    status: Status,
256    _permit: OwnedSemaphorePermit,
257) {
258    let update = StatusUpdate {
259        app: app.to_string(),
260        status,
261    };
262    debug!(
263        app = app,
264        component_id = component_id,
265        "sending status to component",
266    );
267
268    let cx: HeaderMap = NatsHeaderInjector::default_with_span().into();
269
270    if let Err(e) = wasmcloud::wadm::handler::handle_status_update(wrpc, Some(cx), &update).await {
271        error!(
272            error = %e,
273            "Unable to send message"
274        );
275    }
276}
277
278impl Provider for WadmProvider {
279    #[instrument(level = "debug", skip_all, fields(source_id))]
280    async fn receive_link_config_as_target(
281        &self,
282        link_config @ LinkConfig { source_id, .. }: LinkConfig<'_>,
283    ) -> anyhow::Result<()> {
284        let config = extract_wadm_config(&link_config, false)
285            .ok_or_else(|| anyhow!("Failed to extract WADM configuration"))?;
286
287        let merged_config = self.default_config.merge(&config);
288
289        let mut update_map = self.consumer_components.write().await;
290        let bundle = self
291            .connect(merged_config, source_id, false)
292            .await
293            .context("Failed to connect to NATS")?;
294
295        update_map.insert(source_id.into(), bundle);
296        Ok(())
297    }
298
299    #[instrument(level = "debug", skip_all, fields(target_id))]
300    async fn receive_link_config_as_source(
301        &self,
302        link_config @ LinkConfig { target_id, .. }: LinkConfig<'_>,
303    ) -> anyhow::Result<()> {
304        let config = extract_wadm_config(&link_config, true)
305            .ok_or_else(|| anyhow!("Failed to extract WADM configuration"))?;
306
307        let merged_config = self.default_config.merge(&config);
308
309        let mut update_map = self.handler_components.write().await;
310        let bundle = self
311            .connect(merged_config, target_id, true)
312            .await
313            .context("Failed to connect to NATS")?;
314
315        update_map.insert(target_id.into(), bundle);
316        Ok(())
317    }
318
319    #[instrument(level = "info", skip_all, fields(target_id = info.get_target_id()))]
320    async fn delete_link_as_source(&self, info: impl LinkDeleteInfo) -> anyhow::Result<()> {
321        let component_id = info.get_target_id();
322        self.handler_components.write().await.remove(component_id);
323        Ok(())
324    }
325
326    #[instrument(level = "info", skip_all, fields(source_id = info.get_source_id()))]
327    async fn delete_link_as_target(&self, info: impl LinkDeleteInfo) -> anyhow::Result<()> {
328        let component_id = info.get_source_id();
329        self.consumer_components.write().await.remove(component_id);
330        Ok(())
331    }
332
333    /// Handle shutdown request by closing all connections
334    async fn shutdown(&self) -> anyhow::Result<()> {
335        let mut handlers = self.handler_components.write().await;
336        handlers.clear();
337
338        let mut consumers = self.consumer_components.write().await;
339        consumers.clear();
340
341        // dropping all connections should send unsubscribes and close the connections, so no need
342        // to handle that here
343        Ok(())
344    }
345}
346
347impl exports::wasmcloud::wadm::client::Handler<Option<Context>> for WadmProvider {
348    #[instrument(level = "debug", skip(self, ctx), fields(model_name = %model_name))]
349    async fn deploy_model(
350        &self,
351        ctx: Option<Context>,
352        model_name: String,
353        version: Option<String>,
354        lattice: Option<String>,
355    ) -> anyhow::Result<Result<String, String>> {
356        let client = self.get_client(ctx).await?;
357        match client
358            .deploy_manifest(&model_name, version.as_deref())
359            .await
360        {
361            Ok((name, _version)) => Ok(Ok(name)),
362            Err(err) => {
363                error!("Deployment failed: {err}");
364                Ok(Err(format!("Deployment failed: {err}")))
365            }
366        }
367    }
368
369    #[instrument(level = "debug", skip(self, ctx), fields(model_name = %model_name))]
370    async fn undeploy_model(
371        &self,
372        ctx: Option<Context>,
373        model_name: String,
374        lattice: Option<String>,
375        non_destructive: bool,
376    ) -> anyhow::Result<Result<(), String>> {
377        let client = self.get_client(ctx).await?;
378        match client.undeploy_manifest(&model_name).await {
379            Ok(_) => Ok(Ok(())),
380            Err(err) => {
381                error!("Undeployment failed: {err}");
382                Ok(Err(format!("Undeployment failed: {err}")))
383            }
384        }
385    }
386
387    #[instrument(level = "debug", skip(self, ctx), fields(model = %model))]
388    async fn put_model(
389        &self,
390        ctx: Option<Context>,
391        model: String,
392        lattice: Option<String>,
393    ) -> anyhow::Result<Result<(String, String), String>> {
394        let client = self.get_client(ctx).await?;
395        match client.put_manifest(&model).await {
396            Ok(response) => Ok(Ok(response)),
397            Err(err) => {
398                error!("Failed to store model: {err}");
399                Ok(Err(format!("Failed to store model: {err}")))
400            }
401        }
402    }
403
404    #[instrument(level = "debug", skip(self, ctx), fields(manifest = ?manifest))]
405    async fn put_manifest(
406        &self,
407        ctx: Option<Context>,
408        manifest: OamManifest,
409        lattice: Option<String>,
410    ) -> anyhow::Result<Result<(String, String), String>> {
411        let client = self.get_client(ctx).await?;
412
413        let manifest = wadm_types::Manifest::from(manifest);
414
415        match client.put_manifest(manifest).await {
416            Ok(response) => Ok(Ok(response)),
417            Err(err) => {
418                error!("Failed to store manifest: {err}");
419                Ok(Err(format!("Failed to store manifest: {err}")))
420            }
421        }
422    }
423
424    #[instrument(level = "debug", skip(self, ctx), fields(model_name = %model_name))]
425    async fn get_model_history(
426        &self,
427        ctx: Option<Context>,
428        model_name: String,
429        lattice: Option<String>,
430    ) -> anyhow::Result<Result<Vec<VersionInfo>, String>> {
431        let client = self.get_client(ctx).await?;
432        match client.list_versions(&model_name).await {
433            Ok(history) => {
434                let converted_history: Vec<_> =
435                    history.into_iter().map(|item| item.into()).collect();
436                Ok(Ok(converted_history))
437            }
438            Err(err) => {
439                error!("Failed to retrieve model history: {err}");
440                Ok(Err(format!("Failed to retrieve model history: {err}")))
441            }
442        }
443    }
444
445    #[instrument(level = "debug", skip(self, ctx), fields(model_name = %model_name))]
446    async fn get_model_status(
447        &self,
448        ctx: Option<Context>,
449        model_name: String,
450        lattice: Option<String>,
451    ) -> anyhow::Result<Result<Status, String>> {
452        let client = self.get_client(ctx).await?;
453        match client.get_manifest_status(&model_name).await {
454            Ok(status) => Ok(Ok(status.into())),
455            Err(err) => {
456                error!("Failed to retrieve model status: {err}");
457                Ok(Err(format!("Failed to retrieve model status: {err}")))
458            }
459        }
460    }
461
462    #[instrument(level = "debug", skip(self, ctx), fields(model_name = %model_name))]
463    async fn get_model_details(
464        &self,
465        ctx: Option<Context>,
466        model_name: String,
467        version: Option<String>,
468        lattice: Option<String>,
469    ) -> anyhow::Result<Result<OamManifest, String>> {
470        let client = self.get_client(ctx).await?;
471        match client.get_manifest(&model_name, version.as_deref()).await {
472            Ok(details) => Ok(Ok(details.into())),
473            Err(err) => {
474                error!("Failed to retrieve model details: {err}");
475                Ok(Err(format!("Failed to retrieve model details: {err}")))
476            }
477        }
478    }
479
480    #[instrument(level = "debug", skip(self, ctx), fields(model_name = %model_name))]
481    async fn delete_model_version(
482        &self,
483        ctx: Option<Context>,
484        model_name: String,
485        version: Option<String>,
486        lattice: Option<String>,
487    ) -> anyhow::Result<Result<bool, String>> {
488        let client = self.get_client(ctx).await?;
489        match client
490            .delete_manifest(&model_name, version.as_deref())
491            .await
492        {
493            Ok(response) => Ok(Ok(response)),
494            Err(err) => {
495                error!("Failed to delete model version: {err}");
496                Ok(Err(format!("Failed to delete model version: {err}")))
497            }
498        }
499    }
500
501    #[instrument(level = "debug", skip(self, ctx))]
502    async fn get_models(
503        &self,
504        ctx: Option<Context>,
505        lattice: Option<String>,
506    ) -> anyhow::Result<Result<Vec<ModelSummary>, String>> {
507        let client = self.get_client(ctx).await?;
508        match client.list_manifests().await {
509            Ok(models) => Ok(Ok(models.into_iter().map(|model| model.into()).collect())),
510            Err(err) => {
511                error!("Failed to retrieve models: {err}");
512                Ok(Err(format!("Failed to retrieve models: {err}")))
513            }
514        }
515    }
516}