1#![allow(clippy::type_complexity)] use core::any::Any;
4use core::borrow::Borrow;
5use core::fmt;
6use core::future::Future;
7use core::iter::zip;
8use core::pin::pin;
9use core::time::Duration;
10
11use std::collections::{BTreeMap, HashMap};
12use std::sync::Arc;
13
14use anyhow::{anyhow, bail, Context as _};
15use bytes::{Bytes, BytesMut};
16use futures::future::try_join_all;
17use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt as _};
18use tokio_util::codec::Encoder;
19use tracing::{debug, instrument, trace, warn};
20use uuid::Uuid;
21use wasmtime::component::{
22 types, Func, Resource, ResourceAny, ResourceTable, ResourceType, Type, Val,
23};
24use wasmtime::{AsContextMut, Engine};
25use wrpc_transport::Invoke;
26
27use crate::bindings::rpc::context::Context;
28use crate::bindings::rpc::error::Error;
29use crate::bindings::rpc::transport::{IncomingChannel, Invocation, OutgoingChannel};
30
31pub mod bindings;
32mod codec;
33mod polyfill;
34pub mod rpc;
35mod serve;
36
37pub use codec::*;
38pub use polyfill::*;
39pub use serve::*;
40
41fn rpc_func_name(name: &str) -> &str {
45 if let Some(name) = name.strip_prefix("[constructor]") {
46 name
47 } else if let Some(name) = name.strip_prefix("[static]") {
48 name
49 } else if let Some(name) = name.strip_prefix("[method]") {
50 name
51 } else {
52 name
53 }
54}
55
56fn rpc_result_type<T: Borrow<Type>>(
57 host_resources: &HashMap<Box<str>, HashMap<Box<str>, (ResourceType, ResourceType)>>,
58 results_ty: impl IntoIterator<Item = T>,
59) -> Option<Option<Type>> {
60 let rpc_err_ty = host_resources
61 .get("wrpc:rpc/error@0.1.0")
62 .and_then(|instance| instance.get("error"));
63 let mut results_ty = results_ty.into_iter();
64 match (
65 rpc_err_ty,
66 results_ty.next().as_ref().map(Borrow::borrow),
67 results_ty.next(),
68 ) {
69 (Some((guest_rpc_err_ty, host_rpc_err_ty)), Some(Type::Result(result_ty)), None)
70 if *host_rpc_err_ty == ResourceType::host::<Error>()
71 && result_ty.err() == Some(Type::Own(*guest_rpc_err_ty)) =>
72 {
73 Some(result_ty.ok())
74 }
75 _ => None,
76 }
77}
78
79pub struct RemoteResource(pub Bytes);
80
81#[derive(Debug, Default)]
83pub struct SharedResourceTable(HashMap<Uuid, ResourceAny>);
84
85pub trait WrpcCtx<T: Invoke>: Send {
86 fn context(&self) -> T::Context;
88
89 fn client(&self) -> &T;
91
92 fn shared_resources(&mut self) -> &mut SharedResourceTable;
94
95 fn timeout(&self) -> Option<Duration> {
98 None
99 }
100}
101
102pub struct WrpcCtxView<'a, T: Invoke> {
103 pub ctx: &'a mut dyn WrpcCtx<T>,
104 pub table: &'a mut ResourceTable,
105}
106
107pub trait WrpcView: Send {
108 type Invoke: Invoke;
109
110 fn wrpc(&mut self) -> WrpcCtxView<'_, Self::Invoke>;
111}
112
113impl<T: WrpcView> WrpcView for &mut T {
114 type Invoke = T::Invoke;
115
116 fn wrpc(&mut self) -> WrpcCtxView<'_, Self::Invoke> {
117 T::wrpc(self)
118 }
119}
120
121pub trait WrpcViewExt: WrpcView {
122 fn push_invocation(
123 &mut self,
124 invocation: impl Future<
125 Output = anyhow::Result<(
126 <Self::Invoke as Invoke>::Outgoing,
127 <Self::Invoke as Invoke>::Incoming,
128 )>,
129 > + Send
130 + 'static,
131 ) -> anyhow::Result<Resource<Invocation>> {
132 self.wrpc()
133 .table
134 .push(Invocation::Future(Box::pin(async move {
135 let res = invocation.await;
136 Box::new(res) as Box<dyn Any + Send>
137 })))
138 .context("failed to push invocation to table")
139 }
140
141 fn get_invocation_result(
142 &mut self,
143 invocation: &Resource<Invocation>,
144 ) -> anyhow::Result<
145 Option<
146 &Box<
147 anyhow::Result<(
148 <Self::Invoke as Invoke>::Outgoing,
149 <Self::Invoke as Invoke>::Incoming,
150 )>,
151 >,
152 >,
153 > {
154 let invocation = self
155 .wrpc()
156 .table
157 .get(invocation)
158 .context("failed to get invocation from table")?;
159 match invocation {
160 Invocation::Future(..) => Ok(None),
161 Invocation::Ready(res) => {
162 let res = res.downcast_ref().context("invalid invocation type")?;
163 Ok(Some(res))
164 }
165 }
166 }
167
168 fn delete_invocation(
169 &mut self,
170 invocation: Resource<Invocation>,
171 ) -> anyhow::Result<
172 impl Future<
173 Output = anyhow::Result<(
174 <Self::Invoke as Invoke>::Outgoing,
175 <Self::Invoke as Invoke>::Incoming,
176 )>,
177 >,
178 > {
179 let invocation = self
180 .wrpc()
181 .table
182 .delete(invocation)
183 .context("failed to delete invocation from table")?;
184 Ok(async move {
185 let res = match invocation {
186 Invocation::Future(fut) => fut.await,
187 Invocation::Ready(res) => res,
188 };
189 let res = res
190 .downcast()
191 .map_err(|_| anyhow!("invalid invocation type"))?;
192 *res
193 })
194 }
195
196 fn push_outgoing_channel(
197 &mut self,
198 outgoing: <Self::Invoke as Invoke>::Outgoing,
199 ) -> anyhow::Result<Resource<OutgoingChannel>> {
200 self.wrpc()
201 .table
202 .push(OutgoingChannel(Arc::new(std::sync::RwLock::new(Box::new(
203 outgoing,
204 )))))
205 .context("failed to push outgoing channel to table")
206 }
207
208 fn delete_outgoing_channel(
209 &mut self,
210 outgoing: Resource<OutgoingChannel>,
211 ) -> anyhow::Result<<Self::Invoke as Invoke>::Outgoing> {
212 let OutgoingChannel(outgoing) = self
213 .wrpc()
214 .table
215 .delete(outgoing)
216 .context("failed to delete outgoing channel from table")?;
217 let outgoing =
218 Arc::into_inner(outgoing).context("outgoing channel has an active stream")?;
219 let Ok(outgoing) = outgoing.into_inner() else {
220 bail!("lock poisoned");
221 };
222 let outgoing = outgoing
223 .downcast()
224 .map_err(|_| anyhow!("invalid outgoing channel type"))?;
225 Ok(*outgoing)
226 }
227
228 fn push_incoming_channel(
229 &mut self,
230 incoming: <Self::Invoke as Invoke>::Incoming,
231 ) -> anyhow::Result<Resource<IncomingChannel>> {
232 self.wrpc()
233 .table
234 .push(IncomingChannel(Arc::new(std::sync::RwLock::new(Box::new(
235 incoming,
236 )))))
237 .context("failed to push incoming channel to table")
238 }
239
240 fn delete_incoming_channel(
241 &mut self,
242 incoming: Resource<IncomingChannel>,
243 ) -> anyhow::Result<<Self::Invoke as Invoke>::Incoming> {
244 let IncomingChannel(incoming) = self
245 .wrpc()
246 .table
247 .delete(incoming)
248 .context("failed to delete incoming channel from table")?;
249 let incoming =
250 Arc::into_inner(incoming).context("incoming channel has an active stream")?;
251 let Ok(incoming) = incoming.into_inner() else {
252 bail!("lock poisoned");
253 };
254 let incoming = incoming
255 .downcast()
256 .map_err(|_| anyhow!("invalid incoming channel type"))?;
257 Ok(*incoming)
258 }
259
260 fn push_error(&mut self, error: Error) -> anyhow::Result<Resource<Error>> {
261 self.wrpc()
262 .table
263 .push(error)
264 .context("failed to push error to table")
265 }
266
267 fn get_error(&mut self, error: &Resource<Error>) -> anyhow::Result<&Error> {
268 let error = self
269 .wrpc()
270 .table
271 .get(error)
272 .context("failed to get error from table")?;
273 Ok(error)
274 }
275
276 fn get_error_mut(&mut self, error: &Resource<Error>) -> anyhow::Result<&mut Error> {
277 let error = self
278 .wrpc()
279 .table
280 .get_mut(error)
281 .context("failed to get error from table")?;
282 Ok(error)
283 }
284
285 fn delete_error(&mut self, error: Resource<Error>) -> anyhow::Result<Error> {
286 let error = self
287 .wrpc()
288 .table
289 .delete(error)
290 .context("failed to delete error from table")?;
291 Ok(error)
292 }
293
294 fn push_context(
295 &mut self,
296 cx: <Self::Invoke as Invoke>::Context,
297 ) -> anyhow::Result<Resource<Context>>
298 where
299 <Self::Invoke as Invoke>::Context: 'static,
300 {
301 self.wrpc()
302 .table
303 .push(Context(Box::new(cx)))
304 .context("failed to push context to table")
305 }
306
307 fn delete_context(
308 &mut self,
309 cx: Resource<Context>,
310 ) -> anyhow::Result<<Self::Invoke as Invoke>::Context>
311 where
312 <Self::Invoke as Invoke>::Context: 'static,
313 {
314 let Context(cx) = self
315 .wrpc()
316 .table
317 .delete(cx)
318 .context("failed to delete context from table")?;
319 let cx = cx.downcast().map_err(|_| anyhow!("invalid context type"))?;
320 Ok(*cx)
321 }
322}
323
324impl<T: WrpcView> WrpcViewExt for T {}
325
326pub enum CallError {
328 Decode(anyhow::Error),
329 Encode(anyhow::Error),
330 Table(anyhow::Error),
331 Call(anyhow::Error),
332 TypeMismatch(anyhow::Error),
333 Write(anyhow::Error),
334 Flush(anyhow::Error),
335 Deferred(anyhow::Error),
336 PostReturn(anyhow::Error),
337 Guest(Error),
338}
339
340impl core::error::Error for CallError {}
341
342impl fmt::Debug for CallError {
343 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
344 match self {
345 CallError::Decode(error)
346 | CallError::Encode(error)
347 | CallError::Table(error)
348 | CallError::Call(error)
349 | CallError::TypeMismatch(error)
350 | CallError::Write(error)
351 | CallError::Flush(error)
352 | CallError::Deferred(error)
353 | CallError::PostReturn(error) => error.fmt(f),
354 CallError::Guest(error) => error.fmt(f),
355 }
356 }
357}
358
359impl fmt::Display for CallError {
360 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
361 match self {
362 CallError::Decode(error)
363 | CallError::Encode(error)
364 | CallError::Table(error)
365 | CallError::Call(error)
366 | CallError::TypeMismatch(error)
367 | CallError::Write(error)
368 | CallError::Flush(error)
369 | CallError::Deferred(error)
370 | CallError::PostReturn(error) => error.fmt(f),
371 CallError::Guest(error) => error.fmt(f),
372 }
373 }
374}
375
376#[allow(clippy::too_many_arguments)]
377pub async fn call<C, I, O>(
378 mut store: C,
379 rx: I,
380 mut tx: O,
381 guest_resources: &[ResourceType],
382 host_resources: &HashMap<Box<str>, HashMap<Box<str>, (ResourceType, ResourceType)>>,
383 params_ty: impl ExactSizeIterator<Item = &Type>,
384 results_ty: &[Type],
385 func: Func,
386) -> Result<(), CallError>
387where
388 I: AsyncRead + wrpc_transport::Index<I> + Send + Sync + Unpin + 'static,
389 O: AsyncWrite + wrpc_transport::Index<O> + Send + Sync + Unpin + 'static,
390 C: AsContextMut,
391 C::Data: WrpcView,
392{
393 let mut params = vec![Val::Bool(false); params_ty.len()];
394 let mut rx = pin!(rx);
395 for (i, (v, ty)) in zip(&mut params, params_ty).enumerate() {
396 read_value(&mut store, &mut rx, guest_resources, v, ty, &[i])
397 .await
398 .with_context(|| format!("failed to decode parameter value {i}"))
399 .map_err(CallError::Decode)?;
400 }
401 let mut results = vec![Val::Bool(false); results_ty.len()];
402 func.call_async(&mut store, ¶ms, &mut results)
403 .await
404 .context("failed to call function")
405 .map_err(CallError::Call)?;
406
407 let mut buf = BytesMut::default();
408 let mut deferred = vec![];
409 match (
410 &rpc_result_type(host_resources, results_ty),
411 results.as_slice(),
412 ) {
413 (None, results) => {
414 for (i, (v, ty)) in zip(results, results_ty).enumerate() {
415 let mut enc = ValEncoder::new(store.as_context_mut(), ty, guest_resources);
416 enc.encode(v, &mut buf)
417 .with_context(|| format!("failed to encode result value {i}"))
418 .map_err(CallError::Encode)?;
419 deferred.push(enc.deferred);
420 }
421 }
422 (Some(None), [Val::Result(Ok(None))]) => {}
424 (Some(Some(ty)), [Val::Result(Ok(Some(v)))]) => {
426 let mut enc = ValEncoder::new(store.as_context_mut(), ty, guest_resources);
427 enc.encode(v, &mut buf)
428 .context("failed to encode result value 0")
429 .map_err(CallError::Encode)?;
430 deferred.push(enc.deferred);
431 }
432 (Some(..), [Val::Result(Err(Some(err)))]) => {
433 let Val::Resource(err) = &**err else {
434 return Err(CallError::TypeMismatch(anyhow!(
435 "RPC result error value is not a resource"
436 )));
437 };
438 let mut store = store.as_context_mut();
439 let err = err
440 .try_into_resource(&mut store)
441 .context("RPC result error resource type mismatch")
442 .map_err(CallError::TypeMismatch)?;
443 let err = store
444 .data_mut()
445 .delete_error(err)
446 .map_err(CallError::Table)?;
447 return Err(CallError::Guest(err));
448 }
449 _ => return Err(CallError::TypeMismatch(anyhow!("RPC result type mismatch"))),
450 }
451
452 debug!("transmitting results");
453 tx.write_all(&buf)
454 .await
455 .context("failed to transmit results")
456 .map_err(CallError::Write)?;
457 tx.flush()
458 .await
459 .context("failed to flush outgoing stream")
460 .map_err(CallError::Flush)?;
461 if let Err(err) = tx.shutdown().await {
462 trace!(?err, "failed to shutdown outgoing stream");
463 }
464 try_join_all(
465 zip(0.., deferred)
466 .filter_map(|(i, f)| f.map(|f| (tx.index(&[i]), f)))
467 .map(|(w, f)| async move {
468 let w = w?;
469 f(w).await
470 }),
471 )
472 .await
473 .map_err(CallError::Deferred)?;
474 func.post_return_async(&mut store)
475 .await
476 .context("failed to perform post-return cleanup")
477 .map_err(CallError::PostReturn)?;
478 Ok(())
479}
480
481#[instrument(level = "debug", skip_all)]
483pub fn collect_item_resource_exports(
484 engine: &Engine,
485 ty: types::ComponentItem,
486 resources: &mut impl Extend<types::ResourceType>,
487) {
488 match ty {
489 types::ComponentItem::ComponentFunc(_)
490 | types::ComponentItem::CoreFunc(_)
491 | types::ComponentItem::Module(_)
492 | types::ComponentItem::Type(_) => {}
493 types::ComponentItem::Component(ty) => {
494 collect_component_resource_exports(engine, &ty, resources)
495 }
496
497 types::ComponentItem::ComponentInstance(ty) => {
498 collect_instance_resource_exports(engine, &ty, resources)
499 }
500 types::ComponentItem::Resource(ty) => {
501 debug!(?ty, "collect resource export");
502 resources.extend([ty])
503 }
504 }
505}
506
507#[instrument(level = "debug", skip_all)]
509pub fn collect_instance_resource_exports(
510 engine: &Engine,
511 ty: &types::ComponentInstance,
512 resources: &mut impl Extend<types::ResourceType>,
513) {
514 for (name, ty) in ty.exports(engine) {
515 trace!(name, ?ty, "collect instance item resource exports");
516 collect_item_resource_exports(engine, ty, resources);
517 }
518}
519
520#[instrument(level = "debug", skip_all)]
522pub fn collect_component_resource_exports(
523 engine: &Engine,
524 ty: &types::Component,
525 resources: &mut impl Extend<types::ResourceType>,
526) {
527 for (name, ty) in ty.exports(engine) {
528 trace!(name, ?ty, "collect component item resource exports");
529 collect_item_resource_exports(engine, ty, resources);
530 }
531}
532
533#[instrument(level = "debug", skip_all)]
535pub fn collect_component_resource_imports(
536 engine: &Engine,
537 ty: &types::Component,
538 resources: &mut BTreeMap<Box<str>, HashMap<Box<str>, types::ResourceType>>,
539) {
540 for (name, ty) in ty.imports(engine) {
541 match ty {
542 types::ComponentItem::ComponentFunc(..)
543 | types::ComponentItem::CoreFunc(..)
544 | types::ComponentItem::Module(..)
545 | types::ComponentItem::Type(..)
546 | types::ComponentItem::Component(..) => {}
547 types::ComponentItem::ComponentInstance(ty) => {
548 let instance = name;
549 for (name, ty) in ty.exports(engine) {
550 if let types::ComponentItem::Resource(ty) = ty {
551 debug!(instance, name, ?ty, "collect instance resource import");
552 if let Some(resources) = resources.get_mut(instance) {
553 resources.insert(name.into(), ty);
554 } else {
555 resources.insert(instance.into(), HashMap::from([(name.into(), ty)]));
556 }
557 }
558 }
559 }
560 types::ComponentItem::Resource(ty) => {
561 debug!(name, "collect component resource import");
562 if let Some(resources) = resources.get_mut("") {
563 resources.insert(name.into(), ty);
564 } else {
565 resources.insert("".into(), HashMap::from([(name.into(), ty)]));
566 }
567 }
568 }
569 }
570}