aws_config/
http_credential_provider.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Generalized HTTP credential provider. Currently, this cannot be used directly and can only
7//! be used via the ECS credential provider.
8//!
9//! Future work will stabilize this interface and enable it to be used directly.
10
11use crate::json_credentials::{parse_json_credentials, JsonCredentials, RefreshableCredentials};
12use crate::provider_config::ProviderConfig;
13use aws_credential_types::attributes::AccountId;
14use aws_credential_types::credential_feature::AwsCredentialFeature;
15use aws_credential_types::provider::{self, error::CredentialsError};
16use aws_credential_types::Credentials;
17use aws_smithy_runtime::client::metrics::MetricsRuntimePlugin;
18use aws_smithy_runtime::client::orchestrator::operation::Operation;
19use aws_smithy_runtime::client::retries::classifiers::{
20    HttpStatusCodeClassifier, TransientErrorClassifier,
21};
22use aws_smithy_runtime_api::client::http::HttpConnectorSettings;
23use aws_smithy_runtime_api::client::interceptors::context::{Error, InterceptorContext};
24use aws_smithy_runtime_api::client::orchestrator::{
25    HttpResponse, Metadata, OrchestratorError, SensitiveOutput,
26};
27use aws_smithy_runtime_api::client::result::SdkError;
28use aws_smithy_runtime_api::client::retries::classifiers::ClassifyRetry;
29use aws_smithy_runtime_api::client::retries::classifiers::RetryAction;
30use aws_smithy_runtime_api::client::runtime_plugin::StaticRuntimePlugin;
31use aws_smithy_types::body::SdkBody;
32use aws_smithy_types::config_bag::Layer;
33use aws_smithy_types::retry::RetryConfig;
34use aws_smithy_types::timeout::TimeoutConfig;
35use http::header::{ACCEPT, AUTHORIZATION};
36use http::HeaderValue;
37use std::time::Duration;
38
39const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(5);
40const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(2);
41
42#[derive(Debug)]
43struct HttpProviderAuth {
44    auth: Option<HeaderValue>,
45}
46
47#[derive(Debug)]
48pub(crate) struct HttpCredentialProvider {
49    operation: Operation<HttpProviderAuth, Credentials, CredentialsError>,
50}
51
52impl HttpCredentialProvider {
53    pub(crate) fn builder() -> Builder {
54        Builder::default()
55    }
56
57    pub(crate) async fn credentials(&self, auth: Option<HeaderValue>) -> provider::Result {
58        let credentials =
59            self.operation
60                .invoke(HttpProviderAuth { auth })
61                .await
62                .map(|mut creds| {
63                    creds
64                        .get_property_mut_or_default::<Vec<AwsCredentialFeature>>()
65                        .push(AwsCredentialFeature::CredentialsHttp);
66                    creds
67                });
68        match credentials {
69            Ok(creds) => Ok(creds),
70            Err(SdkError::ServiceError(context)) => Err(context.into_err()),
71            Err(other) => Err(CredentialsError::unhandled(other)),
72        }
73    }
74}
75
76#[derive(Default)]
77pub(crate) struct Builder {
78    provider_config: Option<ProviderConfig>,
79    http_connector_settings: Option<HttpConnectorSettings>,
80}
81
82impl Builder {
83    pub(crate) fn configure(mut self, provider_config: &ProviderConfig) -> Self {
84        self.provider_config = Some(provider_config.clone());
85        self
86    }
87
88    pub(crate) fn http_connector_settings(
89        mut self,
90        http_connector_settings: HttpConnectorSettings,
91    ) -> Self {
92        self.http_connector_settings = Some(http_connector_settings);
93        self
94    }
95
96    pub(crate) fn build(
97        self,
98        provider_name: &'static str,
99        endpoint: &str,
100        path: impl Into<String>,
101    ) -> HttpCredentialProvider {
102        let provider_config = self.provider_config.unwrap_or_default();
103        let path = path.into();
104
105        let mut builder = Operation::builder()
106            .service_name("HttpCredentialProvider")
107            .operation_name("LoadCredentials")
108            .with_connection_poisoning()
109            .endpoint_url(endpoint)
110            .no_auth()
111            .timeout_config(
112                TimeoutConfig::builder()
113                    .connect_timeout(DEFAULT_CONNECT_TIMEOUT)
114                    .read_timeout(DEFAULT_READ_TIMEOUT)
115                    .build(),
116            )
117            .runtime_plugin(StaticRuntimePlugin::new().with_config({
118                let mut layer = Layer::new("SensitiveOutput");
119                layer.store_put(SensitiveOutput);
120                layer.freeze()
121            }))
122            .runtime_plugin(
123                MetricsRuntimePlugin::builder()
124                    .with_scope("aws_config::http_credential_provider")
125                    .with_time_source(provider_config.time_source())
126                    .with_metadata(Metadata::new(path.clone(), provider_name))
127                    .build()
128                    .expect("All required fields have been set"),
129            );
130        if let Some(http_client) = provider_config.http_client() {
131            builder = builder.http_client(http_client);
132        }
133        if let Some(sleep_impl) = provider_config.sleep_impl() {
134            builder = builder
135                .standard_retry(&RetryConfig::standard())
136                // The following errors are retryable:
137                //   - Socket errors
138                //   - Networking timeouts
139                //   - 5xx errors
140                //   - Non-parseable 200 responses.
141                .retry_classifier(HttpCredentialRetryClassifier)
142                // Socket errors and network timeouts
143                .retry_classifier(TransientErrorClassifier::<Error>::new())
144                // 5xx errors
145                .retry_classifier(HttpStatusCodeClassifier::default())
146                .sleep_impl(sleep_impl);
147        } else {
148            builder = builder.no_retry();
149        }
150        let operation = builder
151            .serializer(move |input: HttpProviderAuth| {
152                let mut http_req = http::Request::builder()
153                    .uri(path.clone())
154                    .header(ACCEPT, "application/json");
155                if let Some(auth) = input.auth {
156                    http_req = http_req.header(AUTHORIZATION, auth);
157                }
158                Ok(http_req
159                    .body(SdkBody::empty())
160                    .expect("valid request")
161                    .try_into()
162                    .unwrap())
163            })
164            .deserializer(move |response| parse_response(provider_name, response))
165            .build();
166        HttpCredentialProvider { operation }
167    }
168}
169
170fn parse_response(
171    provider_name: &'static str,
172    response: &HttpResponse,
173) -> Result<Credentials, OrchestratorError<CredentialsError>> {
174    if !response.status().is_success() {
175        return Err(OrchestratorError::operation(
176            CredentialsError::provider_error(format!(
177                "Non-success status from HTTP credential provider: {:?}",
178                response.status()
179            )),
180        ));
181    }
182    let resp_bytes = response.body().bytes().expect("non-streaming deserializer");
183    let str_resp = std::str::from_utf8(resp_bytes)
184        .map_err(|err| OrchestratorError::operation(CredentialsError::unhandled(err)))?;
185    let json_creds = parse_json_credentials(str_resp)
186        .map_err(|err| OrchestratorError::operation(CredentialsError::unhandled(err)))?;
187    match json_creds {
188        JsonCredentials::RefreshableCredentials(RefreshableCredentials {
189            access_key_id,
190            secret_access_key,
191            session_token,
192            account_id,
193            expiration,
194        }) => {
195            let mut builder = Credentials::builder()
196                .access_key_id(access_key_id)
197                .secret_access_key(secret_access_key)
198                .session_token(session_token)
199                .expiry(expiration)
200                .provider_name(provider_name);
201            builder.set_account_id(account_id.map(AccountId::from));
202            Ok(builder.build())
203        }
204        JsonCredentials::Error { code, message } => Err(OrchestratorError::operation(
205            CredentialsError::provider_error(format!(
206                "failed to load credentials [{code}]: {message}",
207            )),
208        )),
209    }
210}
211
212#[derive(Clone, Debug)]
213struct HttpCredentialRetryClassifier;
214
215impl ClassifyRetry for HttpCredentialRetryClassifier {
216    fn name(&self) -> &'static str {
217        "HttpCredentialRetryClassifier"
218    }
219
220    fn classify_retry(&self, ctx: &InterceptorContext) -> RetryAction {
221        let output_or_error = ctx.output_or_error();
222        let error = match output_or_error {
223            Some(Ok(_)) | None => return RetryAction::NoActionIndicated,
224            Some(Err(err)) => err,
225        };
226
227        // Retry non-parseable 200 responses
228        if let Some((err, status)) = error
229            .as_operation_error()
230            .and_then(|err| err.downcast_ref::<CredentialsError>())
231            .zip(ctx.response().map(HttpResponse::status))
232        {
233            if matches!(err, CredentialsError::Unhandled { .. }) && status.is_success() {
234                return RetryAction::server_error();
235            }
236        }
237
238        RetryAction::NoActionIndicated
239    }
240}
241
242#[cfg(test)]
243mod test {
244    use super::*;
245    use aws_credential_types::credential_feature::AwsCredentialFeature;
246    use aws_credential_types::provider::error::CredentialsError;
247    use aws_smithy_http_client::test_util::{ReplayEvent, StaticReplayClient};
248    use aws_smithy_types::body::SdkBody;
249    use http::{Request, Response, Uri};
250    use std::time::SystemTime;
251
252    async fn provide_creds(
253        http_client: StaticReplayClient,
254    ) -> Result<Credentials, CredentialsError> {
255        let provider_config = ProviderConfig::default().with_http_client(http_client.clone());
256        let provider = HttpCredentialProvider::builder()
257            .configure(&provider_config)
258            .build("test", "http://localhost:1234/", "/some-creds");
259        provider.credentials(None).await
260    }
261
262    fn successful_req_resp() -> ReplayEvent {
263        ReplayEvent::new(
264            Request::builder()
265                .uri(Uri::from_static("http://localhost:1234/some-creds"))
266                .body(SdkBody::empty())
267                .unwrap(),
268            Response::builder()
269                .status(200)
270                .body(SdkBody::from(
271                    r#"{
272                        "AccessKeyId" : "MUA...",
273                        "SecretAccessKey" : "/7PC5om....",
274                        "Token" : "AQoDY....=",
275                        "Expiration" : "2016-02-25T06:03:31Z"
276                    }"#,
277                ))
278                .unwrap(),
279        )
280    }
281
282    #[tokio::test]
283    async fn successful_response() {
284        let http_client = StaticReplayClient::new(vec![successful_req_resp()]);
285        let creds = provide_creds(http_client.clone()).await.expect("success");
286        assert_eq!("MUA...", creds.access_key_id());
287        assert_eq!("/7PC5om....", creds.secret_access_key());
288        assert_eq!(Some("AQoDY....="), creds.session_token());
289        assert_eq!(
290            Some(SystemTime::UNIX_EPOCH + Duration::from_secs(1456380211)),
291            creds.expiry()
292        );
293        http_client.assert_requests_match(&[]);
294    }
295
296    #[tokio::test]
297    async fn retry_nonparseable_response() {
298        let http_client = StaticReplayClient::new(vec![
299            ReplayEvent::new(
300                Request::builder()
301                    .uri(Uri::from_static("http://localhost:1234/some-creds"))
302                    .body(SdkBody::empty())
303                    .unwrap(),
304                Response::builder()
305                    .status(200)
306                    .body(SdkBody::from(r#"not json"#))
307                    .unwrap(),
308            ),
309            successful_req_resp(),
310        ]);
311        let creds = provide_creds(http_client.clone()).await.expect("success");
312        assert_eq!("MUA...", creds.access_key_id());
313        http_client.assert_requests_match(&[]);
314    }
315
316    #[tokio::test]
317    async fn retry_error_code() {
318        let http_client = StaticReplayClient::new(vec![
319            ReplayEvent::new(
320                Request::builder()
321                    .uri(Uri::from_static("http://localhost:1234/some-creds"))
322                    .body(SdkBody::empty())
323                    .unwrap(),
324                Response::builder()
325                    .status(500)
326                    .body(SdkBody::from(r#"it broke"#))
327                    .unwrap(),
328            ),
329            successful_req_resp(),
330        ]);
331        let creds = provide_creds(http_client.clone()).await.expect("success");
332        assert_eq!("MUA...", creds.access_key_id());
333        http_client.assert_requests_match(&[]);
334    }
335
336    #[tokio::test]
337    async fn explicit_error_not_retryable() {
338        let http_client = StaticReplayClient::new(vec![ReplayEvent::new(
339            Request::builder()
340                .uri(Uri::from_static("http://localhost:1234/some-creds"))
341                .body(SdkBody::empty())
342                .unwrap(),
343            Response::builder()
344                .status(400)
345                .body(SdkBody::from(
346                    r#"{ "Code": "Error", "Message": "There was a problem, it was your fault" }"#,
347                ))
348                .unwrap(),
349        )]);
350        let err = provide_creds(http_client.clone())
351            .await
352            .expect_err("it should fail");
353        assert!(
354            matches!(err, CredentialsError::ProviderError { .. }),
355            "should be CredentialsError::ProviderError: {err}",
356        );
357        http_client.assert_requests_match(&[]);
358    }
359
360    #[tokio::test]
361    async fn credentials_feature() {
362        let http_client = StaticReplayClient::new(vec![successful_req_resp()]);
363        let creds = provide_creds(http_client.clone()).await.expect("success");
364        assert_eq!(
365            &vec![AwsCredentialFeature::CredentialsHttp],
366            creds.get_property::<Vec<AwsCredentialFeature>>().unwrap()
367        );
368    }
369}