1use std::collections::HashMap;
4
5use anyhow::Context as _;
6use serde::{Deserialize, Serialize};
7
8use tracing::{instrument, trace};
9use wascap::{jwt, prelude::ClaimsBuilder};
10
11#[derive(Debug, Serialize, Deserialize)]
13#[serde(untagged)]
14pub(crate) enum StoredClaims {
15 Component(StoredComponentClaims),
16 Provider(StoredProviderClaims),
17}
18
19#[derive(Debug, Default, Serialize, Deserialize)]
20pub(crate) struct StoredComponentClaims {
21 call_alias: String,
22 #[serde(alias = "iss")]
23 issuer: String,
24 name: String,
25 #[serde(alias = "rev")]
26 revision: String,
27 #[serde(alias = "sub")]
28 subject: String,
29 #[serde(deserialize_with = "deserialize_messy_vec")]
30 tags: Vec<String>,
31 version: String,
32}
33
34#[derive(Debug, Default, Serialize, Deserialize)]
35pub(crate) struct StoredProviderClaims {
36 #[serde(alias = "iss")]
37 issuer: String,
38 name: String,
39 #[serde(alias = "rev")]
40 revision: String,
41 #[serde(alias = "sub")]
42 subject: String,
43 version: String,
44 #[serde(default, skip_serializing_if = "Option::is_none")]
45 config_schema: Option<String>,
46}
47
48impl TryFrom<Claims> for StoredClaims {
49 type Error = anyhow::Error;
50
51 fn try_from(claims: Claims) -> Result<Self, Self::Error> {
52 match claims {
53 Claims::Component(jwt::Claims {
54 issuer,
55 subject,
56 metadata,
57 ..
58 }) => {
59 let jwt::Component {
60 name,
61 tags,
62 rev,
63 ver,
64 call_alias,
65 ..
66 } = metadata.context("no metadata found on component claims")?;
67 Ok(StoredClaims::Component(StoredComponentClaims {
68 call_alias: call_alias.unwrap_or_default(),
69 issuer,
70 name: name.unwrap_or_default(),
71 revision: rev.unwrap_or_default().to_string(),
72 subject,
73 tags: tags.unwrap_or_default(),
74 version: ver.unwrap_or_default(),
75 }))
76 }
77 Claims::Provider(jwt::Claims {
78 issuer,
79 subject,
80 metadata,
81 ..
82 }) => {
83 let jwt::CapabilityProvider {
84 name,
85 rev,
86 ver,
87 config_schema,
88 ..
89 } = metadata.context("no metadata found on provider claims")?;
90 Ok(StoredClaims::Provider(StoredProviderClaims {
91 issuer,
92 name: name.unwrap_or_default(),
93 revision: rev.unwrap_or_default().to_string(),
94 subject,
95 version: ver.unwrap_or_default(),
96 config_schema: config_schema.map(|schema| schema.to_string()),
97 }))
98 }
99 }
100 }
101}
102
103impl TryFrom<&Claims> for StoredClaims {
104 type Error = anyhow::Error;
105
106 fn try_from(claims: &Claims) -> Result<Self, Self::Error> {
107 match claims {
108 Claims::Component(jwt::Claims {
109 issuer,
110 subject,
111 metadata,
112 ..
113 }) => {
114 let jwt::Component {
115 name,
116 tags,
117 rev,
118 ver,
119 call_alias,
120 ..
121 } = metadata
122 .as_ref()
123 .context("no metadata found on component claims")?;
124 Ok(StoredClaims::Component(StoredComponentClaims {
125 call_alias: call_alias.clone().unwrap_or_default(),
126 issuer: issuer.clone(),
127 name: name.clone().unwrap_or_default(),
128 revision: rev.unwrap_or_default().to_string(),
129 subject: subject.clone(),
130 tags: tags.clone().unwrap_or_default(),
131 version: ver.clone().unwrap_or_default(),
132 }))
133 }
134 Claims::Provider(jwt::Claims {
135 issuer,
136 subject,
137 metadata,
138 ..
139 }) => {
140 let jwt::CapabilityProvider {
141 name,
142 rev,
143 ver,
144 config_schema,
145 ..
146 } = metadata
147 .as_ref()
148 .context("no metadata found on provider claims")?;
149 Ok(StoredClaims::Provider(StoredProviderClaims {
150 issuer: issuer.clone(),
151 name: name.clone().unwrap_or_default(),
152 revision: rev.unwrap_or_default().to_string(),
153 subject: subject.clone(),
154 version: ver.clone().unwrap_or_default(),
155 config_schema: config_schema.as_ref().map(ToString::to_string),
156 }))
157 }
158 }
159 }
160}
161
162#[allow(clippy::implicit_hasher)]
163impl From<StoredClaims> for HashMap<String, String> {
164 fn from(claims: StoredClaims) -> Self {
165 match claims {
166 StoredClaims::Component(claims) => HashMap::from([
167 ("call_alias".to_string(), claims.call_alias),
168 ("iss".to_string(), claims.issuer.clone()), ("issuer".to_string(), claims.issuer),
170 ("name".to_string(), claims.name),
171 ("rev".to_string(), claims.revision.clone()), ("revision".to_string(), claims.revision),
173 ("sub".to_string(), claims.subject.clone()), ("subject".to_string(), claims.subject),
175 ("tags".to_string(), claims.tags.join(",")),
176 ("version".to_string(), claims.version),
177 ]),
178 StoredClaims::Provider(claims) => HashMap::from([
179 ("iss".to_string(), claims.issuer.clone()), ("issuer".to_string(), claims.issuer),
181 ("name".to_string(), claims.name),
182 ("rev".to_string(), claims.revision.clone()), ("revision".to_string(), claims.revision),
184 ("sub".to_string(), claims.subject.clone()), ("subject".to_string(), claims.subject),
186 ("version".to_string(), claims.version),
187 (
188 "config_schema".to_string(),
189 claims.config_schema.unwrap_or_default(),
190 ),
191 ]),
192 }
193 }
194}
195
196#[allow(clippy::large_enum_variant)] pub(crate) enum Claims {
198 Component(jwt::Claims<jwt::Component>),
199 Provider(jwt::Claims<jwt::CapabilityProvider>),
200}
201
202impl Claims {
203 pub(crate) fn subject(&self) -> &str {
204 match self {
205 Claims::Component(claims) => &claims.subject,
206 Claims::Provider(claims) => &claims.subject,
207 }
208 }
209}
210
211impl From<StoredClaims> for Claims {
212 fn from(claims: StoredClaims) -> Self {
213 match claims {
214 StoredClaims::Component(claims) => {
215 let name = (!claims.name.is_empty()).then_some(claims.name);
216 let rev = claims.revision.parse().ok();
217 let ver = (!claims.version.is_empty()).then_some(claims.version);
218 let tags = (!claims.tags.is_empty()).then_some(claims.tags);
219 let call_alias = (!claims.call_alias.is_empty()).then_some(claims.call_alias);
220 let metadata = jwt::Component {
221 name,
222 tags,
223 rev,
224 ver,
225 call_alias,
226 ..Default::default()
227 };
228 let claims = ClaimsBuilder::new()
229 .subject(&claims.subject)
230 .issuer(&claims.issuer)
231 .with_metadata(metadata)
232 .build();
233 Claims::Component(claims)
234 }
235 StoredClaims::Provider(claims) => {
236 let name = (!claims.name.is_empty()).then_some(claims.name);
237 let rev = claims.revision.parse().ok();
238 let ver = (!claims.version.is_empty()).then_some(claims.version);
239 let config_schema: Option<serde_json::Value> = claims
240 .config_schema
241 .and_then(|schema| serde_json::from_str(&schema).ok());
242 let metadata = jwt::CapabilityProvider {
243 name,
244 rev,
245 ver,
246 config_schema,
247 ..Default::default()
248 };
249 let claims = ClaimsBuilder::new()
250 .subject(&claims.subject)
251 .issuer(&claims.issuer)
252 .with_metadata(metadata)
253 .build();
254 Claims::Provider(claims)
255 }
256 }
257 }
258}
259
260impl super::Host {
261 #[instrument(level = "debug", skip_all)]
262 pub(crate) async fn store_claims(&self, claims: Claims) -> anyhow::Result<()> {
264 match &claims {
265 Claims::Component(claims) => {
266 self.store_component_claims(claims.clone()).await?;
267 }
268 Claims::Provider(claims) => {
269 self.store_provider_claims(claims.clone()).await?;
270 }
271 };
272 let claims: StoredClaims = claims.try_into()?;
273 let subject = match &claims {
274 StoredClaims::Component(claims) => &claims.subject,
275 StoredClaims::Provider(claims) => &claims.subject,
276 };
277 let key = format!("CLAIMS_{subject}");
278 trace!(?claims, ?key, "storing claims");
279
280 let bytes = serde_json::to_vec(&claims)
281 .context("failed to serialize claims")?
282 .into();
283 self.data_store
284 .put(&key, bytes)
285 .await
286 .context("failed to put claims")?;
287 Ok(())
288 }
289
290 #[instrument(level = "trace", skip_all)]
291 pub(crate) async fn store_component_claims(
293 &self,
294 claims: jwt::Claims<jwt::Component>,
295 ) -> anyhow::Result<()> {
296 self.component_claims
297 .write()
298 .await
299 .insert(claims.subject.clone(), claims);
300 Ok(())
301 }
302
303 #[instrument(level = "trace", skip_all)]
304 pub(crate) async fn delete_component_claims(&self, subject: &str) -> anyhow::Result<()> {
306 self.component_claims.write().await.remove(subject);
307 Ok(())
308 }
309
310 #[instrument(level = "trace", skip_all)]
311 pub(crate) async fn store_provider_claims(
313 &self,
314 claims: jwt::Claims<jwt::CapabilityProvider>,
315 ) -> anyhow::Result<()> {
316 self.provider_claims
317 .write()
318 .await
319 .insert(claims.subject.clone(), claims);
320 Ok(())
321 }
322
323 #[instrument(level = "trace", skip_all)]
324 pub(crate) async fn delete_provider_claims(&self, subject: &str) -> anyhow::Result<()> {
326 self.provider_claims.write().await.remove(subject);
327 Ok(())
328 }
329}
330
331fn deserialize_messy_vec<'de, D: serde::Deserializer<'de>>(
332 deserializer: D,
333) -> Result<Vec<String>, D::Error> {
334 MessyVec::deserialize(deserializer).map(|messy_vec| messy_vec.0)
335}
336struct MessyVec(pub Vec<String>);
338
339struct MessyVecVisitor;
340
341impl<'de> serde::de::Visitor<'de> for MessyVecVisitor {
344 type Value = MessyVec;
345
346 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
347 formatter.write_str("string or array of strings")
348 }
349
350 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
351 where
352 A: serde::de::SeqAccess<'de>,
353 {
354 let mut values = Vec::new();
355
356 while let Some(value) = seq.next_element()? {
357 values.push(value);
358 }
359
360 Ok(MessyVec(values))
361 }
362
363 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
364 where
365 E: serde::de::Error,
366 {
367 Ok(MessyVec(value.split(',').map(String::from).collect()))
368 }
369}
370
371impl<'de> Deserialize<'de> for MessyVec {
372 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
373 where
374 D: serde::de::Deserializer<'de>,
375 {
376 deserializer.deserialize_any(MessyVecVisitor)
377 }
378}