aws_sdk_s3/
http_request_checksum.rs

1// Code generated by software.amazon.smithy.rust.codegen.smithy-rs. DO NOT EDIT.
2/*
3 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
4 * SPDX-License-Identifier: Apache-2.0
5 */
6
7#![allow(dead_code)]
8
9//! Interceptor for handling Smithy `@httpChecksum` request checksumming with AWS SigV4
10
11use aws_runtime::auth::PayloadSigningOverride;
12use aws_runtime::content_encoding::header_value::AWS_CHUNKED;
13use aws_runtime::content_encoding::{AwsChunkedBody, AwsChunkedBodyOptions};
14use aws_smithy_checksums::ChecksumAlgorithm;
15use aws_smithy_checksums::{body::calculate, http::HttpChecksum};
16use aws_smithy_runtime_api::box_error::BoxError;
17use aws_smithy_runtime_api::client::interceptors::context::{BeforeSerializationInterceptorContextRef, BeforeTransmitInterceptorContextMut, Input};
18use aws_smithy_runtime_api::client::interceptors::Intercept;
19use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
20use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
21use aws_smithy_types::body::SdkBody;
22use aws_smithy_types::config_bag::{ConfigBag, Layer, Storable, StoreReplace};
23use aws_smithy_types::error::operation::BuildError;
24use http::HeaderValue;
25use http_body::Body;
26use std::{fmt, mem};
27
28/// Errors related to constructing checksum-validated HTTP requests
29#[derive(Debug)]
30pub(crate) enum Error {
31    /// Only request bodies with a known size can be checksum validated
32    UnsizedRequestBody,
33    ChecksumHeadersAreUnsupportedForStreamingBody,
34}
35
36impl fmt::Display for Error {
37    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38        match self {
39            Self::UnsizedRequestBody => write!(f, "Only request bodies with a known size can be checksum validated."),
40            Self::ChecksumHeadersAreUnsupportedForStreamingBody => write!(
41                f,
42                "Checksum header insertion is only supported for non-streaming HTTP bodies. \
43                   To checksum validate a streaming body, the checksums must be sent as trailers."
44            ),
45        }
46    }
47}
48
49impl std::error::Error for Error {}
50
51#[derive(Debug)]
52struct RequestChecksumInterceptorState {
53    checksum_algorithm: Option<ChecksumAlgorithm>,
54}
55impl Storable for RequestChecksumInterceptorState {
56    type Storer = StoreReplace<Self>;
57}
58
59type CustomDefaultFn = Box<dyn Fn(Option<ChecksumAlgorithm>, &ConfigBag) -> Option<ChecksumAlgorithm> + Send + Sync + 'static>;
60
61pub(crate) struct DefaultRequestChecksumOverride {
62    custom_default: CustomDefaultFn,
63}
64impl fmt::Debug for DefaultRequestChecksumOverride {
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        f.debug_struct("DefaultRequestChecksumOverride").finish()
67    }
68}
69impl Storable for DefaultRequestChecksumOverride {
70    type Storer = StoreReplace<Self>;
71}
72impl DefaultRequestChecksumOverride {
73    pub(crate) fn new<F>(custom_default: F) -> Self
74    where
75        F: Fn(Option<ChecksumAlgorithm>, &ConfigBag) -> Option<ChecksumAlgorithm> + Send + Sync + 'static,
76    {
77        Self {
78            custom_default: Box::new(custom_default),
79        }
80    }
81    pub(crate) fn custom_default(&self, original: Option<ChecksumAlgorithm>, config_bag: &ConfigBag) -> Option<ChecksumAlgorithm> {
82        (self.custom_default)(original, config_bag)
83    }
84}
85
86pub(crate) struct RequestChecksumInterceptor<AP> {
87    algorithm_provider: AP,
88}
89
90impl<AP> fmt::Debug for RequestChecksumInterceptor<AP> {
91    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92        f.debug_struct("RequestChecksumInterceptor").finish()
93    }
94}
95
96impl<AP> RequestChecksumInterceptor<AP> {
97    pub(crate) fn new(algorithm_provider: AP) -> Self {
98        Self { algorithm_provider }
99    }
100}
101
102impl<AP> Intercept for RequestChecksumInterceptor<AP>
103where
104    AP: Fn(&Input) -> Result<Option<ChecksumAlgorithm>, BoxError> + Send + Sync,
105{
106    fn name(&self) -> &'static str {
107        "RequestChecksumInterceptor"
108    }
109
110    fn read_before_serialization(
111        &self,
112        context: &BeforeSerializationInterceptorContextRef<'_>,
113        _runtime_components: &RuntimeComponents,
114        cfg: &mut ConfigBag,
115    ) -> Result<(), BoxError> {
116        let checksum_algorithm = (self.algorithm_provider)(context.input())?;
117
118        let mut layer = Layer::new("RequestChecksumInterceptor");
119        layer.store_put(RequestChecksumInterceptorState { checksum_algorithm });
120        cfg.push_layer(layer);
121
122        Ok(())
123    }
124
125    /// Calculate a checksum and modify the request to include the checksum as a header
126    /// (for in-memory request bodies) or a trailer (for streaming request bodies).
127    /// Streaming bodies must be sized or this will return an error.
128    fn modify_before_signing(
129        &self,
130        context: &mut BeforeTransmitInterceptorContextMut<'_>,
131        _runtime_components: &RuntimeComponents,
132        cfg: &mut ConfigBag,
133    ) -> Result<(), BoxError> {
134        let state = cfg.load::<RequestChecksumInterceptorState>().expect("set in `read_before_serialization`");
135
136        let checksum_algorithm = incorporate_custom_default(state.checksum_algorithm, cfg);
137        if let Some(checksum_algorithm) = checksum_algorithm {
138            let request = context.request_mut();
139            add_checksum_for_request_body(request, checksum_algorithm, cfg)?;
140        }
141
142        Ok(())
143    }
144}
145
146fn incorporate_custom_default(checksum: Option<ChecksumAlgorithm>, cfg: &ConfigBag) -> Option<ChecksumAlgorithm> {
147    match cfg.load::<DefaultRequestChecksumOverride>() {
148        Some(checksum_override) => checksum_override.custom_default(checksum, cfg),
149        None => checksum,
150    }
151}
152
153fn add_checksum_for_request_body(request: &mut HttpRequest, checksum_algorithm: ChecksumAlgorithm, cfg: &mut ConfigBag) -> Result<(), BoxError> {
154    match request.body().bytes() {
155        // Body is in-memory: read it and insert the checksum as a header.
156        Some(data) => {
157            tracing::debug!("applying {checksum_algorithm:?} of the request body as a header");
158            let mut checksum = checksum_algorithm.into_impl();
159            checksum.update(data);
160
161            request.headers_mut().insert(checksum.header_name(), checksum.header_value());
162        }
163        // Body is streaming: wrap the body so it will emit a checksum as a trailer.
164        None => {
165            tracing::debug!("applying {checksum_algorithm:?} of the request body as a trailer");
166            cfg.interceptor_state().store_put(PayloadSigningOverride::StreamingUnsignedPayloadTrailer);
167            wrap_streaming_request_body_in_checksum_calculating_body(request, checksum_algorithm)?;
168        }
169    }
170    Ok(())
171}
172
173fn wrap_streaming_request_body_in_checksum_calculating_body(
174    request: &mut HttpRequest,
175    checksum_algorithm: ChecksumAlgorithm,
176) -> Result<(), BuildError> {
177    let original_body_size = request
178        .body()
179        .size_hint()
180        .exact()
181        .ok_or_else(|| BuildError::other(Error::UnsizedRequestBody))?;
182
183    let mut body = {
184        let body = mem::replace(request.body_mut(), SdkBody::taken());
185
186        body.map(move |body| {
187            let checksum = checksum_algorithm.into_impl();
188            let trailer_len = HttpChecksum::size(checksum.as_ref());
189            let body = calculate::ChecksumBody::new(body, checksum);
190            let aws_chunked_body_options = AwsChunkedBodyOptions::new(original_body_size, vec![trailer_len]);
191
192            let body = AwsChunkedBody::new(body, aws_chunked_body_options);
193
194            SdkBody::from_body_0_4(body)
195        })
196    };
197
198    let encoded_content_length = body.size_hint().exact().ok_or_else(|| BuildError::other(Error::UnsizedRequestBody))?;
199
200    let headers = request.headers_mut();
201
202    headers.insert(
203        http::header::HeaderName::from_static("x-amz-trailer"),
204        checksum_algorithm.into_impl().header_name(),
205    );
206
207    headers.insert(http::header::CONTENT_LENGTH, HeaderValue::from(encoded_content_length));
208    headers.insert(
209        http::header::HeaderName::from_static("x-amz-decoded-content-length"),
210        HeaderValue::from(original_body_size),
211    );
212    headers.insert(
213        http::header::CONTENT_ENCODING,
214        HeaderValue::from_str(AWS_CHUNKED)
215            .map_err(BuildError::other)
216            .expect("\"aws-chunked\" will always be a valid HeaderValue"),
217    );
218
219    mem::swap(request.body_mut(), &mut body);
220
221    Ok(())
222}
223
224#[cfg(test)]
225mod tests {
226    use crate::http_request_checksum::wrap_streaming_request_body_in_checksum_calculating_body;
227    use aws_smithy_checksums::ChecksumAlgorithm;
228    use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
229    use aws_smithy_types::base64;
230    use aws_smithy_types::body::SdkBody;
231    use aws_smithy_types::byte_stream::ByteStream;
232    use bytes::BytesMut;
233    use http_body::Body;
234    use tempfile::NamedTempFile;
235
236    #[tokio::test]
237    async fn test_checksum_body_is_retryable() {
238        let input_text = "Hello world";
239        let chunk_len_hex = format!("{:X}", input_text.len());
240        let mut request: HttpRequest = http::Request::builder()
241            .body(SdkBody::retryable(move || SdkBody::from(input_text)))
242            .unwrap()
243            .try_into()
244            .unwrap();
245
246        // ensure original SdkBody is retryable
247        assert!(request.body().try_clone().is_some());
248
249        let checksum_algorithm: ChecksumAlgorithm = "crc32".parse().unwrap();
250        wrap_streaming_request_body_in_checksum_calculating_body(&mut request, checksum_algorithm).unwrap();
251
252        // ensure wrapped SdkBody is retryable
253        let mut body = request.body().try_clone().expect("body is retryable");
254
255        let mut body_data = BytesMut::new();
256        while let Some(data) = body.data().await {
257            body_data.extend_from_slice(&data.unwrap())
258        }
259        let body = std::str::from_utf8(&body_data).unwrap();
260        assert_eq!(
261            format!("{chunk_len_hex}\r\n{input_text}\r\n0\r\nx-amz-checksum-crc32:i9aeUg==\r\n\r\n"),
262            body
263        );
264    }
265
266    #[tokio::test]
267    async fn test_checksum_body_from_file_is_retryable() {
268        use std::io::Write;
269        let mut file = NamedTempFile::new().unwrap();
270        let checksum_algorithm: ChecksumAlgorithm = "crc32c".parse().unwrap();
271
272        let mut crc32c_checksum = checksum_algorithm.into_impl();
273        for i in 0..10000 {
274            let line = format!("This is a large file created for testing purposes {}", i);
275            file.as_file_mut().write_all(line.as_bytes()).unwrap();
276            crc32c_checksum.update(line.as_bytes());
277        }
278        let crc32c_checksum = crc32c_checksum.finalize();
279
280        let mut request = HttpRequest::new(ByteStream::read_from().path(&file).buffer_size(1024).build().await.unwrap().into_inner());
281
282        // ensure original SdkBody is retryable
283        assert!(request.body().try_clone().is_some());
284
285        wrap_streaming_request_body_in_checksum_calculating_body(&mut request, checksum_algorithm).unwrap();
286
287        // ensure wrapped SdkBody is retryable
288        let mut body = request.body().try_clone().expect("body is retryable");
289
290        let mut body_data = BytesMut::new();
291        while let Some(data) = body.data().await {
292            body_data.extend_from_slice(&data.unwrap())
293        }
294        let body = std::str::from_utf8(&body_data).unwrap();
295        let expected_checksum = base64::encode(&crc32c_checksum);
296        let expected = format!("This is a large file created for testing purposes 9999\r\n0\r\nx-amz-checksum-crc32c:{expected_checksum}\r\n\r\n");
297        assert!(body.ends_with(&expected), "expected {body} to end with '{expected}'");
298    }
299}