1#![allow(clippy::type_complexity)]
4
5#[cfg(any(
6 not(any(
7 feature = "async-nats-0_39",
8 feature = "async-nats-0_38",
9 feature = "async-nats-0_37",
10 feature = "async-nats-0_36",
11 )),
12 all(feature = "async-nats-0_39", feature = "async-nats-0_38"),
13 all(feature = "async-nats-0_39", feature = "async-nats-0_37"),
14 all(feature = "async-nats-0_39", feature = "async-nats-0_36"),
15 all(feature = "async-nats-0_38", feature = "async-nats-0_37"),
16 all(feature = "async-nats-0_38", feature = "async-nats-0_36"),
17 all(feature = "async-nats-0_37", feature = "async-nats-0_36"),
18))]
19compile_error!(
20 "Either feature \"async-nats-0_39\", \"async-nats-0_38\", \"async-nats-0_37\" or \"async-nats-0_36\" must be enabled for this crate."
21);
22
23#[cfg(feature = "async-nats-0_39")]
24use async_nats_0_39 as async_nats;
25
26#[cfg(feature = "async-nats-0_38")]
27use async_nats_0_38 as async_nats;
28
29#[cfg(feature = "async-nats-0_37")]
30use async_nats_0_37 as async_nats;
31
32#[cfg(feature = "async-nats-0_36")]
33use async_nats_0_36 as async_nats;
34
35use core::future::Future;
36use core::iter::zip;
37use core::ops::{Deref, DerefMut};
38use core::pin::{pin, Pin};
39use core::task::{ready, Context, Poll};
40use core::{mem, str};
41
42use std::collections::HashMap;
43use std::sync::Arc;
44
45use anyhow::{anyhow, ensure, Context as _};
46use async_nats::{HeaderMap, PublishMessage, ServerInfo, StatusCode, Subject};
47use bytes::{Buf as _, Bytes};
48use futures::sink::SinkExt as _;
49use futures::{Stream, StreamExt};
50use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
51use tokio::select;
52use tokio::sync::{mpsc, oneshot};
53use tokio::task::JoinSet;
54use tokio_stream::wrappers::ReceiverStream;
55use tracing::{debug, error, instrument, trace, warn};
56use wrpc_transport::Index as _;
57
58pub const PROTOCOL: &str = "wrpc.0.0.1";
59
60fn spawn_async(fut: impl Future<Output = ()> + Send + 'static) {
61 match tokio::runtime::Handle::try_current() {
62 Ok(rt) => {
63 rt.spawn(fut);
64 }
65 Err(_) => match tokio::runtime::Runtime::new() {
66 Ok(rt) => {
67 rt.spawn(fut);
68 }
69 Err(err) => error!(?err, "failed to create a new Tokio runtime"),
70 },
71 }
72}
73
74fn new_inbox(inbox: &str) -> String {
75 let id = nuid::next();
76 let mut s = String::with_capacity(inbox.len().saturating_add(id.len()));
77 s.push_str(inbox);
78 s.push_str(&id);
79 s
80}
81
82#[must_use]
83#[inline]
84pub fn param_subject(prefix: &str) -> String {
85 format!("{prefix}.params")
86}
87
88#[must_use]
89#[inline]
90pub fn result_subject(prefix: &str) -> String {
91 format!("{prefix}.results")
92}
93
94#[must_use]
95#[inline]
96pub fn index_path(prefix: &str, path: &[usize]) -> String {
97 let mut s = String::with_capacity(prefix.len() + path.len() * 2);
98 if !prefix.is_empty() {
99 s.push_str(prefix);
100 }
101 for p in path {
102 if !s.is_empty() {
103 s.push('.');
104 }
105 s.push_str(&p.to_string());
106 }
107 s
108}
109
110#[must_use]
111#[inline]
112pub fn subscribe_path(prefix: &str, path: &[Option<usize>]) -> String {
113 let mut s = String::with_capacity(prefix.len() + path.len() * 2);
114 if !prefix.is_empty() {
115 s.push_str(prefix);
116 }
117 for p in path {
118 if !s.is_empty() {
119 s.push('.');
120 }
121 if let Some(p) = p {
122 s.push_str(&p.to_string());
123 } else {
124 s.push('*');
125 }
126 }
127 s
128}
129
130#[must_use]
131#[inline]
132pub fn invocation_subject(prefix: &str, instance: &str, func: &str) -> String {
133 let mut s =
134 String::with_capacity(prefix.len() + PROTOCOL.len() + instance.len() + func.len() + 3);
135 if !prefix.is_empty() {
136 s.push_str(prefix);
137 s.push('.');
138 }
139 s.push_str(PROTOCOL);
140 s.push('.');
141 if !instance.is_empty() {
142 s.push_str(instance);
143 s.push('.');
144 }
145 s.push_str(func);
146 s
147}
148
149fn corrupted_memory_error() -> std::io::Error {
150 std::io::Error::new(std::io::ErrorKind::Other, "corrupted memory state")
151}
152
153pub struct Subscriber {
155 rx: ReceiverStream<Message>,
156 subject: Subject,
157 commands: mpsc::Sender<Command>,
158 tasks: Arc<JoinSet<()>>,
159}
160
161impl Drop for Subscriber {
162 fn drop(&mut self) {
163 let commands = self.commands.clone();
164 let subject = mem::replace(&mut self.subject, Subject::from_static(""));
165 let tasks = Arc::clone(&self.tasks);
166 spawn_async(async move {
167 trace!(?subject, "shutting down subscriber");
168 if let Err(err) = commands.send(Command::Unsubscribe(subject)).await {
169 warn!(?err, "failed to shutdown subscriber");
170 }
171 drop(tasks);
172 });
173 }
174}
175
176impl Deref for Subscriber {
177 type Target = ReceiverStream<Message>;
178
179 fn deref(&self) -> &Self::Target {
180 &self.rx
181 }
182}
183
184impl DerefMut for Subscriber {
185 fn deref_mut(&mut self) -> &mut Self::Target {
186 &mut self.rx
187 }
188}
189
190enum Command {
191 Subscribe(Subject, mpsc::Sender<Message>),
192 Unsubscribe(Subject),
193 Batch(Box<[Command]>),
194}
195
196pub struct Message {
198 reply: Option<Subject>,
199 payload: Bytes,
200 status: Option<async_nats::StatusCode>,
201 description: Option<String>,
202}
203
204#[derive(Clone, Debug)]
205pub struct Client {
206 nats: Arc<async_nats::Client>,
207 prefix: Arc<str>,
208 inbox: Arc<str>,
209 queue_group: Option<Arc<str>>,
210 commands: mpsc::Sender<Command>,
211 tasks: Arc<JoinSet<()>>,
212}
213
214impl Client {
215 pub async fn new(
216 nats: impl Into<Arc<async_nats::Client>>,
217 prefix: impl Into<Arc<str>>,
218 queue_group: Option<Arc<str>>,
219 ) -> anyhow::Result<Self> {
220 let nats = nats.into();
221 let mut inbox = nats.new_inbox();
222 inbox.push('.');
223 let mut subject = String::with_capacity(inbox.len().saturating_add(1));
224 subject.push_str(&inbox);
225 subject.push('>');
226 let mut sub = nats
227 .subscribe(Subject::from(subject))
228 .await
229 .context("failed to subscribe on an inbox subject")?;
230
231 let mut tasks = JoinSet::new();
232 let (cmd_tx, mut cmd_rx) = mpsc::channel(8192);
233 tasks.spawn({
234 async move {
235 fn handle_command(subs: &mut HashMap<String, mpsc::Sender<Message>>, cmd: Command) {
236 match cmd {
237 Command::Subscribe(s, tx) => {
238 subs.insert(s.into_string(), tx);
239 }
240 Command::Unsubscribe(s) => {
241 subs.remove(s.as_str());
242 }
243 Command::Batch(cmds) => {
244 for cmd in cmds {
245 handle_command(subs, cmd);
246 }
247 }
248 }
249 }
250 async fn handle_message(
251 subs: &mut HashMap<String, mpsc::Sender<Message>>,
252 async_nats::Message {
253 subject,
254 reply,
255 payload,
256 status,
257 description,
258 ..
259 }: async_nats::Message,
260 ) {
261 let Some(sub) = subs.get_mut(subject.as_str()) else {
262 debug!(?subject, "drop message with no subscriber");
263 return;
264 };
265 let Ok(sub) = sub.reserve().await else {
266 debug!(?subject, "drop message with closed subscriber");
267 subs.remove(subject.as_str());
268 return;
269 };
270 sub.send(Message {
271 reply,
272 payload,
273 status,
274 description,
275 });
276 }
277
278 let mut subs = HashMap::new();
279 loop {
280 select! {
281 Some(msg) = sub.next() => handle_message(&mut subs, msg).await,
282 Some(cmd) = cmd_rx.recv() => handle_command(&mut subs, cmd),
283 else => return,
284 }
285 }
286 }
287 });
288 Ok(Self {
289 nats,
290 prefix: prefix.into(),
291 inbox: inbox.into(),
292 queue_group,
293 commands: cmd_tx,
294 tasks: Arc::new(tasks),
295 })
296 }
297}
298
299pub struct ByteSubscription(Subscriber);
300
301impl Stream for ByteSubscription {
302 type Item = std::io::Result<Bytes>;
303
304 #[instrument(level = "trace", skip_all)]
305 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
306 match self.0.poll_next_unpin(cx) {
307 Poll::Ready(Some(Message { payload, .. })) => Poll::Ready(Some(Ok(payload))),
308 Poll::Ready(None) => Poll::Ready(None),
309 Poll::Pending => Poll::Pending,
310 }
311 }
312}
313
314#[derive(Default)]
315enum IndexTrie {
316 #[default]
317 Empty,
318 Leaf(Subscriber),
319 IndexNode {
320 subscriber: Option<Subscriber>,
321 nested: Vec<Option<IndexTrie>>,
322 },
323 WildcardNode {
324 subscriber: Option<Subscriber>,
325 nested: Option<Box<IndexTrie>>,
326 },
327}
328
329impl<'a> From<(&'a [Option<usize>], Subscriber)> for IndexTrie {
330 fn from((path, sub): (&'a [Option<usize>], Subscriber)) -> Self {
331 match path {
332 [] => Self::Leaf(sub),
333 [None, path @ ..] => Self::WildcardNode {
334 subscriber: None,
335 nested: Some(Box::new(Self::from((path, sub)))),
336 },
337 [Some(i), path @ ..] => Self::IndexNode {
338 subscriber: None,
339 nested: {
340 let n = i.saturating_add(1);
341 let mut nested = Vec::with_capacity(n);
342 nested.resize_with(n, Option::default);
343 nested[*i] = Some(Self::from((path, sub)));
344 nested
345 },
346 },
347 }
348 }
349}
350
351impl<P: AsRef<[Option<usize>]>> FromIterator<(P, Subscriber)> for IndexTrie {
352 fn from_iter<T: IntoIterator<Item = (P, Subscriber)>>(iter: T) -> Self {
353 let mut root = Self::Empty;
354 for (path, sub) in iter {
355 if !root.insert(path.as_ref(), sub) {
356 return Self::Empty;
357 }
358 }
359 root
360 }
361}
362
363impl IndexTrie {
364 #[inline]
365 fn is_empty(&self) -> bool {
366 matches!(self, IndexTrie::Empty)
367 }
368
369 #[instrument(level = "trace", skip_all)]
370 fn take(&mut self, path: &[usize]) -> Option<Subscriber> {
371 let Some((i, path)) = path.split_first() else {
372 return match mem::take(self) {
373 IndexTrie::Empty | IndexTrie::WildcardNode { .. } => None,
384 IndexTrie::Leaf(subscriber) => Some(subscriber),
385 IndexTrie::IndexNode { subscriber, nested } => {
386 if !nested.is_empty() {
387 *self = IndexTrie::IndexNode {
388 subscriber: None,
389 nested,
390 }
391 }
392 subscriber
393 }
394 };
395 };
396 match self {
397 Self::Empty | Self::Leaf(..) | Self::WildcardNode { .. } => None,
402 Self::IndexNode { ref mut nested, .. } => nested
403 .get_mut(*i)
404 .and_then(|nested| nested.as_mut().and_then(|nested| nested.take(path))),
405 }
406 }
407
408 #[instrument(level = "trace", skip_all)]
411 fn insert(&mut self, path: &[Option<usize>], sub: Subscriber) -> bool {
412 match self {
413 Self::Empty => {
414 *self = Self::from((path, sub));
415 true
416 }
417 Self::Leaf(..) => {
418 let Some((i, path)) = path.split_first() else {
419 return false;
420 };
421 let Self::Leaf(subscriber) = mem::take(self) else {
422 return false;
423 };
424 if let Some(i) = i {
425 let n = i.saturating_add(1);
426 let mut nested = Vec::with_capacity(n);
427 nested.resize_with(n, Option::default);
428 nested[*i] = Some(Self::from((path, sub)));
429 *self = Self::IndexNode {
430 subscriber: Some(subscriber),
431 nested,
432 };
433 } else {
434 *self = Self::WildcardNode {
435 subscriber: Some(subscriber),
436 nested: Some(Box::new(Self::from((path, sub)))),
437 };
438 }
439 true
440 }
441 Self::WildcardNode {
442 ref mut subscriber,
443 ref mut nested,
444 } => match (&subscriber, path) {
445 (None, []) => {
446 *subscriber = Some(sub);
447 true
448 }
449 (_, [None, path @ ..]) => {
450 if let Some(nested) = nested {
451 nested.insert(path, sub)
452 } else {
453 *nested = Some(Box::new(Self::from((path, sub))));
454 true
455 }
456 }
457 _ => false,
458 },
459 Self::IndexNode {
460 ref mut subscriber,
461 ref mut nested,
462 } => match (&subscriber, path) {
463 (None, []) => {
464 *subscriber = Some(sub);
465 true
466 }
467 (_, [Some(i), path @ ..]) => {
468 let cap = i.saturating_add(1);
469 if nested.len() < cap {
470 nested.resize_with(cap, Option::default);
471 }
472 let nested = &mut nested[*i];
473 if let Some(nested) = nested {
474 nested.insert(path, sub)
475 } else {
476 *nested = Some(Self::from((path, sub)));
477 true
478 }
479 }
480 _ => false,
481 },
482 }
483 }
484}
485
486pub struct Reader {
487 buffer: Bytes,
488 incoming: Option<Subscriber>,
489 nested: Arc<std::sync::Mutex<IndexTrie>>,
490 path: Box<[usize]>,
491}
492
493impl wrpc_transport::Index<Self> for Reader {
494 #[instrument(level = "trace", skip(self))]
495 fn index(&self, path: &[usize]) -> anyhow::Result<Self> {
496 ensure!(!path.is_empty());
497 trace!("locking index tree");
498 let mut nested = self
499 .nested
500 .lock()
501 .map_err(|err| anyhow!(err.to_string()).context("failed to lock map"))?;
502 trace!("taking index subscription");
503 let mut p = self.path.to_vec();
504 p.extend_from_slice(path);
505 let incoming = nested.take(&p);
506 Ok(Self {
507 buffer: Bytes::default(),
508 incoming,
509 nested: Arc::clone(&self.nested),
510 path: p.into_boxed_slice(),
511 })
512 }
513}
514
515impl AsyncRead for Reader {
516 #[instrument(level = "trace", skip_all, ret)]
517 fn poll_read(
518 mut self: Pin<&mut Self>,
519 cx: &mut Context<'_>,
520 buf: &mut ReadBuf<'_>,
521 ) -> Poll<std::io::Result<()>> {
522 let cap = buf.remaining();
523 if cap == 0 {
524 trace!("attempt to read empty buffer");
525 return Poll::Ready(Ok(()));
526 }
527
528 if !self.buffer.is_empty() {
529 if self.buffer.len() > cap {
530 trace!(cap, len = self.buffer.len(), "reading part of buffer");
531 buf.put_slice(&self.buffer.split_to(cap));
532 } else {
533 trace!(cap, len = self.buffer.len(), "reading full buffer");
534 buf.put_slice(&mem::take(&mut self.buffer));
535 }
536 return Poll::Ready(Ok(()));
537 }
538 let Some(incoming) = self.incoming.as_mut() else {
539 return Poll::Ready(Err(std::io::Error::new(
540 std::io::ErrorKind::NotFound,
541 format!("subscription not found for path {:?}", self.path),
542 )));
543 };
544 trace!("polling for next message");
545 match incoming.poll_next_unpin(cx) {
546 Poll::Ready(Some(Message { mut payload, .. })) => {
547 trace!(?payload, "received message");
548 if payload.is_empty() {
549 trace!("received stream shutdown message");
550 return Poll::Ready(Ok(()));
551 }
552 if payload.len() > cap {
553 trace!(len = payload.len(), cap, "partially reading the message");
554 buf.put_slice(&payload.split_to(cap));
555 self.buffer = payload;
556 } else {
557 trace!(len = payload.len(), cap, "filling the buffer with payload");
558 buf.put_slice(&payload);
559 }
560 Poll::Ready(Ok(()))
561 }
562 Poll::Ready(None) => {
563 trace!("subscription finished");
564 Poll::Ready(Ok(()))
565 }
566 Poll::Pending => Poll::Pending,
567 }
568 }
569}
570
571#[derive(Clone, Debug)]
572pub struct SubjectWriter {
573 nats: async_nats::Client,
574 tx: Subject,
575 shutdown: bool,
576 tasks: Arc<JoinSet<()>>,
577}
578
579impl SubjectWriter {
580 fn new(nats: async_nats::Client, tx: Subject, tasks: Arc<JoinSet<()>>) -> Self {
581 Self {
582 nats,
583 tx,
584 shutdown: false,
585 tasks,
586 }
587 }
588}
589
590impl wrpc_transport::Index<Self> for SubjectWriter {
591 #[instrument(level = "trace", skip(self))]
592 fn index(&self, path: &[usize]) -> anyhow::Result<Self> {
593 ensure!(!path.is_empty());
594 let tx = Subject::from(index_path(self.tx.as_str(), path));
595 Ok(Self {
596 nats: self.nats.clone(),
597 tx,
598 shutdown: false,
599 tasks: Arc::clone(&self.tasks),
600 })
601 }
602}
603
604impl AsyncWrite for SubjectWriter {
605 #[instrument(level = "trace", skip_all, ret, fields(subject = self.tx.as_str(), buf = format!("{buf:02x?}")))]
606 fn poll_write(
607 mut self: Pin<&mut Self>,
608 cx: &mut Context<'_>,
609 mut buf: &[u8],
610 ) -> Poll<std::io::Result<usize>> {
611 trace!("polling for readiness");
612 match self.nats.poll_ready_unpin(cx) {
613 Poll::Pending => return Poll::Pending,
614 Poll::Ready(Err(err)) => {
615 return Poll::Ready(Err(std::io::Error::new(
616 std::io::ErrorKind::BrokenPipe,
617 err,
618 )))
619 }
620 Poll::Ready(Ok(())) => {}
621 }
622 let ServerInfo { max_payload, .. } = self.nats.server_info();
623 if max_payload == 0 {
624 return Poll::Ready(Err(std::io::ErrorKind::WriteZero.into()));
625 }
626 if buf.len() > max_payload {
627 (buf, _) = buf.split_at(max_payload);
628 }
629 trace!("starting send");
630 let subject = self.tx.clone();
631 match self.nats.start_send_unpin(PublishMessage {
632 subject,
633 payload: Bytes::copy_from_slice(buf),
634 reply: None,
635 headers: None,
636 }) {
637 Ok(()) => Poll::Ready(Ok(buf.len())),
638 Err(err) => Poll::Ready(Err(std::io::Error::new(
639 std::io::ErrorKind::BrokenPipe,
640 err,
641 ))),
642 }
643 }
644
645 #[instrument(level = "trace", skip_all, ret, fields(subject = self.tx.as_str()))]
646 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
647 trace!("flushing");
648 self.nats
649 .poll_flush_unpin(cx)
650 .map_err(|_| std::io::ErrorKind::BrokenPipe.into())
651 }
652
653 #[instrument(level = "trace", skip_all, ret, fields(subject = self.tx.as_str()))]
654 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
655 trace!("writing stream shutdown message");
656 ready!(self.as_mut().poll_write(cx, &[]))?;
657 self.shutdown = true;
658 Poll::Ready(Ok(()))
659 }
660}
661
662impl Drop for SubjectWriter {
663 fn drop(&mut self) {
664 if !self.shutdown {
665 let nats = self.nats.clone();
666 let subject = mem::replace(&mut self.tx, Subject::from_static(""));
667 let tasks = Arc::clone(&self.tasks);
668 spawn_async(async move {
669 trace!("writing stream shutdown message");
670 if let Err(err) = nats.publish(subject, Bytes::default()).await {
671 warn!(?err, "failed to publish stream shutdown message");
672 }
673 drop(tasks);
674 });
675 }
676 }
677}
678
679#[derive(Default)]
680pub enum RootParamWriter {
681 #[default]
682 Corrupted,
683 Handshaking {
684 nats: async_nats::Client,
685 sub: Subscriber,
686 indexed: std::sync::Mutex<Vec<(Vec<usize>, oneshot::Sender<SubjectWriter>)>>,
687 buffer: Bytes,
688 tasks: Arc<JoinSet<()>>,
689 },
690 Draining {
691 tx: SubjectWriter,
692 buffer: Bytes,
693 },
694 Active(SubjectWriter),
695}
696
697impl RootParamWriter {
698 fn new(
699 nats: async_nats::Client,
700 sub: Subscriber,
701 buffer: Bytes,
702 tasks: Arc<JoinSet<()>>,
703 ) -> Self {
704 Self::Handshaking {
705 nats,
706 sub,
707 indexed: std::sync::Mutex::default(),
708 buffer,
709 tasks,
710 }
711 }
712}
713
714impl RootParamWriter {
715 #[instrument(level = "trace", skip_all, ret)]
716 fn poll_active(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
717 match &mut *self {
718 Self::Corrupted => Poll::Ready(Err(corrupted_memory_error())),
719 Self::Handshaking { sub, .. } => {
720 trace!("polling for handshake response");
721 match sub.poll_next_unpin(cx) {
722 Poll::Ready(Some(Message {
723 status: Some(StatusCode::NO_RESPONDERS),
724 ..
725 })) => Poll::Ready(Err(std::io::ErrorKind::NotConnected.into())),
726 Poll::Ready(Some(Message {
727 status: Some(StatusCode::TIMEOUT),
728 ..
729 })) => Poll::Ready(Err(std::io::ErrorKind::TimedOut.into())),
730 Poll::Ready(Some(Message {
731 status: Some(StatusCode::REQUEST_TERMINATED),
732 ..
733 })) => Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())),
734 Poll::Ready(Some(Message {
735 status: Some(code),
736 description,
737 ..
738 })) if !code.is_success() => Poll::Ready(Err(std::io::Error::new(
739 std::io::ErrorKind::Other,
740 if let Some(description) = description {
741 format!("received a response with code `{code}` ({description})")
742 } else {
743 format!("received a response with code `{code}`")
744 },
745 ))),
746 Poll::Ready(Some(Message {
747 reply: Some(tx), ..
748 })) => {
749 let Self::Handshaking {
750 nats,
751 indexed,
752 buffer,
753 tasks,
754 ..
755 } = mem::take(&mut *self)
756 else {
757 return Poll::Ready(Err(corrupted_memory_error()));
758 };
759 let tx = SubjectWriter::new(nats, Subject::from(param_subject(&tx)), tasks);
760 let indexed = indexed.into_inner().map_err(|err| {
761 std::io::Error::new(std::io::ErrorKind::Other, err.to_string())
762 })?;
763 for (path, tx_tx) in indexed {
764 let tx = tx.index(&path).map_err(|err| {
765 std::io::Error::new(std::io::ErrorKind::Other, err)
766 })?;
767 tx_tx.send(tx).map_err(|_| {
768 std::io::Error::from(std::io::ErrorKind::BrokenPipe)
769 })?;
770 }
771 trace!("handshake succeeded");
772 if buffer.is_empty() {
773 *self = Self::Active(tx);
774 Poll::Ready(Ok(()))
775 } else {
776 *self = Self::Draining { tx, buffer };
777 self.poll_active(cx)
778 }
779 }
780 Poll::Ready(Some(..)) => Poll::Ready(Err(std::io::Error::new(
781 std::io::ErrorKind::InvalidInput,
782 "peer did not specify a reply subject",
783 ))),
784 Poll::Ready(None) => {
785 *self = Self::Corrupted;
786 Poll::Ready(Err(std::io::Error::from(std::io::ErrorKind::BrokenPipe)))
787 }
788 Poll::Pending => Poll::Pending,
789 }
790 }
791 Self::Draining { tx, buffer } => {
792 let mut tx = pin!(tx);
793 while !buffer.is_empty() {
794 trace!(?tx.tx, "draining parameter buffer");
795 match tx.as_mut().poll_write(cx, buffer) {
796 Poll::Ready(Ok(n)) => {
797 buffer.advance(n);
798 }
799 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
800 Poll::Pending => return Poll::Pending,
801 }
802 }
803 let Self::Draining { tx, .. } = mem::take(&mut *self) else {
804 return Poll::Ready(Err(corrupted_memory_error()));
805 };
806 trace!("parameter buffer draining succeeded");
807 *self = Self::Active(tx);
808 Poll::Ready(Ok(()))
809 }
810 Self::Active(..) => Poll::Ready(Ok(())),
811 }
812 }
813}
814
815impl wrpc_transport::Index<IndexedParamWriter> for RootParamWriter {
816 #[instrument(level = "trace", skip(self))]
817 fn index(&self, path: &[usize]) -> anyhow::Result<IndexedParamWriter> {
818 ensure!(!path.is_empty());
819 match self {
820 Self::Corrupted => Err(anyhow!(corrupted_memory_error())),
821 Self::Handshaking { indexed, .. } => {
822 let (tx_tx, tx_rx) = oneshot::channel();
823 let mut indexed = indexed.lock().map_err(|err| {
824 std::io::Error::new(std::io::ErrorKind::Other, err.to_string())
825 })?;
826 indexed.push((path.to_vec(), tx_tx));
827 Ok(IndexedParamWriter::Handshaking {
828 tx_rx,
829 indexed: std::sync::Mutex::default(),
830 })
831 }
832 Self::Draining { tx, .. } | Self::Active(tx) => {
833 tx.index(path).map(IndexedParamWriter::Active)
834 }
835 }
836 }
837}
838
839impl AsyncWrite for RootParamWriter {
840 #[instrument(level = "trace", skip_all, ret, fields(buf = format!("{buf:02x?}")))]
841 fn poll_write(
842 mut self: Pin<&mut Self>,
843 cx: &mut Context<'_>,
844 buf: &[u8],
845 ) -> Poll<std::io::Result<usize>> {
846 match self.as_mut().poll_active(cx)? {
847 Poll::Ready(()) => {
848 let Self::Active(tx) = &mut *self else {
849 return Poll::Ready(Err(corrupted_memory_error()));
850 };
851 trace!("writing buffer");
852 pin!(tx).poll_write(cx, buf)
853 }
854 Poll::Pending => Poll::Pending,
855 }
856 }
857
858 #[instrument(level = "trace", skip_all, ret)]
859 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
860 match self.as_mut().poll_active(cx)? {
861 Poll::Ready(()) => {
862 let Self::Active(tx) = &mut *self else {
863 return Poll::Ready(Err(corrupted_memory_error()));
864 };
865 trace!("flushing");
866 pin!(tx).poll_flush(cx)
867 }
868 Poll::Pending => Poll::Pending,
869 }
870 }
871
872 #[instrument(level = "trace", skip_all, ret)]
873 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
874 match self.as_mut().poll_active(cx)? {
875 Poll::Ready(()) => {
876 let Self::Active(tx) = &mut *self else {
877 return Poll::Ready(Err(corrupted_memory_error()));
878 };
879 trace!("shutting down");
880 pin!(tx).poll_shutdown(cx)
881 }
882 Poll::Pending => Poll::Pending,
883 }
884 }
885}
886
887#[derive(Debug, Default)]
888pub enum IndexedParamWriter {
889 #[default]
890 Corrupted,
891 Handshaking {
892 tx_rx: oneshot::Receiver<SubjectWriter>,
893 indexed: std::sync::Mutex<Vec<(Vec<usize>, oneshot::Sender<SubjectWriter>)>>,
894 },
895 Active(SubjectWriter),
896}
897
898impl IndexedParamWriter {
899 #[instrument(level = "trace", skip_all, ret)]
900 fn poll_active(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
901 match &mut *self {
902 Self::Corrupted => Poll::Ready(Err(corrupted_memory_error())),
903 Self::Handshaking { tx_rx, .. } => {
904 trace!("polling for handshake");
905 match pin!(tx_rx).poll(cx) {
906 Poll::Ready(Ok(tx)) => {
907 let Self::Handshaking { indexed, .. } = mem::take(&mut *self) else {
908 return Poll::Ready(Err(corrupted_memory_error()));
909 };
910 let indexed = indexed.into_inner().map_err(|err| {
911 std::io::Error::new(std::io::ErrorKind::Other, err.to_string())
912 })?;
913 for (path, tx_tx) in indexed {
914 let tx = tx.index(&path).map_err(|err| {
915 std::io::Error::new(std::io::ErrorKind::Other, err)
916 })?;
917 tx_tx.send(tx).map_err(|_| {
918 std::io::Error::from(std::io::ErrorKind::BrokenPipe)
919 })?;
920 }
921 *self = Self::Active(tx);
922 Poll::Ready(Ok(()))
923 }
924 Poll::Ready(Err(..)) => Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())),
925 Poll::Pending => Poll::Pending,
926 }
927 }
928 Self::Active(..) => Poll::Ready(Ok(())),
929 }
930 }
931}
932
933impl wrpc_transport::Index<Self> for IndexedParamWriter {
934 #[instrument(level = "trace", skip_all)]
935 fn index(&self, path: &[usize]) -> anyhow::Result<Self> {
936 ensure!(!path.is_empty());
937 match self {
938 Self::Corrupted => Err(anyhow!(corrupted_memory_error())),
939 Self::Handshaking { indexed, .. } => {
940 let (tx_tx, tx_rx) = oneshot::channel();
941 let mut indexed = indexed.lock().map_err(|err| {
942 std::io::Error::new(std::io::ErrorKind::Other, err.to_string())
943 })?;
944 indexed.push((path.to_vec(), tx_tx));
945 Ok(Self::Handshaking {
946 tx_rx,
947 indexed: std::sync::Mutex::default(),
948 })
949 }
950 Self::Active(tx) => tx.index(path).map(Self::Active),
951 }
952 }
953}
954
955impl AsyncWrite for IndexedParamWriter {
956 #[instrument(level = "trace", skip_all, ret, fields(buf = format!("{buf:02x?}")))]
957 fn poll_write(
958 mut self: Pin<&mut Self>,
959 cx: &mut Context<'_>,
960 buf: &[u8],
961 ) -> Poll<std::io::Result<usize>> {
962 match self.as_mut().poll_active(cx)? {
963 Poll::Ready(()) => {
964 let Self::Active(tx) = &mut *self else {
965 return Poll::Ready(Err(corrupted_memory_error()));
966 };
967 trace!("writing buffer");
968 pin!(tx).poll_write(cx, buf)
969 }
970 Poll::Pending => Poll::Pending,
971 }
972 }
973
974 #[instrument(level = "trace", skip_all, ret)]
975 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
976 match self.as_mut().poll_active(cx)? {
977 Poll::Ready(()) => {
978 let Self::Active(tx) = &mut *self else {
979 return Poll::Ready(Err(corrupted_memory_error()));
980 };
981 trace!("flushing");
982 pin!(tx).poll_flush(cx)
983 }
984 Poll::Pending => Poll::Pending,
985 }
986 }
987
988 #[instrument(level = "trace", skip_all, ret)]
989 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
990 match self.as_mut().poll_active(cx)? {
991 Poll::Ready(()) => {
992 let Self::Active(tx) = &mut *self else {
993 return Poll::Ready(Err(corrupted_memory_error()));
994 };
995 trace!("shutting down");
996 pin!(tx).poll_shutdown(cx)
997 }
998 Poll::Pending => Poll::Pending,
999 }
1000 }
1001}
1002
1003pub enum ParamWriter {
1004 Root(RootParamWriter),
1005 Nested(IndexedParamWriter),
1006}
1007
1008impl wrpc_transport::Index<Self> for ParamWriter {
1009 fn index(&self, path: &[usize]) -> anyhow::Result<Self> {
1010 ensure!(!path.is_empty());
1011 match self {
1012 ParamWriter::Root(w) => w.index(path),
1013 ParamWriter::Nested(w) => w.index(path),
1014 }
1015 .map(Self::Nested)
1016 }
1017}
1018
1019impl AsyncWrite for ParamWriter {
1020 #[instrument(level = "trace", skip_all, ret, fields(buf = format!("{buf:02x?}")))]
1021 fn poll_write(
1022 mut self: Pin<&mut Self>,
1023 cx: &mut Context<'_>,
1024 buf: &[u8],
1025 ) -> Poll<std::io::Result<usize>> {
1026 match &mut *self {
1027 ParamWriter::Root(w) => pin!(w).poll_write(cx, buf),
1028 ParamWriter::Nested(w) => pin!(w).poll_write(cx, buf),
1029 }
1030 }
1031
1032 #[instrument(level = "trace", skip_all, ret)]
1033 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1034 match &mut *self {
1035 ParamWriter::Root(w) => pin!(w).poll_flush(cx),
1036 ParamWriter::Nested(w) => pin!(w).poll_flush(cx),
1037 }
1038 }
1039
1040 #[instrument(level = "trace", skip_all, ret)]
1041 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1042 match &mut *self {
1043 ParamWriter::Root(w) => pin!(w).poll_shutdown(cx),
1044 ParamWriter::Nested(w) => pin!(w).poll_shutdown(cx),
1045 }
1046 }
1047}
1048
1049impl wrpc_transport::Invoke for Client {
1050 type Context = Option<HeaderMap>;
1051 type Outgoing = ParamWriter;
1052 type Incoming = Reader;
1053
1054 #[instrument(level = "trace", skip(self, paths, params), fields(params = format!("{params:02x?}")))]
1055 async fn invoke<P: AsRef<[Option<usize>]> + Send + Sync>(
1056 &self,
1057 cx: Self::Context,
1058 instance: &str,
1059 func: &str,
1060 mut params: Bytes,
1061 paths: impl AsRef<[P]> + Send,
1062 ) -> anyhow::Result<(Self::Outgoing, Self::Incoming)> {
1063 let paths = paths.as_ref();
1064 let mut cmds = Vec::with_capacity(paths.len().saturating_add(2));
1065
1066 let rx = Subject::from(new_inbox(&self.inbox));
1067 let (handshake_tx, handshake_rx) = mpsc::channel(1);
1068 cmds.push(Command::Subscribe(rx.clone(), handshake_tx));
1069
1070 let result = Subject::from(result_subject(&rx));
1071 let (result_tx, result_rx) = mpsc::channel(16);
1072 cmds.push(Command::Subscribe(result.clone(), result_tx));
1073
1074 let nested = paths.iter().map(|path| {
1075 let (tx, rx) = mpsc::channel(16);
1076 let subject = Subject::from(subscribe_path(&result, path.as_ref()));
1077 cmds.push(Command::Subscribe(subject.clone(), tx));
1078 Subscriber {
1079 rx: ReceiverStream::new(rx),
1080 commands: self.commands.clone(),
1081 subject,
1082 tasks: Arc::clone(&self.tasks),
1083 }
1084 });
1085 let nested: IndexTrie = zip(paths.iter(), nested).collect();
1086 ensure!(
1087 paths.is_empty() == nested.is_empty(),
1088 "failed to construct subscription tree"
1089 );
1090
1091 self.commands
1092 .send(Command::Batch(cmds.into_boxed_slice()))
1093 .await
1094 .context("failed to subscribe")?;
1095
1096 let ServerInfo {
1097 mut max_payload, ..
1098 } = self.nats.server_info();
1099 max_payload = max_payload.saturating_sub(rx.len());
1100 let param_tx = Subject::from(invocation_subject(&self.prefix, instance, func));
1101 if let Some(headers) = cx {
1102 max_payload = max_payload
1104 .saturating_sub(b"NATS/1.0\r\n".len())
1105 .saturating_sub(b"\r\n".len());
1106 for (k, vs) in headers.iter() {
1107 let k: &[u8] = k.as_ref();
1108 for v in vs {
1109 max_payload = max_payload
1110 .saturating_sub(k.len())
1111 .saturating_sub(b": ".len())
1112 .saturating_sub(v.as_str().len())
1113 .saturating_sub(b"\r\n".len());
1114 }
1115 }
1116 trace!("publishing handshake");
1117 self.nats
1118 .publish_with_reply_and_headers(
1119 param_tx,
1120 rx.clone(),
1121 headers,
1122 params.split_to(max_payload.min(params.len())),
1123 )
1124 .await
1125 } else {
1126 trace!("publishing handshake");
1127 self.nats
1128 .publish_with_reply(
1129 param_tx,
1130 rx.clone(),
1131 params.split_to(max_payload.min(params.len())),
1132 )
1133 .await
1134 }
1135 .context("failed to publish handshake")?;
1136 let nats = Arc::clone(&self.nats);
1137 tokio::spawn(async move {
1138 if let Err(err) = nats.flush().await {
1139 error!(?err, "failed to flush");
1140 }
1141 });
1142 Ok((
1143 ParamWriter::Root(RootParamWriter::new(
1144 (*self.nats).clone(),
1145 Subscriber {
1146 rx: ReceiverStream::new(handshake_rx),
1147 commands: self.commands.clone(),
1148 subject: rx,
1149 tasks: Arc::clone(&self.tasks),
1150 },
1151 params,
1152 Arc::clone(&self.tasks),
1153 )),
1154 Reader {
1155 buffer: Bytes::default(),
1156 incoming: Some(Subscriber {
1157 rx: ReceiverStream::new(result_rx),
1158 commands: self.commands.clone(),
1159 subject: result,
1160 tasks: Arc::clone(&self.tasks),
1161 }),
1162 nested: Arc::new(std::sync::Mutex::new(nested)),
1163 path: Box::default(),
1164 },
1165 ))
1166 }
1167}
1168
1169async fn handle_message(
1170 nats: &async_nats::Client,
1171 rx: Subject,
1172 commands: mpsc::Sender<Command>,
1173 async_nats::Message {
1174 reply: tx,
1175 payload,
1176 headers,
1177 ..
1178 }: async_nats::Message,
1179 paths: &[Box<[Option<usize>]>],
1180 tasks: Arc<JoinSet<()>>,
1181) -> anyhow::Result<(Option<HeaderMap>, SubjectWriter, Reader)> {
1182 let tx = tx.context("peer did not specify a reply subject")?;
1183
1184 let mut cmds = Vec::with_capacity(paths.len().saturating_add(1));
1185
1186 let param = Subject::from(param_subject(&rx));
1187 let (param_tx, param_rx) = mpsc::channel(16);
1188 cmds.push(Command::Subscribe(param.clone(), param_tx));
1189
1190 let nested = paths.iter().map(|path| {
1191 let (tx, rx) = mpsc::channel(16);
1192 let subject = Subject::from(subscribe_path(¶m, path.as_ref()));
1193 cmds.push(Command::Subscribe(subject.clone(), tx));
1194 Subscriber {
1195 rx: ReceiverStream::new(rx),
1196 commands: commands.clone(),
1197 subject,
1198 tasks: Arc::clone(&tasks),
1199 }
1200 });
1201 let nested: IndexTrie = zip(paths.iter(), nested).collect();
1202 ensure!(
1203 paths.is_empty() == nested.is_empty(),
1204 "failed to construct subscription tree"
1205 );
1206
1207 commands
1208 .send(Command::Batch(cmds.into_boxed_slice()))
1209 .await
1210 .context("failed to subscribe")?;
1211
1212 trace!("publishing handshake response");
1213 nats.publish_with_reply(tx.clone(), rx, Bytes::default())
1214 .await
1215 .context("failed to publish handshake accept")?;
1216 Ok((
1217 headers,
1218 SubjectWriter::new(
1219 nats.clone(),
1220 Subject::from(result_subject(&tx)),
1221 Arc::clone(&tasks),
1222 ),
1223 Reader {
1224 buffer: payload,
1225 incoming: Some(Subscriber {
1226 rx: ReceiverStream::new(param_rx),
1227 commands,
1228 subject: param,
1229 tasks,
1230 }),
1231 nested: Arc::new(std::sync::Mutex::new(nested)),
1232 path: Box::default(),
1233 },
1234 ))
1235}
1236
1237impl wrpc_transport::Serve for Client {
1238 type Context = Option<HeaderMap>;
1239 type Outgoing = SubjectWriter;
1240 type Incoming = Reader;
1241
1242 #[instrument(level = "trace", skip(self, paths))]
1243 async fn serve(
1244 &self,
1245 instance: &str,
1246 func: &str,
1247 paths: impl Into<Arc<[Box<[Option<usize>]>]>> + Send,
1248 ) -> anyhow::Result<
1249 impl Stream<Item = anyhow::Result<(Self::Context, Self::Outgoing, Self::Incoming)>> + 'static,
1250 > {
1251 let subject = invocation_subject(&self.prefix, instance, func);
1252 let sub = if let Some(group) = &self.queue_group {
1253 debug!(subject, ?group, "queue-subscribing on invocation subject");
1254 self.nats
1255 .queue_subscribe(subject, group.to_string())
1256 .await?
1257 } else {
1258 debug!(subject, "subscribing on invocation subject");
1259 self.nats.subscribe(subject).await?
1260 };
1261 let nats = Arc::clone(&self.nats);
1262 let paths = paths.into();
1263 let commands = self.commands.clone();
1264 let inbox = Arc::clone(&self.inbox);
1265 let tasks = Arc::clone(&self.tasks);
1266 Ok(sub.then(move |msg| {
1267 let tasks = Arc::clone(&tasks);
1268 let nats = Arc::clone(&nats);
1269 let paths = Arc::clone(&paths);
1270 let commands = commands.clone();
1271 let rx = Subject::from(new_inbox(&inbox));
1272 async move { handle_message(&nats, rx, commands, msg, &paths, tasks).await }
1273 }))
1274 }
1275}