aws_sdk_s3/
http_request_checksum.rs1#![allow(dead_code)]
8
9use aws_runtime::auth::PayloadSigningOverride;
12use aws_runtime::content_encoding::header_value::AWS_CHUNKED;
13use aws_runtime::content_encoding::{AwsChunkedBody, AwsChunkedBodyOptions};
14use aws_smithy_checksums::ChecksumAlgorithm;
15use aws_smithy_checksums::{body::calculate, http::HttpChecksum};
16use aws_smithy_runtime_api::box_error::BoxError;
17use aws_smithy_runtime_api::client::interceptors::context::{BeforeSerializationInterceptorContextRef, BeforeTransmitInterceptorContextMut, Input};
18use aws_smithy_runtime_api::client::interceptors::Intercept;
19use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
20use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
21use aws_smithy_types::body::SdkBody;
22use aws_smithy_types::config_bag::{ConfigBag, Layer, Storable, StoreReplace};
23use aws_smithy_types::error::operation::BuildError;
24use http::HeaderValue;
25use http_body::Body;
26use std::{fmt, mem};
27
28#[derive(Debug)]
30pub(crate) enum Error {
31 UnsizedRequestBody,
33 ChecksumHeadersAreUnsupportedForStreamingBody,
34}
35
36impl fmt::Display for Error {
37 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38 match self {
39 Self::UnsizedRequestBody => write!(f, "Only request bodies with a known size can be checksum validated."),
40 Self::ChecksumHeadersAreUnsupportedForStreamingBody => write!(
41 f,
42 "Checksum header insertion is only supported for non-streaming HTTP bodies. \
43 To checksum validate a streaming body, the checksums must be sent as trailers."
44 ),
45 }
46 }
47}
48
49impl std::error::Error for Error {}
50
51#[derive(Debug)]
52struct RequestChecksumInterceptorState {
53 checksum_algorithm: Option<ChecksumAlgorithm>,
54}
55impl Storable for RequestChecksumInterceptorState {
56 type Storer = StoreReplace<Self>;
57}
58
59type CustomDefaultFn = Box<dyn Fn(Option<ChecksumAlgorithm>, &ConfigBag) -> Option<ChecksumAlgorithm> + Send + Sync + 'static>;
60
61pub(crate) struct DefaultRequestChecksumOverride {
62 custom_default: CustomDefaultFn,
63}
64impl fmt::Debug for DefaultRequestChecksumOverride {
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 f.debug_struct("DefaultRequestChecksumOverride").finish()
67 }
68}
69impl Storable for DefaultRequestChecksumOverride {
70 type Storer = StoreReplace<Self>;
71}
72impl DefaultRequestChecksumOverride {
73 pub(crate) fn new<F>(custom_default: F) -> Self
74 where
75 F: Fn(Option<ChecksumAlgorithm>, &ConfigBag) -> Option<ChecksumAlgorithm> + Send + Sync + 'static,
76 {
77 Self {
78 custom_default: Box::new(custom_default),
79 }
80 }
81 pub(crate) fn custom_default(&self, original: Option<ChecksumAlgorithm>, config_bag: &ConfigBag) -> Option<ChecksumAlgorithm> {
82 (self.custom_default)(original, config_bag)
83 }
84}
85
86pub(crate) struct RequestChecksumInterceptor<AP> {
87 algorithm_provider: AP,
88}
89
90impl<AP> fmt::Debug for RequestChecksumInterceptor<AP> {
91 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92 f.debug_struct("RequestChecksumInterceptor").finish()
93 }
94}
95
96impl<AP> RequestChecksumInterceptor<AP> {
97 pub(crate) fn new(algorithm_provider: AP) -> Self {
98 Self { algorithm_provider }
99 }
100}
101
102impl<AP> Intercept for RequestChecksumInterceptor<AP>
103where
104 AP: Fn(&Input) -> Result<Option<ChecksumAlgorithm>, BoxError> + Send + Sync,
105{
106 fn name(&self) -> &'static str {
107 "RequestChecksumInterceptor"
108 }
109
110 fn read_before_serialization(
111 &self,
112 context: &BeforeSerializationInterceptorContextRef<'_>,
113 _runtime_components: &RuntimeComponents,
114 cfg: &mut ConfigBag,
115 ) -> Result<(), BoxError> {
116 let checksum_algorithm = (self.algorithm_provider)(context.input())?;
117
118 let mut layer = Layer::new("RequestChecksumInterceptor");
119 layer.store_put(RequestChecksumInterceptorState { checksum_algorithm });
120 cfg.push_layer(layer);
121
122 Ok(())
123 }
124
125 fn modify_before_signing(
129 &self,
130 context: &mut BeforeTransmitInterceptorContextMut<'_>,
131 _runtime_components: &RuntimeComponents,
132 cfg: &mut ConfigBag,
133 ) -> Result<(), BoxError> {
134 let state = cfg.load::<RequestChecksumInterceptorState>().expect("set in `read_before_serialization`");
135
136 let checksum_algorithm = incorporate_custom_default(state.checksum_algorithm, cfg);
137 if let Some(checksum_algorithm) = checksum_algorithm {
138 let request = context.request_mut();
139 add_checksum_for_request_body(request, checksum_algorithm, cfg)?;
140 }
141
142 Ok(())
143 }
144}
145
146fn incorporate_custom_default(checksum: Option<ChecksumAlgorithm>, cfg: &ConfigBag) -> Option<ChecksumAlgorithm> {
147 match cfg.load::<DefaultRequestChecksumOverride>() {
148 Some(checksum_override) => checksum_override.custom_default(checksum, cfg),
149 None => checksum,
150 }
151}
152
153fn add_checksum_for_request_body(request: &mut HttpRequest, checksum_algorithm: ChecksumAlgorithm, cfg: &mut ConfigBag) -> Result<(), BoxError> {
154 match request.body().bytes() {
155 Some(data) => {
157 tracing::debug!("applying {checksum_algorithm:?} of the request body as a header");
158 let mut checksum = checksum_algorithm.into_impl();
159 checksum.update(data);
160
161 request.headers_mut().insert(checksum.header_name(), checksum.header_value());
162 }
163 None => {
165 tracing::debug!("applying {checksum_algorithm:?} of the request body as a trailer");
166 cfg.interceptor_state().store_put(PayloadSigningOverride::StreamingUnsignedPayloadTrailer);
167 wrap_streaming_request_body_in_checksum_calculating_body(request, checksum_algorithm)?;
168 }
169 }
170 Ok(())
171}
172
173fn wrap_streaming_request_body_in_checksum_calculating_body(
174 request: &mut HttpRequest,
175 checksum_algorithm: ChecksumAlgorithm,
176) -> Result<(), BuildError> {
177 let original_body_size = request
178 .body()
179 .size_hint()
180 .exact()
181 .ok_or_else(|| BuildError::other(Error::UnsizedRequestBody))?;
182
183 let mut body = {
184 let body = mem::replace(request.body_mut(), SdkBody::taken());
185
186 body.map(move |body| {
187 let checksum = checksum_algorithm.into_impl();
188 let trailer_len = HttpChecksum::size(checksum.as_ref());
189 let body = calculate::ChecksumBody::new(body, checksum);
190 let aws_chunked_body_options = AwsChunkedBodyOptions::new(original_body_size, vec![trailer_len]);
191
192 let body = AwsChunkedBody::new(body, aws_chunked_body_options);
193
194 SdkBody::from_body_0_4(body)
195 })
196 };
197
198 let encoded_content_length = body.size_hint().exact().ok_or_else(|| BuildError::other(Error::UnsizedRequestBody))?;
199
200 let headers = request.headers_mut();
201
202 headers.insert(
203 http::header::HeaderName::from_static("x-amz-trailer"),
204 checksum_algorithm.into_impl().header_name(),
205 );
206
207 headers.insert(http::header::CONTENT_LENGTH, HeaderValue::from(encoded_content_length));
208 headers.insert(
209 http::header::HeaderName::from_static("x-amz-decoded-content-length"),
210 HeaderValue::from(original_body_size),
211 );
212 headers.insert(
213 http::header::CONTENT_ENCODING,
214 HeaderValue::from_str(AWS_CHUNKED)
215 .map_err(BuildError::other)
216 .expect("\"aws-chunked\" will always be a valid HeaderValue"),
217 );
218
219 mem::swap(request.body_mut(), &mut body);
220
221 Ok(())
222}
223
224#[cfg(test)]
225mod tests {
226 use crate::http_request_checksum::wrap_streaming_request_body_in_checksum_calculating_body;
227 use aws_smithy_checksums::ChecksumAlgorithm;
228 use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
229 use aws_smithy_types::base64;
230 use aws_smithy_types::body::SdkBody;
231 use aws_smithy_types::byte_stream::ByteStream;
232 use bytes::BytesMut;
233 use http_body::Body;
234 use tempfile::NamedTempFile;
235
236 #[tokio::test]
237 async fn test_checksum_body_is_retryable() {
238 let input_text = "Hello world";
239 let chunk_len_hex = format!("{:X}", input_text.len());
240 let mut request: HttpRequest = http::Request::builder()
241 .body(SdkBody::retryable(move || SdkBody::from(input_text)))
242 .unwrap()
243 .try_into()
244 .unwrap();
245
246 assert!(request.body().try_clone().is_some());
248
249 let checksum_algorithm: ChecksumAlgorithm = "crc32".parse().unwrap();
250 wrap_streaming_request_body_in_checksum_calculating_body(&mut request, checksum_algorithm).unwrap();
251
252 let mut body = request.body().try_clone().expect("body is retryable");
254
255 let mut body_data = BytesMut::new();
256 while let Some(data) = body.data().await {
257 body_data.extend_from_slice(&data.unwrap())
258 }
259 let body = std::str::from_utf8(&body_data).unwrap();
260 assert_eq!(
261 format!("{chunk_len_hex}\r\n{input_text}\r\n0\r\nx-amz-checksum-crc32:i9aeUg==\r\n\r\n"),
262 body
263 );
264 }
265
266 #[tokio::test]
267 async fn test_checksum_body_from_file_is_retryable() {
268 use std::io::Write;
269 let mut file = NamedTempFile::new().unwrap();
270 let checksum_algorithm: ChecksumAlgorithm = "crc32c".parse().unwrap();
271
272 let mut crc32c_checksum = checksum_algorithm.into_impl();
273 for i in 0..10000 {
274 let line = format!("This is a large file created for testing purposes {}", i);
275 file.as_file_mut().write_all(line.as_bytes()).unwrap();
276 crc32c_checksum.update(line.as_bytes());
277 }
278 let crc32c_checksum = crc32c_checksum.finalize();
279
280 let mut request = HttpRequest::new(ByteStream::read_from().path(&file).buffer_size(1024).build().await.unwrap().into_inner());
281
282 assert!(request.body().try_clone().is_some());
284
285 wrap_streaming_request_body_in_checksum_calculating_body(&mut request, checksum_algorithm).unwrap();
286
287 let mut body = request.body().try_clone().expect("body is retryable");
289
290 let mut body_data = BytesMut::new();
291 while let Some(data) = body.data().await {
292 body_data.extend_from_slice(&data.unwrap())
293 }
294 let body = std::str::from_utf8(&body_data).unwrap();
295 let expected_checksum = base64::encode(&crc32c_checksum);
296 let expected = format!("This is a large file created for testing purposes 9999\r\n0\r\nx-amz-checksum-crc32c:{expected_checksum}\r\n\r\n");
297 assert!(body.ends_with(&expected), "expected {body} to end with '{expected}'");
298 }
299}