wrpc_runtime_wasmtime/rpc/host/
transport.rs1use core::marker::PhantomData;
2
3use std::sync::Arc;
4
5use anyhow::{bail, Context as _};
6use wasmtime::component::Resource;
7use wasmtime_wasi::p2::bindings::io::poll::Pollable;
8use wasmtime_wasi::p2::bindings::io::streams::{InputStream, OutputStream};
9use wasmtime_wasi::p2::pipe::{AsyncReadStream, AsyncWriteStream};
10use wasmtime_wasi::p2::subscribe;
11use wrpc_transport::{Index as _, Invoke};
12
13use crate::bindings::rpc::error::Error;
14use crate::bindings::rpc::transport::{
15 Host, HostIncomingChannel, HostInvocation, HostOutgoingChannel, IncomingChannel, Invocation,
16 OutgoingChannel,
17};
18use crate::rpc::{IncomingChannelStream, OutgoingChannelStream, WrpcRpcImpl};
19use crate::{WrpcView, WrpcViewExt as _};
20
21impl<T: WrpcView> Host for WrpcRpcImpl<T> {}
22
23impl<T: WrpcView> HostInvocation for WrpcRpcImpl<T> {
24 fn subscribe(
25 &mut self,
26 invocation: Resource<Invocation>,
27 ) -> wasmtime::Result<Resource<Pollable>> {
28 subscribe(self.0.wrpc().table, invocation)
29 }
30
31 async fn finish(
32 &mut self,
33 invocation: Resource<Invocation>,
34 ) -> wasmtime::Result<
35 Result<(Resource<OutgoingChannel>, Resource<IncomingChannel>), Resource<Error>>,
36 > {
37 let invocation = self.0.delete_invocation(invocation)?;
38 match invocation.await {
39 Ok((tx, rx)) => {
40 let rx = self.0.push_incoming_channel(rx)?;
41 let tx = self.0.push_outgoing_channel(tx)?;
42 Ok(Ok((tx, rx)))
43 }
44 Err(error) => {
45 let error = self.0.push_error(Error::Invoke(error))?;
46 Ok(Err(error))
47 }
48 }
49 }
50
51 fn drop(&mut self, invocation: Resource<Invocation>) -> wasmtime::Result<()> {
52 _ = self.0.delete_invocation(invocation)?;
53 Ok(())
54 }
55}
56
57impl<T: WrpcView> HostIncomingChannel for WrpcRpcImpl<T> {
58 fn data(
59 &mut self,
60 incoming: Resource<IncomingChannel>,
61 ) -> wasmtime::Result<Option<Resource<InputStream>>> {
62 let IncomingChannel(stream) = self
63 .0
64 .wrpc()
65 .table
66 .get_mut(&incoming)
67 .context("failed to get incoming channel from table")?;
68 if Arc::get_mut(stream).is_none() {
69 return Ok(None);
70 }
71 let stream = Arc::clone(stream);
72 let stream = self
73 .0
74 .wrpc()
75 .table
76 .push_child(
77 Box::new(AsyncReadStream::new(IncomingChannelStream {
78 incoming: IncomingChannel(stream),
79 _ty: PhantomData::<<T::Invoke as Invoke>::Incoming>,
80 })) as InputStream,
81 &incoming,
82 )
83 .context("failed to push input stream to table")?;
84 Ok(Some(stream))
85 }
86
87 fn index(
88 &mut self,
89 incoming: Resource<IncomingChannel>,
90 path: Vec<u32>,
91 ) -> wasmtime::Result<Result<Resource<IncomingChannel>, Resource<Error>>> {
92 let path = path
93 .into_iter()
94 .map(usize::try_from)
95 .collect::<Result<Box<[_]>, _>>()
96 .context("failed to construct subscription path")?;
97 let IncomingChannel(incoming) = self
98 .0
99 .wrpc()
100 .table
101 .get(&incoming)
102 .context("failed to get incoming channel from table")?;
103 let incoming = {
104 let Ok(incoming) = incoming.read() else {
105 bail!("lock poisoned");
106 };
107 let incoming = incoming
108 .downcast_ref::<<T::Invoke as Invoke>::Incoming>()
109 .context("invalid incoming channel type")?;
110 incoming.index(&path)
111 };
112 match incoming {
113 Ok(incoming) => {
114 let incoming = self.0.push_incoming_channel(incoming)?;
115 Ok(Ok(incoming))
116 }
117 Err(error) => {
118 let error = self.0.push_error(Error::IncomingIndex(error))?;
119 Ok(Err(error))
120 }
121 }
122 }
123
124 fn drop(&mut self, incoming: Resource<IncomingChannel>) -> wasmtime::Result<()> {
125 self.0.delete_incoming_channel(incoming)?;
126 Ok(())
127 }
128}
129
130impl<T: WrpcView> HostOutgoingChannel for WrpcRpcImpl<T> {
131 fn data(
132 &mut self,
133 outgoing: Resource<OutgoingChannel>,
134 ) -> wasmtime::Result<Option<Resource<OutputStream>>> {
135 let OutgoingChannel(stream) = self
136 .0
137 .wrpc()
138 .table
139 .get_mut(&outgoing)
140 .context("failed to get outgoing channel from table")?;
141 if Arc::get_mut(stream).is_none() {
142 return Ok(None);
143 }
144 let stream = Arc::clone(stream);
145 let stream = self
146 .0
147 .wrpc()
148 .table
149 .push_child(
150 Box::new(AsyncWriteStream::new(
151 8192,
152 OutgoingChannelStream {
153 outgoing: OutgoingChannel(stream),
154 _ty: PhantomData::<<T::Invoke as Invoke>::Outgoing>,
155 },
156 )) as OutputStream,
157 &outgoing,
158 )
159 .context("failed to push output stream to table")?;
160 Ok(Some(stream))
161 }
162
163 fn index(
164 &mut self,
165 outgoing: Resource<OutgoingChannel>,
166 path: Vec<u32>,
167 ) -> wasmtime::Result<Result<Resource<OutgoingChannel>, Resource<Error>>> {
168 let path = path
169 .into_iter()
170 .map(usize::try_from)
171 .collect::<Result<Box<[_]>, _>>()
172 .context("failed to construct subscription path")?;
173 let OutgoingChannel(outgoing) = self
174 .0
175 .wrpc()
176 .table
177 .get(&outgoing)
178 .context("failed to get outgoing channel from table")?;
179 let incoming = {
180 let Ok(outgoing) = outgoing.read() else {
181 bail!("lock poisoned");
182 };
183 let outgoing = outgoing
184 .downcast_ref::<<T::Invoke as Invoke>::Outgoing>()
185 .context("invalid outgoing channel type")?;
186 outgoing.index(&path)
187 };
188 match incoming {
189 Ok(outgoing) => {
190 let outgoing = self.0.push_outgoing_channel(outgoing)?;
191 Ok(Ok(outgoing))
192 }
193 Err(error) => {
194 let error = self.0.push_error(Error::OutgoingIndex(error))?;
195 Ok(Err(error))
196 }
197 }
198 }
199
200 fn drop(&mut self, outgoing: Resource<OutgoingChannel>) -> wasmtime::Result<()> {
201 self.0.delete_outgoing_channel(outgoing)?;
202 Ok(())
203 }
204}