wasmcloud_host/nats/
policy.rs

1//! Policy manager implementation that uses NATS to send policy requests
2//! to a policy server.
3
4use core::time::Duration;
5
6use std::collections::{BTreeMap, HashMap};
7use std::sync::Arc;
8
9use anyhow::Context;
10use futures::{
11    stream::{AbortHandle, Abortable},
12    StreamExt,
13};
14use tokio::spawn;
15use tokio::sync::RwLock;
16use tracing::{debug, error, instrument, trace, warn};
17use ulid::Ulid;
18use uuid::Uuid;
19use wascap::jwt;
20
21use crate::policy::{
22    ComponentInformation, HostInfo, PerformInvocationRequest, PolicyClaims, PolicyManager,
23    ProviderInformation, Request, RequestBody, RequestKey, RequestKind, Response,
24    POLICY_TYPE_VERSION,
25};
26
27/// Encapsulates making requests for policy decisions, and receiving updated decisions
28#[derive(Debug, Clone)]
29pub struct NatsPolicyManager {
30    nats: async_nats::Client,
31    host_info: HostInfo,
32    policy_topic: Option<String>,
33    policy_timeout: Duration,
34    decision_cache: Arc<RwLock<HashMap<RequestKey, Response>>>,
35    request_to_key: Arc<RwLock<HashMap<String, RequestKey>>>,
36    /// An abort handle for the policy changes subscription
37    pub policy_changes: AbortHandle,
38}
39
40impl NatsPolicyManager {
41    /// Construct a new policy manager. Can fail if policy_changes_topic is set but we fail to subscribe to it
42    #[instrument(skip(nats))]
43    pub async fn new(
44        nats: async_nats::Client,
45        host_info: HostInfo,
46        policy_topic: Option<String>,
47        policy_timeout: Option<Duration>,
48        policy_changes_topic: Option<String>,
49    ) -> anyhow::Result<Self> {
50        const DEFAULT_POLICY_TIMEOUT: Duration = Duration::from_secs(1);
51
52        let (policy_changes_abort, policy_changes_abort_reg) = AbortHandle::new_pair();
53
54        let manager = NatsPolicyManager {
55            nats: nats.clone(),
56            host_info,
57            policy_topic,
58            policy_timeout: policy_timeout.unwrap_or(DEFAULT_POLICY_TIMEOUT),
59            decision_cache: Arc::default(),
60            request_to_key: Arc::default(),
61            policy_changes: policy_changes_abort,
62        };
63
64        if let Some(policy_changes_topic) = policy_changes_topic {
65            let policy_changes = nats
66                .subscribe(policy_changes_topic)
67                .await
68                .context("failed to subscribe to policy changes")?;
69
70            let _policy_changes = spawn({
71                let manager = Arc::new(manager.clone());
72                Abortable::new(policy_changes, policy_changes_abort_reg).for_each(move |msg| {
73                    let manager = Arc::clone(&manager);
74                    async move {
75                        if let Err(e) = manager.override_decision(msg).await {
76                            error!("failed to process policy decision override: {}", e);
77                        }
78                    }
79                })
80            });
81        }
82
83        Ok(manager)
84    }
85
86    /// Sends a policy request to the policy server and caches the response
87    #[instrument(level = "trace", skip_all)]
88    pub async fn evaluate_action(&self, request: RequestBody) -> anyhow::Result<Response> {
89        let Some(policy_topic) = self.policy_topic.clone() else {
90            // Ensure we short-circuit and allow the request if no policy topic is configured
91            return Ok(Response {
92                request_id: String::new(),
93                permitted: true,
94                message: None,
95            });
96        };
97
98        let kind = match request {
99            RequestBody::StartComponent(_) => RequestKind::StartComponent,
100            RequestBody::StartProvider(_) => RequestKind::StartProvider,
101            RequestBody::PerformInvocation(_) => RequestKind::PerformInvocation,
102            RequestBody::Unknown => RequestKind::Unknown,
103        };
104        let cache_key = (&request).into();
105        if let Some(entry) = self.decision_cache.read().await.get(&cache_key) {
106            trace!(?cache_key, ?entry, "using cached policy decision");
107            return Ok(entry.clone());
108        }
109
110        let request_id = Uuid::from_u128(Ulid::new().into()).to_string();
111        trace!(?cache_key, "requesting policy decision");
112        let payload = serde_json::to_vec(&Request {
113            request_id: request_id.clone(),
114            request,
115            kind,
116            version: POLICY_TYPE_VERSION.to_string(),
117            host: self.host_info.clone(),
118        })
119        .context("failed to serialize policy request")?;
120        let request = async_nats::Request::new()
121            .payload(payload.into())
122            .timeout(Some(self.policy_timeout));
123        let res = self
124            .nats
125            .send_request(policy_topic, request)
126            .await
127            .context("policy request failed")?;
128        let decision = serde_json::from_slice::<Response>(&res.payload)
129            .context("failed to deserialize policy response")?;
130
131        self.decision_cache
132            .write()
133            .await
134            .insert(cache_key.clone(), decision.clone()); // cache policy decision
135        self.request_to_key
136            .write()
137            .await
138            .insert(request_id, cache_key); // cache request id -> decision key
139        Ok(decision)
140    }
141
142    #[instrument(skip(self))]
143    async fn override_decision(&self, msg: async_nats::Message) -> anyhow::Result<()> {
144        let Response {
145            request_id,
146            permitted,
147            message,
148        } = serde_json::from_slice(&msg.payload)
149            .context("failed to deserialize policy decision override")?;
150
151        debug!(request_id, "received policy decision override");
152
153        let mut decision_cache = self.decision_cache.write().await;
154        let request_to_key = self.request_to_key.read().await;
155
156        if let Some(key) = request_to_key.get(&request_id) {
157            decision_cache.insert(
158                key.clone(),
159                Response {
160                    request_id: request_id.clone(),
161                    permitted,
162                    message,
163                },
164            );
165        } else {
166            warn!(
167                request_id,
168                "received policy decision override for unknown request id"
169            );
170        }
171
172        Ok(())
173    }
174}
175
176#[async_trait::async_trait]
177impl PolicyManager for NatsPolicyManager {
178    #[instrument(level = "trace", skip_all)]
179    /// Use the policy manager to evaluate whether a component may be started
180    async fn evaluate_start_component(
181        &self,
182        component_id: &str,
183        image_ref: &str,
184        max_instances: u32,
185        annotations: &BTreeMap<String, String>,
186        claims: Option<&jwt::Claims<jwt::Component>>,
187    ) -> anyhow::Result<Response> {
188        let request = ComponentInformation {
189            component_id: component_id.to_string(),
190            image_ref: image_ref.to_string(),
191            max_instances,
192            annotations: annotations.clone(),
193            claims: claims.map(PolicyClaims::from),
194        };
195        self.evaluate_action(RequestBody::StartComponent(request))
196            .await
197    }
198
199    /// Use the policy manager to evaluate whether a provider may be started
200    #[instrument(level = "trace", skip_all)]
201    async fn evaluate_start_provider(
202        &self,
203        provider_id: &str,
204        provider_ref: &str,
205        annotations: &BTreeMap<String, String>,
206        claims: Option<&jwt::Claims<jwt::CapabilityProvider>>,
207    ) -> anyhow::Result<Response> {
208        let request = ProviderInformation {
209            provider_id: provider_id.to_string(),
210            image_ref: provider_ref.to_string(),
211            annotations: annotations.clone(),
212            claims: claims.map(PolicyClaims::from),
213        };
214        self.evaluate_action(RequestBody::StartProvider(request))
215            .await
216    }
217
218    /// Use the policy manager to evaluate whether a component may be invoked
219    #[instrument(level = "trace", skip_all)]
220    async fn evaluate_perform_invocation(
221        &self,
222        component_id: &str,
223        image_ref: &str,
224        annotations: &BTreeMap<String, String>,
225        claims: Option<&jwt::Claims<jwt::Component>>,
226        interface: String,
227        function: String,
228    ) -> anyhow::Result<Response> {
229        let request = PerformInvocationRequest {
230            interface,
231            function,
232            target: ComponentInformation {
233                component_id: component_id.to_string(),
234                image_ref: image_ref.to_string(),
235                max_instances: 0,
236                annotations: annotations.clone(),
237                claims: claims.map(PolicyClaims::from),
238            },
239        };
240        self.evaluate_action(RequestBody::PerformInvocation(request))
241            .await
242    }
243}