wascap/
wasm.rs

1//! Functions for extracting and embedding claims within a WebAssembly module
2
3use crate::{
4    errors::{self, ErrorKind},
5    jwt::{Claims, Component, Token, MIN_WASCAP_INTERNAL_REVISION},
6    Result,
7};
8use data_encoding::HEXUPPER;
9use nkeys::KeyPair;
10use ring::digest::{Context, Digest, SHA256};
11use std::{
12    io::Read,
13    mem,
14    time::{SystemTime, UNIX_EPOCH},
15};
16use wasm_encoder::ComponentSectionId;
17use wasm_encoder::Encode;
18use wasm_encoder::Section;
19use wasmparser::Parser;
20const SECS_PER_DAY: u64 = 86400;
21const SECTION_JWT: &str = "jwt"; // Versions of wascap prior to 0.9 used this section
22const SECTION_WC_JWT: &str = "wasmcloud_jwt";
23
24/// Extracts a set of claims from the raw bytes of a WebAssembly module. In the case where no
25/// JWT is discovered in the module, this function returns `None`.
26/// If there is a token in the file with a valid hash, then you will get a `Token` back
27/// containing both the raw JWT and the decoded claims.
28///
29/// # Errors
30/// Will return an error if hash computation fails or it can't read the JWT from inside
31/// a section's data, etc
32pub fn extract_claims(contents: impl AsRef<[u8]>) -> Result<Option<Token<Component>>> {
33    use wasmparser::Payload::{ComponentSection, CustomSection, End, ModuleSection};
34
35    let target_hash = compute_hash(&strip_custom_section(contents.as_ref())?)?;
36    let parser = wasmparser::Parser::new(0);
37    let mut depth = 0;
38    for payload in parser.parse_all(contents.as_ref()) {
39        let payload = payload?;
40        match payload {
41            ModuleSection { .. } | ComponentSection { .. } => depth += 1,
42            End { .. } => depth -= 1,
43            CustomSection(c)
44                if (c.name() == SECTION_JWT) || (c.name() == SECTION_WC_JWT) && depth == 0 =>
45            {
46                let jwt = String::from_utf8(c.data().to_vec())?;
47                let claims: Claims<Component> = Claims::decode(&jwt)?;
48                let Some(ref meta) = claims.metadata else {
49                    return Err(errors::new(ErrorKind::InvalidAlgorithm));
50                };
51                if meta.module_hash != target_hash
52                    && claims.wascap_revision.unwrap_or_default() >= MIN_WASCAP_INTERNAL_REVISION
53                {
54                    return Err(errors::new(ErrorKind::InvalidModuleHash));
55                }
56                return Ok(Some(Token { jwt, claims }));
57            }
58            _ => {}
59        }
60    }
61    Ok(None)
62}
63
64/// This function will embed a set of claims inside the bytecode of a WebAssembly module. The claims
65/// are converted into a JWT and signed using the provided `KeyPair`.
66/// According to the WebAssembly [custom section](https://webassembly.github.io/spec/core/appendix/custom.html)
67/// specification, arbitrary sets of bytes can be stored in a WebAssembly module without impacting
68/// parsers or interpreters. Returns a vector of bytes representing the new WebAssembly module which can
69/// be saved to a `.wasm` file
70#[allow(clippy::missing_errors_doc)] // TODO: document errors
71pub fn embed_claims(
72    orig_bytecode: &[u8],
73    claims: &Claims<Component>,
74    kp: &KeyPair,
75) -> Result<Vec<u8>> {
76    let mut bytes = orig_bytecode.to_vec();
77    bytes = strip_custom_section(&bytes)?;
78
79    let hash = compute_hash(&bytes)?;
80    let mut claims = (*claims).clone();
81    let meta = claims.metadata.map(|md| Component {
82        module_hash: hash,
83        ..md
84    });
85    claims.metadata = meta;
86
87    let encoded = claims.encode(kp)?;
88    let encvec = encoded.as_bytes().to_vec();
89    wasm_gen::write_custom_section(&mut bytes, SECTION_WC_JWT, &encvec);
90
91    Ok(bytes)
92}
93
94/// Sign a buffer containing bytes for a WebAssembly component
95/// with provided claims
96#[allow(clippy::too_many_arguments)]
97#[allow(clippy::missing_errors_doc)] // TODO: document
98pub fn sign_buffer_with_claims(
99    name: String,
100    buf: impl AsRef<[u8]>,
101    mod_kp: &KeyPair,
102    acct_kp: &KeyPair,
103    expires_in_days: Option<u64>,
104    not_before_days: Option<u64>,
105    tags: Vec<String>,
106    provider: bool,
107    rev: Option<i32>,
108    ver: Option<String>,
109    call_alias: Option<String>,
110) -> Result<Vec<u8>> {
111    let claims = Claims::<Component>::with_dates(
112        name,
113        acct_kp.public_key(),
114        mod_kp.public_key(),
115        Some(tags),
116        days_from_now_to_jwt_time(not_before_days),
117        days_from_now_to_jwt_time(expires_in_days),
118        provider,
119        rev,
120        ver,
121        call_alias,
122    );
123    embed_claims(buf.as_ref(), &claims, acct_kp)
124}
125
126pub(crate) fn strip_custom_section(buf: &[u8]) -> Result<Vec<u8>> {
127    use wasmparser::Payload::{ComponentSection, CustomSection, End, ModuleSection, Version};
128
129    let mut output: Vec<u8> = Vec::new();
130    let mut stack = Vec::new();
131    for payload in Parser::new(0).parse_all(buf) {
132        let payload = payload?;
133        match payload {
134            Version { encoding, .. } => {
135                output.extend_from_slice(match encoding {
136                    wasmparser::Encoding::Component => &wasm_encoder::Component::HEADER,
137                    wasmparser::Encoding::Module => &wasm_encoder::Module::HEADER,
138                });
139            }
140            ModuleSection { .. } | ComponentSection { .. } => {
141                stack.push(mem::take(&mut output));
142                continue;
143            }
144            End { .. } => {
145                let Some(mut parent) = stack.pop() else { break };
146                if output.starts_with(&wasm_encoder::Component::HEADER) {
147                    parent.push(ComponentSectionId::Component as u8);
148                    output.encode(&mut parent);
149                } else {
150                    parent.push(ComponentSectionId::CoreModule as u8);
151                    output.encode(&mut parent);
152                }
153                output = parent;
154            }
155            _ => {}
156        }
157
158        match payload {
159            CustomSection(c) if (c.name() == SECTION_JWT) || (c.name() == SECTION_WC_JWT) => {
160                // skip
161            }
162            _ => {
163                if let Some((id, range)) = payload.as_section() {
164                    if range.end <= buf.len() {
165                        wasm_encoder::RawSection {
166                            id,
167                            data: &buf[range],
168                        }
169                        .append_to(&mut output);
170                    } else {
171                        return Err(errors::new(ErrorKind::IO(std::io::Error::new(
172                            std::io::ErrorKind::UnexpectedEof,
173                            "Invalid section range",
174                        ))));
175                    }
176                }
177            }
178        }
179    }
180
181    Ok(output)
182}
183
184fn since_the_epoch() -> std::time::Duration {
185    let start = SystemTime::now();
186    start
187        .duration_since(UNIX_EPOCH)
188        .expect("A timey wimey problem has occurred!")
189}
190
191#[must_use]
192pub fn days_from_now_to_jwt_time(stamp: Option<u64>) -> Option<u64> {
193    stamp.map(|e| since_the_epoch().as_secs() + e * SECS_PER_DAY)
194}
195
196fn sha256_digest<R: Read>(mut reader: R) -> Result<Digest> {
197    let mut context = Context::new(&SHA256);
198    let mut buffer = [0; 1024];
199
200    loop {
201        let count = reader.read(&mut buffer)?;
202        if count == 0 {
203            break;
204        }
205        context.update(&buffer[..count]);
206    }
207
208    Ok(context.finish())
209}
210
211fn compute_hash(modbytes: &[u8]) -> Result<String> {
212    let digest = sha256_digest(modbytes)?;
213    Ok(HEXUPPER.encode(digest.as_ref()))
214}
215
216#[cfg(test)]
217mod test {
218    use std::fs::File;
219
220    use super::*;
221    use crate::jwt::{Claims, Component, WASCAP_INTERNAL_REVISION};
222    use data_encoding::BASE64;
223
224    const WASM_BASE64: &str =
225        "AGFzbQEAAAAADAZkeWxpbmuAgMACAAGKgICAAAJgAn9/AX9gAAACwYCAgAAEA2VudgptZW1vcnlCYXNl\
226         A38AA2VudgZtZW1vcnkCAIACA2VudgV0YWJsZQFwAAADZW52CXRhYmxlQmFzZQN/AAOEgICAAAMAAQEGi\
227         4CAgAACfwFBAAt/AUEACwejgICAAAIKX3RyYW5zZm9ybQAAEl9fcG9zdF9pbnN0YW50aWF0ZQACCYGAgI\
228         AAAArpgICAAAPBgICAAAECfwJ/IABBAEoEQEEAIQIFIAAPCwNAIAEgAmoiAywAAEHpAEYEQCADQfkAOgA\
229         ACyACQQFqIgIgAEcNAAsgAAsLg4CAgAAAAQuVgICAAAACQCMAJAIjAkGAgMACaiQDEAELCw==";
230
231    #[test]
232    fn strip_custom() {
233        let mut f = File::open("./fixtures/guest.component.wasm").unwrap();
234        let mut buffer = Vec::new();
235        f.read_to_end(&mut buffer).unwrap();
236
237        let kp = KeyPair::new_account();
238        let claims = Claims {
239            metadata: Some(Component::new(
240                "testing".to_string(),
241                Some(vec![]),
242                false,
243                Some(1),
244                Some(String::new()),
245                None,
246            )),
247            expires: None,
248            id: nuid::next().to_string(),
249            issued_at: 0,
250            issuer: kp.public_key(),
251            subject: "test.wasm".to_string(),
252            not_before: None,
253            wascap_revision: Some(WASCAP_INTERNAL_REVISION),
254        };
255        let modified_bytecode = embed_claims(&buffer, &claims, &kp).unwrap();
256
257        super::strip_custom_section(&modified_bytecode).unwrap();
258    }
259
260    #[test]
261    fn legacy_modules_still_extract() {
262        // Ensure that we can still extract claims from legacy (signed prior to 0.9.0) modules without
263        // a hash violation error
264        let mut f = File::open("./fixtures/logger.wasm").unwrap();
265        let mut buffer = Vec::new();
266        f.read_to_end(&mut buffer).unwrap();
267
268        let t = extract_claims(&buffer).unwrap();
269        assert!(t.is_some());
270    }
271
272    #[test]
273    fn decode_wasi_preview() {
274        let mut f = File::open("./fixtures/guest.component.wasm").unwrap();
275        let mut buffer = Vec::new();
276        f.read_to_end(&mut buffer).unwrap();
277
278        let kp = KeyPair::new_account();
279        let claims = Claims {
280            metadata: Some(Component::new(
281                "testing".to_string(),
282                Some(vec![]),
283                false,
284                Some(1),
285                Some(String::new()),
286                None,
287            )),
288            expires: None,
289            id: nuid::next().to_string(),
290            issued_at: 0,
291            issuer: kp.public_key(),
292            subject: "test.wasm".to_string(),
293            not_before: None,
294            wascap_revision: Some(WASCAP_INTERNAL_REVISION),
295        };
296        let modified_bytecode = embed_claims(&buffer, &claims, &kp).unwrap();
297
298        if let Some(token) = extract_claims(modified_bytecode).unwrap() {
299            assert_eq!(claims.issuer, token.claims.issuer);
300        } else {
301            unreachable!()
302        }
303    }
304
305    #[test]
306    fn claims_roundtrip() {
307        // Serialize and de-serialize this because the module loader adds bytes to
308        // the above base64 encoded module.
309        let dec_module = BASE64.decode(WASM_BASE64.as_bytes()).unwrap();
310
311        let kp = KeyPair::new_account();
312        let claims = Claims {
313            metadata: Some(Component::new(
314                "testing".to_string(),
315                Some(vec![]),
316                false,
317                Some(1),
318                Some(String::new()),
319                None,
320            )),
321            expires: None,
322            id: nuid::next().to_string(),
323            issued_at: 0,
324            issuer: kp.public_key(),
325            subject: "test.wasm".to_string(),
326            not_before: None,
327            wascap_revision: Some(WASCAP_INTERNAL_REVISION),
328        };
329        let modified_bytecode = embed_claims(&dec_module, &claims, &kp).unwrap();
330
331        if let Some(token) = extract_claims(modified_bytecode).unwrap() {
332            assert_eq!(claims.issuer, token.claims.issuer);
333        } else {
334            unreachable!()
335        }
336    }
337
338    #[test]
339    fn claims_doublesign_roundtrip() {
340        // Verify that we can sign a previously signed module by stripping the old
341        // custom JWT and maintaining valid hashes
342        let dec_module = BASE64.decode(WASM_BASE64.as_bytes()).unwrap();
343
344        let kp = KeyPair::new_account();
345        let claims = Claims {
346            metadata: Some(Component::new(
347                "testing".to_string(),
348                Some(vec![]),
349                false,
350                Some(1),
351                Some(String::new()),
352                None,
353            )),
354            expires: None,
355            id: nuid::next().to_string(),
356            issued_at: 0,
357            issuer: kp.public_key(),
358            subject: "test.wasm".to_string(),
359            not_before: None,
360            wascap_revision: Some(WASCAP_INTERNAL_REVISION),
361        };
362        let c2 = claims.clone();
363        let modified_bytecode = embed_claims(&dec_module, &claims, &kp).unwrap();
364
365        let new_claims = Claims {
366            subject: "altered.wasm".to_string(),
367            ..claims
368        };
369
370        let modified_bytecode2 = embed_claims(&modified_bytecode, &new_claims, &kp).unwrap();
371        if let Some(token) = extract_claims(modified_bytecode2).unwrap() {
372            assert_eq!(c2.issuer, token.claims.issuer);
373            assert_eq!(token.claims.subject, "altered.wasm");
374        } else {
375            unreachable!()
376        }
377    }
378
379    #[test]
380    fn claims_logging_roundtrip() {
381        // Serialize and de-serialize this because the module loader adds bytes to
382        // the above base64 encoded module.
383        let dec_module = BASE64.decode(WASM_BASE64.as_bytes()).unwrap();
384
385        let kp = KeyPair::new_account();
386        let claims = Claims {
387            metadata: Some(Component::new(
388                "testing".to_string(),
389                Some(vec![]),
390                false,
391                Some(1),
392                Some(String::new()),
393                Some("somealias".to_string()),
394            )),
395            expires: None,
396            id: nuid::next().to_string(),
397            issued_at: 0,
398            issuer: kp.public_key(),
399            subject: "test.wasm".to_string(),
400            not_before: None,
401            wascap_revision: Some(WASCAP_INTERNAL_REVISION),
402        };
403        let modified_bytecode = embed_claims(&dec_module, &claims, &kp).unwrap();
404
405        if let Some(token) = extract_claims(modified_bytecode).unwrap() {
406            assert_eq!(claims.issuer, token.claims.issuer);
407            assert_eq!(claims.subject, token.claims.subject);
408
409            let claims_met = claims.metadata.as_ref().unwrap();
410            let token_met = token.claims.metadata.as_ref().unwrap();
411
412            assert_eq!(claims_met.call_alias, token_met.call_alias);
413        } else {
414            unreachable!()
415        }
416    }
417}