1use prost::Message;
2use std::marker::PhantomData;
3use tonic::codec::{BufferSettings, Codec, DecodeBuf, Decoder, EncodeBuf, Encoder};
4use tonic::Status;
5
6#[derive(Debug, Clone)]
8pub struct ProstCodec<T, U> {
9 _pd: PhantomData<(T, U)>,
10}
11
12impl<T, U> ProstCodec<T, U> {
13 pub fn new() -> Self {
16 Self { _pd: PhantomData }
17 }
18}
19
20impl<T, U> Default for ProstCodec<T, U> {
21 fn default() -> Self {
22 Self::new()
23 }
24}
25
26impl<T, U> ProstCodec<T, U>
27where
28 T: Message + Send + 'static,
29 U: Message + Default + Send + 'static,
30{
31 pub fn raw_encoder(buffer_settings: BufferSettings) -> <Self as Codec>::Encoder {
34 ProstEncoder {
35 _pd: PhantomData,
36 buffer_settings,
37 }
38 }
39
40 pub fn raw_decoder(buffer_settings: BufferSettings) -> <Self as Codec>::Decoder {
43 ProstDecoder {
44 _pd: PhantomData,
45 buffer_settings,
46 }
47 }
48}
49
50impl<T, U> Codec for ProstCodec<T, U>
51where
52 T: Message + Send + 'static,
53 U: Message + Default + Send + 'static,
54{
55 type Encode = T;
56 type Decode = U;
57
58 type Encoder = ProstEncoder<T>;
59 type Decoder = ProstDecoder<U>;
60
61 fn encoder(&mut self) -> Self::Encoder {
62 ProstEncoder {
63 _pd: PhantomData,
64 buffer_settings: BufferSettings::default(),
65 }
66 }
67
68 fn decoder(&mut self) -> Self::Decoder {
69 ProstDecoder {
70 _pd: PhantomData,
71 buffer_settings: BufferSettings::default(),
72 }
73 }
74}
75
76#[derive(Debug, Clone, Default)]
78pub struct ProstEncoder<T> {
79 _pd: PhantomData<T>,
80 buffer_settings: BufferSettings,
81}
82
83impl<T> ProstEncoder<T> {
84 pub fn new(buffer_settings: BufferSettings) -> Self {
86 Self {
87 _pd: PhantomData,
88 buffer_settings,
89 }
90 }
91}
92
93impl<T: Message> Encoder for ProstEncoder<T> {
94 type Item = T;
95 type Error = Status;
96
97 fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
98 item.encode(buf)
99 .expect("Message only errors if not enough space");
100
101 Ok(())
102 }
103
104 fn buffer_settings(&self) -> BufferSettings {
105 self.buffer_settings
106 }
107}
108
109#[derive(Debug, Clone, Default)]
111pub struct ProstDecoder<U> {
112 _pd: PhantomData<U>,
113 buffer_settings: BufferSettings,
114}
115
116impl<U> ProstDecoder<U> {
117 pub fn new(buffer_settings: BufferSettings) -> Self {
119 Self {
120 _pd: PhantomData,
121 buffer_settings,
122 }
123 }
124}
125
126impl<U: Message + Default> Decoder for ProstDecoder<U> {
127 type Item = U;
128 type Error = Status;
129
130 fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
131 let item = Message::decode(buf)
132 .map(Option::Some)
133 .map_err(from_decode_error)?;
134
135 Ok(item)
136 }
137
138 fn buffer_settings(&self) -> BufferSettings {
139 self.buffer_settings
140 }
141}
142
143fn from_decode_error(error: prost::DecodeError) -> Status {
144 Status::internal(error.to_string())
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152 use bytes::{Buf, BufMut, BytesMut};
153 use http_body::Body;
154 use http_body_util::BodyExt as _;
155 use std::pin::pin;
156 use tonic::codec::SingleMessageCompressionOverride;
157 use tonic::codec::{EncodeBody, Streaming, HEADER_SIZE};
158
159 const LEN: usize = 10000;
160 const MAX_MESSAGE_SIZE: usize = 2 * 1024 * 1024;
162
163 #[tokio::test]
164 async fn decode() {
165 let decoder = MockDecoder::default();
166
167 let msg = vec![0u8; LEN];
168
169 let mut buf = BytesMut::new();
170
171 buf.reserve(msg.len() + HEADER_SIZE);
172 buf.put_u8(0);
173 buf.put_u32(msg.len() as u32);
174
175 buf.put(&msg[..]);
176
177 let body = body::MockBody::new(&buf[..], 10005, 0);
178
179 let mut stream = Streaming::new_request(decoder, body, None, None);
180
181 let mut i = 0usize;
182 while let Some(output_msg) = stream.message().await.unwrap() {
183 assert_eq!(output_msg.len(), msg.len());
184 i += 1;
185 }
186 assert_eq!(i, 1);
187 }
188
189 #[tokio::test]
190 async fn decode_max_message_size_exceeded() {
191 let decoder = MockDecoder::default();
192
193 let msg = vec![0u8; MAX_MESSAGE_SIZE + 1];
194
195 let mut buf = BytesMut::new();
196
197 buf.reserve(msg.len() + HEADER_SIZE);
198 buf.put_u8(0);
199 buf.put_u32(msg.len() as u32);
200
201 buf.put(&msg[..]);
202
203 let body = body::MockBody::new(&buf[..], MAX_MESSAGE_SIZE + HEADER_SIZE + 1, 0);
204
205 let mut stream = Streaming::new_request(decoder, body, None, Some(MAX_MESSAGE_SIZE));
206
207 let actual = stream.message().await.unwrap_err();
208
209 let expected = Status::out_of_range(format!(
210 "Error, decoded message length too large: found {} bytes, the limit is: {} bytes",
211 msg.len(),
212 MAX_MESSAGE_SIZE
213 ));
214
215 assert_eq!(actual.code(), expected.code());
216 assert_eq!(actual.message(), expected.message());
217 }
218
219 #[tokio::test]
220 async fn encode() {
221 let encoder = MockEncoder::default();
222
223 let msg = Vec::from(&[0u8; 1024][..]);
224
225 let messages = std::iter::repeat_with(move || Ok::<_, Status>(msg.clone())).take(10000);
226 let source = tokio_stream::iter(messages);
227
228 let mut body = pin!(EncodeBody::new_server(
229 encoder,
230 source,
231 None,
232 SingleMessageCompressionOverride::default(),
233 None,
234 ));
235
236 while let Some(r) = body.frame().await {
237 r.unwrap();
238 }
239 }
240
241 #[tokio::test]
242 async fn encode_max_message_size_exceeded() {
243 let encoder = MockEncoder::default();
244
245 let msg = vec![0u8; MAX_MESSAGE_SIZE + 1];
246
247 let messages = std::iter::once(Ok::<_, Status>(msg));
248 let source = tokio_stream::iter(messages);
249
250 let mut body = pin!(EncodeBody::new_server(
251 encoder,
252 source,
253 None,
254 SingleMessageCompressionOverride::default(),
255 Some(MAX_MESSAGE_SIZE),
256 ));
257
258 let frame = body
259 .frame()
260 .await
261 .expect("at least one frame")
262 .expect("no error polling frame");
263 assert_eq!(
264 frame
265 .into_trailers()
266 .expect("got trailers")
267 .get(Status::GRPC_STATUS)
268 .expect("grpc-status header"),
269 "11"
270 );
271 assert!(body.is_end_stream());
272 }
273
274 #[cfg(not(target_family = "windows"))]
276 #[tokio::test]
277 async fn encode_too_big() {
278 let encoder = MockEncoder::default();
279
280 let msg = vec![0u8; u32::MAX as usize + 1];
281
282 let messages = std::iter::once(Ok::<_, Status>(msg));
283 let source = tokio_stream::iter(messages);
284
285 let mut body = pin!(EncodeBody::new_server(
286 encoder,
287 source,
288 None,
289 SingleMessageCompressionOverride::default(),
290 Some(usize::MAX),
291 ));
292
293 let frame = body
294 .frame()
295 .await
296 .expect("at least one frame")
297 .expect("no error polling frame");
298 assert_eq!(
299 frame
300 .into_trailers()
301 .expect("got trailers")
302 .get(Status::GRPC_STATUS)
303 .expect("grpc-status header"),
304 "8"
305 );
306 assert!(body.is_end_stream());
307 }
308
309 #[derive(Debug, Clone, Default)]
310 struct MockEncoder {}
311
312 impl Encoder for MockEncoder {
313 type Item = Vec<u8>;
314 type Error = Status;
315
316 fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
317 buf.put(&item[..]);
318 Ok(())
319 }
320
321 fn buffer_settings(&self) -> BufferSettings {
322 Default::default()
323 }
324 }
325
326 #[derive(Debug, Clone, Default)]
327 struct MockDecoder {}
328
329 impl Decoder for MockDecoder {
330 type Item = Vec<u8>;
331 type Error = Status;
332
333 fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
334 let out = Vec::from(buf.chunk());
335 buf.advance(LEN);
336 Ok(Some(out))
337 }
338
339 fn buffer_settings(&self) -> BufferSettings {
340 Default::default()
341 }
342 }
343
344 mod body {
345 use bytes::Bytes;
346 use http_body::{Body, Frame};
347 use std::{
348 pin::Pin,
349 task::{Context, Poll},
350 };
351 use tonic::Status;
352
353 #[derive(Debug)]
354 pub(super) struct MockBody {
355 data: Bytes,
356
357 partial_len: usize,
359
360 count: usize,
362 }
363
364 impl MockBody {
365 pub(super) fn new(b: &[u8], partial_len: usize, count: usize) -> Self {
366 MockBody {
367 data: Bytes::copy_from_slice(b),
368 partial_len,
369 count,
370 }
371 }
372 }
373
374 impl Body for MockBody {
375 type Data = Bytes;
376 type Error = Status;
377
378 fn poll_frame(
379 mut self: Pin<&mut Self>,
380 cx: &mut Context<'_>,
381 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
382 let should_send = self.count % 2 == 0;
384 let data_len = self.data.len();
385 let partial_len = self.partial_len;
386 let count = self.count;
387 if data_len > 0 {
388 let result = if should_send {
389 let response =
390 self.data
391 .split_to(if count == 0 { partial_len } else { data_len });
392 Poll::Ready(Some(Ok(Frame::data(response))))
393 } else {
394 cx.waker().wake_by_ref();
395 Poll::Pending
396 };
397 self.count += 1;
399 result
400 } else {
401 Poll::Ready(None)
402 }
403 }
404 }
405 }
406}