wasmcloud_host/
policy.rs

1use core::time::Duration;
2
3use std::collections::{BTreeMap, HashMap};
4use std::hash::Hash;
5use std::sync::Arc;
6use std::time::{SystemTime, UNIX_EPOCH};
7
8use anyhow::Context;
9use futures::{
10    stream::{AbortHandle, Abortable},
11    StreamExt,
12};
13use serde::{Deserialize, Serialize};
14use tokio::spawn;
15use tokio::sync::RwLock;
16use tracing::{debug, error, instrument, trace, warn};
17use ulid::Ulid;
18use uuid::Uuid;
19use wascap::jwt;
20
21// NOTE: All requests will be v1 until the schema changes, at which point we can change the version
22// per-request type
23const POLICY_TYPE_VERSION: &str = "v1";
24
25#[derive(Clone, Debug, Default, Eq, PartialEq, Serialize, Hash)]
26/// Claims associated with a policy request, if embedded inside the component or provider
27pub struct PolicyClaims {
28    /// The public key of the component
29    #[serde(rename = "publicKey")]
30    pub public_key: String,
31    /// The issuer key of the component
32    pub issuer: String,
33    /// The time the claims were signed
34    #[serde(rename = "issuedAt")]
35    pub issued_at: String,
36    /// The time the claims expire, if any
37    #[serde(rename = "expiresAt")]
38    pub expires_at: Option<u64>,
39    /// Whether the claims have expired already. This is included in case the policy server is fulfilled by an component, which cannot access the system clock
40    pub expired: bool,
41}
42
43#[derive(Clone, Debug, Default, Eq, PartialEq, Serialize, Hash)]
44/// Relevant policy information for evaluating a component
45pub struct ComponentInformation {
46    /// The unique identifier of the component
47    #[serde(rename = "componentId")]
48    pub component_id: String,
49    /// The image reference of the component
50    #[serde(rename = "imageRef")]
51    pub image_ref: String,
52    /// The requested maximum number of concurrent instances for this component
53    #[serde(rename = "maxInstances")]
54    pub max_instances: u32,
55    /// Annotations associated with the component
56    pub annotations: BTreeMap<String, String>,
57    /// Claims, if embedded, within the component
58    pub claims: Option<PolicyClaims>,
59}
60
61#[derive(Clone, Debug, Default, Eq, PartialEq, Serialize, Hash)]
62/// Relevant policy information for evaluating a provider
63pub struct ProviderInformation {
64    /// The unique identifier of the provider
65    #[serde(rename = "providerId")]
66    pub provider_id: String,
67    /// The image reference of the provider
68    #[serde(rename = "imageRef")]
69    pub image_ref: String,
70    /// Annotations associated with the provider
71    pub annotations: BTreeMap<String, String>,
72    /// Claims, if embedded, within the provider
73    pub claims: Option<PolicyClaims>,
74}
75
76#[derive(Clone, Debug, Default, Eq, PartialEq, Serialize, Hash)]
77/// A request to invoke a component function
78pub struct PerformInvocationRequest {
79    /// The interface of the invocation
80    pub interface: String,
81    /// The function of the invocation
82    pub function: String,
83    /// Target of the invocation
84    pub target: ComponentInformation,
85}
86
87/// Relevant information about the host that is receiving the invocation, or starting the component or provider
88#[derive(Clone, Debug, Serialize)]
89pub struct HostInfo {
90    /// The public key ID of the host
91    #[serde(rename = "publicKey")]
92    pub public_key: String,
93    /// The name of the lattice the host is running in
94    #[serde(rename = "lattice")]
95    pub lattice: String,
96    /// The labels associated with the host
97    pub labels: HashMap<String, String>,
98}
99
100/// The action being requested
101#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Hash)]
102pub enum RequestKind {
103    /// The host is checking whether it may invoke the target component
104    #[serde(rename = "performInvocation")]
105    PerformInvocation,
106    /// The host is checking whether it may start the target component
107    #[serde(rename = "startComponent")]
108    StartComponent,
109    /// The host is checking whether it may start the target provider
110    #[serde(rename = "startProvider")]
111    StartProvider,
112    /// An unknown or unsupported request type
113    #[serde(rename = "unknown")]
114    Unknown,
115}
116
117#[derive(Clone, Debug, Eq, PartialEq, Serialize, Hash)]
118#[serde(untagged)]
119/// The body of a policy request, typed by the request kind
120pub enum RequestBody {
121    /// A request to invoke a function on a component
122    PerformInvocation(PerformInvocationRequest),
123    /// A request to start a component on a host
124    StartComponent(ComponentInformation),
125    /// A request to start a provider on a host
126    StartProvider(ProviderInformation),
127    /// Request body has an unknown type
128    Unknown,
129}
130
131impl From<&RequestBody> for RequestKey {
132    fn from(val: &RequestBody) -> RequestKey {
133        match val {
134            RequestBody::StartComponent(ref req) => RequestKey {
135                kind: RequestKind::StartComponent,
136                cache_key: format!("{}_{}", req.component_id, req.image_ref),
137            },
138            RequestBody::StartProvider(ref req) => RequestKey {
139                kind: RequestKind::StartProvider,
140                cache_key: format!("{}_{}", req.provider_id, req.image_ref),
141            },
142            RequestBody::PerformInvocation(ref req) => RequestKey {
143                kind: RequestKind::PerformInvocation,
144                cache_key: format!(
145                    "{}_{}_{}_{}",
146                    req.target.component_id, req.target.image_ref, req.interface, req.function
147                ),
148            },
149            RequestBody::Unknown => RequestKey {
150                kind: RequestKind::Unknown,
151                cache_key: String::new(),
152            },
153        }
154    }
155}
156
157/// A request for a policy decision
158#[derive(Serialize)]
159struct Request {
160    /// A unique request id. This value is returned in the response
161    #[serde(rename = "requestId")]
162    #[allow(clippy::struct_field_names)]
163    request_id: String,
164    /// The kind of policy request being made
165    kind: RequestKind,
166    /// The version of the policy request body
167    version: String,
168    /// The policy request body
169    request: RequestBody,
170    /// Information about the host making the request
171    host: HostInfo,
172}
173
174#[derive(Clone, Debug, Hash, Eq, PartialEq)]
175struct RequestKey {
176    /// The kind of request being made
177    kind: RequestKind,
178    /// Information about this request combined to form a unique string.
179    /// For example, a StartComponent request can be uniquely cached based
180    /// on the component_id and image_ref, so this cache_key is a concatenation
181    /// of those values
182    cache_key: String,
183}
184
185/// A policy decision response
186#[derive(Clone, Debug, Deserialize)]
187pub struct Response {
188    /// The request id copied from the request
189    #[serde(rename = "requestId")]
190    pub request_id: String,
191    /// Whether the request is permitted
192    pub permitted: bool,
193    /// An optional error explaining why the request was denied. Suitable for logging
194    #[serde(skip_serializing_if = "Option::is_none")]
195    pub message: Option<String>,
196}
197
198fn is_expired(expires: u64) -> bool {
199    SystemTime::now()
200        .duration_since(UNIX_EPOCH)
201        .expect("time went backwards") // SAFETY: now() should always be greater than UNIX_EPOCH
202        .as_secs()
203        > expires
204}
205
206impl From<&jwt::Claims<jwt::Component>> for PolicyClaims {
207    fn from(claims: &jwt::Claims<jwt::Component>) -> Self {
208        PolicyClaims {
209            public_key: claims.subject.to_string(),
210            issuer: claims.issuer.to_string(),
211            issued_at: claims.issued_at.to_string(),
212            expires_at: claims.expires,
213            expired: claims.expires.is_some_and(is_expired),
214        }
215    }
216}
217
218impl From<&jwt::Claims<jwt::CapabilityProvider>> for PolicyClaims {
219    fn from(claims: &jwt::Claims<jwt::CapabilityProvider>) -> Self {
220        PolicyClaims {
221            public_key: claims.subject.to_string(),
222            issuer: claims.issuer.to_string(),
223            issued_at: claims.issued_at.to_string(),
224            expires_at: claims.expires,
225            expired: claims.expires.is_some_and(is_expired),
226        }
227    }
228}
229
230/// Encapsulates making requests for policy decisions, and receiving updated decisions
231#[derive(Debug)]
232pub struct Manager {
233    nats: async_nats::Client,
234    host_info: HostInfo,
235    policy_topic: Option<String>,
236    policy_timeout: Duration,
237    decision_cache: Arc<RwLock<HashMap<RequestKey, Response>>>,
238    request_to_key: Arc<RwLock<HashMap<String, RequestKey>>>,
239    /// An abort handle for the policy changes subscription
240    pub policy_changes: AbortHandle,
241}
242
243impl Manager {
244    /// Construct a new policy manager. Can fail if policy_changes_topic is set but we fail to subscribe to it
245    #[instrument(skip(nats))]
246    pub async fn new(
247        nats: async_nats::Client,
248        host_info: HostInfo,
249        policy_topic: Option<String>,
250        policy_timeout: Option<Duration>,
251        policy_changes_topic: Option<String>,
252    ) -> anyhow::Result<Arc<Self>> {
253        const DEFAULT_POLICY_TIMEOUT: Duration = Duration::from_secs(1);
254
255        let (policy_changes_abort, policy_changes_abort_reg) = AbortHandle::new_pair();
256
257        let manager = Manager {
258            nats: nats.clone(),
259            host_info,
260            policy_topic,
261            policy_timeout: policy_timeout.unwrap_or(DEFAULT_POLICY_TIMEOUT),
262            decision_cache: Arc::default(),
263            request_to_key: Arc::default(),
264            policy_changes: policy_changes_abort,
265        };
266        let manager = Arc::new(manager);
267
268        if let Some(policy_changes_topic) = policy_changes_topic {
269            let policy_changes = nats
270                .subscribe(policy_changes_topic)
271                .await
272                .context("failed to subscribe to policy changes")?;
273
274            let _policy_changes = spawn({
275                let manager = Arc::clone(&manager);
276                Abortable::new(policy_changes, policy_changes_abort_reg).for_each(move |msg| {
277                    let manager = Arc::clone(&manager);
278                    async move {
279                        if let Err(e) = manager.override_decision(msg).await {
280                            error!("failed to process policy decision override: {}", e);
281                        }
282                    }
283                })
284            });
285        }
286
287        Ok(manager)
288    }
289
290    #[instrument(level = "trace", skip_all)]
291    /// Use the policy manager to evaluate whether a component may be started
292    pub async fn evaluate_start_component(
293        &self,
294        component_id: impl AsRef<str>,
295        image_ref: impl AsRef<str>,
296        max_instances: u32,
297        annotations: &BTreeMap<String, String>,
298        claims: Option<&jwt::Claims<jwt::Component>>,
299    ) -> anyhow::Result<Response> {
300        let request = ComponentInformation {
301            component_id: component_id.as_ref().to_string(),
302            image_ref: image_ref.as_ref().to_string(),
303            max_instances,
304            annotations: annotations.clone(),
305            claims: claims.map(PolicyClaims::from),
306        };
307        self.evaluate_action(RequestBody::StartComponent(request))
308            .await
309    }
310
311    /// Use the policy manager to evaluate whether a provider may be started
312    #[instrument(level = "trace", skip_all)]
313    pub async fn evaluate_start_provider(
314        &self,
315        provider_id: impl AsRef<str>,
316        provider_ref: impl AsRef<str>,
317        annotations: &BTreeMap<String, String>,
318        claims: Option<&jwt::Claims<jwt::CapabilityProvider>>,
319    ) -> anyhow::Result<Response> {
320        let request = ProviderInformation {
321            provider_id: provider_id.as_ref().to_string(),
322            image_ref: provider_ref.as_ref().to_string(),
323            annotations: annotations.clone(),
324            claims: claims.map(PolicyClaims::from),
325        };
326        self.evaluate_action(RequestBody::StartProvider(request))
327            .await
328    }
329
330    /// Use the policy manager to evaluate whether a component may be invoked
331    #[instrument(level = "trace", skip_all)]
332    pub async fn evaluate_perform_invocation(
333        &self,
334        component_id: impl AsRef<str>,
335        image_ref: impl AsRef<str>,
336        annotations: &BTreeMap<String, String>,
337        claims: Option<&jwt::Claims<jwt::Component>>,
338        interface: String,
339        function: String,
340    ) -> anyhow::Result<Response> {
341        let request = PerformInvocationRequest {
342            interface,
343            function,
344            target: ComponentInformation {
345                component_id: component_id.as_ref().to_string(),
346                image_ref: image_ref.as_ref().to_string(),
347                max_instances: 0,
348                annotations: annotations.clone(),
349                claims: claims.map(PolicyClaims::from),
350            },
351        };
352        self.evaluate_action(RequestBody::PerformInvocation(request))
353            .await
354    }
355
356    /// Sends a policy request to the policy server and caches the response
357    #[instrument(level = "trace", skip_all)]
358    pub async fn evaluate_action(&self, request: RequestBody) -> anyhow::Result<Response> {
359        let Some(policy_topic) = self.policy_topic.clone() else {
360            // Ensure we short-circuit and allow the request if no policy topic is configured
361            return Ok(Response {
362                request_id: String::new(),
363                permitted: true,
364                message: None,
365            });
366        };
367
368        let kind = match request {
369            RequestBody::StartComponent(_) => RequestKind::StartComponent,
370            RequestBody::StartProvider(_) => RequestKind::StartProvider,
371            RequestBody::PerformInvocation(_) => RequestKind::PerformInvocation,
372            RequestBody::Unknown => RequestKind::Unknown,
373        };
374        let cache_key = (&request).into();
375        if let Some(entry) = self.decision_cache.read().await.get(&cache_key) {
376            trace!(?cache_key, ?entry, "using cached policy decision");
377            return Ok(entry.clone());
378        }
379
380        let request_id = Uuid::from_u128(Ulid::new().into()).to_string();
381        trace!(?cache_key, "requesting policy decision");
382        let payload = serde_json::to_vec(&Request {
383            request_id: request_id.clone(),
384            request,
385            kind,
386            version: POLICY_TYPE_VERSION.to_string(),
387            host: self.host_info.clone(),
388        })
389        .context("failed to serialize policy request")?;
390        let request = async_nats::Request::new()
391            .payload(payload.into())
392            .timeout(Some(self.policy_timeout));
393        let res = self
394            .nats
395            .send_request(policy_topic, request)
396            .await
397            .context("policy request failed")?;
398        let decision = serde_json::from_slice::<Response>(&res.payload)
399            .context("failed to deserialize policy response")?;
400
401        self.decision_cache
402            .write()
403            .await
404            .insert(cache_key.clone(), decision.clone()); // cache policy decision
405        self.request_to_key
406            .write()
407            .await
408            .insert(request_id, cache_key); // cache request id -> decision key
409        Ok(decision)
410    }
411
412    #[instrument(skip(self))]
413    async fn override_decision(&self, msg: async_nats::Message) -> anyhow::Result<()> {
414        let Response {
415            request_id,
416            permitted,
417            message,
418        } = serde_json::from_slice(&msg.payload)
419            .context("failed to deserialize policy decision override")?;
420
421        debug!(request_id, "received policy decision override");
422
423        let mut decision_cache = self.decision_cache.write().await;
424        let request_to_key = self.request_to_key.read().await;
425
426        if let Some(key) = request_to_key.get(&request_id) {
427            decision_cache.insert(
428                key.clone(),
429                Response {
430                    request_id: request_id.clone(),
431                    permitted,
432                    message,
433                },
434            );
435        } else {
436            warn!(
437                request_id,
438                "received policy decision override for unknown request id"
439            );
440        }
441
442        Ok(())
443    }
444}