azure_storage/authorization/
authorization_policy.rs

1use crate::{clients::ServiceType, StorageCredentials, StorageCredentialsInner};
2use azure_core::{
3    auth::Secret,
4    error::{ErrorKind, ResultExt},
5    headers::*,
6    hmac::hmac_sha256,
7    Context, Method, Policy, PolicyResult, Request, Url,
8};
9use std::{borrow::Cow, ops::Deref, sync::Arc};
10use tracing::trace;
11
12const STORAGE_TOKEN_SCOPE: &str = "https://storage.azure.com/.default";
13
14#[derive(Debug, Clone)]
15pub struct AuthorizationPolicy {
16    credentials: StorageCredentials,
17}
18
19impl AuthorizationPolicy {
20    pub(crate) fn new(credentials: StorageCredentials) -> Self {
21        Self { credentials }
22    }
23}
24
25#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
26#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
27impl Policy for AuthorizationPolicy {
28    async fn send(
29        &self,
30        ctx: &Context,
31        request: &mut Request,
32        next: &[Arc<dyn Policy>],
33    ) -> PolicyResult {
34        trace!("called AuthorizationPolicy::send. self == {:#?}", self);
35
36        assert!(
37            !next.is_empty(),
38            "Authorization policies cannot be the last policy of a pipeline"
39        );
40
41        // lock the credentials within a scope so that it is released as soon as possible
42        {
43            let creds = self.credentials.0.read().await;
44
45            match creds.deref() {
46                StorageCredentialsInner::Key(account, key) => {
47                    if !request.url().query_pairs().any(|(k, _)| &*k == "sig") {
48                        let auth = generate_authorization(
49                            request.headers(),
50                            request.url(),
51                            *request.method(),
52                            account,
53                            key,
54                            *ctx.get()
55                                .expect("ServiceType must be in the Context at this point"),
56                        )?;
57                        request.insert_header(AUTHORIZATION, auth);
58                    }
59                }
60                StorageCredentialsInner::SASToken(query_pairs) => {
61                    // Ensure the signature param is not already present
62                    if !request.url().query_pairs().any(|(k, _)| &*k == "sig") {
63                        request
64                            .url_mut()
65                            .query_pairs_mut()
66                            .extend_pairs(query_pairs);
67                    }
68                }
69                StorageCredentialsInner::BearerToken(token) => {
70                    request.insert_header(AUTHORIZATION, format!("Bearer {}", token.secret()));
71                }
72                StorageCredentialsInner::TokenCredential(token_credential) => {
73                    let bearer_token = token_credential
74                        .get_token(&[STORAGE_TOKEN_SCOPE])
75                        .await
76                        .context(ErrorKind::Credential, "failed to get bearer token")?;
77
78                    request.insert_header(
79                        AUTHORIZATION,
80                        format!("Bearer {}", bearer_token.token.secret()),
81                    );
82                }
83                StorageCredentialsInner::Anonymous => {}
84            }
85        };
86
87        next[0].send(ctx, request, &next[1..]).await
88    }
89}
90
91fn generate_authorization(
92    h: &Headers,
93    u: &Url,
94    method: Method,
95    account: &str,
96    key: &Secret,
97    service_type: ServiceType,
98) -> azure_core::Result<String> {
99    let str_to_sign = string_to_sign(h, u, method, account, service_type);
100    let auth = hmac_sha256(&str_to_sign, key).context(
101        azure_core::error::ErrorKind::Credential,
102        "failed to sign the hmac",
103    )?;
104    Ok(format!("SharedKey {account}:{auth}"))
105}
106
107fn add_if_exists<'a>(h: &'a Headers, key: &HeaderName) -> &'a str {
108    h.get_optional_str(key).unwrap_or_default()
109}
110
111#[allow(unknown_lints)]
112fn string_to_sign(
113    h: &Headers,
114    u: &Url,
115    method: Method,
116    account: &str,
117    service_type: ServiceType,
118) -> String {
119    if matches!(service_type, ServiceType::Table) {
120        format!(
121            "{}\n{}\n{}\n{}\n{}",
122            method.as_ref(),
123            add_if_exists(h, &CONTENT_MD5),
124            add_if_exists(h, &CONTENT_TYPE),
125            add_if_exists(h, &MS_DATE),
126            canonicalized_resource_table(account, u)
127        )
128    } else {
129        // content length must only be specified if != 0
130        // this is valid from 2015-02-21
131        let content_length = h
132            .get_optional_str(&CONTENT_LENGTH)
133            .filter(|&v| v != "0")
134            .unwrap_or_default();
135        format!(
136            "{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}{}",
137            method.as_ref(),
138            add_if_exists(h, &CONTENT_ENCODING),
139            add_if_exists(h, &CONTENT_LANGUAGE),
140            content_length,
141            add_if_exists(h, &CONTENT_MD5),
142            add_if_exists(h, &CONTENT_TYPE),
143            add_if_exists(h, &DATE),
144            add_if_exists(h, &IF_MODIFIED_SINCE),
145            add_if_exists(h, &IF_MATCH),
146            add_if_exists(h, &IF_NONE_MATCH),
147            add_if_exists(h, &IF_UNMODIFIED_SINCE),
148            add_if_exists(h, &RANGE),
149            canonicalize_header(h),
150            canonicalized_resource(account, u)
151        )
152    }
153}
154
155fn canonicalize_header(headers: &Headers) -> String {
156    let mut names = headers
157        .iter()
158        .filter_map(|(k, _)| (k.as_str().starts_with("x-ms")).then_some(k))
159        .collect::<Vec<_>>();
160    names.sort_unstable();
161
162    let mut result = String::new();
163
164    for header_name in names {
165        let value = headers.get_optional_str(header_name).unwrap();
166        let name = header_name.as_str();
167        result = format!("{result}{name}:{value}\n");
168    }
169    result
170}
171
172fn canonicalized_resource_table(account: &str, u: &Url) -> String {
173    format!("/{}{}", account, u.path())
174}
175
176fn canonicalized_resource(account: &str, uri: &Url) -> String {
177    let mut can_res: String = String::new();
178    can_res += "/";
179    can_res += account;
180
181    for p in uri.path_segments().into_iter().flatten() {
182        can_res.push('/');
183        can_res.push_str(p);
184    }
185    can_res += "\n";
186
187    // query parameters
188    let query_pairs = uri.query_pairs();
189    {
190        let mut qps: Vec<String> = Vec::new();
191        for (q, _) in query_pairs {
192            if !(qps.iter().any(|x| x == &*q)) {
193                qps.push(q.into_owned());
194            }
195        }
196
197        qps.sort();
198
199        for qparam in qps {
200            // find correct parameter
201            let ret = lexy_sort(query_pairs, &qparam);
202
203            can_res = can_res + &qparam.to_lowercase() + ":";
204
205            for (i, item) in ret.iter().enumerate() {
206                if i > 0 {
207                    can_res += ",";
208                }
209                can_res += item;
210            }
211
212            can_res += "\n";
213        }
214    };
215
216    can_res[0..can_res.len() - 1].to_owned()
217}
218
219fn lexy_sort<'a>(
220    vec: impl Iterator<Item = (Cow<'a, str>, Cow<'a, str>)> + 'a,
221    query_param: &str,
222) -> Vec<Cow<'a, str>> {
223    let mut values = vec
224        .filter(|(k, _)| *k == query_param)
225        .map(|(_, v)| v)
226        .collect::<Vec<_>>();
227    values.sort_unstable();
228    values
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use azure_core::{BytesStream, Response};
235
236    #[derive(Debug, Clone)]
237    struct AssertSigHeaderUniqueMockPolicy;
238
239    #[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
240    #[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
241    impl Policy for AssertSigHeaderUniqueMockPolicy {
242        async fn send(
243            &self,
244            _ctx: &Context,
245            request: &mut Request,
246            _next: &[Arc<dyn Policy>],
247        ) -> PolicyResult {
248            let sig_header_count = request
249                .url()
250                .query_pairs()
251                .filter(|param| param.0 == "sig")
252                .count();
253            assert_eq!(sig_header_count, 1);
254
255            Ok(Response::new(
256                azure_core::StatusCode::Accepted,
257                Headers::new(),
258                Box::pin(BytesStream::new(vec![])),
259            ))
260        }
261    }
262
263    const SAMPLE_SAS_TOKEN: &str = "sp=r&st=1970-01-01T00:00:00Z&se=1970-01-01T00:00:00Z&spr=https&sv=1970-01-01&sr=c&sig=AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA";
264
265    #[tokio::test]
266    async fn authorization_policy_applies_sas_token() {
267        let ctx = Context::default();
268        let storage_credentials = StorageCredentials::sas_token(SAMPLE_SAS_TOKEN).unwrap();
269        let auth_policy = AuthorizationPolicy::new(storage_credentials);
270        let mut request = Request::new(Url::parse("https://example.com").unwrap(), Method::Get);
271
272        let assert_sig_header_unique_mock_policy = Arc::new(AssertSigHeaderUniqueMockPolicy);
273
274        auth_policy
275            .send(&ctx, &mut request, &[assert_sig_header_unique_mock_policy])
276            .await
277            .unwrap();
278    }
279
280    #[tokio::test]
281    async fn authorization_policy_with_sas_token_does_not_apply_twice() {
282        let ctx = Context::default();
283        let storage_credentials = StorageCredentials::sas_token(SAMPLE_SAS_TOKEN).unwrap();
284        let auth_policy = AuthorizationPolicy::new(storage_credentials);
285        let mut request = Request::new(Url::parse("https://example.com").unwrap(), Method::Get);
286
287        let assert_sig_header_unique_mock_policy = Arc::new(AssertSigHeaderUniqueMockPolicy);
288
289        // apply policy twice
290        auth_policy
291            .send(
292                &ctx,
293                &mut request,
294                &[assert_sig_header_unique_mock_policy.clone()],
295            )
296            .await
297            .unwrap();
298        auth_policy
299            .send(&ctx, &mut request, &[assert_sig_header_unique_mock_policy])
300            .await
301            .unwrap();
302    }
303}