wrpc_transport_nats/
lib.rs

1//! wRPC NATS.io transport
2
3#![allow(clippy::type_complexity)]
4
5#[cfg(any(
6    not(any(
7        feature = "async-nats-0_39",
8        feature = "async-nats-0_38",
9        feature = "async-nats-0_37",
10        feature = "async-nats-0_36",
11    )),
12    all(feature = "async-nats-0_39", feature = "async-nats-0_38"),
13    all(feature = "async-nats-0_39", feature = "async-nats-0_37"),
14    all(feature = "async-nats-0_39", feature = "async-nats-0_36"),
15    all(feature = "async-nats-0_38", feature = "async-nats-0_37"),
16    all(feature = "async-nats-0_38", feature = "async-nats-0_36"),
17    all(feature = "async-nats-0_37", feature = "async-nats-0_36"),
18))]
19compile_error!(
20    "Either feature \"async-nats-0_39\", \"async-nats-0_38\", \"async-nats-0_37\" or \"async-nats-0_36\" must be enabled for this crate."
21);
22
23#[cfg(feature = "async-nats-0_39")]
24use async_nats_0_39 as async_nats;
25
26#[cfg(feature = "async-nats-0_38")]
27use async_nats_0_38 as async_nats;
28
29#[cfg(feature = "async-nats-0_37")]
30use async_nats_0_37 as async_nats;
31
32#[cfg(feature = "async-nats-0_36")]
33use async_nats_0_36 as async_nats;
34
35use core::future::Future;
36use core::iter::zip;
37use core::ops::{Deref, DerefMut};
38use core::pin::{pin, Pin};
39use core::task::{ready, Context, Poll};
40use core::{mem, str};
41
42use std::collections::HashMap;
43use std::sync::Arc;
44
45use anyhow::{anyhow, ensure, Context as _};
46use async_nats::{HeaderMap, PublishMessage, ServerInfo, StatusCode, Subject};
47use bytes::{Buf as _, Bytes};
48use futures::sink::SinkExt as _;
49use futures::{Stream, StreamExt};
50use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
51use tokio::select;
52use tokio::sync::{mpsc, oneshot};
53use tokio::task::JoinSet;
54use tokio_stream::wrappers::ReceiverStream;
55use tracing::{debug, error, instrument, trace, warn};
56use wrpc_transport::Index as _;
57
58pub const PROTOCOL: &str = "wrpc.0.0.1";
59
60fn spawn_async(fut: impl Future<Output = ()> + Send + 'static) {
61    match tokio::runtime::Handle::try_current() {
62        Ok(rt) => {
63            rt.spawn(fut);
64        }
65        Err(_) => match tokio::runtime::Runtime::new() {
66            Ok(rt) => {
67                rt.spawn(fut);
68            }
69            Err(err) => error!(?err, "failed to create a new Tokio runtime"),
70        },
71    }
72}
73
74fn new_inbox(inbox: &str) -> String {
75    let id = nuid::next();
76    let mut s = String::with_capacity(inbox.len().saturating_add(id.len()));
77    s.push_str(inbox);
78    s.push_str(&id);
79    s
80}
81
82#[must_use]
83#[inline]
84pub fn param_subject(prefix: &str) -> String {
85    format!("{prefix}.params")
86}
87
88#[must_use]
89#[inline]
90pub fn result_subject(prefix: &str) -> String {
91    format!("{prefix}.results")
92}
93
94#[must_use]
95#[inline]
96pub fn index_path(prefix: &str, path: &[usize]) -> String {
97    let mut s = String::with_capacity(prefix.len() + path.len() * 2);
98    if !prefix.is_empty() {
99        s.push_str(prefix);
100    }
101    for p in path {
102        if !s.is_empty() {
103            s.push('.');
104        }
105        s.push_str(&p.to_string());
106    }
107    s
108}
109
110#[must_use]
111#[inline]
112pub fn subscribe_path(prefix: &str, path: &[Option<usize>]) -> String {
113    let mut s = String::with_capacity(prefix.len() + path.len() * 2);
114    if !prefix.is_empty() {
115        s.push_str(prefix);
116    }
117    for p in path {
118        if !s.is_empty() {
119            s.push('.');
120        }
121        if let Some(p) = p {
122            s.push_str(&p.to_string());
123        } else {
124            s.push('*');
125        }
126    }
127    s
128}
129
130#[must_use]
131#[inline]
132pub fn invocation_subject(prefix: &str, instance: &str, func: &str) -> String {
133    let mut s =
134        String::with_capacity(prefix.len() + PROTOCOL.len() + instance.len() + func.len() + 3);
135    if !prefix.is_empty() {
136        s.push_str(prefix);
137        s.push('.');
138    }
139    s.push_str(PROTOCOL);
140    s.push('.');
141    if !instance.is_empty() {
142        s.push_str(instance);
143        s.push('.');
144    }
145    s.push_str(func);
146    s
147}
148
149fn corrupted_memory_error() -> std::io::Error {
150    std::io::Error::new(std::io::ErrorKind::Other, "corrupted memory state")
151}
152
153/// Transport subscriber
154pub struct Subscriber {
155    rx: ReceiverStream<Message>,
156    subject: Subject,
157    commands: mpsc::Sender<Command>,
158    tasks: Arc<JoinSet<()>>,
159}
160
161impl Drop for Subscriber {
162    fn drop(&mut self) {
163        let commands = self.commands.clone();
164        let subject = mem::replace(&mut self.subject, Subject::from_static(""));
165        let tasks = Arc::clone(&self.tasks);
166        spawn_async(async move {
167            trace!(?subject, "shutting down subscriber");
168            if let Err(err) = commands.send(Command::Unsubscribe(subject)).await {
169                warn!(?err, "failed to shutdown subscriber");
170            }
171            drop(tasks);
172        });
173    }
174}
175
176impl Deref for Subscriber {
177    type Target = ReceiverStream<Message>;
178
179    fn deref(&self) -> &Self::Target {
180        &self.rx
181    }
182}
183
184impl DerefMut for Subscriber {
185    fn deref_mut(&mut self) -> &mut Self::Target {
186        &mut self.rx
187    }
188}
189
190enum Command {
191    Subscribe(Subject, mpsc::Sender<Message>),
192    Unsubscribe(Subject),
193    Batch(Box<[Command]>),
194}
195
196/// Subset of [`async_nats::Message`](async_nats::Message) used by this crate
197pub struct Message {
198    reply: Option<Subject>,
199    payload: Bytes,
200    status: Option<async_nats::StatusCode>,
201    description: Option<String>,
202}
203
204#[derive(Clone, Debug)]
205pub struct Client {
206    nats: Arc<async_nats::Client>,
207    prefix: Arc<str>,
208    inbox: Arc<str>,
209    queue_group: Option<Arc<str>>,
210    commands: mpsc::Sender<Command>,
211    tasks: Arc<JoinSet<()>>,
212}
213
214impl Client {
215    pub async fn new(
216        nats: impl Into<Arc<async_nats::Client>>,
217        prefix: impl Into<Arc<str>>,
218        queue_group: Option<Arc<str>>,
219    ) -> anyhow::Result<Self> {
220        let nats = nats.into();
221        let mut inbox = nats.new_inbox();
222        inbox.push('.');
223        let mut subject = String::with_capacity(inbox.len().saturating_add(1));
224        subject.push_str(&inbox);
225        subject.push('>');
226        let mut sub = nats
227            .subscribe(Subject::from(subject))
228            .await
229            .context("failed to subscribe on an inbox subject")?;
230
231        let mut tasks = JoinSet::new();
232        let (cmd_tx, mut cmd_rx) = mpsc::channel(8192);
233        tasks.spawn({
234            async move {
235                fn handle_command(subs: &mut HashMap<String, mpsc::Sender<Message>>, cmd: Command) {
236                    match cmd {
237                        Command::Subscribe(s, tx) => {
238                            subs.insert(s.into_string(), tx);
239                        }
240                        Command::Unsubscribe(s) => {
241                            subs.remove(s.as_str());
242                        }
243                        Command::Batch(cmds) => {
244                            for cmd in cmds {
245                                handle_command(subs, cmd);
246                            }
247                        }
248                    }
249                }
250                async fn handle_message(
251                    subs: &mut HashMap<String, mpsc::Sender<Message>>,
252                    async_nats::Message {
253                        subject,
254                        reply,
255                        payload,
256                        status,
257                        description,
258                        ..
259                    }: async_nats::Message,
260                ) {
261                    let Some(sub) = subs.get_mut(subject.as_str()) else {
262                        debug!(?subject, "drop message with no subscriber");
263                        return;
264                    };
265                    let Ok(sub) = sub.reserve().await else {
266                        debug!(?subject, "drop message with closed subscriber");
267                        subs.remove(subject.as_str());
268                        return;
269                    };
270                    sub.send(Message {
271                        reply,
272                        payload,
273                        status,
274                        description,
275                    });
276                }
277
278                let mut subs = HashMap::new();
279                loop {
280                    select! {
281                        Some(msg) = sub.next() => handle_message(&mut subs, msg).await,
282                        Some(cmd) = cmd_rx.recv() => handle_command(&mut subs, cmd),
283                        else => return,
284                    }
285                }
286            }
287        });
288        Ok(Self {
289            nats,
290            prefix: prefix.into(),
291            inbox: inbox.into(),
292            queue_group,
293            commands: cmd_tx,
294            tasks: Arc::new(tasks),
295        })
296    }
297}
298
299pub struct ByteSubscription(Subscriber);
300
301impl Stream for ByteSubscription {
302    type Item = std::io::Result<Bytes>;
303
304    #[instrument(level = "trace", skip_all)]
305    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
306        match self.0.poll_next_unpin(cx) {
307            Poll::Ready(Some(Message { payload, .. })) => Poll::Ready(Some(Ok(payload))),
308            Poll::Ready(None) => Poll::Ready(None),
309            Poll::Pending => Poll::Pending,
310        }
311    }
312}
313
314#[derive(Default)]
315enum IndexTrie {
316    #[default]
317    Empty,
318    Leaf(Subscriber),
319    IndexNode {
320        subscriber: Option<Subscriber>,
321        nested: Vec<Option<IndexTrie>>,
322    },
323    WildcardNode {
324        subscriber: Option<Subscriber>,
325        nested: Option<Box<IndexTrie>>,
326    },
327}
328
329impl<'a> From<(&'a [Option<usize>], Subscriber)> for IndexTrie {
330    fn from((path, sub): (&'a [Option<usize>], Subscriber)) -> Self {
331        match path {
332            [] => Self::Leaf(sub),
333            [None, path @ ..] => Self::WildcardNode {
334                subscriber: None,
335                nested: Some(Box::new(Self::from((path, sub)))),
336            },
337            [Some(i), path @ ..] => Self::IndexNode {
338                subscriber: None,
339                nested: {
340                    let n = i.saturating_add(1);
341                    let mut nested = Vec::with_capacity(n);
342                    nested.resize_with(n, Option::default);
343                    nested[*i] = Some(Self::from((path, sub)));
344                    nested
345                },
346            },
347        }
348    }
349}
350
351impl<P: AsRef<[Option<usize>]>> FromIterator<(P, Subscriber)> for IndexTrie {
352    fn from_iter<T: IntoIterator<Item = (P, Subscriber)>>(iter: T) -> Self {
353        let mut root = Self::Empty;
354        for (path, sub) in iter {
355            if !root.insert(path.as_ref(), sub) {
356                return Self::Empty;
357            }
358        }
359        root
360    }
361}
362
363impl IndexTrie {
364    #[inline]
365    fn is_empty(&self) -> bool {
366        matches!(self, IndexTrie::Empty)
367    }
368
369    #[instrument(level = "trace", skip_all)]
370    fn take(&mut self, path: &[usize]) -> Option<Subscriber> {
371        let Some((i, path)) = path.split_first() else {
372            return match mem::take(self) {
373                // TODO: Demux the subscription
374                //IndexTrie::WildcardNode { subscriber, nested } => {
375                //    if let Some(nested) = nested {
376                //        *self = IndexTrie::WildcardNode {
377                //            subscriber: None,
378                //            nested: Some(nested),
379                //        }
380                //    }
381                //    subscriber
382                //}
383                IndexTrie::Empty | IndexTrie::WildcardNode { .. } => None,
384                IndexTrie::Leaf(subscriber) => Some(subscriber),
385                IndexTrie::IndexNode { subscriber, nested } => {
386                    if !nested.is_empty() {
387                        *self = IndexTrie::IndexNode {
388                            subscriber: None,
389                            nested,
390                        }
391                    }
392                    subscriber
393                }
394            };
395        };
396        match self {
397            // TODO: Demux the subscription
398            //Self::WildcardNode { ref mut nested, .. } => {
399            //    nested.as_mut().and_then(|nested| nested.take(path))
400            //}
401            Self::Empty | Self::Leaf(..) | Self::WildcardNode { .. } => None,
402            Self::IndexNode { ref mut nested, .. } => nested
403                .get_mut(*i)
404                .and_then(|nested| nested.as_mut().and_then(|nested| nested.take(path))),
405        }
406    }
407
408    /// Inserts `sub` under a `path` - returns `false` if it failed and `true` if it succeeded.
409    /// Tree state after `false` is returned in undefined
410    #[instrument(level = "trace", skip_all)]
411    fn insert(&mut self, path: &[Option<usize>], sub: Subscriber) -> bool {
412        match self {
413            Self::Empty => {
414                *self = Self::from((path, sub));
415                true
416            }
417            Self::Leaf(..) => {
418                let Some((i, path)) = path.split_first() else {
419                    return false;
420                };
421                let Self::Leaf(subscriber) = mem::take(self) else {
422                    return false;
423                };
424                if let Some(i) = i {
425                    let n = i.saturating_add(1);
426                    let mut nested = Vec::with_capacity(n);
427                    nested.resize_with(n, Option::default);
428                    nested[*i] = Some(Self::from((path, sub)));
429                    *self = Self::IndexNode {
430                        subscriber: Some(subscriber),
431                        nested,
432                    };
433                } else {
434                    *self = Self::WildcardNode {
435                        subscriber: Some(subscriber),
436                        nested: Some(Box::new(Self::from((path, sub)))),
437                    };
438                }
439                true
440            }
441            Self::WildcardNode {
442                ref mut subscriber,
443                ref mut nested,
444            } => match (&subscriber, path) {
445                (None, []) => {
446                    *subscriber = Some(sub);
447                    true
448                }
449                (_, [None, path @ ..]) => {
450                    if let Some(nested) = nested {
451                        nested.insert(path, sub)
452                    } else {
453                        *nested = Some(Box::new(Self::from((path, sub))));
454                        true
455                    }
456                }
457                _ => false,
458            },
459            Self::IndexNode {
460                ref mut subscriber,
461                ref mut nested,
462            } => match (&subscriber, path) {
463                (None, []) => {
464                    *subscriber = Some(sub);
465                    true
466                }
467                (_, [Some(i), path @ ..]) => {
468                    let cap = i.saturating_add(1);
469                    if nested.len() < cap {
470                        nested.resize_with(cap, Option::default);
471                    }
472                    let nested = &mut nested[*i];
473                    if let Some(nested) = nested {
474                        nested.insert(path, sub)
475                    } else {
476                        *nested = Some(Self::from((path, sub)));
477                        true
478                    }
479                }
480                _ => false,
481            },
482        }
483    }
484}
485
486pub struct Reader {
487    buffer: Bytes,
488    incoming: Option<Subscriber>,
489    nested: Arc<std::sync::Mutex<IndexTrie>>,
490    path: Box<[usize]>,
491}
492
493impl wrpc_transport::Index<Self> for Reader {
494    #[instrument(level = "trace", skip(self))]
495    fn index(&self, path: &[usize]) -> anyhow::Result<Self> {
496        ensure!(!path.is_empty());
497        trace!("locking index tree");
498        let mut nested = self
499            .nested
500            .lock()
501            .map_err(|err| anyhow!(err.to_string()).context("failed to lock map"))?;
502        trace!("taking index subscription");
503        let mut p = self.path.to_vec();
504        p.extend_from_slice(path);
505        let incoming = nested.take(&p);
506        Ok(Self {
507            buffer: Bytes::default(),
508            incoming,
509            nested: Arc::clone(&self.nested),
510            path: p.into_boxed_slice(),
511        })
512    }
513}
514
515impl AsyncRead for Reader {
516    #[instrument(level = "trace", skip_all, ret)]
517    fn poll_read(
518        mut self: Pin<&mut Self>,
519        cx: &mut Context<'_>,
520        buf: &mut ReadBuf<'_>,
521    ) -> Poll<std::io::Result<()>> {
522        let cap = buf.remaining();
523        if cap == 0 {
524            trace!("attempt to read empty buffer");
525            return Poll::Ready(Ok(()));
526        }
527
528        if !self.buffer.is_empty() {
529            if self.buffer.len() > cap {
530                trace!(cap, len = self.buffer.len(), "reading part of buffer");
531                buf.put_slice(&self.buffer.split_to(cap));
532            } else {
533                trace!(cap, len = self.buffer.len(), "reading full buffer");
534                buf.put_slice(&mem::take(&mut self.buffer));
535            }
536            return Poll::Ready(Ok(()));
537        }
538        let Some(incoming) = self.incoming.as_mut() else {
539            return Poll::Ready(Err(std::io::Error::new(
540                std::io::ErrorKind::NotFound,
541                format!("subscription not found for path {:?}", self.path),
542            )));
543        };
544        trace!("polling for next message");
545        match incoming.poll_next_unpin(cx) {
546            Poll::Ready(Some(Message { mut payload, .. })) => {
547                trace!(?payload, "received message");
548                if payload.is_empty() {
549                    trace!("received stream shutdown message");
550                    return Poll::Ready(Ok(()));
551                }
552                if payload.len() > cap {
553                    trace!(len = payload.len(), cap, "partially reading the message");
554                    buf.put_slice(&payload.split_to(cap));
555                    self.buffer = payload;
556                } else {
557                    trace!(len = payload.len(), cap, "filling the buffer with payload");
558                    buf.put_slice(&payload);
559                }
560                Poll::Ready(Ok(()))
561            }
562            Poll::Ready(None) => {
563                trace!("subscription finished");
564                Poll::Ready(Ok(()))
565            }
566            Poll::Pending => Poll::Pending,
567        }
568    }
569}
570
571#[derive(Clone, Debug)]
572pub struct SubjectWriter {
573    nats: async_nats::Client,
574    tx: Subject,
575    shutdown: bool,
576    tasks: Arc<JoinSet<()>>,
577}
578
579impl SubjectWriter {
580    fn new(nats: async_nats::Client, tx: Subject, tasks: Arc<JoinSet<()>>) -> Self {
581        Self {
582            nats,
583            tx,
584            shutdown: false,
585            tasks,
586        }
587    }
588}
589
590impl wrpc_transport::Index<Self> for SubjectWriter {
591    #[instrument(level = "trace", skip(self))]
592    fn index(&self, path: &[usize]) -> anyhow::Result<Self> {
593        ensure!(!path.is_empty());
594        let tx = Subject::from(index_path(self.tx.as_str(), path));
595        Ok(Self {
596            nats: self.nats.clone(),
597            tx,
598            shutdown: false,
599            tasks: Arc::clone(&self.tasks),
600        })
601    }
602}
603
604impl AsyncWrite for SubjectWriter {
605    #[instrument(level = "trace", skip_all, ret, fields(subject = self.tx.as_str(), buf = format!("{buf:02x?}")))]
606    fn poll_write(
607        mut self: Pin<&mut Self>,
608        cx: &mut Context<'_>,
609        mut buf: &[u8],
610    ) -> Poll<std::io::Result<usize>> {
611        trace!("polling for readiness");
612        match self.nats.poll_ready_unpin(cx) {
613            Poll::Pending => return Poll::Pending,
614            Poll::Ready(Err(err)) => {
615                return Poll::Ready(Err(std::io::Error::new(
616                    std::io::ErrorKind::BrokenPipe,
617                    err,
618                )))
619            }
620            Poll::Ready(Ok(())) => {}
621        }
622        let ServerInfo { max_payload, .. } = self.nats.server_info();
623        if max_payload == 0 {
624            return Poll::Ready(Err(std::io::ErrorKind::WriteZero.into()));
625        }
626        if buf.len() > max_payload {
627            (buf, _) = buf.split_at(max_payload);
628        }
629        trace!("starting send");
630        let subject = self.tx.clone();
631        match self.nats.start_send_unpin(PublishMessage {
632            subject,
633            payload: Bytes::copy_from_slice(buf),
634            reply: None,
635            headers: None,
636        }) {
637            Ok(()) => Poll::Ready(Ok(buf.len())),
638            Err(err) => Poll::Ready(Err(std::io::Error::new(
639                std::io::ErrorKind::BrokenPipe,
640                err,
641            ))),
642        }
643    }
644
645    #[instrument(level = "trace", skip_all, ret, fields(subject = self.tx.as_str()))]
646    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
647        trace!("flushing");
648        self.nats
649            .poll_flush_unpin(cx)
650            .map_err(|_| std::io::ErrorKind::BrokenPipe.into())
651    }
652
653    #[instrument(level = "trace", skip_all, ret, fields(subject = self.tx.as_str()))]
654    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
655        trace!("writing stream shutdown message");
656        ready!(self.as_mut().poll_write(cx, &[]))?;
657        self.shutdown = true;
658        Poll::Ready(Ok(()))
659    }
660}
661
662impl Drop for SubjectWriter {
663    fn drop(&mut self) {
664        if !self.shutdown {
665            let nats = self.nats.clone();
666            let subject = mem::replace(&mut self.tx, Subject::from_static(""));
667            let tasks = Arc::clone(&self.tasks);
668            spawn_async(async move {
669                trace!("writing stream shutdown message");
670                if let Err(err) = nats.publish(subject, Bytes::default()).await {
671                    warn!(?err, "failed to publish stream shutdown message");
672                }
673                drop(tasks);
674            });
675        }
676    }
677}
678
679#[derive(Default)]
680pub enum RootParamWriter {
681    #[default]
682    Corrupted,
683    Handshaking {
684        nats: async_nats::Client,
685        sub: Subscriber,
686        indexed: std::sync::Mutex<Vec<(Vec<usize>, oneshot::Sender<SubjectWriter>)>>,
687        buffer: Bytes,
688        tasks: Arc<JoinSet<()>>,
689    },
690    Draining {
691        tx: SubjectWriter,
692        buffer: Bytes,
693    },
694    Active(SubjectWriter),
695}
696
697impl RootParamWriter {
698    fn new(
699        nats: async_nats::Client,
700        sub: Subscriber,
701        buffer: Bytes,
702        tasks: Arc<JoinSet<()>>,
703    ) -> Self {
704        Self::Handshaking {
705            nats,
706            sub,
707            indexed: std::sync::Mutex::default(),
708            buffer,
709            tasks,
710        }
711    }
712}
713
714impl RootParamWriter {
715    #[instrument(level = "trace", skip_all, ret)]
716    fn poll_active(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
717        match &mut *self {
718            Self::Corrupted => Poll::Ready(Err(corrupted_memory_error())),
719            Self::Handshaking { sub, .. } => {
720                trace!("polling for handshake response");
721                match sub.poll_next_unpin(cx) {
722                    Poll::Ready(Some(Message {
723                        status: Some(StatusCode::NO_RESPONDERS),
724                        ..
725                    })) => Poll::Ready(Err(std::io::ErrorKind::NotConnected.into())),
726                    Poll::Ready(Some(Message {
727                        status: Some(StatusCode::TIMEOUT),
728                        ..
729                    })) => Poll::Ready(Err(std::io::ErrorKind::TimedOut.into())),
730                    Poll::Ready(Some(Message {
731                        status: Some(StatusCode::REQUEST_TERMINATED),
732                        ..
733                    })) => Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())),
734                    Poll::Ready(Some(Message {
735                        status: Some(code),
736                        description,
737                        ..
738                    })) if !code.is_success() => Poll::Ready(Err(std::io::Error::new(
739                        std::io::ErrorKind::Other,
740                        if let Some(description) = description {
741                            format!("received a response with code `{code}` ({description})")
742                        } else {
743                            format!("received a response with code `{code}`")
744                        },
745                    ))),
746                    Poll::Ready(Some(Message {
747                        reply: Some(tx), ..
748                    })) => {
749                        let Self::Handshaking {
750                            nats,
751                            indexed,
752                            buffer,
753                            tasks,
754                            ..
755                        } = mem::take(&mut *self)
756                        else {
757                            return Poll::Ready(Err(corrupted_memory_error()));
758                        };
759                        let tx = SubjectWriter::new(nats, Subject::from(param_subject(&tx)), tasks);
760                        let indexed = indexed.into_inner().map_err(|err| {
761                            std::io::Error::new(std::io::ErrorKind::Other, err.to_string())
762                        })?;
763                        for (path, tx_tx) in indexed {
764                            let tx = tx.index(&path).map_err(|err| {
765                                std::io::Error::new(std::io::ErrorKind::Other, err)
766                            })?;
767                            tx_tx.send(tx).map_err(|_| {
768                                std::io::Error::from(std::io::ErrorKind::BrokenPipe)
769                            })?;
770                        }
771                        trace!("handshake succeeded");
772                        if buffer.is_empty() {
773                            *self = Self::Active(tx);
774                            Poll::Ready(Ok(()))
775                        } else {
776                            *self = Self::Draining { tx, buffer };
777                            self.poll_active(cx)
778                        }
779                    }
780                    Poll::Ready(Some(..)) => Poll::Ready(Err(std::io::Error::new(
781                        std::io::ErrorKind::InvalidInput,
782                        "peer did not specify a reply subject",
783                    ))),
784                    Poll::Ready(None) => {
785                        *self = Self::Corrupted;
786                        Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::BrokenPipe)))
787                    }
788                    Poll::Pending => Poll::Pending,
789                }
790            }
791            Self::Draining { tx, buffer } => {
792                let mut tx = pin!(tx);
793                while !buffer.is_empty() {
794                    trace!(?tx.tx, "draining parameter buffer");
795                    match tx.as_mut().poll_write(cx, buffer) {
796                        Poll::Ready(Ok(n)) => {
797                            buffer.advance(n);
798                        }
799                        Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
800                        Poll::Pending => return Poll::Pending,
801                    }
802                }
803                let Self::Draining { tx, .. } = mem::take(&mut *self) else {
804                    return Poll::Ready(Err(corrupted_memory_error()));
805                };
806                trace!("parameter buffer draining succeeded");
807                *self = Self::Active(tx);
808                Poll::Ready(Ok(()))
809            }
810            Self::Active(..) => Poll::Ready(Ok(())),
811        }
812    }
813}
814
815impl wrpc_transport::Index<IndexedParamWriter> for RootParamWriter {
816    #[instrument(level = "trace", skip(self))]
817    fn index(&self, path: &[usize]) -> anyhow::Result<IndexedParamWriter> {
818        ensure!(!path.is_empty());
819        match self {
820            Self::Corrupted => Err(anyhow!(corrupted_memory_error())),
821            Self::Handshaking { indexed, .. } => {
822                let (tx_tx, tx_rx) = oneshot::channel();
823                let mut indexed = indexed.lock().map_err(|err| {
824                    std::io::Error::new(std::io::ErrorKind::Other, err.to_string())
825                })?;
826                indexed.push((path.to_vec(), tx_tx));
827                Ok(IndexedParamWriter::Handshaking {
828                    tx_rx,
829                    indexed: std::sync::Mutex::default(),
830                })
831            }
832            Self::Draining { tx, .. } | Self::Active(tx) => {
833                tx.index(path).map(IndexedParamWriter::Active)
834            }
835        }
836    }
837}
838
839impl AsyncWrite for RootParamWriter {
840    #[instrument(level = "trace", skip_all, ret, fields(buf = format!("{buf:02x?}")))]
841    fn poll_write(
842        mut self: Pin<&mut Self>,
843        cx: &mut Context<'_>,
844        buf: &[u8],
845    ) -> Poll<std::io::Result<usize>> {
846        match self.as_mut().poll_active(cx)? {
847            Poll::Ready(()) => {
848                let Self::Active(tx) = &mut *self else {
849                    return Poll::Ready(Err(corrupted_memory_error()));
850                };
851                trace!("writing buffer");
852                pin!(tx).poll_write(cx, buf)
853            }
854            Poll::Pending => Poll::Pending,
855        }
856    }
857
858    #[instrument(level = "trace", skip_all, ret)]
859    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
860        match self.as_mut().poll_active(cx)? {
861            Poll::Ready(()) => {
862                let Self::Active(tx) = &mut *self else {
863                    return Poll::Ready(Err(corrupted_memory_error()));
864                };
865                trace!("flushing");
866                pin!(tx).poll_flush(cx)
867            }
868            Poll::Pending => Poll::Pending,
869        }
870    }
871
872    #[instrument(level = "trace", skip_all, ret)]
873    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
874        match self.as_mut().poll_active(cx)? {
875            Poll::Ready(()) => {
876                let Self::Active(tx) = &mut *self else {
877                    return Poll::Ready(Err(corrupted_memory_error()));
878                };
879                trace!("shutting down");
880                pin!(tx).poll_shutdown(cx)
881            }
882            Poll::Pending => Poll::Pending,
883        }
884    }
885}
886
887#[derive(Debug, Default)]
888pub enum IndexedParamWriter {
889    #[default]
890    Corrupted,
891    Handshaking {
892        tx_rx: oneshot::Receiver<SubjectWriter>,
893        indexed: std::sync::Mutex<Vec<(Vec<usize>, oneshot::Sender<SubjectWriter>)>>,
894    },
895    Active(SubjectWriter),
896}
897
898impl IndexedParamWriter {
899    #[instrument(level = "trace", skip_all, ret)]
900    fn poll_active(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
901        match &mut *self {
902            Self::Corrupted => Poll::Ready(Err(corrupted_memory_error())),
903            Self::Handshaking { tx_rx, .. } => {
904                trace!("polling for handshake");
905                match pin!(tx_rx).poll(cx) {
906                    Poll::Ready(Ok(tx)) => {
907                        let Self::Handshaking { indexed, .. } = mem::take(&mut *self) else {
908                            return Poll::Ready(Err(corrupted_memory_error()));
909                        };
910                        let indexed = indexed.into_inner().map_err(|err| {
911                            std::io::Error::new(std::io::ErrorKind::Other, err.to_string())
912                        })?;
913                        for (path, tx_tx) in indexed {
914                            let tx = tx.index(&path).map_err(|err| {
915                                std::io::Error::new(std::io::ErrorKind::Other, err)
916                            })?;
917                            tx_tx.send(tx).map_err(|_| {
918                                std::io::Error::from(std::io::ErrorKind::BrokenPipe)
919                            })?;
920                        }
921                        *self = Self::Active(tx);
922                        Poll::Ready(Ok(()))
923                    }
924                    Poll::Ready(Err(..)) => Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())),
925                    Poll::Pending => Poll::Pending,
926                }
927            }
928            Self::Active(..) => Poll::Ready(Ok(())),
929        }
930    }
931}
932
933impl wrpc_transport::Index<Self> for IndexedParamWriter {
934    #[instrument(level = "trace", skip_all)]
935    fn index(&self, path: &[usize]) -> anyhow::Result<Self> {
936        ensure!(!path.is_empty());
937        match self {
938            Self::Corrupted => Err(anyhow!(corrupted_memory_error())),
939            Self::Handshaking { indexed, .. } => {
940                let (tx_tx, tx_rx) = oneshot::channel();
941                let mut indexed = indexed.lock().map_err(|err| {
942                    std::io::Error::new(std::io::ErrorKind::Other, err.to_string())
943                })?;
944                indexed.push((path.to_vec(), tx_tx));
945                Ok(Self::Handshaking {
946                    tx_rx,
947                    indexed: std::sync::Mutex::default(),
948                })
949            }
950            Self::Active(tx) => tx.index(path).map(Self::Active),
951        }
952    }
953}
954
955impl AsyncWrite for IndexedParamWriter {
956    #[instrument(level = "trace", skip_all, ret, fields(buf = format!("{buf:02x?}")))]
957    fn poll_write(
958        mut self: Pin<&mut Self>,
959        cx: &mut Context<'_>,
960        buf: &[u8],
961    ) -> Poll<std::io::Result<usize>> {
962        match self.as_mut().poll_active(cx)? {
963            Poll::Ready(()) => {
964                let Self::Active(tx) = &mut *self else {
965                    return Poll::Ready(Err(corrupted_memory_error()));
966                };
967                trace!("writing buffer");
968                pin!(tx).poll_write(cx, buf)
969            }
970            Poll::Pending => Poll::Pending,
971        }
972    }
973
974    #[instrument(level = "trace", skip_all, ret)]
975    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
976        match self.as_mut().poll_active(cx)? {
977            Poll::Ready(()) => {
978                let Self::Active(tx) = &mut *self else {
979                    return Poll::Ready(Err(corrupted_memory_error()));
980                };
981                trace!("flushing");
982                pin!(tx).poll_flush(cx)
983            }
984            Poll::Pending => Poll::Pending,
985        }
986    }
987
988    #[instrument(level = "trace", skip_all, ret)]
989    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
990        match self.as_mut().poll_active(cx)? {
991            Poll::Ready(()) => {
992                let Self::Active(tx) = &mut *self else {
993                    return Poll::Ready(Err(corrupted_memory_error()));
994                };
995                trace!("shutting down");
996                pin!(tx).poll_shutdown(cx)
997            }
998            Poll::Pending => Poll::Pending,
999        }
1000    }
1001}
1002
1003pub enum ParamWriter {
1004    Root(RootParamWriter),
1005    Nested(IndexedParamWriter),
1006}
1007
1008impl wrpc_transport::Index<Self> for ParamWriter {
1009    fn index(&self, path: &[usize]) -> anyhow::Result<Self> {
1010        ensure!(!path.is_empty());
1011        match self {
1012            ParamWriter::Root(w) => w.index(path),
1013            ParamWriter::Nested(w) => w.index(path),
1014        }
1015        .map(Self::Nested)
1016    }
1017}
1018
1019impl AsyncWrite for ParamWriter {
1020    #[instrument(level = "trace", skip_all, ret, fields(buf = format!("{buf:02x?}")))]
1021    fn poll_write(
1022        mut self: Pin<&mut Self>,
1023        cx: &mut Context<'_>,
1024        buf: &[u8],
1025    ) -> Poll<std::io::Result<usize>> {
1026        match &mut *self {
1027            ParamWriter::Root(w) => pin!(w).poll_write(cx, buf),
1028            ParamWriter::Nested(w) => pin!(w).poll_write(cx, buf),
1029        }
1030    }
1031
1032    #[instrument(level = "trace", skip_all, ret)]
1033    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1034        match &mut *self {
1035            ParamWriter::Root(w) => pin!(w).poll_flush(cx),
1036            ParamWriter::Nested(w) => pin!(w).poll_flush(cx),
1037        }
1038    }
1039
1040    #[instrument(level = "trace", skip_all, ret)]
1041    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1042        match &mut *self {
1043            ParamWriter::Root(w) => pin!(w).poll_shutdown(cx),
1044            ParamWriter::Nested(w) => pin!(w).poll_shutdown(cx),
1045        }
1046    }
1047}
1048
1049impl wrpc_transport::Invoke for Client {
1050    type Context = Option<HeaderMap>;
1051    type Outgoing = ParamWriter;
1052    type Incoming = Reader;
1053
1054    #[instrument(level = "trace", skip(self, paths, params), fields(params = format!("{params:02x?}")))]
1055    async fn invoke<P: AsRef<[Option<usize>]> + Send + Sync>(
1056        &self,
1057        cx: Self::Context,
1058        instance: &str,
1059        func: &str,
1060        mut params: Bytes,
1061        paths: impl AsRef<[P]> + Send,
1062    ) -> anyhow::Result<(Self::Outgoing, Self::Incoming)> {
1063        let paths = paths.as_ref();
1064        let mut cmds = Vec::with_capacity(paths.len().saturating_add(2));
1065
1066        let rx = Subject::from(new_inbox(&self.inbox));
1067        let (handshake_tx, handshake_rx) = mpsc::channel(1);
1068        cmds.push(Command::Subscribe(rx.clone(), handshake_tx));
1069
1070        let result = Subject::from(result_subject(&rx));
1071        let (result_tx, result_rx) = mpsc::channel(16);
1072        cmds.push(Command::Subscribe(result.clone(), result_tx));
1073
1074        let nested = paths.iter().map(|path| {
1075            let (tx, rx) = mpsc::channel(16);
1076            let subject = Subject::from(subscribe_path(&result, path.as_ref()));
1077            cmds.push(Command::Subscribe(subject.clone(), tx));
1078            Subscriber {
1079                rx: ReceiverStream::new(rx),
1080                commands: self.commands.clone(),
1081                subject,
1082                tasks: Arc::clone(&self.tasks),
1083            }
1084        });
1085        let nested: IndexTrie = zip(paths.iter(), nested).collect();
1086        ensure!(
1087            paths.is_empty() == nested.is_empty(),
1088            "failed to construct subscription tree"
1089        );
1090
1091        self.commands
1092            .send(Command::Batch(cmds.into_boxed_slice()))
1093            .await
1094            .context("failed to subscribe")?;
1095
1096        let ServerInfo {
1097            mut max_payload, ..
1098        } = self.nats.server_info();
1099        max_payload = max_payload.saturating_sub(rx.len());
1100        let param_tx = Subject::from(invocation_subject(&self.prefix, instance, func));
1101        if let Some(headers) = cx {
1102            // based on https://github.com/nats-io/nats.rs/blob/0942c473ce56163fdd1fbc62762f8164e3afa7bf/async-nats/src/header.rs#L215-L224
1103            max_payload = max_payload
1104                .saturating_sub(b"NATS/1.0\r\n".len())
1105                .saturating_sub(b"\r\n".len());
1106            for (k, vs) in headers.iter() {
1107                let k: &[u8] = k.as_ref();
1108                for v in vs {
1109                    max_payload = max_payload
1110                        .saturating_sub(k.len())
1111                        .saturating_sub(b": ".len())
1112                        .saturating_sub(v.as_str().len())
1113                        .saturating_sub(b"\r\n".len());
1114                }
1115            }
1116            trace!("publishing handshake");
1117            self.nats
1118                .publish_with_reply_and_headers(
1119                    param_tx,
1120                    rx.clone(),
1121                    headers,
1122                    params.split_to(max_payload.min(params.len())),
1123                )
1124                .await
1125        } else {
1126            trace!("publishing handshake");
1127            self.nats
1128                .publish_with_reply(
1129                    param_tx,
1130                    rx.clone(),
1131                    params.split_to(max_payload.min(params.len())),
1132                )
1133                .await
1134        }
1135        .context("failed to publish handshake")?;
1136        let nats = Arc::clone(&self.nats);
1137        tokio::spawn(async move {
1138            if let Err(err) = nats.flush().await {
1139                error!(?err, "failed to flush");
1140            }
1141        });
1142        Ok((
1143            ParamWriter::Root(RootParamWriter::new(
1144                (*self.nats).clone(),
1145                Subscriber {
1146                    rx: ReceiverStream::new(handshake_rx),
1147                    commands: self.commands.clone(),
1148                    subject: rx,
1149                    tasks: Arc::clone(&self.tasks),
1150                },
1151                params,
1152                Arc::clone(&self.tasks),
1153            )),
1154            Reader {
1155                buffer: Bytes::default(),
1156                incoming: Some(Subscriber {
1157                    rx: ReceiverStream::new(result_rx),
1158                    commands: self.commands.clone(),
1159                    subject: result,
1160                    tasks: Arc::clone(&self.tasks),
1161                }),
1162                nested: Arc::new(std::sync::Mutex::new(nested)),
1163                path: Box::default(),
1164            },
1165        ))
1166    }
1167}
1168
1169async fn handle_message(
1170    nats: &async_nats::Client,
1171    rx: Subject,
1172    commands: mpsc::Sender<Command>,
1173    async_nats::Message {
1174        reply: tx,
1175        payload,
1176        headers,
1177        ..
1178    }: async_nats::Message,
1179    paths: &[Box<[Option<usize>]>],
1180    tasks: Arc<JoinSet<()>>,
1181) -> anyhow::Result<(Option<HeaderMap>, SubjectWriter, Reader)> {
1182    let tx = tx.context("peer did not specify a reply subject")?;
1183
1184    let mut cmds = Vec::with_capacity(paths.len().saturating_add(1));
1185
1186    let param = Subject::from(param_subject(&rx));
1187    let (param_tx, param_rx) = mpsc::channel(16);
1188    cmds.push(Command::Subscribe(param.clone(), param_tx));
1189
1190    let nested = paths.iter().map(|path| {
1191        let (tx, rx) = mpsc::channel(16);
1192        let subject = Subject::from(subscribe_path(&param, path.as_ref()));
1193        cmds.push(Command::Subscribe(subject.clone(), tx));
1194        Subscriber {
1195            rx: ReceiverStream::new(rx),
1196            commands: commands.clone(),
1197            subject,
1198            tasks: Arc::clone(&tasks),
1199        }
1200    });
1201    let nested: IndexTrie = zip(paths.iter(), nested).collect();
1202    ensure!(
1203        paths.is_empty() == nested.is_empty(),
1204        "failed to construct subscription tree"
1205    );
1206
1207    commands
1208        .send(Command::Batch(cmds.into_boxed_slice()))
1209        .await
1210        .context("failed to subscribe")?;
1211
1212    trace!("publishing handshake response");
1213    nats.publish_with_reply(tx.clone(), rx, Bytes::default())
1214        .await
1215        .context("failed to publish handshake accept")?;
1216    Ok((
1217        headers,
1218        SubjectWriter::new(
1219            nats.clone(),
1220            Subject::from(result_subject(&tx)),
1221            Arc::clone(&tasks),
1222        ),
1223        Reader {
1224            buffer: payload,
1225            incoming: Some(Subscriber {
1226                rx: ReceiverStream::new(param_rx),
1227                commands,
1228                subject: param,
1229                tasks,
1230            }),
1231            nested: Arc::new(std::sync::Mutex::new(nested)),
1232            path: Box::default(),
1233        },
1234    ))
1235}
1236
1237impl wrpc_transport::Serve for Client {
1238    type Context = Option<HeaderMap>;
1239    type Outgoing = SubjectWriter;
1240    type Incoming = Reader;
1241
1242    #[instrument(level = "trace", skip(self, paths))]
1243    async fn serve(
1244        &self,
1245        instance: &str,
1246        func: &str,
1247        paths: impl Into<Arc<[Box<[Option<usize>]>]>> + Send,
1248    ) -> anyhow::Result<
1249        impl Stream<Item = anyhow::Result<(Self::Context, Self::Outgoing, Self::Incoming)>> + 'static,
1250    > {
1251        let subject = invocation_subject(&self.prefix, instance, func);
1252        let sub = if let Some(group) = &self.queue_group {
1253            debug!(subject, ?group, "queue-subscribing on invocation subject");
1254            self.nats
1255                .queue_subscribe(subject, group.to_string())
1256                .await?
1257        } else {
1258            debug!(subject, "subscribing on invocation subject");
1259            self.nats.subscribe(subject).await?
1260        };
1261        let nats = Arc::clone(&self.nats);
1262        let paths = paths.into();
1263        let commands = self.commands.clone();
1264        let inbox = Arc::clone(&self.inbox);
1265        let tasks = Arc::clone(&self.tasks);
1266        Ok(sub.then(move |msg| {
1267            let tasks = Arc::clone(&tasks);
1268            let nats = Arc::clone(&nats);
1269            let paths = Arc::clone(&paths);
1270            let commands = commands.clone();
1271            let rx = Subject::from(new_inbox(&inbox));
1272            async move { handle_message(&nats, rx, commands, msg, &paths, tasks).await }
1273        }))
1274    }
1275}