1use core::fmt;
2use core::fmt::Formatter;
3use core::future::Future;
4
5use core::pin::{pin, Pin};
6use core::time::Duration;
7use std::collections::HashMap;
8use std::io::BufRead;
9use std::sync::Arc;
10
11use anyhow::{bail, Context as _, Result};
12use async_nats::subject::ToSubject as _;
13use async_nats::HeaderMap;
14use base64::Engine;
15use bytes::Bytes;
16use futures::{stream, Stream, StreamExt as _, TryStreamExt as _};
17use nkeys::XKey;
18use once_cell::sync::OnceCell;
19use serde::{Deserialize, Serialize};
20use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
21use tokio::task::{spawn_blocking, JoinSet};
22use tokio::{select, spawn, try_join};
23use tracing::{debug, error, info, instrument, trace, warn, Instrument as _};
24use wasmcloud_core::nats::convert_header_map_to_hashmap;
25use wasmcloud_core::rpc::{health_subject, link_del_subject, link_put_subject, shutdown_subject};
26use wasmcloud_core::secrets::SecretValue;
27use wasmcloud_core::{
28 provider_config_update_subject, HealthCheckRequest, HealthCheckResponse, HostData,
29 InterfaceLinkDefinition, LatticeTarget,
30};
31
32#[cfg(feature = "otel")]
33use wasmcloud_core::TraceContext;
34#[cfg(feature = "otel")]
35use wasmcloud_tracing::context::attach_span_context;
36use wrpc_transport::InvokeExt as _;
37
38use crate::error::{ProviderInitError, ProviderInitResult};
39use crate::{with_connection_event_logging, Context, LinkConfig, Provider, DEFAULT_NATS_ADDR};
40
41const WRPC_SOURCE_ID_HEADER_NAME: &str = "source-id";
43
44static HOST_DATA: OnceCell<HostData> = OnceCell::new();
45static CONNECTION: OnceCell<ProviderConnection> = OnceCell::new();
46
47pub fn get_connection() -> &'static ProviderConnection {
55 CONNECTION
56 .get()
57 .expect("Provider connection not initialized")
58}
59
60pub fn load_host_data() -> ProviderInitResult<&'static HostData> {
67 HOST_DATA.get_or_try_init(_load_host_data)
68}
69
70pub fn initialize_host_data(host_data: HostData) -> ProviderInitResult<&'static HostData> {
75 HOST_DATA.get_or_try_init(|| Ok(host_data))
76}
77
78fn _load_host_data() -> ProviderInitResult<HostData> {
80 let mut buffer = String::new();
81 let stdin = std::io::stdin();
82 {
83 let mut handle = stdin.lock();
84 handle.read_line(&mut buffer).map_err(|e| {
85 ProviderInitError::Initialization(format!(
86 "failed to read host data configuration from stdin: {e}"
87 ))
88 })?;
89 }
90 let buffer = buffer.trim();
92 if buffer.is_empty() {
93 return Err(ProviderInitError::Initialization(
94 "stdin is empty - expecting host data configuration".to_string(),
95 ));
96 }
97 let bytes = base64::engine::general_purpose::STANDARD
98 .decode(buffer.as_bytes())
99 .map_err(|e| {
100 ProviderInitError::Initialization(format!(
101 "host data configuration passed through stdin has invalid encoding (expected base64): \
102 {e}"
103 ))
104 })?;
105 let host_data: HostData = serde_json::from_slice(&bytes).map_err(|e| {
106 ProviderInitError::Initialization(format!(
107 "parsing host data: {}:\n{}",
108 e,
109 String::from_utf8_lossy(&bytes)
110 ))
111 })?;
112 Ok(host_data)
113}
114
115pub type QuitSignal = broadcast::Receiver<()>;
116
117#[derive(Debug, Clone, Serialize, Deserialize, Default)]
118struct ShutdownMessage {
119 pub host_id: String,
121}
122
123#[doc(hidden)]
124macro_rules! process_until_quit {
131 ($sub:ident, $channel:ident, $msg:ident, $on_item:tt) => {
132 spawn(async move {
133 loop {
134 select! {
135 _ = $channel.recv() => {
136 let _ = $sub.unsubscribe().await;
137 break;
138 },
139 __msg = $sub.next() => {
140 match __msg {
141 None => break,
142 Some($msg) => $on_item
143 }
144 }
145 }
146 }
147 })
148 };
149}
150
151async fn subscribe_health(
152 nats: Arc<async_nats::Client>,
153 mut quit: broadcast::Receiver<()>,
154 lattice: &str,
155 provider_key: &str,
156) -> ProviderInitResult<mpsc::Receiver<(HealthCheckRequest, oneshot::Sender<HealthCheckResponse>)>>
157{
158 let mut sub = nats
159 .subscribe(health_subject(lattice, provider_key))
160 .await?;
161 let (health_tx, health_rx) = mpsc::channel(1);
162 spawn({
163 let nats = Arc::clone(&nats);
164 async move {
165 process_until_quit!(sub, quit, msg, {
166 let (tx, rx) = oneshot::channel();
167 if let Err(err) = health_tx.send((HealthCheckRequest {}, tx)).await {
168 error!(%err, "failed to send health check request");
169 continue;
170 }
171 match rx.await.as_ref().map(serde_json::to_vec) {
172 Err(err) => {
173 error!(%err, "failed to receive health check response");
174 }
175 Ok(Ok(t)) => {
176 if let Some(reply_to) = msg.reply {
177 if let Err(err) = nats.publish(reply_to, t.into()).await {
178 error!(%err, "failed sending health check response");
179 }
180 }
181 }
182 Ok(Err(err)) => {
183 error!(%err, "failed serializing HealthCheckResponse");
185 }
186 }
187 });
188 }
189 .instrument(tracing::debug_span!("subscribe_health"))
190 });
191 Ok(health_rx)
192}
193
194async fn subscribe_shutdown(
195 nats: Arc<async_nats::Client>,
196 quit: broadcast::Sender<()>,
197 lattice: &str,
198 provider_key: &str,
199 host_id: impl Into<Arc<str>>,
200) -> ProviderInitResult<mpsc::Receiver<oneshot::Sender<()>>> {
201 let mut sub = nats
202 .subscribe(shutdown_subject(lattice, provider_key, "default"))
203 .await?;
204 let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
205 let host_id = host_id.into();
206 spawn({
207 async move {
208 loop {
209 let msg = sub.next().await;
210 if let Some(async_nats::Message {
212 reply: Some(reply_to),
213 payload,
214 ..
215 }) = msg
216 {
217 let ShutdownMessage {
218 host_id: ref req_host_id,
219 } = serde_json::from_slice(&payload).unwrap_or_default();
220 if req_host_id == host_id.as_ref() {
221 info!("Received termination signal and stopping");
222 let (tx, rx) = oneshot::channel();
225 match shutdown_tx.send(tx).await {
226 Ok(()) => {
227 if let Err(err) = rx.await {
228 error!(%err, "failed to await shutdown");
229 }
230 }
231 Err(err) => error!(%err, "failed to send shutdown"),
232 }
233 if let Err(err) = nats.publish(reply_to, "shutting down".into()).await {
234 warn!(%err, "failed to send shutdown ack");
235 }
236 if let Err(err) = sub.unsubscribe().await {
238 warn!(%err, "failed to unsubscribe from shutdown topic");
239 }
240 if let Err(err) = quit.send(()) {
242 error!(%err, "Problem shutting down: failure to send signal");
243 }
244 break;
245 }
246 trace!("Ignoring termination signal (request targeted for different host)");
247 }
248 }
249 }
250 .instrument(tracing::debug_span!("shutdown_subscriber"))
251 });
252 Ok(shutdown_rx)
253}
254
255async fn subscribe_link_put(
256 nats: Arc<async_nats::Client>,
257 mut quit: broadcast::Receiver<()>,
258 lattice: &str,
259 provider_xkey: &str,
260) -> ProviderInitResult<mpsc::Receiver<(InterfaceLinkDefinition, oneshot::Sender<()>)>> {
261 let (link_put_tx, link_put_rx) = mpsc::channel(1);
262 let mut sub = nats
263 .subscribe(link_put_subject(lattice, provider_xkey))
264 .await?;
265 spawn(async move {
266 process_until_quit!(sub, quit, msg, {
267 match serde_json::from_slice::<InterfaceLinkDefinition>(&msg.payload) {
268 Ok(ld) => {
269 let span = tracing::Span::current();
270 span.record("source_id", tracing::field::display(&ld.source_id));
271 span.record("target", tracing::field::display(&ld.target));
272 span.record("wit_namespace", tracing::field::display(&ld.wit_namespace));
273 span.record("wit_package", tracing::field::display(&ld.wit_package));
274 span.record(
275 "wit_interfaces",
276 tracing::field::display(&ld.interfaces.join(",")),
277 );
278 span.record("link_name", tracing::field::display(&ld.name));
279 let (tx, rx) = oneshot::channel();
280 if let Err(err) = link_put_tx.send((ld, tx)).await {
281 error!(%err, "failed to send link put request");
282 continue;
283 }
284 if let Err(err) = rx.await {
285 error!(%err, "failed to await link_put");
286 }
287 }
288 Err(err) => {
289 error!(%err, "received invalid link def data on message");
290 }
291 }
292 });
293 });
294 Ok(link_put_rx)
295}
296
297async fn subscribe_link_del(
298 nats: Arc<async_nats::Client>,
299 mut quit: broadcast::Receiver<()>,
300 lattice: &str,
301 provider_key: &str,
302) -> ProviderInitResult<mpsc::Receiver<(InterfaceLinkDefinition, oneshot::Sender<()>)>> {
303 let subject = link_del_subject(lattice, provider_key).to_subject();
304 debug!(%subject, "subscribing for link del");
305 let mut sub = nats.subscribe(subject.clone()).await?;
306 let (link_del_tx, link_del_rx) = mpsc::channel(1);
307 let span = tracing::trace_span!("subscribe_link_del", %subject);
308 spawn(
309 async move {
310 process_until_quit!(sub, quit, msg, {
311 if let Ok(ld) = serde_json::from_slice::<InterfaceLinkDefinition>(&msg.payload) {
312 let (tx, rx) = oneshot::channel();
313 if let Err(err) = link_del_tx.send((ld, tx)).await {
314 error!(%err, "failed to send link del request");
315 continue;
316 }
317 if let Err(err) = rx.await {
318 error!(%err, "failed to await link_del");
319 }
320 } else {
321 error!("received invalid link on link_del");
322 }
323 });
324 }
325 .instrument(span),
326 );
327 Ok(link_del_rx)
328}
329
330async fn subscribe_config_update(
336 nats: Arc<async_nats::Client>,
337 mut quit: broadcast::Receiver<()>,
338 lattice: &str,
339 provider_key: &str,
340) -> ProviderInitResult<mpsc::Receiver<(HashMap<String, String>, oneshot::Sender<()>)>> {
341 let (config_update_tx, config_update_rx) = mpsc::channel(1);
342 let mut sub = nats
343 .subscribe(provider_config_update_subject(lattice, provider_key).to_subject())
344 .await?;
345 spawn({
346 async move {
347 process_until_quit!(sub, quit, msg, {
348 match serde_json::from_slice::<HashMap<String, String>>(&msg.payload) {
349 Ok(update) => {
350 let (tx, rx) = oneshot::channel();
351 if let Err(err) = config_update_tx.send((update, tx)).await {
353 error!(%err, "failed to send config update");
354 continue;
355 }
356 if let Err(err) = rx.await.as_ref() {
358 error!(%err, "failed to receive config update response");
359 }
360 }
361 Err(err) => {
362 error!(%err, "received invalid config update data on message");
363 }
364 }
365 });
366 }
367 .instrument(tracing::debug_span!("subscribe_config_update"))
368 });
369
370 Ok(config_update_rx)
371}
372
373pub struct ProviderCommandReceivers {
374 health: mpsc::Receiver<(HealthCheckRequest, oneshot::Sender<HealthCheckResponse>)>,
375 shutdown: mpsc::Receiver<oneshot::Sender<()>>,
376 link_put: mpsc::Receiver<(InterfaceLinkDefinition, oneshot::Sender<()>)>,
377 link_del: mpsc::Receiver<(InterfaceLinkDefinition, oneshot::Sender<()>)>,
378 config_update: mpsc::Receiver<(HashMap<String, String>, oneshot::Sender<()>)>,
379}
380
381impl ProviderCommandReceivers {
382 pub async fn new(
383 nats: Arc<async_nats::Client>,
384 quit_tx: &broadcast::Sender<()>,
385 lattice: &str,
386 provider_key: &str,
387 provider_link_put_id: &str,
388 host_id: &str,
389 ) -> ProviderInitResult<Self> {
390 let (health, shutdown, link_put, link_del, config_update) = try_join!(
391 subscribe_health(
392 Arc::clone(&nats),
393 quit_tx.subscribe(),
394 lattice,
395 provider_key
396 ),
397 subscribe_shutdown(
398 Arc::clone(&nats),
399 quit_tx.clone(),
400 lattice,
401 provider_key,
402 host_id
403 ),
404 subscribe_link_put(
405 Arc::clone(&nats),
406 quit_tx.subscribe(),
407 lattice,
408 provider_link_put_id
409 ),
410 subscribe_link_del(
411 Arc::clone(&nats),
412 quit_tx.subscribe(),
413 lattice,
414 provider_key
415 ),
416 subscribe_config_update(
417 Arc::clone(&nats),
418 quit_tx.subscribe(),
419 lattice,
420 provider_key
421 ),
422 )?;
423 Ok(Self {
424 health,
425 shutdown,
426 link_put,
427 link_del,
428 config_update,
429 })
430 }
431}
432
433pub(crate) struct ProviderInitState {
435 pub nats: Arc<async_nats::Client>,
436 pub quit_rx: broadcast::Receiver<()>,
437 pub quit_tx: broadcast::Sender<()>,
438 pub host_id: String,
439 pub lattice_rpc_prefix: String,
440 pub provider_key: String,
441 pub link_definitions: Vec<InterfaceLinkDefinition>,
442 pub commands: ProviderCommandReceivers,
443 pub config: HashMap<String, String>,
444 pub secrets: HashMap<String, SecretValue>,
445 host_public_xkey: XKey,
448 provider_private_xkey: XKey,
449}
450
451#[instrument]
452async fn init_provider(name: &str) -> ProviderInitResult<ProviderInitState> {
453 let HostData {
454 host_id,
455 lattice_rpc_prefix,
456 lattice_rpc_user_jwt,
457 lattice_rpc_user_seed,
458 lattice_rpc_url,
459 provider_key,
460 env_values: _,
461 cluster_issuers: _,
462 instance_id,
463 link_definitions,
464 config,
465 secrets,
466 default_rpc_timeout_ms: _,
467 link_name: _link_name,
468 host_xkey_public_key,
469 provider_xkey_private_key,
470 ..
471 } = spawn_blocking(load_host_data).await.map_err(|e| {
472 ProviderInitError::Initialization(format!("failed to load host data: {e}"))
473 })??;
474
475 let (quit_tx, quit_rx) = broadcast::channel(1);
476
477 let host_public_xkey = if host_xkey_public_key.is_empty() {
480 warn!("Provider is running on a host that does not provide a host xkey, secrets will not be supported");
481 XKey::new()
482 } else {
483 XKey::from_public_key(host_xkey_public_key).map_err(|e| {
484 ProviderInitError::Initialization(format!(
485 "failed to create host xkey from public key: {e}"
486 ))
487 })?
488 };
489 let provider_private_xkey = if provider_xkey_private_key.is_empty() {
490 warn!("Provider is running on a host that does not provide a provider xkey, secrets will not be supported");
491 XKey::new()
492 } else {
493 XKey::from_seed(provider_xkey_private_key).map_err(|e| {
494 ProviderInitError::Initialization(format!(
495 "failed to create provider xkey from private key: {e}"
496 ))
497 })?
498 };
499
500 let provider_link_put_id = if host_xkey_public_key.is_empty()
504 && provider_xkey_private_key.is_empty()
505 {
506 debug!("Provider is running on a host that does not provide xkeys, using provider key in NATS subject");
507 provider_key.to_string()
508 } else {
509 debug!("Provider is running on a host that provides xkeys, using provider xkey in NATS subject");
510 provider_private_xkey.public_key()
511 };
512
513 info!(
514 "Starting capability provider {provider_key} instance {instance_id} with nats url {lattice_rpc_url}"
515 );
516
517 let nats_addr = if !lattice_rpc_url.is_empty() {
519 lattice_rpc_url.as_str()
520 } else {
521 DEFAULT_NATS_ADDR
522 };
523
524 let nats = with_connection_event_logging(
525 match (lattice_rpc_user_jwt.trim(), lattice_rpc_user_seed.trim()) {
526 ("", "") => async_nats::ConnectOptions::default(),
527 (rpc_jwt, rpc_seed) => {
528 let key_pair = Arc::new(nkeys::KeyPair::from_seed(rpc_seed).unwrap());
529 let jwt = rpc_jwt.to_owned();
530 async_nats::ConnectOptions::with_jwt(jwt, move |nonce| {
531 let key_pair = key_pair.clone();
532 async move { key_pair.sign(&nonce).map_err(async_nats::AuthError::new) }
533 })
534 }
535 },
536 )
537 .name(name)
538 .connect(nats_addr)
539 .await?;
540 let nats = Arc::new(nats);
541
542 let commands = ProviderCommandReceivers::new(
544 Arc::clone(&nats),
545 &quit_tx,
546 lattice_rpc_prefix,
547 provider_key,
548 &provider_link_put_id,
549 host_id,
550 )
551 .await?;
552 Ok(ProviderInitState {
553 nats,
554 quit_rx,
555 quit_tx,
556 host_id: host_id.clone(),
557 lattice_rpc_prefix: lattice_rpc_prefix.clone(),
558 provider_key: provider_key.clone(),
559 link_definitions: link_definitions.clone(),
560 config: config.clone(),
561 secrets: secrets.clone(),
562 host_public_xkey,
563 provider_private_xkey,
564 commands,
565 })
566}
567
568pub async fn receive_link_for_provider<P>(
570 provider: &P,
571 connection: &ProviderConnection,
572 ld: InterfaceLinkDefinition,
573) -> Result<()>
574where
575 P: Provider,
576{
577 match if ld.source_id == *connection.provider_id {
578 provider
579 .receive_link_config_as_source(LinkConfig {
580 source_id: &ld.source_id,
581 target_id: &ld.target,
582 link_name: &ld.name,
583 config: &ld.source_config,
584 secrets: &decrypt_link_secret(
585 ld.source_secrets.as_deref(),
586 &connection.provider_xkey,
587 &connection.host_xkey,
588 )?,
589 wit_metadata: (&ld.wit_namespace, &ld.wit_package, &ld.interfaces),
590 })
591 .await
592 } else if ld.target == *connection.provider_id {
593 provider
594 .receive_link_config_as_target(LinkConfig {
595 source_id: &ld.source_id,
596 target_id: &ld.target,
597 link_name: &ld.name,
598 config: &ld.target_config,
599 secrets: &decrypt_link_secret(
600 ld.target_secrets.as_deref(),
601 &connection.provider_xkey,
602 &connection.host_xkey,
603 )?,
604 wit_metadata: (&ld.wit_namespace, &ld.wit_package, &ld.interfaces),
605 })
606 .await
607 } else {
608 bail!("received link put where provider was neither source nor target");
609 } {
610 Ok(()) => connection.put_link(ld).await,
611 Err(e) => {
612 warn!(error = %e, "receiving link failed");
613 }
614 };
615 Ok(())
616}
617
618fn decrypt_link_secret(
624 secrets: Option<&[u8]>,
625 provider_xkey: &XKey,
626 host_xkey: &XKey,
627) -> Result<HashMap<String, SecretValue>> {
628 secrets
631 .map(|secrets| {
632 provider_xkey.open(secrets, host_xkey).map(|secrets| {
633 serde_json::from_slice(&secrets).context("failed to deserialize secrets")
634 })?
635 })
636 .unwrap_or(Ok(HashMap::with_capacity(0)))
637}
638
639async fn delete_link_for_provider<P>(
640 provider: &P,
641 connection: &ProviderConnection,
642 ld: InterfaceLinkDefinition,
643) -> Result<()>
644where
645 P: Provider,
646{
647 debug!(
648 provider_id = &connection.provider_id.to_string(),
649 "Deleting link for provider {ld:?}"
650 );
651 if *ld.source_id == *connection.provider_id {
652 if let Err(e) = provider.delete_link_as_source(&ld).await {
653 error!(error = %e, target = &ld.target, "failed to delete link to component");
654 }
655 } else if *ld.target == *connection.provider_id {
656 if let Err(e) = provider.delete_link_as_target(&ld).await {
657 error!(error = %e, source = &ld.source_id, "failed to delete link from component");
658 }
659 }
660 connection.delete_link(&ld.source_id, &ld.target).await;
661 Ok(())
662}
663
664pub async fn handle_provider_commands(
666 provider: impl Provider,
667 connection: &ProviderConnection,
668 mut quit_rx: broadcast::Receiver<()>,
669 quit_tx: broadcast::Sender<()>,
670 ProviderCommandReceivers {
671 mut health,
672 mut shutdown,
673 mut link_put,
674 mut link_del,
675 mut config_update,
676 }: ProviderCommandReceivers,
677) {
678 loop {
679 select! {
680 _ = quit_rx.recv() => {
682 connection.flush().await;
684 return
685 }
686 req = health.recv() => {
687 if let Some((req, tx)) = req {
688 let res = match provider.health_request(&req).await {
689 Ok(v) => v,
690 Err(e) => {
691 error!(error = %e, "provider health request failed");
692 return;
693 }
694 };
695 if tx.send(res).is_err() {
696 error!("failed to send health check response");
697 }
698 } else {
699 error!("failed to handle health check, shutdown");
700 if let Err(e) = provider.shutdown().await {
701 error!(error = %e, "failed to shutdown provider");
702 }
703 if quit_tx.send(()).is_err() {
704 error!("failed to send quit");
705 };
706 return
707 };
708 }
709 req = shutdown.recv() => {
710 if let Some(tx) = req {
711 if let Err(e) = provider.shutdown().await {
712 error!(error = %e, "failed to shutdown provider");
713 }
714 if tx.send(()).is_err() {
715 error!("failed to send shutdown response");
716 }
717 } else {
718 error!("failed to handle shutdown, shutdown");
719 if let Err(e) = provider.shutdown().await {
720 error!(error = %e, "failed to shutdown provider");
721 }
722 if quit_tx.send(()).is_err() {
723 error!("failed to send quit");
724 };
725 return
726 };
727 }
728 req = link_put.recv() => {
729 if let Some((ld, tx)) = req {
730 if connection.is_linked(&ld.source_id, &ld.target, &ld.wit_namespace, &ld.wit_package, &ld.name).await {
732 warn!(
733 source = &ld.source_id,
734 target = &ld.target,
735 link_name = &ld.name,
736 "Ignoring duplicate link put"
737 );
738 } else {
739 info!("Linking component with provider");
740 if let Err(e) = receive_link_for_provider(&provider, connection, ld).await {
741 error!(error = %e, "failed to receive link for provider");
742 }
743 }
744 if tx.send(()).is_err() {
745 error!("failed to send link put response");
746 }
747 } else {
748 error!("failed to handle link put, shutdown");
749 if let Err(e) = provider.shutdown().await {
750 error!(error = %e, "failed to shutdown provider");
751 }
752 if quit_tx.send(()).is_err() {
753 error!("failed to send quit");
754 };
755 return;
756 };
757 }
758 req = link_del.recv() => {
759 if let Some((ld, tx)) = req {
760 if let Err(e) = delete_link_for_provider(&provider, connection, ld).await {
762 error!(error = %e, "failed to delete link for provider");
763 }
764
765 if tx.send(()).is_err() {
766 error!("failed to send link del response");
767 }
768 } else {
769 error!("failed to handle link del, shutdown");
770 if let Err(e) = provider.shutdown().await {
771 error!(error = %e, "failed to shutdown provider");
772 }
773 if quit_tx.send(()).is_err() {
774 error!("failed to send quit");
775 };
776 return
777 };
778 }
779 req = config_update.recv() => {
780 if let Some((cfg, tx)) = req {
781 if let Err(e) = provider.on_config_update(&cfg).await {
783 error!(error = %e, "failed to pass through config update for provider");
784 }
785
786 if tx.send(()).is_err() {
787 error!("failed to send config update response");
788 }
789 } else {
790 error!("failed to handle config update, shutdown");
791 if let Err(e) = provider.shutdown().await {
792 error!(error = %e, "failed to shutdown provider");
793 }
794 if quit_tx.send(()).is_err() {
795 error!("failed to send quit");
796 };
797 return
798 };
799 }
800 }
801 }
802}
803
804pub async fn run_provider(
807 provider: impl Provider,
808 friendly_name: &str,
809) -> ProviderInitResult<impl Future<Output = ()>> {
810 let init_state = init_provider(friendly_name).await?;
811
812 if let Err(e) = provider.init(&init_state).await {
814 return Err(ProviderInitError::Initialization(format!(
815 "provider init failed: {e}"
816 )));
817 }
818
819 let ProviderInitState {
820 nats,
821 quit_rx,
822 quit_tx,
823 host_id,
824 lattice_rpc_prefix,
825 provider_key,
826 link_definitions,
827 commands,
828 config,
829 secrets: _secrets,
830 host_public_xkey: host_xkey,
831 provider_private_xkey: provider_xkey,
832 } = init_state;
833
834 let connection = ProviderConnection::new(
835 Arc::clone(&nats),
836 provider_key,
837 lattice_rpc_prefix,
838 host_id,
839 config,
840 provider_xkey,
841 host_xkey,
842 )?;
843 CONNECTION.set(connection).map_err(|_| {
844 ProviderInitError::Initialization("Provider connection was already initialized".to_string())
845 })?;
846 let connection = get_connection();
847
848 for ld in link_definitions {
850 if let Err(e) = receive_link_for_provider(&provider, connection, ld).await {
851 error!(
852 error = %e,
853 "failed to initialize link during provider startup",
854 );
855 }
856 }
857
858 debug!(?friendly_name, "provider finished initialization");
859 Ok(handle_provider_commands(
860 provider, connection, quit_rx, quit_tx, commands,
861 ))
862}
863
864pub type InvocationStreams = Vec<(
866 &'static str,
867 &'static str,
868 Pin<
869 Box<
870 dyn Stream<
871 Item = anyhow::Result<
872 Pin<Box<dyn Future<Output = anyhow::Result<()>> + Send + 'static>>,
873 >,
874 > + Send
875 + 'static,
876 >,
877 >,
878)>;
879
880pub async fn serve_provider_exports<'a, P, F, Fut>(
882 client: &'a WrpcClient,
883 provider: P,
884 shutdown: impl Future<Output = ()>,
885 serve: F,
886) -> anyhow::Result<()>
887where
888 F: FnOnce(&'a WrpcClient, P) -> Fut,
889 Fut: Future<Output = anyhow::Result<InvocationStreams>> + wrpc_transport::Captures<'a>,
890{
891 let invocations = serve(client, provider)
892 .await
893 .context("failed to serve exports")?;
894 let mut invocations = stream::select_all(
895 invocations
896 .into_iter()
897 .map(|(instance, name, invocations)| invocations.map(move |res| (instance, name, res))),
898 );
899 let mut shutdown = pin!(shutdown);
900 let mut tasks = JoinSet::new();
901 loop {
902 select! {
903 Some((instance, name, res)) = invocations.next() => {
904 match res {
905 Ok(fut) => {
906 tasks.spawn(async move {
907 if let Err(err) = fut.await {
908 warn!(?err, instance, name, "failed to serve invocation");
909 }
910 trace!(instance, name, "successfully served invocation");
911 });
912 },
913 Err(err) => {
914 warn!(?err, instance, name, "failed to accept invocation");
915 }
916 }
917 },
918 () = &mut shutdown => {
919 return Ok(())
920 }
921 }
922 }
923}
924
925type SourceId = String;
927
928#[derive(Clone)]
929pub struct ProviderConnection {
930 pub source_links: Arc<RwLock<HashMap<LatticeTarget, InterfaceLinkDefinition>>>,
933 pub target_links: Arc<RwLock<HashMap<SourceId, InterfaceLinkDefinition>>>,
936
937 pub nats: Arc<async_nats::Client>,
939
940 pub lattice: Arc<str>,
942 pub host_id: String,
943 pub provider_id: Arc<str>,
944
945 pub provider_xkey: Arc<XKey>,
947 pub host_xkey: Arc<XKey>,
948
949 #[allow(unused)]
951 pub config: HashMap<String, String>,
952}
953
954impl fmt::Debug for ProviderConnection {
955 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
956 f.debug_struct("ProviderConnection")
957 .field("provider_id", &self.provider_key())
958 .field("host_id", &self.host_id)
959 .field("lattice", &self.lattice)
960 .finish()
961 }
962}
963
964pub fn invocation_context(headers: &HeaderMap) -> Context {
966 #[cfg(feature = "otel")]
967 {
968 let trace_context: TraceContext = convert_header_map_to_hashmap(headers)
969 .into_iter()
970 .collect::<Vec<(String, String)>>();
971 attach_span_context(&trace_context);
972 }
973 let source_id = headers
975 .get(WRPC_SOURCE_ID_HEADER_NAME)
976 .map_or_else(|| "<unknown>".into(), ToString::to_string);
977 Context {
978 component: Some(source_id),
979 tracing: convert_header_map_to_hashmap(headers),
980 }
981}
982
983#[derive(Clone)]
984pub struct WrpcClient {
985 nats: wrpc_transport_nats::Client,
986 timeout: Duration,
987 provider_id: Arc<str>,
988 target: Arc<str>,
989}
990
991impl wrpc_transport::Invoke for WrpcClient {
992 type Context = Option<HeaderMap>;
993 type Outgoing = <wrpc_transport_nats::Client as wrpc_transport::Invoke>::Outgoing;
994 type Incoming = <wrpc_transport_nats::Client as wrpc_transport::Invoke>::Incoming;
995
996 async fn invoke<P>(
997 &self,
998 cx: Self::Context,
999 instance: &str,
1000 func: &str,
1001 params: Bytes,
1002 paths: impl AsRef<[P]> + Send,
1003 ) -> anyhow::Result<(Self::Outgoing, Self::Incoming)>
1004 where
1005 P: AsRef<[Option<usize>]> + Send + Sync,
1006 {
1007 let mut headers = cx.unwrap_or_default();
1008 headers.insert("source-id", &*self.provider_id);
1009 headers.insert("target-id", &*self.target);
1010 self.nats
1011 .timeout(self.timeout)
1012 .invoke(Some(headers), instance, func, params, paths)
1013 .await
1014 }
1015}
1016
1017impl wrpc_transport::Serve for WrpcClient {
1018 type Context = Option<Context>;
1019 type Outgoing = <wrpc_transport_nats::Client as wrpc_transport::Serve>::Outgoing;
1020 type Incoming = <wrpc_transport_nats::Client as wrpc_transport::Serve>::Incoming;
1021
1022 async fn serve(
1023 &self,
1024 instance: &str,
1025 func: &str,
1026 paths: impl Into<Arc<[Box<[Option<usize>]>]>> + Send,
1027 ) -> anyhow::Result<
1028 impl Stream<Item = anyhow::Result<(Self::Context, Self::Outgoing, Self::Incoming)>>
1029 + Send
1030 + 'static,
1031 > {
1032 let invocations = self.nats.serve(instance, func, paths).await?;
1033 Ok(invocations.and_then(|(cx, tx, rx)| async move {
1034 Ok((cx.as_ref().map(invocation_context), tx, rx))
1035 }))
1036 }
1037}
1038
1039impl ProviderConnection {
1040 pub fn new(
1041 nats: impl Into<Arc<async_nats::Client>>,
1042 provider_id: impl Into<Arc<str>>,
1043 lattice: impl Into<Arc<str>>,
1044 host_id: String,
1045 config: HashMap<String, String>,
1046 provider_private_xkey: impl Into<Arc<XKey>>,
1047 host_public_xkey: impl Into<Arc<XKey>>,
1048 ) -> ProviderInitResult<ProviderConnection> {
1049 Ok(ProviderConnection {
1050 source_links: Arc::default(),
1051 target_links: Arc::default(),
1052 nats: nats.into(),
1053 lattice: lattice.into(),
1054 host_id,
1055 provider_id: provider_id.into(),
1056 config,
1057 provider_xkey: provider_private_xkey.into(),
1058 host_xkey: host_public_xkey.into(),
1059 })
1060 }
1061
1062 pub async fn get_wrpc_client(&self, target: &str) -> anyhow::Result<WrpcClient> {
1068 self.get_wrpc_client_custom(target, None).await
1069 }
1070
1071 pub async fn get_wrpc_client_custom(
1079 &self,
1080 target: &str,
1081 timeout: Option<Duration>,
1082 ) -> anyhow::Result<WrpcClient> {
1083 let prefix = Arc::from(format!("{}.{target}", &self.lattice));
1084 let nats = wrpc_transport_nats::Client::new(
1085 Arc::clone(&self.nats),
1086 Arc::clone(&prefix),
1087 Some(prefix),
1088 )
1089 .await?;
1090 Ok(WrpcClient {
1091 nats,
1092 provider_id: Arc::clone(&self.provider_id),
1093 target: Arc::from(target),
1094 timeout: timeout.unwrap_or_else(|| Duration::from_secs(10)),
1095 })
1096 }
1097
1098 #[must_use]
1100 pub fn provider_key(&self) -> &str {
1101 &self.provider_id
1102 }
1103
1104 pub async fn put_link(&self, ld: InterfaceLinkDefinition) {
1107 if ld.source_id == *self.provider_id {
1108 self.source_links
1109 .write()
1110 .await
1111 .insert(ld.target.to_string(), ld);
1112 } else {
1113 self.target_links
1114 .write()
1115 .await
1116 .insert(ld.source_id.to_string(), ld);
1117 }
1118 }
1119
1120 pub async fn delete_link(&self, source_id: &str, target: &str) {
1123 if source_id == &*self.provider_id {
1124 self.source_links.write().await.remove(target);
1125 } else if target == &*self.provider_id {
1126 self.target_links.write().await.remove(source_id);
1127 }
1128 }
1129
1130 pub async fn is_linked(
1133 &self,
1134 source_id: &str,
1135 target_id: &str,
1136 wit_namespace: &str,
1137 wit_package: &str,
1138 link_name: &str,
1139 ) -> bool {
1140 if &*self.provider_id == source_id {
1142 if let Some(link) = self.source_links.read().await.get(target_id) {
1143 (link.wit_namespace.is_empty() || link.wit_namespace == wit_namespace)
1146 && (link.wit_package.is_empty() || link.wit_package == wit_package)
1147 && link.name == link_name
1148 } else {
1149 false
1150 }
1151 } else if &*self.provider_id == target_id {
1153 if let Some(link) = self.target_links.read().await.get(source_id) {
1154 (link.wit_namespace.is_empty() || link.wit_namespace == wit_namespace)
1157 && (link.wit_package.is_empty() || link.wit_package == wit_package)
1158 && link.name == link_name
1159 } else {
1160 false
1161 }
1162 } else {
1163 false
1165 }
1166 }
1167
1168 pub(crate) async fn flush(&self) {
1170 if let Err(err) = self.nats.flush().await {
1171 error!(%err, "error flushing NATS client");
1172 }
1173 }
1174}