wrpc_runtime_wasmtime/
lib.rs

1#![allow(clippy::type_complexity)] // TODO: https://github.com/bytecodealliance/wrpc/issues/2
2
3use core::any::Any;
4use core::borrow::Borrow;
5use core::fmt;
6use core::future::Future;
7use core::iter::zip;
8use core::pin::pin;
9use core::time::Duration;
10
11use std::collections::{BTreeMap, HashMap};
12use std::sync::Arc;
13
14use anyhow::{anyhow, bail, Context as _};
15use bytes::{Bytes, BytesMut};
16use futures::future::try_join_all;
17use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt as _};
18use tokio_util::codec::Encoder;
19use tracing::{debug, instrument, trace, warn};
20use uuid::Uuid;
21use wasmtime::component::{
22    types, Func, Resource, ResourceAny, ResourceTable, ResourceType, Type, Val,
23};
24use wasmtime::{AsContextMut, Engine};
25use wrpc_transport::Invoke;
26
27use crate::bindings::rpc::context::Context;
28use crate::bindings::rpc::error::Error;
29use crate::bindings::rpc::transport::{IncomingChannel, Invocation, OutgoingChannel};
30
31pub mod bindings;
32mod codec;
33mod polyfill;
34pub mod rpc;
35mod serve;
36
37pub use codec::*;
38pub use polyfill::*;
39pub use serve::*;
40
41// this returns the RPC name for a wasmtime function name.
42// Unfortunately, the [`types::ComponentFunc`] does not include the kind information and we want to
43// avoid (re-)parsing the WIT here.
44fn rpc_func_name(name: &str) -> &str {
45    if let Some(name) = name.strip_prefix("[constructor]") {
46        name
47    } else if let Some(name) = name.strip_prefix("[static]") {
48        name
49    } else if let Some(name) = name.strip_prefix("[method]") {
50        name
51    } else {
52        name
53    }
54}
55
56fn rpc_result_type<T: Borrow<Type>>(
57    host_resources: &HashMap<Box<str>, HashMap<Box<str>, (ResourceType, ResourceType)>>,
58    results_ty: impl IntoIterator<Item = T>,
59) -> Option<Option<Type>> {
60    let rpc_err_ty = host_resources
61        .get("wrpc:rpc/error@0.1.0")
62        .and_then(|instance| instance.get("error"));
63    let mut results_ty = results_ty.into_iter();
64    match (
65        rpc_err_ty,
66        results_ty.next().as_ref().map(Borrow::borrow),
67        results_ty.next(),
68    ) {
69        (Some((guest_rpc_err_ty, host_rpc_err_ty)), Some(Type::Result(result_ty)), None)
70            if *host_rpc_err_ty == ResourceType::host::<Error>()
71                && result_ty.err() == Some(Type::Own(*guest_rpc_err_ty)) =>
72        {
73            Some(result_ty.ok())
74        }
75        _ => None,
76    }
77}
78
79pub struct RemoteResource(pub Bytes);
80
81/// A table of shared resources exported by the component
82#[derive(Debug, Default)]
83pub struct SharedResourceTable(HashMap<Uuid, ResourceAny>);
84
85pub trait WrpcCtx<T: Invoke>: Send {
86    /// Returns context to use for invocation
87    fn context(&self) -> T::Context;
88
89    /// Returns an [Invoke] implementation used to satisfy polyfilled imports
90    fn client(&self) -> &T;
91
92    /// Returns a table of shared exported resources
93    fn shared_resources(&mut self) -> &mut SharedResourceTable;
94
95    /// Optional invocation timeout, component will trap if invocation is not finished within the
96    /// returned [Duration]. If this method returns [None], then no timeout will be used.
97    fn timeout(&self) -> Option<Duration> {
98        None
99    }
100}
101
102pub struct WrpcCtxView<'a, T: Invoke> {
103    pub ctx: &'a mut dyn WrpcCtx<T>,
104    pub table: &'a mut ResourceTable,
105}
106
107pub trait WrpcView: Send {
108    type Invoke: Invoke;
109
110    fn wrpc(&mut self) -> WrpcCtxView<'_, Self::Invoke>;
111}
112
113impl<T: WrpcView> WrpcView for &mut T {
114    type Invoke = T::Invoke;
115
116    fn wrpc(&mut self) -> WrpcCtxView<'_, Self::Invoke> {
117        T::wrpc(self)
118    }
119}
120
121pub trait WrpcViewExt: WrpcView {
122    fn push_invocation(
123        &mut self,
124        invocation: impl Future<
125                Output = anyhow::Result<(
126                    <Self::Invoke as Invoke>::Outgoing,
127                    <Self::Invoke as Invoke>::Incoming,
128                )>,
129            > + Send
130            + 'static,
131    ) -> anyhow::Result<Resource<Invocation>> {
132        self.wrpc()
133            .table
134            .push(Invocation::Future(Box::pin(async move {
135                let res = invocation.await;
136                Box::new(res) as Box<dyn Any + Send>
137            })))
138            .context("failed to push invocation to table")
139    }
140
141    fn get_invocation_result(
142        &mut self,
143        invocation: &Resource<Invocation>,
144    ) -> anyhow::Result<
145        Option<
146            &Box<
147                anyhow::Result<(
148                    <Self::Invoke as Invoke>::Outgoing,
149                    <Self::Invoke as Invoke>::Incoming,
150                )>,
151            >,
152        >,
153    > {
154        let invocation = self
155            .wrpc()
156            .table
157            .get(invocation)
158            .context("failed to get invocation from table")?;
159        match invocation {
160            Invocation::Future(..) => Ok(None),
161            Invocation::Ready(res) => {
162                let res = res.downcast_ref().context("invalid invocation type")?;
163                Ok(Some(res))
164            }
165        }
166    }
167
168    fn delete_invocation(
169        &mut self,
170        invocation: Resource<Invocation>,
171    ) -> anyhow::Result<
172        impl Future<
173            Output = anyhow::Result<(
174                <Self::Invoke as Invoke>::Outgoing,
175                <Self::Invoke as Invoke>::Incoming,
176            )>,
177        >,
178    > {
179        let invocation = self
180            .wrpc()
181            .table
182            .delete(invocation)
183            .context("failed to delete invocation from table")?;
184        Ok(async move {
185            let res = match invocation {
186                Invocation::Future(fut) => fut.await,
187                Invocation::Ready(res) => res,
188            };
189            let res = res
190                .downcast()
191                .map_err(|_| anyhow!("invalid invocation type"))?;
192            *res
193        })
194    }
195
196    fn push_outgoing_channel(
197        &mut self,
198        outgoing: <Self::Invoke as Invoke>::Outgoing,
199    ) -> anyhow::Result<Resource<OutgoingChannel>> {
200        self.wrpc()
201            .table
202            .push(OutgoingChannel(Arc::new(std::sync::RwLock::new(Box::new(
203                outgoing,
204            )))))
205            .context("failed to push outgoing channel to table")
206    }
207
208    fn delete_outgoing_channel(
209        &mut self,
210        outgoing: Resource<OutgoingChannel>,
211    ) -> anyhow::Result<<Self::Invoke as Invoke>::Outgoing> {
212        let OutgoingChannel(outgoing) = self
213            .wrpc()
214            .table
215            .delete(outgoing)
216            .context("failed to delete outgoing channel from table")?;
217        let outgoing =
218            Arc::into_inner(outgoing).context("outgoing channel has an active stream")?;
219        let Ok(outgoing) = outgoing.into_inner() else {
220            bail!("lock poisoned");
221        };
222        let outgoing = outgoing
223            .downcast()
224            .map_err(|_| anyhow!("invalid outgoing channel type"))?;
225        Ok(*outgoing)
226    }
227
228    fn push_incoming_channel(
229        &mut self,
230        incoming: <Self::Invoke as Invoke>::Incoming,
231    ) -> anyhow::Result<Resource<IncomingChannel>> {
232        self.wrpc()
233            .table
234            .push(IncomingChannel(Arc::new(std::sync::RwLock::new(Box::new(
235                incoming,
236            )))))
237            .context("failed to push incoming channel to table")
238    }
239
240    fn delete_incoming_channel(
241        &mut self,
242        incoming: Resource<IncomingChannel>,
243    ) -> anyhow::Result<<Self::Invoke as Invoke>::Incoming> {
244        let IncomingChannel(incoming) = self
245            .wrpc()
246            .table
247            .delete(incoming)
248            .context("failed to delete incoming channel from table")?;
249        let incoming =
250            Arc::into_inner(incoming).context("incoming channel has an active stream")?;
251        let Ok(incoming) = incoming.into_inner() else {
252            bail!("lock poisoned");
253        };
254        let incoming = incoming
255            .downcast()
256            .map_err(|_| anyhow!("invalid incoming channel type"))?;
257        Ok(*incoming)
258    }
259
260    fn push_error(&mut self, error: Error) -> anyhow::Result<Resource<Error>> {
261        self.wrpc()
262            .table
263            .push(error)
264            .context("failed to push error to table")
265    }
266
267    fn get_error(&mut self, error: &Resource<Error>) -> anyhow::Result<&Error> {
268        let error = self
269            .wrpc()
270            .table
271            .get(error)
272            .context("failed to get error from table")?;
273        Ok(error)
274    }
275
276    fn get_error_mut(&mut self, error: &Resource<Error>) -> anyhow::Result<&mut Error> {
277        let error = self
278            .wrpc()
279            .table
280            .get_mut(error)
281            .context("failed to get error from table")?;
282        Ok(error)
283    }
284
285    fn delete_error(&mut self, error: Resource<Error>) -> anyhow::Result<Error> {
286        let error = self
287            .wrpc()
288            .table
289            .delete(error)
290            .context("failed to delete error from table")?;
291        Ok(error)
292    }
293
294    fn push_context(
295        &mut self,
296        cx: <Self::Invoke as Invoke>::Context,
297    ) -> anyhow::Result<Resource<Context>>
298    where
299        <Self::Invoke as Invoke>::Context: 'static,
300    {
301        self.wrpc()
302            .table
303            .push(Context(Box::new(cx)))
304            .context("failed to push context to table")
305    }
306
307    fn delete_context(
308        &mut self,
309        cx: Resource<Context>,
310    ) -> anyhow::Result<<Self::Invoke as Invoke>::Context>
311    where
312        <Self::Invoke as Invoke>::Context: 'static,
313    {
314        let Context(cx) = self
315            .wrpc()
316            .table
317            .delete(cx)
318            .context("failed to delete context from table")?;
319        let cx = cx.downcast().map_err(|_| anyhow!("invalid context type"))?;
320        Ok(*cx)
321    }
322}
323
324impl<T: WrpcView> WrpcViewExt for T {}
325
326/// Error type returned by [call]
327pub enum CallError {
328    Decode(anyhow::Error),
329    Encode(anyhow::Error),
330    Table(anyhow::Error),
331    Call(anyhow::Error),
332    TypeMismatch(anyhow::Error),
333    Write(anyhow::Error),
334    Flush(anyhow::Error),
335    Deferred(anyhow::Error),
336    PostReturn(anyhow::Error),
337    Guest(Error),
338}
339
340impl core::error::Error for CallError {}
341
342impl fmt::Debug for CallError {
343    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
344        match self {
345            CallError::Decode(error)
346            | CallError::Encode(error)
347            | CallError::Table(error)
348            | CallError::Call(error)
349            | CallError::TypeMismatch(error)
350            | CallError::Write(error)
351            | CallError::Flush(error)
352            | CallError::Deferred(error)
353            | CallError::PostReturn(error) => error.fmt(f),
354            CallError::Guest(error) => error.fmt(f),
355        }
356    }
357}
358
359impl fmt::Display for CallError {
360    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
361        match self {
362            CallError::Decode(error)
363            | CallError::Encode(error)
364            | CallError::Table(error)
365            | CallError::Call(error)
366            | CallError::TypeMismatch(error)
367            | CallError::Write(error)
368            | CallError::Flush(error)
369            | CallError::Deferred(error)
370            | CallError::PostReturn(error) => error.fmt(f),
371            CallError::Guest(error) => error.fmt(f),
372        }
373    }
374}
375
376#[allow(clippy::too_many_arguments)]
377pub async fn call<C, I, O>(
378    mut store: C,
379    rx: I,
380    mut tx: O,
381    guest_resources: &[ResourceType],
382    host_resources: &HashMap<Box<str>, HashMap<Box<str>, (ResourceType, ResourceType)>>,
383    params_ty: impl ExactSizeIterator<Item = &Type>,
384    results_ty: &[Type],
385    func: Func,
386) -> Result<(), CallError>
387where
388    I: AsyncRead + wrpc_transport::Index<I> + Send + Sync + Unpin + 'static,
389    O: AsyncWrite + wrpc_transport::Index<O> + Send + Sync + Unpin + 'static,
390    C: AsContextMut,
391    C::Data: WrpcView,
392{
393    let mut params = vec![Val::Bool(false); params_ty.len()];
394    let mut rx = pin!(rx);
395    for (i, (v, ty)) in zip(&mut params, params_ty).enumerate() {
396        read_value(&mut store, &mut rx, guest_resources, v, ty, &[i])
397            .await
398            .with_context(|| format!("failed to decode parameter value {i}"))
399            .map_err(CallError::Decode)?;
400    }
401    let mut results = vec![Val::Bool(false); results_ty.len()];
402    func.call_async(&mut store, &params, &mut results)
403        .await
404        .context("failed to call function")
405        .map_err(CallError::Call)?;
406
407    let mut buf = BytesMut::default();
408    let mut deferred = vec![];
409    match (
410        &rpc_result_type(host_resources, results_ty),
411        results.as_slice(),
412    ) {
413        (None, results) => {
414            for (i, (v, ty)) in zip(results, results_ty).enumerate() {
415                let mut enc = ValEncoder::new(store.as_context_mut(), ty, guest_resources);
416                enc.encode(v, &mut buf)
417                    .with_context(|| format!("failed to encode result value {i}"))
418                    .map_err(CallError::Encode)?;
419                deferred.push(enc.deferred);
420            }
421        }
422        // `result<_, rpc-eror>`
423        (Some(None), [Val::Result(Ok(None))]) => {}
424        // `result<T, rpc-eror>`
425        (Some(Some(ty)), [Val::Result(Ok(Some(v)))]) => {
426            let mut enc = ValEncoder::new(store.as_context_mut(), ty, guest_resources);
427            enc.encode(v, &mut buf)
428                .context("failed to encode result value 0")
429                .map_err(CallError::Encode)?;
430            deferred.push(enc.deferred);
431        }
432        (Some(..), [Val::Result(Err(Some(err)))]) => {
433            let Val::Resource(err) = &**err else {
434                return Err(CallError::TypeMismatch(anyhow!(
435                    "RPC result error value is not a resource"
436                )));
437            };
438            let mut store = store.as_context_mut();
439            let err = err
440                .try_into_resource(&mut store)
441                .context("RPC result error resource type mismatch")
442                .map_err(CallError::TypeMismatch)?;
443            let err = store
444                .data_mut()
445                .delete_error(err)
446                .map_err(CallError::Table)?;
447            return Err(CallError::Guest(err));
448        }
449        _ => return Err(CallError::TypeMismatch(anyhow!("RPC result type mismatch"))),
450    }
451
452    debug!("transmitting results");
453    tx.write_all(&buf)
454        .await
455        .context("failed to transmit results")
456        .map_err(CallError::Write)?;
457    tx.flush()
458        .await
459        .context("failed to flush outgoing stream")
460        .map_err(CallError::Flush)?;
461    if let Err(err) = tx.shutdown().await {
462        trace!(?err, "failed to shutdown outgoing stream");
463    }
464    try_join_all(
465        zip(0.., deferred)
466            .filter_map(|(i, f)| f.map(|f| (tx.index(&[i]), f)))
467            .map(|(w, f)| async move {
468                let w = w?;
469                f(w).await
470            }),
471    )
472    .await
473    .map_err(CallError::Deferred)?;
474    func.post_return_async(&mut store)
475        .await
476        .context("failed to perform post-return cleanup")
477        .map_err(CallError::PostReturn)?;
478    Ok(())
479}
480
481/// Recursively iterates the component item type and collects all exported resource types
482#[instrument(level = "debug", skip_all)]
483pub fn collect_item_resource_exports(
484    engine: &Engine,
485    ty: types::ComponentItem,
486    resources: &mut impl Extend<types::ResourceType>,
487) {
488    match ty {
489        types::ComponentItem::ComponentFunc(_)
490        | types::ComponentItem::CoreFunc(_)
491        | types::ComponentItem::Module(_)
492        | types::ComponentItem::Type(_) => {}
493        types::ComponentItem::Component(ty) => {
494            collect_component_resource_exports(engine, &ty, resources)
495        }
496
497        types::ComponentItem::ComponentInstance(ty) => {
498            collect_instance_resource_exports(engine, &ty, resources)
499        }
500        types::ComponentItem::Resource(ty) => {
501            debug!(?ty, "collect resource export");
502            resources.extend([ty])
503        }
504    }
505}
506
507/// Recursively iterates the instance type and collects all exported resource types
508#[instrument(level = "debug", skip_all)]
509pub fn collect_instance_resource_exports(
510    engine: &Engine,
511    ty: &types::ComponentInstance,
512    resources: &mut impl Extend<types::ResourceType>,
513) {
514    for (name, ty) in ty.exports(engine) {
515        trace!(name, ?ty, "collect instance item resource exports");
516        collect_item_resource_exports(engine, ty, resources);
517    }
518}
519
520/// Recursively iterates the component type and collects all exported resource types
521#[instrument(level = "debug", skip_all)]
522pub fn collect_component_resource_exports(
523    engine: &Engine,
524    ty: &types::Component,
525    resources: &mut impl Extend<types::ResourceType>,
526) {
527    for (name, ty) in ty.exports(engine) {
528        trace!(name, ?ty, "collect component item resource exports");
529        collect_item_resource_exports(engine, ty, resources);
530    }
531}
532
533/// Iterates the component type and collects all imported resource types
534#[instrument(level = "debug", skip_all)]
535pub fn collect_component_resource_imports(
536    engine: &Engine,
537    ty: &types::Component,
538    resources: &mut BTreeMap<Box<str>, HashMap<Box<str>, types::ResourceType>>,
539) {
540    for (name, ty) in ty.imports(engine) {
541        match ty {
542            types::ComponentItem::ComponentFunc(..)
543            | types::ComponentItem::CoreFunc(..)
544            | types::ComponentItem::Module(..)
545            | types::ComponentItem::Type(..)
546            | types::ComponentItem::Component(..) => {}
547            types::ComponentItem::ComponentInstance(ty) => {
548                let instance = name;
549                for (name, ty) in ty.exports(engine) {
550                    if let types::ComponentItem::Resource(ty) = ty {
551                        debug!(instance, name, ?ty, "collect instance resource import");
552                        if let Some(resources) = resources.get_mut(instance) {
553                            resources.insert(name.into(), ty);
554                        } else {
555                            resources.insert(instance.into(), HashMap::from([(name.into(), ty)]));
556                        }
557                    }
558                }
559            }
560            types::ComponentItem::Resource(ty) => {
561                debug!(name, "collect component resource import");
562                if let Some(resources) = resources.get_mut("") {
563                    resources.insert(name.into(), ty);
564                } else {
565                    resources.insert("".into(), HashMap::from([(name.into(), ty)]));
566                }
567            }
568        }
569    }
570}