1use core::time::Duration;
2
3use std::collections::{BTreeMap, HashMap};
4use std::hash::Hash;
5use std::sync::Arc;
6use std::time::{SystemTime, UNIX_EPOCH};
7
8use anyhow::Context;
9use futures::{
10 stream::{AbortHandle, Abortable},
11 StreamExt,
12};
13use serde::{Deserialize, Serialize};
14use tokio::spawn;
15use tokio::sync::RwLock;
16use tracing::{debug, error, instrument, trace, warn};
17use ulid::Ulid;
18use uuid::Uuid;
19use wascap::jwt;
20
21const POLICY_TYPE_VERSION: &str = "v1";
24
25#[derive(Clone, Debug, Default, Eq, PartialEq, Serialize, Hash)]
26pub struct PolicyClaims {
28 #[serde(rename = "publicKey")]
30 pub public_key: String,
31 pub issuer: String,
33 #[serde(rename = "issuedAt")]
35 pub issued_at: String,
36 #[serde(rename = "expiresAt")]
38 pub expires_at: Option<u64>,
39 pub expired: bool,
41}
42
43#[derive(Clone, Debug, Default, Eq, PartialEq, Serialize, Hash)]
44pub struct ComponentInformation {
46 #[serde(rename = "componentId")]
48 pub component_id: String,
49 #[serde(rename = "imageRef")]
51 pub image_ref: String,
52 #[serde(rename = "maxInstances")]
54 pub max_instances: u32,
55 pub annotations: BTreeMap<String, String>,
57 pub claims: Option<PolicyClaims>,
59}
60
61#[derive(Clone, Debug, Default, Eq, PartialEq, Serialize, Hash)]
62pub struct ProviderInformation {
64 #[serde(rename = "providerId")]
66 pub provider_id: String,
67 #[serde(rename = "imageRef")]
69 pub image_ref: String,
70 pub annotations: BTreeMap<String, String>,
72 pub claims: Option<PolicyClaims>,
74}
75
76#[derive(Clone, Debug, Default, Eq, PartialEq, Serialize, Hash)]
77pub struct PerformInvocationRequest {
79 pub interface: String,
81 pub function: String,
83 pub target: ComponentInformation,
85}
86
87#[derive(Clone, Debug, Serialize)]
89pub struct HostInfo {
90 #[serde(rename = "publicKey")]
92 pub public_key: String,
93 #[serde(rename = "lattice")]
95 pub lattice: String,
96 pub labels: HashMap<String, String>,
98}
99
100#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Hash)]
102pub enum RequestKind {
103 #[serde(rename = "performInvocation")]
105 PerformInvocation,
106 #[serde(rename = "startComponent")]
108 StartComponent,
109 #[serde(rename = "startProvider")]
111 StartProvider,
112 #[serde(rename = "unknown")]
114 Unknown,
115}
116
117#[derive(Clone, Debug, Eq, PartialEq, Serialize, Hash)]
118#[serde(untagged)]
119pub enum RequestBody {
121 PerformInvocation(PerformInvocationRequest),
123 StartComponent(ComponentInformation),
125 StartProvider(ProviderInformation),
127 Unknown,
129}
130
131impl From<&RequestBody> for RequestKey {
132 fn from(val: &RequestBody) -> RequestKey {
133 match val {
134 RequestBody::StartComponent(ref req) => RequestKey {
135 kind: RequestKind::StartComponent,
136 cache_key: format!("{}_{}", req.component_id, req.image_ref),
137 },
138 RequestBody::StartProvider(ref req) => RequestKey {
139 kind: RequestKind::StartProvider,
140 cache_key: format!("{}_{}", req.provider_id, req.image_ref),
141 },
142 RequestBody::PerformInvocation(ref req) => RequestKey {
143 kind: RequestKind::PerformInvocation,
144 cache_key: format!(
145 "{}_{}_{}_{}",
146 req.target.component_id, req.target.image_ref, req.interface, req.function
147 ),
148 },
149 RequestBody::Unknown => RequestKey {
150 kind: RequestKind::Unknown,
151 cache_key: String::new(),
152 },
153 }
154 }
155}
156
157#[derive(Serialize)]
159struct Request {
160 #[serde(rename = "requestId")]
162 #[allow(clippy::struct_field_names)]
163 request_id: String,
164 kind: RequestKind,
166 version: String,
168 request: RequestBody,
170 host: HostInfo,
172}
173
174#[derive(Clone, Debug, Hash, Eq, PartialEq)]
175struct RequestKey {
176 kind: RequestKind,
178 cache_key: String,
183}
184
185#[derive(Clone, Debug, Deserialize)]
187pub struct Response {
188 #[serde(rename = "requestId")]
190 pub request_id: String,
191 pub permitted: bool,
193 #[serde(skip_serializing_if = "Option::is_none")]
195 pub message: Option<String>,
196}
197
198fn is_expired(expires: u64) -> bool {
199 SystemTime::now()
200 .duration_since(UNIX_EPOCH)
201 .expect("time went backwards") .as_secs()
203 > expires
204}
205
206impl From<&jwt::Claims<jwt::Component>> for PolicyClaims {
207 fn from(claims: &jwt::Claims<jwt::Component>) -> Self {
208 PolicyClaims {
209 public_key: claims.subject.to_string(),
210 issuer: claims.issuer.to_string(),
211 issued_at: claims.issued_at.to_string(),
212 expires_at: claims.expires,
213 expired: claims.expires.is_some_and(is_expired),
214 }
215 }
216}
217
218impl From<&jwt::Claims<jwt::CapabilityProvider>> for PolicyClaims {
219 fn from(claims: &jwt::Claims<jwt::CapabilityProvider>) -> Self {
220 PolicyClaims {
221 public_key: claims.subject.to_string(),
222 issuer: claims.issuer.to_string(),
223 issued_at: claims.issued_at.to_string(),
224 expires_at: claims.expires,
225 expired: claims.expires.is_some_and(is_expired),
226 }
227 }
228}
229
230#[derive(Debug)]
232pub struct Manager {
233 nats: async_nats::Client,
234 host_info: HostInfo,
235 policy_topic: Option<String>,
236 policy_timeout: Duration,
237 decision_cache: Arc<RwLock<HashMap<RequestKey, Response>>>,
238 request_to_key: Arc<RwLock<HashMap<String, RequestKey>>>,
239 pub policy_changes: AbortHandle,
241}
242
243impl Manager {
244 #[instrument(skip(nats))]
246 pub async fn new(
247 nats: async_nats::Client,
248 host_info: HostInfo,
249 policy_topic: Option<String>,
250 policy_timeout: Option<Duration>,
251 policy_changes_topic: Option<String>,
252 ) -> anyhow::Result<Arc<Self>> {
253 const DEFAULT_POLICY_TIMEOUT: Duration = Duration::from_secs(1);
254
255 let (policy_changes_abort, policy_changes_abort_reg) = AbortHandle::new_pair();
256
257 let manager = Manager {
258 nats: nats.clone(),
259 host_info,
260 policy_topic,
261 policy_timeout: policy_timeout.unwrap_or(DEFAULT_POLICY_TIMEOUT),
262 decision_cache: Arc::default(),
263 request_to_key: Arc::default(),
264 policy_changes: policy_changes_abort,
265 };
266 let manager = Arc::new(manager);
267
268 if let Some(policy_changes_topic) = policy_changes_topic {
269 let policy_changes = nats
270 .subscribe(policy_changes_topic)
271 .await
272 .context("failed to subscribe to policy changes")?;
273
274 let _policy_changes = spawn({
275 let manager = Arc::clone(&manager);
276 Abortable::new(policy_changes, policy_changes_abort_reg).for_each(move |msg| {
277 let manager = Arc::clone(&manager);
278 async move {
279 if let Err(e) = manager.override_decision(msg).await {
280 error!("failed to process policy decision override: {}", e);
281 }
282 }
283 })
284 });
285 }
286
287 Ok(manager)
288 }
289
290 #[instrument(level = "trace", skip_all)]
291 pub async fn evaluate_start_component(
293 &self,
294 component_id: impl AsRef<str>,
295 image_ref: impl AsRef<str>,
296 max_instances: u32,
297 annotations: &BTreeMap<String, String>,
298 claims: Option<&jwt::Claims<jwt::Component>>,
299 ) -> anyhow::Result<Response> {
300 let request = ComponentInformation {
301 component_id: component_id.as_ref().to_string(),
302 image_ref: image_ref.as_ref().to_string(),
303 max_instances,
304 annotations: annotations.clone(),
305 claims: claims.map(PolicyClaims::from),
306 };
307 self.evaluate_action(RequestBody::StartComponent(request))
308 .await
309 }
310
311 #[instrument(level = "trace", skip_all)]
313 pub async fn evaluate_start_provider(
314 &self,
315 provider_id: impl AsRef<str>,
316 provider_ref: impl AsRef<str>,
317 annotations: &BTreeMap<String, String>,
318 claims: Option<&jwt::Claims<jwt::CapabilityProvider>>,
319 ) -> anyhow::Result<Response> {
320 let request = ProviderInformation {
321 provider_id: provider_id.as_ref().to_string(),
322 image_ref: provider_ref.as_ref().to_string(),
323 annotations: annotations.clone(),
324 claims: claims.map(PolicyClaims::from),
325 };
326 self.evaluate_action(RequestBody::StartProvider(request))
327 .await
328 }
329
330 #[instrument(level = "trace", skip_all)]
332 pub async fn evaluate_perform_invocation(
333 &self,
334 component_id: impl AsRef<str>,
335 image_ref: impl AsRef<str>,
336 annotations: &BTreeMap<String, String>,
337 claims: Option<&jwt::Claims<jwt::Component>>,
338 interface: String,
339 function: String,
340 ) -> anyhow::Result<Response> {
341 let request = PerformInvocationRequest {
342 interface,
343 function,
344 target: ComponentInformation {
345 component_id: component_id.as_ref().to_string(),
346 image_ref: image_ref.as_ref().to_string(),
347 max_instances: 0,
348 annotations: annotations.clone(),
349 claims: claims.map(PolicyClaims::from),
350 },
351 };
352 self.evaluate_action(RequestBody::PerformInvocation(request))
353 .await
354 }
355
356 #[instrument(level = "trace", skip_all)]
358 pub async fn evaluate_action(&self, request: RequestBody) -> anyhow::Result<Response> {
359 let Some(policy_topic) = self.policy_topic.clone() else {
360 return Ok(Response {
362 request_id: String::new(),
363 permitted: true,
364 message: None,
365 });
366 };
367
368 let kind = match request {
369 RequestBody::StartComponent(_) => RequestKind::StartComponent,
370 RequestBody::StartProvider(_) => RequestKind::StartProvider,
371 RequestBody::PerformInvocation(_) => RequestKind::PerformInvocation,
372 RequestBody::Unknown => RequestKind::Unknown,
373 };
374 let cache_key = (&request).into();
375 if let Some(entry) = self.decision_cache.read().await.get(&cache_key) {
376 trace!(?cache_key, ?entry, "using cached policy decision");
377 return Ok(entry.clone());
378 }
379
380 let request_id = Uuid::from_u128(Ulid::new().into()).to_string();
381 trace!(?cache_key, "requesting policy decision");
382 let payload = serde_json::to_vec(&Request {
383 request_id: request_id.clone(),
384 request,
385 kind,
386 version: POLICY_TYPE_VERSION.to_string(),
387 host: self.host_info.clone(),
388 })
389 .context("failed to serialize policy request")?;
390 let request = async_nats::Request::new()
391 .payload(payload.into())
392 .timeout(Some(self.policy_timeout));
393 let res = self
394 .nats
395 .send_request(policy_topic, request)
396 .await
397 .context("policy request failed")?;
398 let decision = serde_json::from_slice::<Response>(&res.payload)
399 .context("failed to deserialize policy response")?;
400
401 self.decision_cache
402 .write()
403 .await
404 .insert(cache_key.clone(), decision.clone()); self.request_to_key
406 .write()
407 .await
408 .insert(request_id, cache_key); Ok(decision)
410 }
411
412 #[instrument(skip(self))]
413 async fn override_decision(&self, msg: async_nats::Message) -> anyhow::Result<()> {
414 let Response {
415 request_id,
416 permitted,
417 message,
418 } = serde_json::from_slice(&msg.payload)
419 .context("failed to deserialize policy decision override")?;
420
421 debug!(request_id, "received policy decision override");
422
423 let mut decision_cache = self.decision_cache.write().await;
424 let request_to_key = self.request_to_key.read().await;
425
426 if let Some(key) = request_to_key.get(&request_id) {
427 decision_cache.insert(
428 key.clone(),
429 Response {
430 request_id: request_id.clone(),
431 permitted,
432 message,
433 },
434 );
435 } else {
436 warn!(
437 request_id,
438 "received policy decision override for unknown request id"
439 );
440 }
441
442 Ok(())
443 }
444}