wrpc_runtime_wasmtime/rpc/host/
transport.rs

1use core::marker::PhantomData;
2
3use std::sync::Arc;
4
5use anyhow::{bail, Context as _};
6use wasmtime::component::Resource;
7use wasmtime_wasi::p2::bindings::io::poll::Pollable;
8use wasmtime_wasi::p2::bindings::io::streams::{InputStream, OutputStream};
9use wasmtime_wasi::p2::pipe::{AsyncReadStream, AsyncWriteStream};
10use wasmtime_wasi::p2::subscribe;
11use wrpc_transport::{Index as _, Invoke};
12
13use crate::bindings::rpc::error::Error;
14use crate::bindings::rpc::transport::{
15    Host, HostIncomingChannel, HostInvocation, HostOutgoingChannel, IncomingChannel, Invocation,
16    OutgoingChannel,
17};
18use crate::rpc::{IncomingChannelStream, OutgoingChannelStream, WrpcRpcImpl};
19use crate::{WrpcView, WrpcViewExt as _};
20
21impl<T: WrpcView> Host for WrpcRpcImpl<T> {}
22
23impl<T: WrpcView> HostInvocation for WrpcRpcImpl<T> {
24    fn subscribe(
25        &mut self,
26        invocation: Resource<Invocation>,
27    ) -> wasmtime::Result<Resource<Pollable>> {
28        subscribe(self.0.wrpc().table, invocation)
29    }
30
31    async fn finish(
32        &mut self,
33        invocation: Resource<Invocation>,
34    ) -> wasmtime::Result<
35        Result<(Resource<OutgoingChannel>, Resource<IncomingChannel>), Resource<Error>>,
36    > {
37        let invocation = self.0.delete_invocation(invocation)?;
38        match invocation.await {
39            Ok((tx, rx)) => {
40                let rx = self.0.push_incoming_channel(rx)?;
41                let tx = self.0.push_outgoing_channel(tx)?;
42                Ok(Ok((tx, rx)))
43            }
44            Err(error) => {
45                let error = self.0.push_error(Error::Invoke(error))?;
46                Ok(Err(error))
47            }
48        }
49    }
50
51    fn drop(&mut self, invocation: Resource<Invocation>) -> wasmtime::Result<()> {
52        _ = self.0.delete_invocation(invocation)?;
53        Ok(())
54    }
55}
56
57impl<T: WrpcView> HostIncomingChannel for WrpcRpcImpl<T> {
58    fn data(
59        &mut self,
60        incoming: Resource<IncomingChannel>,
61    ) -> wasmtime::Result<Option<Resource<InputStream>>> {
62        let IncomingChannel(stream) = self
63            .0
64            .wrpc()
65            .table
66            .get_mut(&incoming)
67            .context("failed to get incoming channel from table")?;
68        if Arc::get_mut(stream).is_none() {
69            return Ok(None);
70        }
71        let stream = Arc::clone(stream);
72        let stream = self
73            .0
74            .wrpc()
75            .table
76            .push_child(
77                Box::new(AsyncReadStream::new(IncomingChannelStream {
78                    incoming: IncomingChannel(stream),
79                    _ty: PhantomData::<<T::Invoke as Invoke>::Incoming>,
80                })) as InputStream,
81                &incoming,
82            )
83            .context("failed to push input stream to table")?;
84        Ok(Some(stream))
85    }
86
87    fn index(
88        &mut self,
89        incoming: Resource<IncomingChannel>,
90        path: Vec<u32>,
91    ) -> wasmtime::Result<Result<Resource<IncomingChannel>, Resource<Error>>> {
92        let path = path
93            .into_iter()
94            .map(usize::try_from)
95            .collect::<Result<Box<[_]>, _>>()
96            .context("failed to construct subscription path")?;
97        let IncomingChannel(incoming) = self
98            .0
99            .wrpc()
100            .table
101            .get(&incoming)
102            .context("failed to get incoming channel from table")?;
103        let incoming = {
104            let Ok(incoming) = incoming.read() else {
105                bail!("lock poisoned");
106            };
107            let incoming = incoming
108                .downcast_ref::<<T::Invoke as Invoke>::Incoming>()
109                .context("invalid incoming channel type")?;
110            incoming.index(&path)
111        };
112        match incoming {
113            Ok(incoming) => {
114                let incoming = self.0.push_incoming_channel(incoming)?;
115                Ok(Ok(incoming))
116            }
117            Err(error) => {
118                let error = self.0.push_error(Error::IncomingIndex(error))?;
119                Ok(Err(error))
120            }
121        }
122    }
123
124    fn drop(&mut self, incoming: Resource<IncomingChannel>) -> wasmtime::Result<()> {
125        self.0.delete_incoming_channel(incoming)?;
126        Ok(())
127    }
128}
129
130impl<T: WrpcView> HostOutgoingChannel for WrpcRpcImpl<T> {
131    fn data(
132        &mut self,
133        outgoing: Resource<OutgoingChannel>,
134    ) -> wasmtime::Result<Option<Resource<OutputStream>>> {
135        let OutgoingChannel(stream) = self
136            .0
137            .wrpc()
138            .table
139            .get_mut(&outgoing)
140            .context("failed to get outgoing channel from table")?;
141        if Arc::get_mut(stream).is_none() {
142            return Ok(None);
143        }
144        let stream = Arc::clone(stream);
145        let stream = self
146            .0
147            .wrpc()
148            .table
149            .push_child(
150                Box::new(AsyncWriteStream::new(
151                    8192,
152                    OutgoingChannelStream {
153                        outgoing: OutgoingChannel(stream),
154                        _ty: PhantomData::<<T::Invoke as Invoke>::Outgoing>,
155                    },
156                )) as OutputStream,
157                &outgoing,
158            )
159            .context("failed to push output stream to table")?;
160        Ok(Some(stream))
161    }
162
163    fn index(
164        &mut self,
165        outgoing: Resource<OutgoingChannel>,
166        path: Vec<u32>,
167    ) -> wasmtime::Result<Result<Resource<OutgoingChannel>, Resource<Error>>> {
168        let path = path
169            .into_iter()
170            .map(usize::try_from)
171            .collect::<Result<Box<[_]>, _>>()
172            .context("failed to construct subscription path")?;
173        let OutgoingChannel(outgoing) = self
174            .0
175            .wrpc()
176            .table
177            .get(&outgoing)
178            .context("failed to get outgoing channel from table")?;
179        let incoming = {
180            let Ok(outgoing) = outgoing.read() else {
181                bail!("lock poisoned");
182            };
183            let outgoing = outgoing
184                .downcast_ref::<<T::Invoke as Invoke>::Outgoing>()
185                .context("invalid outgoing channel type")?;
186            outgoing.index(&path)
187        };
188        match incoming {
189            Ok(outgoing) => {
190                let outgoing = self.0.push_outgoing_channel(outgoing)?;
191                Ok(Ok(outgoing))
192            }
193            Err(error) => {
194                let error = self.0.push_error(Error::OutgoingIndex(error))?;
195                Ok(Err(error))
196            }
197        }
198    }
199
200    fn drop(&mut self, outgoing: Resource<OutgoingChannel>) -> wasmtime::Result<()> {
201        self.0.delete_outgoing_channel(outgoing)?;
202        Ok(())
203    }
204}