wasmcloud_host/wasmbus/
claims.rs

1//! This module contains structs and logic for managing claims in the host
2
3use std::collections::HashMap;
4
5use anyhow::Context as _;
6use serde::{Deserialize, Serialize};
7
8use tracing::{instrument, trace};
9use wascap::{jwt, prelude::ClaimsBuilder};
10
11// TODO: remove StoredClaims in #1093
12#[derive(Debug, Serialize, Deserialize)]
13#[serde(untagged)]
14pub(crate) enum StoredClaims {
15    Component(StoredComponentClaims),
16    Provider(StoredProviderClaims),
17}
18
19#[derive(Debug, Default, Serialize, Deserialize)]
20pub(crate) struct StoredComponentClaims {
21    call_alias: String,
22    #[serde(alias = "iss")]
23    issuer: String,
24    name: String,
25    #[serde(alias = "rev")]
26    revision: String,
27    #[serde(alias = "sub")]
28    subject: String,
29    #[serde(deserialize_with = "deserialize_messy_vec")]
30    tags: Vec<String>,
31    version: String,
32}
33
34#[derive(Debug, Default, Serialize, Deserialize)]
35pub(crate) struct StoredProviderClaims {
36    #[serde(alias = "iss")]
37    issuer: String,
38    name: String,
39    #[serde(alias = "rev")]
40    revision: String,
41    #[serde(alias = "sub")]
42    subject: String,
43    version: String,
44    #[serde(default, skip_serializing_if = "Option::is_none")]
45    config_schema: Option<String>,
46}
47
48impl TryFrom<Claims> for StoredClaims {
49    type Error = anyhow::Error;
50
51    fn try_from(claims: Claims) -> Result<Self, Self::Error> {
52        match claims {
53            Claims::Component(jwt::Claims {
54                issuer,
55                subject,
56                metadata,
57                ..
58            }) => {
59                let jwt::Component {
60                    name,
61                    tags,
62                    rev,
63                    ver,
64                    call_alias,
65                    ..
66                } = metadata.context("no metadata found on component claims")?;
67                Ok(StoredClaims::Component(StoredComponentClaims {
68                    call_alias: call_alias.unwrap_or_default(),
69                    issuer,
70                    name: name.unwrap_or_default(),
71                    revision: rev.unwrap_or_default().to_string(),
72                    subject,
73                    tags: tags.unwrap_or_default(),
74                    version: ver.unwrap_or_default(),
75                }))
76            }
77            Claims::Provider(jwt::Claims {
78                issuer,
79                subject,
80                metadata,
81                ..
82            }) => {
83                let jwt::CapabilityProvider {
84                    name,
85                    rev,
86                    ver,
87                    config_schema,
88                    ..
89                } = metadata.context("no metadata found on provider claims")?;
90                Ok(StoredClaims::Provider(StoredProviderClaims {
91                    issuer,
92                    name: name.unwrap_or_default(),
93                    revision: rev.unwrap_or_default().to_string(),
94                    subject,
95                    version: ver.unwrap_or_default(),
96                    config_schema: config_schema.map(|schema| schema.to_string()),
97                }))
98            }
99        }
100    }
101}
102
103impl TryFrom<&Claims> for StoredClaims {
104    type Error = anyhow::Error;
105
106    fn try_from(claims: &Claims) -> Result<Self, Self::Error> {
107        match claims {
108            Claims::Component(jwt::Claims {
109                issuer,
110                subject,
111                metadata,
112                ..
113            }) => {
114                let jwt::Component {
115                    name,
116                    tags,
117                    rev,
118                    ver,
119                    call_alias,
120                    ..
121                } = metadata
122                    .as_ref()
123                    .context("no metadata found on component claims")?;
124                Ok(StoredClaims::Component(StoredComponentClaims {
125                    call_alias: call_alias.clone().unwrap_or_default(),
126                    issuer: issuer.clone(),
127                    name: name.clone().unwrap_or_default(),
128                    revision: rev.unwrap_or_default().to_string(),
129                    subject: subject.clone(),
130                    tags: tags.clone().unwrap_or_default(),
131                    version: ver.clone().unwrap_or_default(),
132                }))
133            }
134            Claims::Provider(jwt::Claims {
135                issuer,
136                subject,
137                metadata,
138                ..
139            }) => {
140                let jwt::CapabilityProvider {
141                    name,
142                    rev,
143                    ver,
144                    config_schema,
145                    ..
146                } = metadata
147                    .as_ref()
148                    .context("no metadata found on provider claims")?;
149                Ok(StoredClaims::Provider(StoredProviderClaims {
150                    issuer: issuer.clone(),
151                    name: name.clone().unwrap_or_default(),
152                    revision: rev.unwrap_or_default().to_string(),
153                    subject: subject.clone(),
154                    version: ver.clone().unwrap_or_default(),
155                    config_schema: config_schema.as_ref().map(ToString::to_string),
156                }))
157            }
158        }
159    }
160}
161
162#[allow(clippy::implicit_hasher)]
163impl From<StoredClaims> for HashMap<String, String> {
164    fn from(claims: StoredClaims) -> Self {
165        match claims {
166            StoredClaims::Component(claims) => HashMap::from([
167                ("call_alias".to_string(), claims.call_alias),
168                ("iss".to_string(), claims.issuer.clone()), // TODO: remove in #1093
169                ("issuer".to_string(), claims.issuer),
170                ("name".to_string(), claims.name),
171                ("rev".to_string(), claims.revision.clone()), // TODO: remove in #1093
172                ("revision".to_string(), claims.revision),
173                ("sub".to_string(), claims.subject.clone()), // TODO: remove in #1093
174                ("subject".to_string(), claims.subject),
175                ("tags".to_string(), claims.tags.join(",")),
176                ("version".to_string(), claims.version),
177            ]),
178            StoredClaims::Provider(claims) => HashMap::from([
179                ("iss".to_string(), claims.issuer.clone()), // TODO: remove in #1093
180                ("issuer".to_string(), claims.issuer),
181                ("name".to_string(), claims.name),
182                ("rev".to_string(), claims.revision.clone()), // TODO: remove in #1093
183                ("revision".to_string(), claims.revision),
184                ("sub".to_string(), claims.subject.clone()), // TODO: remove in #1093
185                ("subject".to_string(), claims.subject),
186                ("version".to_string(), claims.version),
187                (
188                    "config_schema".to_string(),
189                    claims.config_schema.unwrap_or_default(),
190                ),
191            ]),
192        }
193    }
194}
195
196#[allow(clippy::large_enum_variant)] // Without this clippy complains component is at least 0 bytes while provider is at least 280 bytes. That doesn't make sense
197pub(crate) enum Claims {
198    Component(jwt::Claims<jwt::Component>),
199    Provider(jwt::Claims<jwt::CapabilityProvider>),
200}
201
202impl Claims {
203    pub(crate) fn subject(&self) -> &str {
204        match self {
205            Claims::Component(claims) => &claims.subject,
206            Claims::Provider(claims) => &claims.subject,
207        }
208    }
209}
210
211impl From<StoredClaims> for Claims {
212    fn from(claims: StoredClaims) -> Self {
213        match claims {
214            StoredClaims::Component(claims) => {
215                let name = (!claims.name.is_empty()).then_some(claims.name);
216                let rev = claims.revision.parse().ok();
217                let ver = (!claims.version.is_empty()).then_some(claims.version);
218                let tags = (!claims.tags.is_empty()).then_some(claims.tags);
219                let call_alias = (!claims.call_alias.is_empty()).then_some(claims.call_alias);
220                let metadata = jwt::Component {
221                    name,
222                    tags,
223                    rev,
224                    ver,
225                    call_alias,
226                    ..Default::default()
227                };
228                let claims = ClaimsBuilder::new()
229                    .subject(&claims.subject)
230                    .issuer(&claims.issuer)
231                    .with_metadata(metadata)
232                    .build();
233                Claims::Component(claims)
234            }
235            StoredClaims::Provider(claims) => {
236                let name = (!claims.name.is_empty()).then_some(claims.name);
237                let rev = claims.revision.parse().ok();
238                let ver = (!claims.version.is_empty()).then_some(claims.version);
239                let config_schema: Option<serde_json::Value> = claims
240                    .config_schema
241                    .and_then(|schema| serde_json::from_str(&schema).ok());
242                let metadata = jwt::CapabilityProvider {
243                    name,
244                    rev,
245                    ver,
246                    config_schema,
247                    ..Default::default()
248                };
249                let claims = ClaimsBuilder::new()
250                    .subject(&claims.subject)
251                    .issuer(&claims.issuer)
252                    .with_metadata(metadata)
253                    .build();
254                Claims::Provider(claims)
255            }
256        }
257    }
258}
259
260impl super::Host {
261    #[instrument(level = "debug", skip_all)]
262    /// Store claims in the data store
263    pub(crate) async fn store_claims(&self, claims: Claims) -> anyhow::Result<()> {
264        match &claims {
265            Claims::Component(claims) => {
266                self.store_component_claims(claims.clone()).await?;
267            }
268            Claims::Provider(claims) => {
269                self.store_provider_claims(claims.clone()).await?;
270            }
271        };
272        let claims: StoredClaims = claims.try_into()?;
273        let subject = match &claims {
274            StoredClaims::Component(claims) => &claims.subject,
275            StoredClaims::Provider(claims) => &claims.subject,
276        };
277        let key = format!("CLAIMS_{subject}");
278        trace!(?claims, ?key, "storing claims");
279
280        let bytes = serde_json::to_vec(&claims)
281            .context("failed to serialize claims")?
282            .into();
283        self.data_store
284            .put(&key, bytes)
285            .await
286            .context("failed to put claims")?;
287        Ok(())
288    }
289
290    #[instrument(level = "trace", skip_all)]
291    /// Store claims in the host in-memory cache
292    pub(crate) async fn store_component_claims(
293        &self,
294        claims: jwt::Claims<jwt::Component>,
295    ) -> anyhow::Result<()> {
296        self.component_claims
297            .write()
298            .await
299            .insert(claims.subject.clone(), claims);
300        Ok(())
301    }
302
303    #[instrument(level = "trace", skip_all)]
304    /// Remove claims from the host in-memory cache
305    pub(crate) async fn delete_component_claims(&self, subject: &str) -> anyhow::Result<()> {
306        self.component_claims.write().await.remove(subject);
307        Ok(())
308    }
309
310    #[instrument(level = "trace", skip_all)]
311    /// Store claims in the host in-memory cache
312    pub(crate) async fn store_provider_claims(
313        &self,
314        claims: jwt::Claims<jwt::CapabilityProvider>,
315    ) -> anyhow::Result<()> {
316        self.provider_claims
317            .write()
318            .await
319            .insert(claims.subject.clone(), claims);
320        Ok(())
321    }
322
323    #[instrument(level = "trace", skip_all)]
324    /// Remove claims from the host in-memory cache
325    pub(crate) async fn delete_provider_claims(&self, subject: &str) -> anyhow::Result<()> {
326        self.provider_claims.write().await.remove(subject);
327        Ok(())
328    }
329}
330
331fn deserialize_messy_vec<'de, D: serde::Deserializer<'de>>(
332    deserializer: D,
333) -> Result<Vec<String>, D::Error> {
334    MessyVec::deserialize(deserializer).map(|messy_vec| messy_vec.0)
335}
336// Helper struct to deserialize either a comma-delimited string or an actual array of strings
337struct MessyVec(pub Vec<String>);
338
339struct MessyVecVisitor;
340
341// Since this is "temporary" code to preserve backwards compatibility with already-serialized claims,
342// we use fully-qualified names instead of importing
343impl<'de> serde::de::Visitor<'de> for MessyVecVisitor {
344    type Value = MessyVec;
345
346    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
347        formatter.write_str("string or array of strings")
348    }
349
350    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
351    where
352        A: serde::de::SeqAccess<'de>,
353    {
354        let mut values = Vec::new();
355
356        while let Some(value) = seq.next_element()? {
357            values.push(value);
358        }
359
360        Ok(MessyVec(values))
361    }
362
363    fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
364    where
365        E: serde::de::Error,
366    {
367        Ok(MessyVec(value.split(',').map(String::from).collect()))
368    }
369}
370
371impl<'de> Deserialize<'de> for MessyVec {
372    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
373    where
374        D: serde::de::Deserializer<'de>,
375    {
376        deserializer.deserialize_any(MessyVecVisitor)
377    }
378}