use crate::interface::InterfaceGenerator;
use anyhow::{bail, Result};
use heck::{ToSnakeCase, ToUpperCamelCase};
use std::collections::{BTreeMap, HashMap, HashSet};
use std::fmt::{self, Write as _};
use std::mem;
use wit_bindgen_core::wit_parser::{
Flags, FlagsRepr, Function, Int, InterfaceId, Resolve, SizeAlign, TypeId, World, WorldId,
WorldItem, WorldKey,
};
use wit_bindgen_core::{
name_package_module, uwrite, uwriteln, Files, InterfaceGenerator as _, Source, Types,
WorldGenerator,
};
mod interface;
struct InterfaceName {
remapped: bool,
path: String,
}
#[derive(Default)]
struct RustWrpc {
types: Types,
src: Source,
opts: Opts,
import_modules: Vec<(String, Vec<String>)>,
export_modules: Vec<(String, Vec<String>)>,
skip: HashSet<String>,
interface_names: HashMap<InterfaceId, InterfaceName>,
import_funcs_called: bool,
with_name_counter: usize,
generated_interfaces: HashSet<String>,
world: Option<WorldId>,
export_paths: Vec<String>,
with: GenerationConfiguration,
}
#[derive(Default)]
struct GenerationConfiguration {
map: HashMap<String, InterfaceGeneration>,
generate_by_default: bool,
}
impl GenerationConfiguration {
fn get(&self, key: &str) -> Option<&InterfaceGeneration> {
self.map.get(key).or_else(|| {
self.generate_by_default
.then_some(&InterfaceGeneration::Generate)
})
}
fn insert(&mut self, name: String, generate: InterfaceGeneration) {
self.map.insert(name, generate);
}
fn iter(&self) -> impl Iterator<Item = (&String, &InterfaceGeneration)> {
self.map.iter()
}
}
enum InterfaceGeneration {
Remap(String),
Generate,
}
#[cfg(feature = "clap")]
fn parse_with(s: &str) -> Result<(String, WithOption), String> {
let (k, v) = s.split_once('=').ok_or_else(|| {
format!("expected string of form `<key>=<value>[,<key>=<value>...]`; got `{s}`")
})?;
let v = match v {
"generate" => WithOption::Generate,
other => WithOption::Path(other.to_string()),
};
Ok((k.to_string(), v))
}
#[derive(Default, Debug, Clone)]
#[cfg_attr(feature = "clap", derive(clap::Args))]
pub struct Opts {
#[cfg_attr(feature = "clap", arg(long))]
pub format: bool,
#[cfg_attr(feature = "clap", arg(long))]
pub skip: Vec<String>,
#[cfg_attr(feature = "clap", arg(long))]
pub bitflags_path: Option<String>,
#[cfg_attr(feature = "clap", arg(long = "additional_derive_attribute", short = 'd', default_values_t = Vec::<String>::new()))]
pub additional_derive_attributes: Vec<String>,
#[cfg_attr(feature = "clap", arg(long, value_parser = parse_with, value_delimiter = ','))]
pub with: Vec<(String, WithOption)>,
#[cfg_attr(feature = "clap", arg(long))]
pub generate_all: bool,
#[cfg_attr(feature = "clap", arg(long))]
pub generate_unused_types: bool,
#[cfg_attr(feature = "clap", arg(long))]
pub anyhow_path: Option<String>,
#[cfg_attr(feature = "clap", arg(long))]
pub bytes_path: Option<String>,
#[cfg_attr(feature = "clap", arg(long))]
pub futures_path: Option<String>,
#[cfg_attr(feature = "clap", arg(long))]
pub tokio_path: Option<String>,
#[cfg_attr(feature = "clap", arg(long))]
pub tokio_util_path: Option<String>,
#[cfg_attr(feature = "clap", arg(long))]
pub tracing_path: Option<String>,
#[cfg_attr(feature = "clap", arg(long))]
pub wasm_tokio_path: Option<String>,
#[cfg_attr(feature = "clap", arg(long))]
pub wrpc_transport_path: Option<String>,
}
impl Opts {
#[must_use]
pub fn build(self) -> Box<dyn WorldGenerator> {
let mut r = RustWrpc::new();
r.skip = self.skip.iter().cloned().collect();
r.opts = self;
Box::new(r)
}
}
impl RustWrpc {
fn new() -> RustWrpc {
RustWrpc::default()
}
fn interface<'a>(
&'a mut self,
identifier: Identifier<'a>,
resolve: &'a Resolve,
in_import: bool,
) -> InterfaceGenerator<'a> {
let mut sizes = SizeAlign::default();
sizes.fill(resolve);
InterfaceGenerator {
identifier,
src: Source::default(),
in_import,
gen: self,
resolve,
}
}
fn emit_modules(&mut self, modules: Vec<(String, Vec<String>)>) {
#[derive(Default)]
struct Module {
submodules: BTreeMap<String, Module>,
contents: Vec<String>,
}
let mut map = Module::default();
for (module, path) in modules {
let mut cur = &mut map;
for name in &path[..path.len() - 1] {
cur = cur
.submodules
.entry(name.clone())
.or_insert(Module::default());
}
cur.contents.push(module);
}
emit(&mut self.src, map);
fn emit(me: &mut Source, module: Module) {
for (name, submodule) in module.submodules {
uwriteln!(me, "#[allow(dead_code)]");
uwriteln!(me, "pub mod {name} {{");
emit(me, submodule);
uwriteln!(me, "}}");
}
for submodule in module.contents {
uwriteln!(me, "{submodule}");
}
}
}
fn anyhow_path(&self) -> &str {
self.opts
.anyhow_path
.as_deref()
.unwrap_or("::wit_bindgen_wrpc::anyhow")
}
fn bitflags_path(&self) -> &str {
self.opts
.bitflags_path
.as_deref()
.unwrap_or("::wit_bindgen_wrpc::bitflags")
}
fn bytes_path(&self) -> &str {
self.opts
.bytes_path
.as_deref()
.unwrap_or("::wit_bindgen_wrpc::bytes")
}
fn futures_path(&self) -> &str {
self.opts
.futures_path
.as_deref()
.unwrap_or("::wit_bindgen_wrpc::futures")
}
fn tokio_path(&self) -> &str {
self.opts
.tokio_path
.as_deref()
.unwrap_or("::wit_bindgen_wrpc::tokio")
}
fn tokio_util_path(&self) -> &str {
self.opts
.tokio_util_path
.as_deref()
.unwrap_or("::wit_bindgen_wrpc::tokio_util")
}
fn tracing_path(&self) -> &str {
self.opts
.tracing_path
.as_deref()
.unwrap_or("::wit_bindgen_wrpc::tracing")
}
fn wasm_tokio_path(&self) -> &str {
self.opts
.wasm_tokio_path
.as_deref()
.unwrap_or("::wit_bindgen_wrpc::wasm_tokio")
}
fn wrpc_transport_path(&self) -> &str {
self.opts
.wrpc_transport_path
.as_deref()
.unwrap_or("::wit_bindgen_wrpc::wrpc_transport")
}
fn name_interface(
&mut self,
resolve: &Resolve,
id: InterfaceId,
name: &WorldKey,
is_export: bool,
) -> Result<bool> {
let with_name = resolve.name_world_key(name);
let Some(remapping) = self.with.get(&with_name) else {
bail!("no remapping found for {with_name:?} - use the `generate!` macro's `with` option to force the interface to be generated or specify where it is already defined:
```
with: {{\n\t{with_name:?}: generate\n}}
```")
};
self.generated_interfaces.insert(with_name);
let entry = match remapping {
InterfaceGeneration::Remap(remapped_path) => {
let name = format!("__with_name{}", self.with_name_counter);
self.with_name_counter += 1;
uwriteln!(self.src, "use {remapped_path} as {name};");
InterfaceName {
remapped: true,
path: name,
}
}
InterfaceGeneration::Generate => {
let path = compute_module_path(name, resolve, is_export).join("::");
InterfaceName {
remapped: false,
path,
}
}
};
let remapped = entry.remapped;
self.interface_names.insert(id, entry);
Ok(remapped)
}
fn finish_serve_function(&mut self) {
const ROOT: &str = "Handler<T::Context>";
let mut traits: Vec<String> = self
.export_paths
.iter()
.map(|path| {
if path.is_empty() {
ROOT.to_string()
} else {
format!("{path}::{ROOT}")
}
})
.collect();
let bound = match traits.len() {
0 => return,
1 => traits.pop().unwrap(),
_ => traits.join(" + "),
};
let anyhow = self.anyhow_path().to_string();
let futures = self.futures_path().to_string();
let tokio = self.tokio_path().to_string();
let wrpc_transport = self.wrpc_transport_path().to_string();
uwriteln!(
self.src,
r#"
#[allow(clippy::manual_async_fn)]
pub fn serve<'a, T: {wrpc_transport}::Serve>(
wrpc: &'a T,
handler: impl {bound} + ::core::marker::Send + ::core::marker::Sync + ::core::clone::Clone + 'static,
) -> impl ::core::future::Future<
Output = {anyhow}::Result<
::std::vec::Vec<
(
&'static str,
&'static str,
::core::pin::Pin<
::std::boxed::Box<
dyn {futures}::Stream<
Item = {anyhow}::Result<
::core::pin::Pin<
::std::boxed::Box<
dyn ::core::future::Future<
Output = {anyhow}::Result<()>
> + ::core::marker::Send + 'static
>
>
>
> + ::core::marker::Send + 'static
>
>
)
>
>
> + ::core::marker::Send + {wrpc_transport}::Captures<'a> {{
async move {{
let interfaces = {tokio}::try_join!("#
);
for path in &self.export_paths {
if !path.is_empty() {
self.src.push_str(path);
self.src.push_str("::");
}
self.src.push_str("serve_interface(wrpc, handler.clone()),");
}
uwriteln!(
self.src,
r#"
)?;
let mut streams = Vec::new();"#
);
for i in 0..self.export_paths.len() {
uwrite!(
self.src,
r"
for s in interfaces.{i} {{
streams.push(s);
}}"
);
}
uwriteln!(
self.src,
r#"
Ok(streams)
}}
}}"#
);
}
}
impl WorldGenerator for RustWrpc {
fn preprocess(&mut self, resolve: &Resolve, world: WorldId) {
wit_bindgen_core::generated_preamble(&mut self.src, env!("CARGO_PKG_VERSION"));
uwriteln!(self.src, "// Options used:");
if !self.opts.skip.is_empty() {
uwriteln!(self.src, "// * skip: {:?}", self.opts.skip);
}
if !self.opts.additional_derive_attributes.is_empty() {
uwriteln!(
self.src,
"// * additional derives {:?}",
self.opts.additional_derive_attributes
);
}
for (k, v) in &self.opts.with {
uwriteln!(self.src, "// * with {k:?} = {v:?}");
}
self.types.analyze(resolve);
self.world = Some(world);
let world = &resolve.worlds[world];
for (key, item) in world.imports.iter().chain(world.exports.iter()) {
if let WorldItem::Interface { id, .. } = item {
if resolve.interfaces[*id].package == world.package {
let name = resolve.name_world_key(key);
if self.with.get(&name).is_none() {
self.with.insert(name, InterfaceGeneration::Generate);
}
}
}
}
for (k, v) in &self.opts.with {
self.with.insert(k.clone(), v.clone().into());
}
self.with.generate_by_default = self.opts.generate_all;
}
fn import_interface(
&mut self,
resolve: &Resolve,
name: &WorldKey,
id: InterfaceId,
_files: &mut Files,
) -> Result<()> {
let mut gen = self.interface(Identifier::Interface(id, name), resolve, true);
let (snake, module_path) = gen.start_append_submodule(name);
if gen.gen.name_interface(resolve, id, name, false)? {
return Ok(());
}
gen.types(id);
let interface = &resolve.interfaces[id];
let name = match name {
WorldKey::Name(s) => s.to_string(),
WorldKey::Interface(..) => interface
.name
.as_ref()
.expect("interface name missing")
.to_string(),
};
let instance = if let Some(package) = interface.package {
resolve.id_of_name(package, &name)
} else {
name
};
gen.generate_imports(&instance, resolve.interfaces[id].functions.values());
gen.finish_append_submodule(&snake, module_path);
Ok(())
}
fn import_funcs(
&mut self,
resolve: &Resolve,
world: WorldId,
funcs: &[(&str, &Function)],
_files: &mut Files,
) {
self.import_funcs_called = true;
let mut gen = self.interface(Identifier::World(world), resolve, true);
let World {
ref name, package, ..
} = resolve.worlds[world];
let instance = if let Some(package) = package {
resolve.id_of_name(package, name)
} else {
name.to_string()
};
gen.generate_imports(&instance, funcs.iter().map(|(_, func)| *func));
let src = gen.finish();
self.src.push_str(&src);
}
fn export_interface(
&mut self,
resolve: &Resolve,
name: &WorldKey,
id: InterfaceId,
_files: &mut Files,
) -> Result<()> {
let mut gen = self.interface(Identifier::Interface(id, name), resolve, false);
let (snake, module_path) = gen.start_append_submodule(name);
if gen.gen.name_interface(resolve, id, name, true)? {
return Ok(());
}
gen.types(id);
let exports = gen.generate_exports(
Identifier::Interface(id, name),
resolve.interfaces[id].functions.values(),
);
gen.finish_append_submodule(&snake, module_path);
if exports {
self.export_paths
.push(self.interface_names[&id].path.clone());
}
Ok(())
}
fn export_funcs(
&mut self,
resolve: &Resolve,
world: WorldId,
funcs: &[(&str, &Function)],
_files: &mut Files,
) -> Result<()> {
let mut gen = self.interface(Identifier::World(world), resolve, false);
let exports = gen.generate_exports(Identifier::World(world), funcs.iter().map(|f| f.1));
let src = gen.finish();
self.src.push_str(&src);
if exports {
self.export_paths.push(String::new());
}
Ok(())
}
fn import_types(
&mut self,
resolve: &Resolve,
world: WorldId,
types: &[(&str, TypeId)],
_files: &mut Files,
) {
let mut gen = self.interface(Identifier::World(world), resolve, true);
for (name, ty) in types {
gen.define_type(name, *ty);
}
let src = gen.finish();
self.src.push_str(&src);
}
fn finish_imports(&mut self, resolve: &Resolve, world: WorldId, files: &mut Files) {
if !self.import_funcs_called {
self.import_funcs(resolve, world, &[], files);
}
}
fn finish(&mut self, resolve: &Resolve, world: WorldId, files: &mut Files) -> Result<()> {
let name = &resolve.worlds[world].name;
let imports = mem::take(&mut self.import_modules);
self.emit_modules(imports);
let exports = mem::take(&mut self.export_modules);
self.emit_modules(exports);
self.finish_serve_function();
let mut src = mem::take(&mut self.src);
if self.opts.format {
let syntax_tree = syn::parse_file(src.as_str()).unwrap();
*src.as_mut_string() = prettyplease::unparse(&syntax_tree);
}
let module_name = name.to_snake_case();
files.push(&format!("{module_name}.rs"), src.as_bytes());
let remapped_keys = self
.with
.iter()
.map(|(k, _)| k)
.cloned()
.collect::<HashSet<String>>();
let mut unused_keys = remapped_keys
.difference(&self.generated_interfaces)
.collect::<Vec<&String>>();
unused_keys.sort();
if !unused_keys.is_empty() {
bail!("unused remappings provided via `with`: {unused_keys:?}");
}
Ok(())
}
}
fn compute_module_path(name: &WorldKey, resolve: &Resolve, is_export: bool) -> Vec<String> {
let mut path = Vec::new();
if is_export {
path.push("exports".to_string());
}
match name {
WorldKey::Name(name) => {
path.push(to_rust_ident(name));
}
WorldKey::Interface(id) => {
let iface = &resolve.interfaces[*id];
let pkg = iface.package.unwrap();
let pkgname = resolve.packages[pkg].name.clone();
path.push(to_rust_ident(&pkgname.namespace));
path.push(to_rust_ident(&name_package_module(resolve, pkg)));
path.push(to_rust_ident(iface.name.as_ref().unwrap()));
}
}
path
}
enum Identifier<'a> {
World(WorldId),
Interface(InterfaceId, &'a WorldKey),
}
#[derive(Debug, Clone)]
pub enum WithOption {
Path(String),
Generate,
}
impl std::fmt::Display for WithOption {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
WithOption::Path(p) => f.write_fmt(format_args!("\"{p}\"")),
WithOption::Generate => f.write_str("generate"),
}
}
}
impl From<WithOption> for InterfaceGeneration {
fn from(opt: WithOption) -> Self {
match opt {
WithOption::Path(p) => InterfaceGeneration::Remap(p),
WithOption::Generate => InterfaceGeneration::Generate,
}
}
}
#[derive(Default)]
struct FnSig {
private: bool,
use_item_name: bool,
self_arg: Option<String>,
self_is_first_param: bool,
}
#[must_use]
pub fn to_rust_ident(name: &str) -> String {
match name {
"as" => "as_".into(),
"break" => "break_".into(),
"const" => "const_".into(),
"continue" => "continue_".into(),
"crate" => "crate_".into(),
"else" => "else_".into(),
"enum" => "enum_".into(),
"extern" => "extern_".into(),
"false" => "false_".into(),
"fn" => "fn_".into(),
"for" => "for_".into(),
"if" => "if_".into(),
"impl" => "impl_".into(),
"in" => "in_".into(),
"let" => "let_".into(),
"loop" => "loop_".into(),
"match" => "match_".into(),
"mod" => "mod_".into(),
"move" => "move_".into(),
"mut" => "mut_".into(),
"pub" => "pub_".into(),
"ref" => "ref_".into(),
"return" => "return_".into(),
"self" => "self_".into(),
"static" => "static_".into(),
"struct" => "struct_".into(),
"super" => "super_".into(),
"trait" => "trait_".into(),
"true" => "true_".into(),
"type" => "type_".into(),
"unsafe" => "unsafe_".into(),
"use" => "use_".into(),
"where" => "where_".into(),
"while" => "while_".into(),
"async" => "async_".into(),
"await" => "await_".into(),
"dyn" => "dyn_".into(),
"abstract" => "abstract_".into(),
"become" => "become_".into(),
"box" => "box_".into(),
"do" => "do_".into(),
"final" => "final_".into(),
"macro" => "macro_".into(),
"override" => "override_".into(),
"priv" => "priv_".into(),
"typeof" => "typeof_".into(),
"unsized" => "unsized_".into(),
"virtual" => "virtual_".into(),
"yield" => "yield_".into(),
"try" => "try_".into(),
s => s.to_snake_case(),
}
}
fn to_upper_camel_case(name: &str) -> String {
match name {
"handler" => "Handler_".to_string(),
s => s.to_upper_camel_case(),
}
}
fn int_repr(repr: Int) -> &'static str {
match repr {
Int::U8 => "u8",
Int::U16 => "u16",
Int::U32 => "u32",
Int::U64 => "u64",
}
}
enum RustFlagsRepr {
U8,
U16,
U32,
U64,
U128,
}
impl RustFlagsRepr {
fn new(f: &Flags) -> RustFlagsRepr {
match f.repr() {
FlagsRepr::U8 => RustFlagsRepr::U8,
FlagsRepr::U16 => RustFlagsRepr::U16,
FlagsRepr::U32(1) => RustFlagsRepr::U32,
FlagsRepr::U32(2) => RustFlagsRepr::U64,
FlagsRepr::U32(3 | 4) => RustFlagsRepr::U128,
FlagsRepr::U32(n) => panic!("unsupported number of flags: {}", n * 32),
}
}
}
impl fmt::Display for RustFlagsRepr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
RustFlagsRepr::U8 => "u8".fmt(f),
RustFlagsRepr::U16 => "u16".fmt(f),
RustFlagsRepr::U32 => "u32".fmt(f),
RustFlagsRepr::U64 => "u64".fmt(f),
RustFlagsRepr::U128 => "u128".fmt(f),
}
}
}
#[derive(Debug, Clone)]
pub struct MissingWith(pub String);
impl fmt::Display for MissingWith {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "missing `with` mapping for the key `{}`", self.0)
}
}
impl std::error::Error for MissingWith {}