tonic_prost/
codec.rs

1use prost::Message;
2use std::marker::PhantomData;
3use tonic::codec::{BufferSettings, Codec, DecodeBuf, Decoder, EncodeBuf, Encoder};
4use tonic::Status;
5
6/// A [`Codec`] that implements `application/grpc+proto` via the prost library.
7#[derive(Debug, Clone)]
8pub struct ProstCodec<T, U> {
9    _pd: PhantomData<(T, U)>,
10}
11
12impl<T, U> ProstCodec<T, U> {
13    /// Configure a ProstCodec with encoder/decoder buffer settings. This is used to control
14    /// how memory is allocated and grows per RPC.
15    pub fn new() -> Self {
16        Self { _pd: PhantomData }
17    }
18}
19
20impl<T, U> Default for ProstCodec<T, U> {
21    fn default() -> Self {
22        Self::new()
23    }
24}
25
26impl<T, U> ProstCodec<T, U>
27where
28    T: Message + Send + 'static,
29    U: Message + Default + Send + 'static,
30{
31    /// A tool for building custom codecs based on prost encoding and decoding.
32    /// See the codec_buffers example for one possible way to use this.
33    pub fn raw_encoder(buffer_settings: BufferSettings) -> <Self as Codec>::Encoder {
34        ProstEncoder {
35            _pd: PhantomData,
36            buffer_settings,
37        }
38    }
39
40    /// A tool for building custom codecs based on prost encoding and decoding.
41    /// See the codec_buffers example for one possible way to use this.
42    pub fn raw_decoder(buffer_settings: BufferSettings) -> <Self as Codec>::Decoder {
43        ProstDecoder {
44            _pd: PhantomData,
45            buffer_settings,
46        }
47    }
48}
49
50impl<T, U> Codec for ProstCodec<T, U>
51where
52    T: Message + Send + 'static,
53    U: Message + Default + Send + 'static,
54{
55    type Encode = T;
56    type Decode = U;
57
58    type Encoder = ProstEncoder<T>;
59    type Decoder = ProstDecoder<U>;
60
61    fn encoder(&mut self) -> Self::Encoder {
62        ProstEncoder {
63            _pd: PhantomData,
64            buffer_settings: BufferSettings::default(),
65        }
66    }
67
68    fn decoder(&mut self) -> Self::Decoder {
69        ProstDecoder {
70            _pd: PhantomData,
71            buffer_settings: BufferSettings::default(),
72        }
73    }
74}
75
76/// A [`Encoder`] that knows how to encode `T`.
77#[derive(Debug, Clone, Default)]
78pub struct ProstEncoder<T> {
79    _pd: PhantomData<T>,
80    buffer_settings: BufferSettings,
81}
82
83impl<T> ProstEncoder<T> {
84    /// Get a new encoder with explicit buffer settings
85    pub fn new(buffer_settings: BufferSettings) -> Self {
86        Self {
87            _pd: PhantomData,
88            buffer_settings,
89        }
90    }
91}
92
93impl<T: Message> Encoder for ProstEncoder<T> {
94    type Item = T;
95    type Error = Status;
96
97    fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
98        item.encode(buf)
99            .expect("Message only errors if not enough space");
100
101        Ok(())
102    }
103
104    fn buffer_settings(&self) -> BufferSettings {
105        self.buffer_settings
106    }
107}
108
109/// A [`Decoder`] that knows how to decode `U`.
110#[derive(Debug, Clone, Default)]
111pub struct ProstDecoder<U> {
112    _pd: PhantomData<U>,
113    buffer_settings: BufferSettings,
114}
115
116impl<U> ProstDecoder<U> {
117    /// Get a new decoder with explicit buffer settings
118    pub fn new(buffer_settings: BufferSettings) -> Self {
119        Self {
120            _pd: PhantomData,
121            buffer_settings,
122        }
123    }
124}
125
126impl<U: Message + Default> Decoder for ProstDecoder<U> {
127    type Item = U;
128    type Error = Status;
129
130    fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
131        let item = Message::decode(buf)
132            .map(Option::Some)
133            .map_err(from_decode_error)?;
134
135        Ok(item)
136    }
137
138    fn buffer_settings(&self) -> BufferSettings {
139        self.buffer_settings
140    }
141}
142
143fn from_decode_error(error: prost::DecodeError) -> Status {
144    // Map Protobuf parse errors to an INTERNAL status code, as per
145    // https://github.com/grpc/grpc/blob/master/doc/statuscodes.md
146    Status::internal(error.to_string())
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152    use bytes::{Buf, BufMut, BytesMut};
153    use http_body::Body;
154    use http_body_util::BodyExt as _;
155    use std::pin::pin;
156    use tonic::codec::SingleMessageCompressionOverride;
157    use tonic::codec::{EncodeBody, Streaming, HEADER_SIZE};
158
159    const LEN: usize = 10000;
160    // The maximum uncompressed size in bytes for a message. Set to 2MB.
161    const MAX_MESSAGE_SIZE: usize = 2 * 1024 * 1024;
162
163    #[tokio::test]
164    async fn decode() {
165        let decoder = MockDecoder::default();
166
167        let msg = vec![0u8; LEN];
168
169        let mut buf = BytesMut::new();
170
171        buf.reserve(msg.len() + HEADER_SIZE);
172        buf.put_u8(0);
173        buf.put_u32(msg.len() as u32);
174
175        buf.put(&msg[..]);
176
177        let body = body::MockBody::new(&buf[..], 10005, 0);
178
179        let mut stream = Streaming::new_request(decoder, body, None, None);
180
181        let mut i = 0usize;
182        while let Some(output_msg) = stream.message().await.unwrap() {
183            assert_eq!(output_msg.len(), msg.len());
184            i += 1;
185        }
186        assert_eq!(i, 1);
187    }
188
189    #[tokio::test]
190    async fn decode_max_message_size_exceeded() {
191        let decoder = MockDecoder::default();
192
193        let msg = vec![0u8; MAX_MESSAGE_SIZE + 1];
194
195        let mut buf = BytesMut::new();
196
197        buf.reserve(msg.len() + HEADER_SIZE);
198        buf.put_u8(0);
199        buf.put_u32(msg.len() as u32);
200
201        buf.put(&msg[..]);
202
203        let body = body::MockBody::new(&buf[..], MAX_MESSAGE_SIZE + HEADER_SIZE + 1, 0);
204
205        let mut stream = Streaming::new_request(decoder, body, None, Some(MAX_MESSAGE_SIZE));
206
207        let actual = stream.message().await.unwrap_err();
208
209        let expected = Status::out_of_range(format!(
210            "Error, decoded message length too large: found {} bytes, the limit is: {} bytes",
211            msg.len(),
212            MAX_MESSAGE_SIZE
213        ));
214
215        assert_eq!(actual.code(), expected.code());
216        assert_eq!(actual.message(), expected.message());
217    }
218
219    #[tokio::test]
220    async fn encode() {
221        let encoder = MockEncoder::default();
222
223        let msg = Vec::from(&[0u8; 1024][..]);
224
225        let messages = std::iter::repeat_with(move || Ok::<_, Status>(msg.clone())).take(10000);
226        let source = tokio_stream::iter(messages);
227
228        let mut body = pin!(EncodeBody::new_server(
229            encoder,
230            source,
231            None,
232            SingleMessageCompressionOverride::default(),
233            None,
234        ));
235
236        while let Some(r) = body.frame().await {
237            r.unwrap();
238        }
239    }
240
241    #[tokio::test]
242    async fn encode_max_message_size_exceeded() {
243        let encoder = MockEncoder::default();
244
245        let msg = vec![0u8; MAX_MESSAGE_SIZE + 1];
246
247        let messages = std::iter::once(Ok::<_, Status>(msg));
248        let source = tokio_stream::iter(messages);
249
250        let mut body = pin!(EncodeBody::new_server(
251            encoder,
252            source,
253            None,
254            SingleMessageCompressionOverride::default(),
255            Some(MAX_MESSAGE_SIZE),
256        ));
257
258        let frame = body
259            .frame()
260            .await
261            .expect("at least one frame")
262            .expect("no error polling frame");
263        assert_eq!(
264            frame
265                .into_trailers()
266                .expect("got trailers")
267                .get(Status::GRPC_STATUS)
268                .expect("grpc-status header"),
269            "11"
270        );
271        assert!(body.is_end_stream());
272    }
273
274    // skip on windows because CI stumbles over our 4GB allocation
275    #[cfg(not(target_family = "windows"))]
276    #[tokio::test]
277    async fn encode_too_big() {
278        let encoder = MockEncoder::default();
279
280        let msg = vec![0u8; u32::MAX as usize + 1];
281
282        let messages = std::iter::once(Ok::<_, Status>(msg));
283        let source = tokio_stream::iter(messages);
284
285        let mut body = pin!(EncodeBody::new_server(
286            encoder,
287            source,
288            None,
289            SingleMessageCompressionOverride::default(),
290            Some(usize::MAX),
291        ));
292
293        let frame = body
294            .frame()
295            .await
296            .expect("at least one frame")
297            .expect("no error polling frame");
298        assert_eq!(
299            frame
300                .into_trailers()
301                .expect("got trailers")
302                .get(Status::GRPC_STATUS)
303                .expect("grpc-status header"),
304            "8"
305        );
306        assert!(body.is_end_stream());
307    }
308
309    #[derive(Debug, Clone, Default)]
310    struct MockEncoder {}
311
312    impl Encoder for MockEncoder {
313        type Item = Vec<u8>;
314        type Error = Status;
315
316        fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
317            buf.put(&item[..]);
318            Ok(())
319        }
320
321        fn buffer_settings(&self) -> BufferSettings {
322            Default::default()
323        }
324    }
325
326    #[derive(Debug, Clone, Default)]
327    struct MockDecoder {}
328
329    impl Decoder for MockDecoder {
330        type Item = Vec<u8>;
331        type Error = Status;
332
333        fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
334            let out = Vec::from(buf.chunk());
335            buf.advance(LEN);
336            Ok(Some(out))
337        }
338
339        fn buffer_settings(&self) -> BufferSettings {
340            Default::default()
341        }
342    }
343
344    mod body {
345        use bytes::Bytes;
346        use http_body::{Body, Frame};
347        use std::{
348            pin::Pin,
349            task::{Context, Poll},
350        };
351        use tonic::Status;
352
353        #[derive(Debug)]
354        pub(super) struct MockBody {
355            data: Bytes,
356
357            // the size of the partial message to send
358            partial_len: usize,
359
360            // the number of times we've sent
361            count: usize,
362        }
363
364        impl MockBody {
365            pub(super) fn new(b: &[u8], partial_len: usize, count: usize) -> Self {
366                MockBody {
367                    data: Bytes::copy_from_slice(b),
368                    partial_len,
369                    count,
370                }
371            }
372        }
373
374        impl Body for MockBody {
375            type Data = Bytes;
376            type Error = Status;
377
378            fn poll_frame(
379                mut self: Pin<&mut Self>,
380                cx: &mut Context<'_>,
381            ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
382                // every other call to poll_data returns data
383                let should_send = self.count % 2 == 0;
384                let data_len = self.data.len();
385                let partial_len = self.partial_len;
386                let count = self.count;
387                if data_len > 0 {
388                    let result = if should_send {
389                        let response =
390                            self.data
391                                .split_to(if count == 0 { partial_len } else { data_len });
392                        Poll::Ready(Some(Ok(Frame::data(response))))
393                    } else {
394                        cx.waker().wake_by_ref();
395                        Poll::Pending
396                    };
397                    // make some fake progress
398                    self.count += 1;
399                    result
400                } else {
401                    Poll::Ready(None)
402                }
403            }
404        }
405    }
406}