wasmcloud_provider_sdk/
provider.rs

1use core::fmt;
2use core::fmt::Formatter;
3use core::future::Future;
4
5use core::pin::{pin, Pin};
6use core::time::Duration;
7use std::collections::HashMap;
8use std::io::BufRead;
9use std::sync::Arc;
10
11use anyhow::{bail, Context as _, Result};
12use async_nats::subject::ToSubject as _;
13use async_nats::HeaderMap;
14use base64::Engine;
15use bytes::Bytes;
16use futures::{stream, Stream, StreamExt as _, TryStreamExt as _};
17use nkeys::XKey;
18use once_cell::sync::OnceCell;
19use serde::{Deserialize, Serialize};
20use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
21use tokio::task::{spawn_blocking, JoinSet};
22use tokio::{select, spawn, try_join};
23use tracing::{debug, error, info, instrument, trace, warn, Instrument as _};
24use wasmcloud_core::nats::convert_header_map_to_hashmap;
25use wasmcloud_core::rpc::{health_subject, link_del_subject, link_put_subject, shutdown_subject};
26use wasmcloud_core::secrets::SecretValue;
27use wasmcloud_core::{
28    provider_config_update_subject, HealthCheckRequest, HealthCheckResponse, HostData,
29    InterfaceLinkDefinition, LatticeTarget,
30};
31
32#[cfg(feature = "otel")]
33use wasmcloud_core::TraceContext;
34#[cfg(feature = "otel")]
35use wasmcloud_tracing::context::attach_span_context;
36use wrpc_transport::InvokeExt as _;
37
38use crate::error::{ProviderInitError, ProviderInitResult};
39use crate::{with_connection_event_logging, Context, LinkConfig, Provider, DEFAULT_NATS_ADDR};
40
41/// Name of the header that should be passed for invocations that identifies the source
42const WRPC_SOURCE_ID_HEADER_NAME: &str = "source-id";
43
44static HOST_DATA: OnceCell<HostData> = OnceCell::new();
45static CONNECTION: OnceCell<ProviderConnection> = OnceCell::new();
46
47/// Retrieves the currently configured connection to the lattice. DO NOT call this method until
48/// after the provider is running (meaning [`run_provider`] has been called)
49/// or this method will panic. Only in extremely rare cases should this be called manually and it
50/// will only be used by generated code
51// NOTE(thomastaylor312): This isn't the most elegant solution, but providers that need to send
52// messages to the lattice rather than just responding need to get the same connection used when the
53// provider was started, which means a global static
54pub fn get_connection() -> &'static ProviderConnection {
55    CONNECTION
56        .get()
57        .expect("Provider connection not initialized")
58}
59
60/// Loads configuration data sent from the host over stdin. The returned host data contains all the
61/// configuration information needed to connect to the lattice and any additional configuration
62/// provided to this provider (like `config_json`).
63///
64/// NOTE: this function will read the data from stdin exactly once. If this function is called more
65/// than once, it will return a copy of the original data fetched
66pub fn load_host_data() -> ProviderInitResult<&'static HostData> {
67    HOST_DATA.get_or_try_init(_load_host_data)
68}
69
70/// Initializes the host data with the provided data. This is useful for testing or if the host data
71/// is not being provided over stdin.
72///
73/// If the host data has already been initialized, this function will return the existing host data.
74pub fn initialize_host_data(host_data: HostData) -> ProviderInitResult<&'static HostData> {
75    HOST_DATA.get_or_try_init(|| Ok(host_data))
76}
77
78// Internal function for populating the host data
79fn _load_host_data() -> ProviderInitResult<HostData> {
80    let mut buffer = String::new();
81    let stdin = std::io::stdin();
82    {
83        let mut handle = stdin.lock();
84        handle.read_line(&mut buffer).map_err(|e| {
85            ProviderInitError::Initialization(format!(
86                "failed to read host data configuration from stdin: {e}"
87            ))
88        })?;
89    }
90    // remove spaces, tabs, and newlines before and after base64-encoded data
91    let buffer = buffer.trim();
92    if buffer.is_empty() {
93        return Err(ProviderInitError::Initialization(
94            "stdin is empty - expecting host data configuration".to_string(),
95        ));
96    }
97    let bytes = base64::engine::general_purpose::STANDARD
98        .decode(buffer.as_bytes())
99        .map_err(|e| {
100            ProviderInitError::Initialization(format!(
101            "host data configuration passed through stdin has invalid encoding (expected base64): \
102             {e}"
103        ))
104        })?;
105    let host_data: HostData = serde_json::from_slice(&bytes).map_err(|e| {
106        ProviderInitError::Initialization(format!(
107            "parsing host data: {}:\n{}",
108            e,
109            String::from_utf8_lossy(&bytes)
110        ))
111    })?;
112    Ok(host_data)
113}
114
115pub type QuitSignal = broadcast::Receiver<()>;
116
117#[derive(Debug, Clone, Serialize, Deserialize, Default)]
118struct ShutdownMessage {
119    /// The ID of the host that sent the message
120    pub host_id: String,
121}
122
123#[doc(hidden)]
124/// Process subscription, until closed or exhausted, or value is received on the channel.
125/// `sub` is a mutable Subscriber (regular or queue subscription)
126/// `channel` may be either tokio mpsc::Receiver or broadcast::Receiver, and is considered signaled
127/// when a value is sent or the channel is closed.
128/// `msg` is the variable name to be used in the handler
129/// `on_item` is an async handler
130macro_rules! process_until_quit {
131    ($sub:ident, $channel:ident, $msg:ident, $on_item:tt) => {
132        spawn(async move {
133            loop {
134                select! {
135                    _ = $channel.recv() => {
136                        let _ = $sub.unsubscribe().await;
137                        break;
138                    },
139                    __msg = $sub.next() => {
140                        match __msg {
141                            None => break,
142                            Some($msg) => $on_item
143                        }
144                    }
145                }
146            }
147        })
148    };
149}
150
151async fn subscribe_health(
152    nats: Arc<async_nats::Client>,
153    mut quit: broadcast::Receiver<()>,
154    lattice: &str,
155    provider_key: &str,
156) -> ProviderInitResult<mpsc::Receiver<(HealthCheckRequest, oneshot::Sender<HealthCheckResponse>)>>
157{
158    let mut sub = nats
159        .subscribe(health_subject(lattice, provider_key))
160        .await?;
161    let (health_tx, health_rx) = mpsc::channel(1);
162    spawn({
163        let nats = Arc::clone(&nats);
164        async move {
165            process_until_quit!(sub, quit, msg, {
166                let (tx, rx) = oneshot::channel();
167                if let Err(err) = health_tx.send((HealthCheckRequest {}, tx)).await {
168                    error!(%err, "failed to send health check request");
169                    continue;
170                }
171                match rx.await.as_ref().map(serde_json::to_vec) {
172                    Err(err) => {
173                        error!(%err, "failed to receive health check response");
174                    }
175                    Ok(Ok(t)) => {
176                        if let Some(reply_to) = msg.reply {
177                            if let Err(err) = nats.publish(reply_to, t.into()).await {
178                                error!(%err, "failed sending health check response");
179                            }
180                        }
181                    }
182                    Ok(Err(err)) => {
183                        // extremely unlikely that InvocationResponse would fail to serialize
184                        error!(%err, "failed serializing HealthCheckResponse");
185                    }
186                }
187            });
188        }
189        .instrument(tracing::debug_span!("subscribe_health"))
190    });
191    Ok(health_rx)
192}
193
194async fn subscribe_shutdown(
195    nats: Arc<async_nats::Client>,
196    quit: broadcast::Sender<()>,
197    lattice: &str,
198    provider_key: &str,
199    host_id: impl Into<Arc<str>>,
200) -> ProviderInitResult<mpsc::Receiver<oneshot::Sender<()>>> {
201    let mut sub = nats
202        .subscribe(shutdown_subject(lattice, provider_key, "default"))
203        .await?;
204    let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
205    let host_id = host_id.into();
206    spawn({
207        async move {
208            loop {
209                let msg = sub.next().await;
210                // Check if we really need to shut down
211                if let Some(async_nats::Message {
212                    reply: Some(reply_to),
213                    payload,
214                    ..
215                }) = msg
216                {
217                    let ShutdownMessage {
218                        host_id: ref req_host_id,
219                    } = serde_json::from_slice(&payload).unwrap_or_default();
220                    if req_host_id == host_id.as_ref() {
221                        info!("Received termination signal and stopping");
222                        // Tell provider to shutdown - before we shut down nats subscriptions,
223                        // in case it needs to do any message passing during shutdown
224                        let (tx, rx) = oneshot::channel();
225                        match shutdown_tx.send(tx).await {
226                            Ok(()) => {
227                                if let Err(err) = rx.await {
228                                    error!(%err, "failed to await shutdown");
229                                }
230                            }
231                            Err(err) => error!(%err, "failed to send shutdown"),
232                        }
233                        if let Err(err) = nats.publish(reply_to, "shutting down".into()).await {
234                            warn!(%err, "failed to send shutdown ack");
235                        }
236                        // unsubscribe from shutdown topic
237                        if let Err(err) = sub.unsubscribe().await {
238                            warn!(%err, "failed to unsubscribe from shutdown topic");
239                        }
240                        // send shutdown signal to all listeners: quit all subscribers and signal main thread to quit
241                        if let Err(err) = quit.send(()) {
242                            error!(%err, "Problem shutting down:  failure to send signal");
243                        }
244                        break;
245                    }
246                    trace!("Ignoring termination signal (request targeted for different host)");
247                }
248            }
249        }
250        .instrument(tracing::debug_span!("shutdown_subscriber"))
251    });
252    Ok(shutdown_rx)
253}
254
255async fn subscribe_link_put(
256    nats: Arc<async_nats::Client>,
257    mut quit: broadcast::Receiver<()>,
258    lattice: &str,
259    provider_xkey: &str,
260) -> ProviderInitResult<mpsc::Receiver<(InterfaceLinkDefinition, oneshot::Sender<()>)>> {
261    let (link_put_tx, link_put_rx) = mpsc::channel(1);
262    let mut sub = nats
263        .subscribe(link_put_subject(lattice, provider_xkey))
264        .await?;
265    spawn(async move {
266        process_until_quit!(sub, quit, msg, {
267            match serde_json::from_slice::<InterfaceLinkDefinition>(&msg.payload) {
268                Ok(ld) => {
269                    let span = tracing::Span::current();
270                    span.record("source_id", tracing::field::display(&ld.source_id));
271                    span.record("target", tracing::field::display(&ld.target));
272                    span.record("wit_namespace", tracing::field::display(&ld.wit_namespace));
273                    span.record("wit_package", tracing::field::display(&ld.wit_package));
274                    span.record(
275                        "wit_interfaces",
276                        tracing::field::display(&ld.interfaces.join(",")),
277                    );
278                    span.record("link_name", tracing::field::display(&ld.name));
279                    let (tx, rx) = oneshot::channel();
280                    if let Err(err) = link_put_tx.send((ld, tx)).await {
281                        error!(%err, "failed to send link put request");
282                        continue;
283                    }
284                    if let Err(err) = rx.await {
285                        error!(%err, "failed to await link_put");
286                    }
287                }
288                Err(err) => {
289                    error!(%err, "received invalid link def data on message");
290                }
291            }
292        });
293    });
294    Ok(link_put_rx)
295}
296
297async fn subscribe_link_del(
298    nats: Arc<async_nats::Client>,
299    mut quit: broadcast::Receiver<()>,
300    lattice: &str,
301    provider_key: &str,
302) -> ProviderInitResult<mpsc::Receiver<(InterfaceLinkDefinition, oneshot::Sender<()>)>> {
303    let subject = link_del_subject(lattice, provider_key).to_subject();
304    debug!(%subject, "subscribing for link del");
305    let mut sub = nats.subscribe(subject.clone()).await?;
306    let (link_del_tx, link_del_rx) = mpsc::channel(1);
307    let span = tracing::trace_span!("subscribe_link_del", %subject);
308    spawn(
309        async move {
310            process_until_quit!(sub, quit, msg, {
311                if let Ok(ld) = serde_json::from_slice::<InterfaceLinkDefinition>(&msg.payload) {
312                    let (tx, rx) = oneshot::channel();
313                    if let Err(err) = link_del_tx.send((ld, tx)).await {
314                        error!(%err, "failed to send link del request");
315                        continue;
316                    }
317                    if let Err(err) = rx.await {
318                        error!(%err, "failed to await link_del");
319                    }
320                } else {
321                    error!("received invalid link on link_del");
322                }
323            });
324        }
325        .instrument(span),
326    );
327    Ok(link_del_rx)
328}
329
330/// Subscribe to configuration updates that are passed by the host.
331///
332/// We expect the hosts to send configuration updates messages over NATS,
333/// with information on whether the configuration applies to a specific link,
334/// and the contents of the new/updated configuration.
335async fn subscribe_config_update(
336    nats: Arc<async_nats::Client>,
337    mut quit: broadcast::Receiver<()>,
338    lattice: &str,
339    provider_key: &str,
340) -> ProviderInitResult<mpsc::Receiver<(HashMap<String, String>, oneshot::Sender<()>)>> {
341    let (config_update_tx, config_update_rx) = mpsc::channel(1);
342    let mut sub = nats
343        .subscribe(provider_config_update_subject(lattice, provider_key).to_subject())
344        .await?;
345    spawn({
346        async move {
347            process_until_quit!(sub, quit, msg, {
348                match serde_json::from_slice::<HashMap<String, String>>(&msg.payload) {
349                    Ok(update) => {
350                        let (tx, rx) = oneshot::channel();
351                        // Perform the config update on the host
352                        if let Err(err) = config_update_tx.send((update, tx)).await {
353                            error!(%err, "failed to send config update");
354                            continue;
355                        }
356                        // Wait for the response from the rx to perform it
357                        if let Err(err) = rx.await.as_ref() {
358                            error!(%err, "failed to receive config update response");
359                        }
360                    }
361                    Err(err) => {
362                        error!(%err, "received invalid config update data on message");
363                    }
364                }
365            });
366        }
367        .instrument(tracing::debug_span!("subscribe_config_update"))
368    });
369
370    Ok(config_update_rx)
371}
372
373pub struct ProviderCommandReceivers {
374    health: mpsc::Receiver<(HealthCheckRequest, oneshot::Sender<HealthCheckResponse>)>,
375    shutdown: mpsc::Receiver<oneshot::Sender<()>>,
376    link_put: mpsc::Receiver<(InterfaceLinkDefinition, oneshot::Sender<()>)>,
377    link_del: mpsc::Receiver<(InterfaceLinkDefinition, oneshot::Sender<()>)>,
378    config_update: mpsc::Receiver<(HashMap<String, String>, oneshot::Sender<()>)>,
379}
380
381impl ProviderCommandReceivers {
382    pub async fn new(
383        nats: Arc<async_nats::Client>,
384        quit_tx: &broadcast::Sender<()>,
385        lattice: &str,
386        provider_key: &str,
387        provider_link_put_id: &str,
388        host_id: &str,
389    ) -> ProviderInitResult<Self> {
390        let (health, shutdown, link_put, link_del, config_update) = try_join!(
391            subscribe_health(
392                Arc::clone(&nats),
393                quit_tx.subscribe(),
394                lattice,
395                provider_key
396            ),
397            subscribe_shutdown(
398                Arc::clone(&nats),
399                quit_tx.clone(),
400                lattice,
401                provider_key,
402                host_id
403            ),
404            subscribe_link_put(
405                Arc::clone(&nats),
406                quit_tx.subscribe(),
407                lattice,
408                provider_link_put_id
409            ),
410            subscribe_link_del(
411                Arc::clone(&nats),
412                quit_tx.subscribe(),
413                lattice,
414                provider_key
415            ),
416            subscribe_config_update(
417                Arc::clone(&nats),
418                quit_tx.subscribe(),
419                lattice,
420                provider_key
421            ),
422        )?;
423        Ok(Self {
424            health,
425            shutdown,
426            link_put,
427            link_del,
428            config_update,
429        })
430    }
431}
432
433/// State of provider initialization
434pub(crate) struct ProviderInitState {
435    pub nats: Arc<async_nats::Client>,
436    pub quit_rx: broadcast::Receiver<()>,
437    pub quit_tx: broadcast::Sender<()>,
438    pub host_id: String,
439    pub lattice_rpc_prefix: String,
440    pub provider_key: String,
441    pub link_definitions: Vec<InterfaceLinkDefinition>,
442    pub commands: ProviderCommandReceivers,
443    pub config: HashMap<String, String>,
444    pub secrets: HashMap<String, SecretValue>,
445    /// The public key xkey of the host, used for decrypting secrets
446    /// Do not attempt to access the [`XKey::seed()`] of this XKey, it will always error.
447    host_public_xkey: XKey,
448    provider_private_xkey: XKey,
449}
450
451#[instrument]
452async fn init_provider(name: &str) -> ProviderInitResult<ProviderInitState> {
453    let HostData {
454        host_id,
455        lattice_rpc_prefix,
456        lattice_rpc_user_jwt,
457        lattice_rpc_user_seed,
458        lattice_rpc_url,
459        provider_key,
460        env_values: _,
461        cluster_issuers: _,
462        instance_id,
463        link_definitions,
464        config,
465        secrets,
466        default_rpc_timeout_ms: _,
467        link_name: _link_name,
468        host_xkey_public_key,
469        provider_xkey_private_key,
470        ..
471    } = spawn_blocking(load_host_data).await.map_err(|e| {
472        ProviderInitError::Initialization(format!("failed to load host data: {e}"))
473    })??;
474
475    let (quit_tx, quit_rx) = broadcast::channel(1);
476
477    // If the xkey strings are empty, it just means that the host is <1.1.0 and does not support secrets.
478    // There aren't any negative side effects here, so it's really just a warning to update to 1.1.0.
479    let host_public_xkey = if host_xkey_public_key.is_empty() {
480        warn!("Provider is running on a host that does not provide a host xkey, secrets will not be supported");
481        XKey::new()
482    } else {
483        XKey::from_public_key(host_xkey_public_key).map_err(|e| {
484            ProviderInitError::Initialization(format!(
485                "failed to create host xkey from public key: {e}"
486            ))
487        })?
488    };
489    let provider_private_xkey = if provider_xkey_private_key.is_empty() {
490        warn!("Provider is running on a host that does not provide a provider xkey, secrets will not be supported");
491        XKey::new()
492    } else {
493        XKey::from_seed(provider_xkey_private_key).map_err(|e| {
494            ProviderInitError::Initialization(format!(
495                "failed to create provider xkey from private key: {e}"
496            ))
497        })?
498    };
499
500    // wasmCloud 1.1.0 hosts provide xkeys and publish links to the provider using the xkey public key in the NATS subject.
501    // Older hosts will use the provider key in the NATS subject.
502    // This allows for backwards compatibility with older hosts.
503    let provider_link_put_id = if host_xkey_public_key.is_empty()
504        && provider_xkey_private_key.is_empty()
505    {
506        debug!("Provider is running on a host that does not provide xkeys, using provider key in NATS subject");
507        provider_key.to_string()
508    } else {
509        debug!("Provider is running on a host that provides xkeys, using provider xkey in NATS subject");
510        provider_private_xkey.public_key()
511    };
512
513    info!(
514        "Starting capability provider {provider_key} instance {instance_id} with nats url {lattice_rpc_url}"
515    );
516
517    // Build the NATS client
518    let nats_addr = if !lattice_rpc_url.is_empty() {
519        lattice_rpc_url.as_str()
520    } else {
521        DEFAULT_NATS_ADDR
522    };
523
524    let nats = with_connection_event_logging(
525        match (lattice_rpc_user_jwt.trim(), lattice_rpc_user_seed.trim()) {
526            ("", "") => async_nats::ConnectOptions::default(),
527            (rpc_jwt, rpc_seed) => {
528                let key_pair = Arc::new(nkeys::KeyPair::from_seed(rpc_seed).unwrap());
529                let jwt = rpc_jwt.to_owned();
530                async_nats::ConnectOptions::with_jwt(jwt, move |nonce| {
531                    let key_pair = key_pair.clone();
532                    async move { key_pair.sign(&nonce).map_err(async_nats::AuthError::new) }
533                })
534            }
535        },
536    )
537    .name(name)
538    .connect(nats_addr)
539    .await?;
540    let nats = Arc::new(nats);
541
542    // Listen and process various provider events/functionality
543    let commands = ProviderCommandReceivers::new(
544        Arc::clone(&nats),
545        &quit_tx,
546        lattice_rpc_prefix,
547        provider_key,
548        &provider_link_put_id,
549        host_id,
550    )
551    .await?;
552    Ok(ProviderInitState {
553        nats,
554        quit_rx,
555        quit_tx,
556        host_id: host_id.clone(),
557        lattice_rpc_prefix: lattice_rpc_prefix.clone(),
558        provider_key: provider_key.clone(),
559        link_definitions: link_definitions.clone(),
560        config: config.clone(),
561        secrets: secrets.clone(),
562        host_public_xkey,
563        provider_private_xkey,
564        commands,
565    })
566}
567
568/// Appropriately receive a link (depending on if it's source/target) for a provider
569pub async fn receive_link_for_provider<P>(
570    provider: &P,
571    connection: &ProviderConnection,
572    ld: InterfaceLinkDefinition,
573) -> Result<()>
574where
575    P: Provider,
576{
577    match if ld.source_id == *connection.provider_id {
578        provider
579            .receive_link_config_as_source(LinkConfig {
580                source_id: &ld.source_id,
581                target_id: &ld.target,
582                link_name: &ld.name,
583                config: &ld.source_config,
584                secrets: &decrypt_link_secret(
585                    ld.source_secrets.as_deref(),
586                    &connection.provider_xkey,
587                    &connection.host_xkey,
588                )?,
589                wit_metadata: (&ld.wit_namespace, &ld.wit_package, &ld.interfaces),
590            })
591            .await
592    } else if ld.target == *connection.provider_id {
593        provider
594            .receive_link_config_as_target(LinkConfig {
595                source_id: &ld.source_id,
596                target_id: &ld.target,
597                link_name: &ld.name,
598                config: &ld.target_config,
599                secrets: &decrypt_link_secret(
600                    ld.target_secrets.as_deref(),
601                    &connection.provider_xkey,
602                    &connection.host_xkey,
603                )?,
604                wit_metadata: (&ld.wit_namespace, &ld.wit_package, &ld.interfaces),
605            })
606            .await
607    } else {
608        bail!("received link put where provider was neither source nor target");
609    } {
610        Ok(()) => connection.put_link(ld).await,
611        Err(e) => {
612            warn!(error = %e, "receiving link failed");
613        }
614    };
615    Ok(())
616}
617
618/// Given a serialized and encrypted [`HashMap<String, SecretValue>`], decrypts the secrets and deserializes
619/// the inner bytes into a [`HashMap<String, SecretValue>`]. This can either fail due to a decryption error
620/// or a deserialization error.
621///
622/// This will return an empty [`HashMap`] if no secrets are provided.
623fn decrypt_link_secret(
624    secrets: Option<&[u8]>,
625    provider_xkey: &XKey,
626    host_xkey: &XKey,
627) -> Result<HashMap<String, SecretValue>> {
628    // Note that we only `unwrap_or` in the fallback case where there are no secrets,
629    // not when the decryption or deserialization fails.
630    secrets
631        .map(|secrets| {
632            provider_xkey.open(secrets, host_xkey).map(|secrets| {
633                serde_json::from_slice(&secrets).context("failed to deserialize secrets")
634            })?
635        })
636        .unwrap_or(Ok(HashMap::with_capacity(0)))
637}
638
639async fn delete_link_for_provider<P>(
640    provider: &P,
641    connection: &ProviderConnection,
642    ld: InterfaceLinkDefinition,
643) -> Result<()>
644where
645    P: Provider,
646{
647    debug!(
648        provider_id = &connection.provider_id.to_string(),
649        "Deleting link for provider {ld:?}"
650    );
651    if *ld.source_id == *connection.provider_id {
652        if let Err(e) = provider.delete_link_as_source(&ld).await {
653            error!(error = %e, target = &ld.target, "failed to delete link to component");
654        }
655    } else if *ld.target == *connection.provider_id {
656        if let Err(e) = provider.delete_link_as_target(&ld).await {
657            error!(error = %e, source = &ld.source_id, "failed to delete link from component");
658        }
659    }
660    connection.delete_link(&ld.source_id, &ld.target).await;
661    Ok(())
662}
663
664/// Handle provider commands in a loop.
665pub async fn handle_provider_commands(
666    provider: impl Provider,
667    connection: &ProviderConnection,
668    mut quit_rx: broadcast::Receiver<()>,
669    quit_tx: broadcast::Sender<()>,
670    ProviderCommandReceivers {
671        mut health,
672        mut shutdown,
673        mut link_put,
674        mut link_del,
675        mut config_update,
676    }: ProviderCommandReceivers,
677) {
678    loop {
679        select! {
680            // run until we receive a shutdown request from host
681            _ = quit_rx.recv() => {
682                // flush async_nats client
683                connection.flush().await;
684                return
685            }
686            req = health.recv() => {
687                if let Some((req, tx)) = req {
688                    let res = match provider.health_request(&req).await {
689                        Ok(v) => v,
690                        Err(e) => {
691                            error!(error = %e, "provider health request failed");
692                            return;
693                        }
694                    };
695                    if tx.send(res).is_err() {
696                        error!("failed to send health check response");
697                    }
698                } else {
699                    error!("failed to handle health check, shutdown");
700                    if let Err(e) = provider.shutdown().await {
701                        error!(error = %e, "failed to shutdown provider");
702                    }
703                    if quit_tx.send(()).is_err() {
704                        error!("failed to send quit");
705                    };
706                    return
707                };
708            }
709            req = shutdown.recv() => {
710                if let Some(tx) = req {
711                    if let Err(e) = provider.shutdown().await {
712                        error!(error = %e, "failed to shutdown provider");
713                    }
714                    if tx.send(()).is_err() {
715                        error!("failed to send shutdown response");
716                    }
717                } else {
718                    error!("failed to handle shutdown, shutdown");
719                    if let Err(e) = provider.shutdown().await {
720                        error!(error = %e, "failed to shutdown provider");
721                    }
722                    if quit_tx.send(()).is_err() {
723                        error!("failed to send quit");
724                    };
725                    return
726                };
727            }
728            req = link_put.recv() => {
729                if let Some((ld, tx)) = req {
730                    // If the link has already been put, return early
731                    if connection.is_linked(&ld.source_id, &ld.target, &ld.wit_namespace, &ld.wit_package, &ld.name).await {
732                        warn!(
733                            source = &ld.source_id,
734                            target = &ld.target,
735                            link_name = &ld.name,
736                            "Ignoring duplicate link put"
737                        );
738                    } else {
739                        info!("Linking component with provider");
740                        if let Err(e) = receive_link_for_provider(&provider, connection, ld).await {
741                            error!(error = %e, "failed to receive link for provider");
742                        }
743                    }
744                    if tx.send(()).is_err() {
745                        error!("failed to send link put response");
746                    }
747                } else {
748                    error!("failed to handle link put, shutdown");
749                    if let Err(e) = provider.shutdown().await {
750                        error!(error = %e, "failed to shutdown provider");
751                    }
752                    if quit_tx.send(()).is_err() {
753                        error!("failed to send quit");
754                    };
755                    return;
756                };
757            }
758            req = link_del.recv() => {
759                if let Some((ld, tx)) = req {
760                    // notify provider that link is deleted
761                    if let Err(e) = delete_link_for_provider(&provider, connection, ld).await {
762                        error!(error = %e, "failed to delete link for provider");
763                    }
764
765                    if tx.send(()).is_err() {
766                        error!("failed to send link del response");
767                    }
768                } else {
769                    error!("failed to handle link del, shutdown");
770                    if let Err(e) = provider.shutdown().await {
771                        error!(error = %e, "failed to shutdown provider");
772                    }
773                    if quit_tx.send(()).is_err() {
774                        error!("failed to send quit");
775                    };
776                    return
777                };
778            }
779            req = config_update.recv() => {
780                if let Some((cfg, tx)) = req {
781                    // Notify the provider that some config has been updated
782                    if let Err(e) = provider.on_config_update(&cfg).await {
783                        error!(error = %e, "failed to pass through config update for provider");
784                    }
785
786                    if tx.send(()).is_err() {
787                        error!("failed to send config update response");
788                    }
789                } else {
790                    error!("failed to handle config update, shutdown");
791                    if let Err(e) = provider.shutdown().await {
792                        error!(error = %e, "failed to shutdown provider");
793                    }
794                    if quit_tx.send(()).is_err() {
795                        error!("failed to send quit");
796                    };
797                    return
798                };
799            }
800        }
801    }
802}
803
804/// Runs the provider handler given a provider implementation and a name.
805/// It returns a [Future], which will become ready once shutdown signal is received.
806pub async fn run_provider(
807    provider: impl Provider,
808    friendly_name: &str,
809) -> ProviderInitResult<impl Future<Output = ()>> {
810    let init_state = init_provider(friendly_name).await?;
811
812    // Run user-implemented provider-internal specific initialization
813    if let Err(e) = provider.init(&init_state).await {
814        return Err(ProviderInitError::Initialization(format!(
815            "provider init failed: {e}"
816        )));
817    }
818
819    let ProviderInitState {
820        nats,
821        quit_rx,
822        quit_tx,
823        host_id,
824        lattice_rpc_prefix,
825        provider_key,
826        link_definitions,
827        commands,
828        config,
829        secrets: _secrets,
830        host_public_xkey: host_xkey,
831        provider_private_xkey: provider_xkey,
832    } = init_state;
833
834    let connection = ProviderConnection::new(
835        Arc::clone(&nats),
836        provider_key,
837        lattice_rpc_prefix,
838        host_id,
839        config,
840        provider_xkey,
841        host_xkey,
842    )?;
843    CONNECTION.set(connection).map_err(|_| {
844        ProviderInitError::Initialization("Provider connection was already initialized".to_string())
845    })?;
846    let connection = get_connection();
847
848    // Provide all links to the provider at startup to establish the initial state
849    for ld in link_definitions {
850        if let Err(e) = receive_link_for_provider(&provider, connection, ld).await {
851            error!(
852                error = %e,
853                "failed to initialize link during provider startup",
854            );
855        }
856    }
857
858    debug!(?friendly_name, "provider finished initialization");
859    Ok(handle_provider_commands(
860        provider, connection, quit_rx, quit_tx, commands,
861    ))
862}
863
864/// This is the type returned by the `serve` function generated by [`wit-bindgen-wrpc`]
865pub type InvocationStreams = Vec<(
866    &'static str,
867    &'static str,
868    Pin<
869        Box<
870            dyn Stream<
871                    Item = anyhow::Result<
872                        Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send + 'static>>,
873                    >,
874                > + Send
875                + 'static,
876        >,
877    >,
878)>;
879
880/// Serve exports of the provider using the `serve` function generated by [`wit-bindgen-wrpc`]
881pub async fn serve_provider_exports<'a, P, F, Fut>(
882    client: &'a WrpcClient,
883    provider: P,
884    shutdown: impl Future<Output = ()>,
885    serve: F,
886) -> anyhow::Result<()>
887where
888    F: FnOnce(&'a WrpcClient, P) -> Fut,
889    Fut: Future<Output = anyhow::Result<InvocationStreams>> + wrpc_transport::Captures<'a>,
890{
891    let invocations = serve(client, provider)
892        .await
893        .context("failed to serve exports")?;
894    let mut invocations = stream::select_all(
895        invocations
896            .into_iter()
897            .map(|(instance, name, invocations)| invocations.map(move |res| (instance, name, res))),
898    );
899    let mut shutdown = pin!(shutdown);
900    let mut tasks = JoinSet::new();
901    loop {
902        select! {
903            Some((instance, name, res)) = invocations.next() => {
904                match res {
905                    Ok(fut) => {
906                        tasks.spawn(async move {
907                            if let Err(err) = fut.await {
908                                warn!(?err, instance, name, "failed to serve invocation");
909                            }
910                            trace!(instance, name, "successfully served invocation");
911                        });
912                    },
913                    Err(err) => {
914                        warn!(?err, instance, name, "failed to accept invocation");
915                    }
916                }
917            },
918            () = &mut shutdown => {
919                return Ok(())
920            }
921        }
922    }
923}
924
925/// Source ID for a link
926type SourceId = String;
927
928#[derive(Clone)]
929pub struct ProviderConnection {
930    /// Links from the provider to other components, aka where the provider is the
931    /// source of the link. Indexed by the component ID of the target
932    pub source_links: Arc<RwLock<HashMap<LatticeTarget, InterfaceLinkDefinition>>>,
933    /// Links from other components to the provider, aka where the provider is the
934    /// target of the link. Indexed by the component ID of the source
935    pub target_links: Arc<RwLock<HashMap<SourceId, InterfaceLinkDefinition>>>,
936
937    /// NATS client used for performing RPCs
938    pub nats: Arc<async_nats::Client>,
939
940    /// Lattice name
941    pub lattice: Arc<str>,
942    pub host_id: String,
943    pub provider_id: Arc<str>,
944
945    /// Secrets XKeys
946    pub provider_xkey: Arc<XKey>,
947    pub host_xkey: Arc<XKey>,
948
949    // TODO: Reference this field to get static config
950    #[allow(unused)]
951    pub config: HashMap<String, String>,
952}
953
954impl fmt::Debug for ProviderConnection {
955    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
956        f.debug_struct("ProviderConnection")
957            .field("provider_id", &self.provider_key())
958            .field("host_id", &self.host_id)
959            .field("lattice", &self.lattice)
960            .finish()
961    }
962}
963
964/// Extracts trace context from incoming headers
965pub fn invocation_context(headers: &HeaderMap) -> Context {
966    #[cfg(feature = "otel")]
967    {
968        let trace_context: TraceContext = convert_header_map_to_hashmap(headers)
969            .into_iter()
970            .collect::<Vec<(String, String)>>();
971        attach_span_context(&trace_context);
972    }
973    // Determine source ID for the invocation
974    let source_id = headers
975        .get(WRPC_SOURCE_ID_HEADER_NAME)
976        .map_or_else(|| "<unknown>".into(), ToString::to_string);
977    Context {
978        component: Some(source_id),
979        tracing: convert_header_map_to_hashmap(headers),
980    }
981}
982
983#[derive(Clone)]
984pub struct WrpcClient {
985    nats: wrpc_transport_nats::Client,
986    timeout: Duration,
987    provider_id: Arc<str>,
988    target: Arc<str>,
989}
990
991impl wrpc_transport::Invoke for WrpcClient {
992    type Context = Option<HeaderMap>;
993    type Outgoing = <wrpc_transport_nats::Client as wrpc_transport::Invoke>::Outgoing;
994    type Incoming = <wrpc_transport_nats::Client as wrpc_transport::Invoke>::Incoming;
995
996    async fn invoke<P>(
997        &self,
998        cx: Self::Context,
999        instance: &str,
1000        func: &str,
1001        params: Bytes,
1002        paths: impl AsRef<[P]> + Send,
1003    ) -> anyhow::Result<(Self::Outgoing, Self::Incoming)>
1004    where
1005        P: AsRef<[Option<usize>]> + Send + Sync,
1006    {
1007        let mut headers = cx.unwrap_or_default();
1008        headers.insert("source-id", &*self.provider_id);
1009        headers.insert("target-id", &*self.target);
1010        self.nats
1011            .timeout(self.timeout)
1012            .invoke(Some(headers), instance, func, params, paths)
1013            .await
1014    }
1015}
1016
1017impl wrpc_transport::Serve for WrpcClient {
1018    type Context = Option<Context>;
1019    type Outgoing = <wrpc_transport_nats::Client as wrpc_transport::Serve>::Outgoing;
1020    type Incoming = <wrpc_transport_nats::Client as wrpc_transport::Serve>::Incoming;
1021
1022    async fn serve(
1023        &self,
1024        instance: &str,
1025        func: &str,
1026        paths: impl Into<Arc<[Box<[Option<usize>]>]>> + Send,
1027    ) -> anyhow::Result<
1028        impl Stream<Item = anyhow::Result<(Self::Context, Self::Outgoing, Self::Incoming)>>
1029            + Send
1030            + 'static,
1031    > {
1032        let invocations = self.nats.serve(instance, func, paths).await?;
1033        Ok(invocations.and_then(|(cx, tx, rx)| async move {
1034            Ok((cx.as_ref().map(invocation_context), tx, rx))
1035        }))
1036    }
1037}
1038
1039impl ProviderConnection {
1040    pub fn new(
1041        nats: impl Into<Arc<async_nats::Client>>,
1042        provider_id: impl Into<Arc<str>>,
1043        lattice: impl Into<Arc<str>>,
1044        host_id: String,
1045        config: HashMap<String, String>,
1046        provider_private_xkey: impl Into<Arc<XKey>>,
1047        host_public_xkey: impl Into<Arc<XKey>>,
1048    ) -> ProviderInitResult<ProviderConnection> {
1049        Ok(ProviderConnection {
1050            source_links: Arc::default(),
1051            target_links: Arc::default(),
1052            nats: nats.into(),
1053            lattice: lattice.into(),
1054            host_id,
1055            provider_id: provider_id.into(),
1056            config,
1057            provider_xkey: provider_private_xkey.into(),
1058            host_xkey: host_public_xkey.into(),
1059        })
1060    }
1061
1062    /// Retrieve a wRPC client that can be used based on the NATS client of this connection
1063    ///
1064    /// # Arguments
1065    ///
1066    /// * `target` - Target ID to which invocations will be sent
1067    pub async fn get_wrpc_client(&self, target: &str) -> anyhow::Result<WrpcClient> {
1068        self.get_wrpc_client_custom(target, None).await
1069    }
1070
1071    /// Retrieve a wRPC client that can be used based on the NATS client of this connection,
1072    /// customized with invocation timeout
1073    ///
1074    /// # Arguments
1075    ///
1076    /// * `target` - Target ID to which invocations will be sent
1077    /// * `timeout` - Timeout to be set on the client (by default if this is unset it will be 10 seconds)
1078    pub async fn get_wrpc_client_custom(
1079        &self,
1080        target: &str,
1081        timeout: Option<Duration>,
1082    ) -> anyhow::Result<WrpcClient> {
1083        let prefix = Arc::from(format!("{}.{target}", &self.lattice));
1084        let nats = wrpc_transport_nats::Client::new(
1085            Arc::clone(&self.nats),
1086            Arc::clone(&prefix),
1087            Some(prefix),
1088        )
1089        .await?;
1090        Ok(WrpcClient {
1091            nats,
1092            provider_id: Arc::clone(&self.provider_id),
1093            target: Arc::from(target),
1094            timeout: timeout.unwrap_or_else(|| Duration::from_secs(10)),
1095        })
1096    }
1097
1098    /// Get the provider key that was assigned to this host at startup
1099    #[must_use]
1100    pub fn provider_key(&self) -> &str {
1101        &self.provider_id
1102    }
1103
1104    /// Stores link in the [`ProviderConnection`], either as a source link or target link
1105    /// depending on if the provider is the source or target of the link
1106    pub async fn put_link(&self, ld: InterfaceLinkDefinition) {
1107        if ld.source_id == *self.provider_id {
1108            self.source_links
1109                .write()
1110                .await
1111                .insert(ld.target.to_string(), ld);
1112        } else {
1113            self.target_links
1114                .write()
1115                .await
1116                .insert(ld.source_id.to_string(), ld);
1117        }
1118    }
1119
1120    /// Deletes link from the [`ProviderConnection`], either a source link or target link
1121    /// based on if the provider is the source or target of the link
1122    pub async fn delete_link(&self, source_id: &str, target: &str) {
1123        if source_id == &*self.provider_id {
1124            self.source_links.write().await.remove(target);
1125        } else if target == &*self.provider_id {
1126            self.target_links.write().await.remove(source_id);
1127        }
1128    }
1129
1130    /// Returns true if the source is linked to this provider or if the provider is linked to the target
1131    /// on the given interface and link name
1132    pub async fn is_linked(
1133        &self,
1134        source_id: &str,
1135        target_id: &str,
1136        wit_namespace: &str,
1137        wit_package: &str,
1138        link_name: &str,
1139    ) -> bool {
1140        // Provider is the source of the link, so we check if the target is linked
1141        if &*self.provider_id == source_id {
1142            if let Some(link) = self.source_links.read().await.get(target_id) {
1143                // In older host versions, the wit_namespace and wit_package are not provided
1144                // so we should see if it's empty
1145                (link.wit_namespace.is_empty() || link.wit_namespace == wit_namespace)
1146                    && (link.wit_package.is_empty() || link.wit_package == wit_package)
1147                    && link.name == link_name
1148            } else {
1149                false
1150            }
1151        // Provider is the target of the link, so we check if the source is linked
1152        } else if &*self.provider_id == target_id {
1153            if let Some(link) = self.target_links.read().await.get(source_id) {
1154                // In older host versions, the wit_namespace and wit_package are not provided
1155                // so we should see if it's empty
1156                (link.wit_namespace.is_empty() || link.wit_namespace == wit_namespace)
1157                    && (link.wit_package.is_empty() || link.wit_package == wit_package)
1158                    && link.name == link_name
1159            } else {
1160                false
1161            }
1162        } else {
1163            // Shouldn't occur, but if the provider is neither source nor target then it's not linked
1164            false
1165        }
1166    }
1167
1168    /// flush nats - called before main process exits
1169    pub(crate) async fn flush(&self) {
1170        if let Err(err) = self.nats.flush().await {
1171            error!(%err, "error flushing NATS client");
1172        }
1173    }
1174}