wascap/
wasm.rs

1//! Functions for extracting and embedding claims within a WebAssembly module
2
3use crate::{
4    errors::{self, ErrorKind},
5    jwt::{Claims, Component, Token, MIN_WASCAP_INTERNAL_REVISION},
6    Result,
7};
8use data_encoding::HEXUPPER;
9use nkeys::KeyPair;
10use ring::digest::{Context, Digest, SHA256};
11use std::{
12    io::Read,
13    mem,
14    time::{SystemTime, UNIX_EPOCH},
15};
16use wasm_encoder::ComponentSectionId;
17use wasm_encoder::Encode;
18use wasm_encoder::Section;
19use wasmparser::Parser;
20
21const SECS_PER_DAY: u64 = 86400;
22const SECTION_JWT: &str = "jwt"; // Versions of wascap prior to 0.9 used this section
23const SECTION_WC_JWT: &str = "wasmcloud_jwt";
24
25/// Extracts a set of claims from the raw bytes of a WebAssembly module. In the case where no
26/// JWT is discovered in the module, this function returns `None`.
27/// If there is a token in the file with a valid hash, then you will get a `Token` back
28/// containing both the raw JWT and the decoded claims.
29///
30/// # Errors
31/// Will return an error if hash computation fails or it can't read the JWT from inside
32/// a section's data, etc
33pub fn extract_claims(contents: impl AsRef<[u8]>) -> Result<Option<Token<Component>>> {
34    use wasmparser::Payload::{ComponentSection, CustomSection, End, ModuleSection};
35
36    let target_hash = compute_hash(&strip_custom_section(contents.as_ref())?)?;
37    let parser = wasmparser::Parser::new(0);
38    let mut depth = 0;
39    for payload in parser.parse_all(contents.as_ref()) {
40        let payload = payload?;
41        match payload {
42            ModuleSection { .. } | ComponentSection { .. } => depth += 1,
43            End { .. } => depth -= 1,
44            CustomSection(c)
45                if (c.name() == SECTION_JWT) || (c.name() == SECTION_WC_JWT) && depth == 0 =>
46            {
47                let jwt = String::from_utf8(c.data().to_vec())?;
48                let claims: Claims<Component> = Claims::decode(&jwt)?;
49                let Some(ref meta) = claims.metadata else {
50                    return Err(errors::new(ErrorKind::InvalidAlgorithm));
51                };
52                if meta.module_hash != target_hash
53                    && claims.wascap_revision.unwrap_or_default() >= MIN_WASCAP_INTERNAL_REVISION
54                {
55                    return Err(errors::new(ErrorKind::InvalidModuleHash));
56                }
57                return Ok(Some(Token { jwt, claims }));
58            }
59            _ => {}
60        }
61    }
62    Ok(None)
63}
64
65/// This function will embed a set of claims inside the bytecode of a WebAssembly module. The claims
66/// are converted into a JWT and signed using the provided `KeyPair`.
67/// According to the WebAssembly [custom section](https://webassembly.github.io/spec/core/appendix/custom.html)
68/// specification, arbitrary sets of bytes can be stored in a WebAssembly module without impacting
69/// parsers or interpreters. Returns a vector of bytes representing the new WebAssembly module which can
70/// be saved to a `.wasm` file
71#[allow(clippy::missing_errors_doc)] // TODO: document errors
72pub fn embed_claims(
73    orig_bytecode: &[u8],
74    claims: &Claims<Component>,
75    kp: &KeyPair,
76) -> Result<Vec<u8>> {
77    let mut bytes = orig_bytecode.to_vec();
78    eprintln!("BEFORE STRIP!");
79    bytes = strip_custom_section(&bytes)?;
80
81    eprintln!("BEFORE META STUFF!");
82    let hash = compute_hash(&bytes)?;
83    let mut claims = (*claims).clone();
84    let meta = claims.metadata.map(|md| Component {
85        module_hash: hash,
86        ..md
87    });
88    claims.metadata = meta;
89
90    eprintln!("BEFORE ENCODE!");
91    let encoded = claims.encode(kp)?;
92    let encvec = encoded.as_bytes().to_vec();
93    wasm_gen::write_custom_section(&mut bytes, SECTION_WC_JWT, &encvec);
94
95    Ok(bytes)
96}
97
98/// Sign a buffer containing bytes for a WebAssembly component
99/// with provided claims
100#[allow(clippy::too_many_arguments)]
101#[allow(clippy::missing_errors_doc)] // TODO: document
102pub fn sign_buffer_with_claims(
103    name: String,
104    buf: impl AsRef<[u8]>,
105    mod_kp: &KeyPair,
106    acct_kp: &KeyPair,
107    expires_in_days: Option<u64>,
108    not_before_days: Option<u64>,
109    tags: Vec<String>,
110    provider: bool,
111    rev: Option<i32>,
112    ver: Option<String>,
113    call_alias: Option<String>,
114) -> Result<Vec<u8>> {
115    let claims = Claims::<Component>::with_dates(
116        name,
117        acct_kp.public_key(),
118        mod_kp.public_key(),
119        Some(tags),
120        days_from_now_to_jwt_time(not_before_days),
121        days_from_now_to_jwt_time(expires_in_days),
122        provider,
123        rev,
124        ver,
125        call_alias,
126    );
127    embed_claims(buf.as_ref(), &claims, acct_kp)
128}
129
130pub(crate) fn strip_custom_section(buf: &[u8]) -> Result<Vec<u8>> {
131    use wasmparser::Payload::{ComponentSection, CustomSection, End, ModuleSection, Version};
132
133    let mut output: Vec<u8> = Vec::new();
134    let mut stack = Vec::new();
135    for payload in Parser::new(0).parse_all(buf) {
136        let payload = payload?;
137        match payload {
138            Version { encoding, .. } => {
139                output.extend_from_slice(match encoding {
140                    wasmparser::Encoding::Component => &wasm_encoder::Component::HEADER,
141                    wasmparser::Encoding::Module => &wasm_encoder::Module::HEADER,
142                });
143            }
144            ModuleSection { .. } | ComponentSection { .. } => {
145                stack.push(mem::take(&mut output));
146                continue;
147            }
148            End { .. } => {
149                let Some(mut parent) = stack.pop() else { break };
150                if output.starts_with(&wasm_encoder::Component::HEADER) {
151                    parent.push(ComponentSectionId::Component as u8);
152                    output.encode(&mut parent);
153                } else {
154                    parent.push(ComponentSectionId::CoreModule as u8);
155                    output.encode(&mut parent);
156                }
157                output = parent;
158            }
159            _ => {}
160        }
161
162        match payload {
163            CustomSection(c) if (c.name() == SECTION_JWT) || (c.name() == SECTION_WC_JWT) => {
164                // skip
165            }
166            _ => {
167                if let Some((id, range)) = payload.as_section() {
168                    if range.end <= buf.len() {
169                        wasm_encoder::RawSection {
170                            id,
171                            data: &buf[range],
172                        }
173                        .append_to(&mut output);
174                    } else {
175                        return Err(errors::new(ErrorKind::IO(std::io::Error::new(
176                            std::io::ErrorKind::UnexpectedEof,
177                            "Invalid section range",
178                        ))));
179                    }
180                }
181            }
182        }
183    }
184
185    Ok(output)
186}
187
188fn since_the_epoch() -> std::time::Duration {
189    let start = SystemTime::now();
190    start
191        .duration_since(UNIX_EPOCH)
192        .expect("A timey wimey problem has occurred!")
193}
194
195#[must_use]
196pub fn days_from_now_to_jwt_time(stamp: Option<u64>) -> Option<u64> {
197    stamp.map(|e| since_the_epoch().as_secs() + e * SECS_PER_DAY)
198}
199
200fn sha256_digest<R: Read>(mut reader: R) -> Result<Digest> {
201    let mut context = Context::new(&SHA256);
202    let mut buffer = [0; 1024];
203
204    loop {
205        let count = reader.read(&mut buffer)?;
206        if count == 0 {
207            break;
208        }
209        context.update(&buffer[..count]);
210    }
211
212    Ok(context.finish())
213}
214
215fn compute_hash(modbytes: &[u8]) -> Result<String> {
216    let digest = sha256_digest(modbytes)?;
217    Ok(HEXUPPER.encode(digest.as_ref()))
218}
219
220#[cfg(test)]
221mod test {
222    use std::fs::File;
223
224    use anyhow::Context as _;
225    use data_encoding::BASE64;
226
227    use super::*;
228    use crate::jwt::{Claims, Component, WASCAP_INTERNAL_REVISION};
229
230    const WASM_BASE64: &str =
231        "AGFzbQEAAAAADAZkeWxpbmuAgMACAAGKgICAAAJgAn9/AX9gAAACwYCAgAAEA2VudgptZW1vcnlCYXNl\
232         A38AA2VudgZtZW1vcnkCAIACA2VudgV0YWJsZQFwAAADZW52CXRhYmxlQmFzZQN/AAOEgICAAAMAAQEGi\
233         4CAgAACfwFBAAt/AUEACwejgICAAAIKX3RyYW5zZm9ybQAAEl9fcG9zdF9pbnN0YW50aWF0ZQACCYGAgI\
234         AAAArpgICAAAPBgICAAAECfwJ/IABBAEoEQEEAIQIFIAAPCwNAIAEgAmoiAywAAEHpAEYEQCADQfkAOgA\
235         ACyACQQFqIgIgAEcNAAsgAAsLg4CAgAAAAQuVgICAAAACQCMAJAIjAkGAgMACaiQDEAELCw==";
236
237    #[test]
238    #[ignore]
239    // NOTE: This test is ignored because we cannot use the newer workspace versions
240    // for wasmparser here as components with incorrect versions trigger failures on
241    // newer versions, and our test components happen to be malformed.
242    //
243    // This should *not* actually happen in practice but is a result of an invalid fixture.
244    //
245    fn strip_custom() -> anyhow::Result<()> {
246        let mut f = File::open("./fixtures/guest.component.wasm").unwrap();
247        let mut buffer = Vec::new();
248        f.read_to_end(&mut buffer).unwrap();
249
250        let kp = KeyPair::new_account();
251        let claims = Claims {
252            metadata: Some(Component::new(
253                "testing".to_string(),
254                Some(vec![]),
255                false,
256                Some(1),
257                Some(String::new()),
258                None,
259            )),
260            expires: None,
261            id: nuid::next().to_string(),
262            issued_at: 0,
263            issuer: kp.public_key(),
264            subject: "test.wasm".to_string(),
265            not_before: None,
266            wascap_revision: Some(WASCAP_INTERNAL_REVISION),
267        };
268        let modified_bytecode =
269            embed_claims(&buffer, &claims, &kp).context("failed to embed claims")?;
270
271        super::strip_custom_section(&modified_bytecode)
272            .context("failed to strip custom section")?;
273        Ok(())
274    }
275
276    #[test]
277    fn legacy_modules_still_extract() {
278        // Ensure that we can still extract claims from legacy (signed prior to 0.9.0) modules without
279        // a hash violation error
280        let mut f = File::open("./fixtures/logger.wasm").unwrap();
281        let mut buffer = Vec::new();
282        f.read_to_end(&mut buffer).unwrap();
283
284        let t = extract_claims(&buffer).unwrap();
285        assert!(t.is_some());
286    }
287
288    #[test]
289    #[ignore]
290    // see: `strip_custom` as to why this is ignored
291    fn decode_wasi_preview() -> anyhow::Result<()> {
292        let mut f = File::open("./fixtures/guest.component.wasm").unwrap();
293        let mut buffer = Vec::new();
294        f.read_to_end(&mut buffer).unwrap();
295
296        let kp = KeyPair::new_account();
297        let claims = Claims {
298            metadata: Some(Component::new(
299                "testing".to_string(),
300                Some(vec![]),
301                false,
302                Some(1),
303                Some(String::new()),
304                None,
305            )),
306            expires: None,
307            id: nuid::next().to_string(),
308            issued_at: 0,
309            issuer: kp.public_key(),
310            subject: "test.wasm".to_string(),
311            not_before: None,
312            wascap_revision: Some(WASCAP_INTERNAL_REVISION),
313        };
314        let modified_bytecode =
315            embed_claims(&buffer, &claims, &kp).context("failed to embed claims")?;
316
317        if let Some(token) =
318            extract_claims(modified_bytecode).context("failed to extract claims")?
319        {
320            assert_eq!(claims.issuer, token.claims.issuer);
321        } else {
322            unreachable!()
323        }
324
325        Ok(())
326    }
327
328    #[test]
329    fn claims_roundtrip() {
330        // Serialize and de-serialize this because the module loader adds bytes to
331        // the above base64 encoded module.
332        let dec_module = BASE64.decode(WASM_BASE64.as_bytes()).unwrap();
333
334        let kp = KeyPair::new_account();
335        let claims = Claims {
336            metadata: Some(Component::new(
337                "testing".to_string(),
338                Some(vec![]),
339                false,
340                Some(1),
341                Some(String::new()),
342                None,
343            )),
344            expires: None,
345            id: nuid::next().to_string(),
346            issued_at: 0,
347            issuer: kp.public_key(),
348            subject: "test.wasm".to_string(),
349            not_before: None,
350            wascap_revision: Some(WASCAP_INTERNAL_REVISION),
351        };
352        let modified_bytecode = embed_claims(&dec_module, &claims, &kp).unwrap();
353
354        if let Some(token) = extract_claims(modified_bytecode).unwrap() {
355            assert_eq!(claims.issuer, token.claims.issuer);
356        } else {
357            unreachable!()
358        }
359    }
360
361    #[test]
362    fn claims_doublesign_roundtrip() {
363        // Verify that we can sign a previously signed module by stripping the old
364        // custom JWT and maintaining valid hashes
365        let dec_module = BASE64.decode(WASM_BASE64.as_bytes()).unwrap();
366
367        let kp = KeyPair::new_account();
368        let claims = Claims {
369            metadata: Some(Component::new(
370                "testing".to_string(),
371                Some(vec![]),
372                false,
373                Some(1),
374                Some(String::new()),
375                None,
376            )),
377            expires: None,
378            id: nuid::next().to_string(),
379            issued_at: 0,
380            issuer: kp.public_key(),
381            subject: "test.wasm".to_string(),
382            not_before: None,
383            wascap_revision: Some(WASCAP_INTERNAL_REVISION),
384        };
385        let c2 = claims.clone();
386        let modified_bytecode = embed_claims(&dec_module, &claims, &kp).unwrap();
387
388        let new_claims = Claims {
389            subject: "altered.wasm".to_string(),
390            ..claims
391        };
392
393        let modified_bytecode2 = embed_claims(&modified_bytecode, &new_claims, &kp).unwrap();
394        if let Some(token) = extract_claims(modified_bytecode2).unwrap() {
395            assert_eq!(c2.issuer, token.claims.issuer);
396            assert_eq!(token.claims.subject, "altered.wasm");
397        } else {
398            unreachable!()
399        }
400    }
401
402    #[test]
403    fn claims_logging_roundtrip() {
404        // Serialize and de-serialize this because the module loader adds bytes to
405        // the above base64 encoded module.
406        let dec_module = BASE64.decode(WASM_BASE64.as_bytes()).unwrap();
407
408        let kp = KeyPair::new_account();
409        let claims = Claims {
410            metadata: Some(Component::new(
411                "testing".to_string(),
412                Some(vec![]),
413                false,
414                Some(1),
415                Some(String::new()),
416                Some("somealias".to_string()),
417            )),
418            expires: None,
419            id: nuid::next().to_string(),
420            issued_at: 0,
421            issuer: kp.public_key(),
422            subject: "test.wasm".to_string(),
423            not_before: None,
424            wascap_revision: Some(WASCAP_INTERNAL_REVISION),
425        };
426        let modified_bytecode = embed_claims(&dec_module, &claims, &kp).unwrap();
427
428        if let Some(token) = extract_claims(modified_bytecode).unwrap() {
429            assert_eq!(claims.issuer, token.claims.issuer);
430            assert_eq!(claims.subject, token.claims.subject);
431
432            let claims_met = claims.metadata.as_ref().unwrap();
433            let token_met = token.claims.metadata.as_ref().unwrap();
434
435            assert_eq!(claims_met.call_alias, token_met.call_alias);
436        } else {
437            unreachable!()
438        }
439    }
440}