1use core::time::Duration;
5
6use std::collections::{BTreeMap, HashMap};
7use std::sync::Arc;
8
9use anyhow::Context;
10use futures::{
11 stream::{AbortHandle, Abortable},
12 StreamExt,
13};
14use tokio::spawn;
15use tokio::sync::RwLock;
16use tracing::{debug, error, instrument, trace, warn};
17use ulid::Ulid;
18use uuid::Uuid;
19use wascap::jwt;
20
21use crate::policy::{
22 ComponentInformation, HostInfo, PerformInvocationRequest, PolicyClaims, PolicyManager,
23 ProviderInformation, Request, RequestBody, RequestKey, RequestKind, Response,
24 POLICY_TYPE_VERSION,
25};
26
27#[derive(Debug, Clone)]
29pub struct NatsPolicyManager {
30 nats: async_nats::Client,
31 host_info: HostInfo,
32 policy_topic: Option<String>,
33 policy_timeout: Duration,
34 decision_cache: Arc<RwLock<HashMap<RequestKey, Response>>>,
35 request_to_key: Arc<RwLock<HashMap<String, RequestKey>>>,
36 pub policy_changes: AbortHandle,
38}
39
40impl NatsPolicyManager {
41 #[instrument(skip(nats))]
43 pub async fn new(
44 nats: async_nats::Client,
45 host_info: HostInfo,
46 policy_topic: Option<String>,
47 policy_timeout: Option<Duration>,
48 policy_changes_topic: Option<String>,
49 ) -> anyhow::Result<Self> {
50 const DEFAULT_POLICY_TIMEOUT: Duration = Duration::from_secs(1);
51
52 let (policy_changes_abort, policy_changes_abort_reg) = AbortHandle::new_pair();
53
54 let manager = NatsPolicyManager {
55 nats: nats.clone(),
56 host_info,
57 policy_topic,
58 policy_timeout: policy_timeout.unwrap_or(DEFAULT_POLICY_TIMEOUT),
59 decision_cache: Arc::default(),
60 request_to_key: Arc::default(),
61 policy_changes: policy_changes_abort,
62 };
63
64 if let Some(policy_changes_topic) = policy_changes_topic {
65 let policy_changes = nats
66 .subscribe(policy_changes_topic)
67 .await
68 .context("failed to subscribe to policy changes")?;
69
70 let _policy_changes = spawn({
71 let manager = Arc::new(manager.clone());
72 Abortable::new(policy_changes, policy_changes_abort_reg).for_each(move |msg| {
73 let manager = Arc::clone(&manager);
74 async move {
75 if let Err(e) = manager.override_decision(msg).await {
76 error!("failed to process policy decision override: {}", e);
77 }
78 }
79 })
80 });
81 }
82
83 Ok(manager)
84 }
85
86 #[instrument(level = "trace", skip_all)]
88 pub async fn evaluate_action(&self, request: RequestBody) -> anyhow::Result<Response> {
89 let Some(policy_topic) = self.policy_topic.clone() else {
90 return Ok(Response {
92 request_id: String::new(),
93 permitted: true,
94 message: None,
95 });
96 };
97
98 let kind = match request {
99 RequestBody::StartComponent(_) => RequestKind::StartComponent,
100 RequestBody::StartProvider(_) => RequestKind::StartProvider,
101 RequestBody::PerformInvocation(_) => RequestKind::PerformInvocation,
102 RequestBody::Unknown => RequestKind::Unknown,
103 };
104 let cache_key = (&request).into();
105 if let Some(entry) = self.decision_cache.read().await.get(&cache_key) {
106 trace!(?cache_key, ?entry, "using cached policy decision");
107 return Ok(entry.clone());
108 }
109
110 let request_id = Uuid::from_u128(Ulid::new().into()).to_string();
111 trace!(?cache_key, "requesting policy decision");
112 let payload = serde_json::to_vec(&Request {
113 request_id: request_id.clone(),
114 request,
115 kind,
116 version: POLICY_TYPE_VERSION.to_string(),
117 host: self.host_info.clone(),
118 })
119 .context("failed to serialize policy request")?;
120 let request = async_nats::Request::new()
121 .payload(payload.into())
122 .timeout(Some(self.policy_timeout));
123 let res = self
124 .nats
125 .send_request(policy_topic, request)
126 .await
127 .context("policy request failed")?;
128 let decision = serde_json::from_slice::<Response>(&res.payload)
129 .context("failed to deserialize policy response")?;
130
131 self.decision_cache
132 .write()
133 .await
134 .insert(cache_key.clone(), decision.clone()); self.request_to_key
136 .write()
137 .await
138 .insert(request_id, cache_key); Ok(decision)
140 }
141
142 #[instrument(skip(self))]
143 async fn override_decision(&self, msg: async_nats::Message) -> anyhow::Result<()> {
144 let Response {
145 request_id,
146 permitted,
147 message,
148 } = serde_json::from_slice(&msg.payload)
149 .context("failed to deserialize policy decision override")?;
150
151 debug!(request_id, "received policy decision override");
152
153 let mut decision_cache = self.decision_cache.write().await;
154 let request_to_key = self.request_to_key.read().await;
155
156 if let Some(key) = request_to_key.get(&request_id) {
157 decision_cache.insert(
158 key.clone(),
159 Response {
160 request_id: request_id.clone(),
161 permitted,
162 message,
163 },
164 );
165 } else {
166 warn!(
167 request_id,
168 "received policy decision override for unknown request id"
169 );
170 }
171
172 Ok(())
173 }
174}
175
176#[async_trait::async_trait]
177impl PolicyManager for NatsPolicyManager {
178 #[instrument(level = "trace", skip_all)]
179 async fn evaluate_start_component(
181 &self,
182 component_id: &str,
183 image_ref: &str,
184 max_instances: u32,
185 annotations: &BTreeMap<String, String>,
186 claims: Option<&jwt::Claims<jwt::Component>>,
187 ) -> anyhow::Result<Response> {
188 let request = ComponentInformation {
189 component_id: component_id.to_string(),
190 image_ref: image_ref.to_string(),
191 max_instances,
192 annotations: annotations.clone(),
193 claims: claims.map(PolicyClaims::from),
194 };
195 self.evaluate_action(RequestBody::StartComponent(request))
196 .await
197 }
198
199 #[instrument(level = "trace", skip_all)]
201 async fn evaluate_start_provider(
202 &self,
203 provider_id: &str,
204 provider_ref: &str,
205 annotations: &BTreeMap<String, String>,
206 claims: Option<&jwt::Claims<jwt::CapabilityProvider>>,
207 ) -> anyhow::Result<Response> {
208 let request = ProviderInformation {
209 provider_id: provider_id.to_string(),
210 image_ref: provider_ref.to_string(),
211 annotations: annotations.clone(),
212 claims: claims.map(PolicyClaims::from),
213 };
214 self.evaluate_action(RequestBody::StartProvider(request))
215 .await
216 }
217
218 #[instrument(level = "trace", skip_all)]
220 async fn evaluate_perform_invocation(
221 &self,
222 component_id: &str,
223 image_ref: &str,
224 annotations: &BTreeMap<String, String>,
225 claims: Option<&jwt::Claims<jwt::Component>>,
226 interface: String,
227 function: String,
228 ) -> anyhow::Result<Response> {
229 let request = PerformInvocationRequest {
230 interface,
231 function,
232 target: ComponentInformation {
233 component_id: component_id.to_string(),
234 image_ref: image_ref.to_string(),
235 max_instances: 0,
236 annotations: annotations.clone(),
237 claims: claims.map(PolicyClaims::from),
238 },
239 };
240 self.evaluate_action(RequestBody::PerformInvocation(request))
241 .await
242 }
243}