azure_storage/authorization/
authorization_policy.rs1use crate::{clients::ServiceType, StorageCredentials, StorageCredentialsInner};
2use azure_core::{
3 auth::Secret,
4 error::{ErrorKind, ResultExt},
5 headers::*,
6 hmac::hmac_sha256,
7 Context, Method, Policy, PolicyResult, Request, Url,
8};
9use std::{borrow::Cow, ops::Deref, sync::Arc};
10use tracing::trace;
11
12const STORAGE_TOKEN_SCOPE: &str = "https://storage.azure.com/.default";
13
14#[derive(Debug, Clone)]
15pub struct AuthorizationPolicy {
16 credentials: StorageCredentials,
17}
18
19impl AuthorizationPolicy {
20 pub(crate) fn new(credentials: StorageCredentials) -> Self {
21 Self { credentials }
22 }
23}
24
25#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
26#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
27impl Policy for AuthorizationPolicy {
28 async fn send(
29 &self,
30 ctx: &Context,
31 request: &mut Request,
32 next: &[Arc<dyn Policy>],
33 ) -> PolicyResult {
34 trace!("called AuthorizationPolicy::send. self == {:#?}", self);
35
36 assert!(
37 !next.is_empty(),
38 "Authorization policies cannot be the last policy of a pipeline"
39 );
40
41 {
43 let creds = self.credentials.0.read().await;
44
45 match creds.deref() {
46 StorageCredentialsInner::Key(account, key) => {
47 if !request.url().query_pairs().any(|(k, _)| &*k == "sig") {
48 let auth = generate_authorization(
49 request.headers(),
50 request.url(),
51 *request.method(),
52 account,
53 key,
54 *ctx.get()
55 .expect("ServiceType must be in the Context at this point"),
56 )?;
57 request.insert_header(AUTHORIZATION, auth);
58 }
59 }
60 StorageCredentialsInner::SASToken(query_pairs) => {
61 if !request.url().query_pairs().any(|(k, _)| &*k == "sig") {
63 request
64 .url_mut()
65 .query_pairs_mut()
66 .extend_pairs(query_pairs);
67 }
68 }
69 StorageCredentialsInner::BearerToken(token) => {
70 request.insert_header(AUTHORIZATION, format!("Bearer {}", token.secret()));
71 }
72 StorageCredentialsInner::TokenCredential(token_credential) => {
73 let bearer_token = token_credential
74 .get_token(&[STORAGE_TOKEN_SCOPE])
75 .await
76 .context(ErrorKind::Credential, "failed to get bearer token")?;
77
78 request.insert_header(
79 AUTHORIZATION,
80 format!("Bearer {}", bearer_token.token.secret()),
81 );
82 }
83 StorageCredentialsInner::Anonymous => {}
84 }
85 };
86
87 next[0].send(ctx, request, &next[1..]).await
88 }
89}
90
91fn generate_authorization(
92 h: &Headers,
93 u: &Url,
94 method: Method,
95 account: &str,
96 key: &Secret,
97 service_type: ServiceType,
98) -> azure_core::Result<String> {
99 let str_to_sign = string_to_sign(h, u, method, account, service_type);
100 let auth = hmac_sha256(&str_to_sign, key).context(
101 azure_core::error::ErrorKind::Credential,
102 "failed to sign the hmac",
103 )?;
104 Ok(format!("SharedKey {account}:{auth}"))
105}
106
107fn add_if_exists<'a>(h: &'a Headers, key: &HeaderName) -> &'a str {
108 h.get_optional_str(key).unwrap_or_default()
109}
110
111#[allow(unknown_lints)]
112fn string_to_sign(
113 h: &Headers,
114 u: &Url,
115 method: Method,
116 account: &str,
117 service_type: ServiceType,
118) -> String {
119 if matches!(service_type, ServiceType::Table) {
120 format!(
121 "{}\n{}\n{}\n{}\n{}",
122 method.as_ref(),
123 add_if_exists(h, &CONTENT_MD5),
124 add_if_exists(h, &CONTENT_TYPE),
125 add_if_exists(h, &MS_DATE),
126 canonicalized_resource_table(account, u)
127 )
128 } else {
129 let content_length = h
132 .get_optional_str(&CONTENT_LENGTH)
133 .filter(|&v| v != "0")
134 .unwrap_or_default();
135 format!(
136 "{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}\n{}{}",
137 method.as_ref(),
138 add_if_exists(h, &CONTENT_ENCODING),
139 add_if_exists(h, &CONTENT_LANGUAGE),
140 content_length,
141 add_if_exists(h, &CONTENT_MD5),
142 add_if_exists(h, &CONTENT_TYPE),
143 add_if_exists(h, &DATE),
144 add_if_exists(h, &IF_MODIFIED_SINCE),
145 add_if_exists(h, &IF_MATCH),
146 add_if_exists(h, &IF_NONE_MATCH),
147 add_if_exists(h, &IF_UNMODIFIED_SINCE),
148 add_if_exists(h, &RANGE),
149 canonicalize_header(h),
150 canonicalized_resource(account, u)
151 )
152 }
153}
154
155fn canonicalize_header(headers: &Headers) -> String {
156 let mut names = headers
157 .iter()
158 .filter_map(|(k, _)| (k.as_str().starts_with("x-ms")).then_some(k))
159 .collect::<Vec<_>>();
160 names.sort_unstable();
161
162 let mut result = String::new();
163
164 for header_name in names {
165 let value = headers.get_optional_str(header_name).unwrap();
166 let name = header_name.as_str();
167 result = format!("{result}{name}:{value}\n");
168 }
169 result
170}
171
172fn canonicalized_resource_table(account: &str, u: &Url) -> String {
173 format!("/{}{}", account, u.path())
174}
175
176fn canonicalized_resource(account: &str, uri: &Url) -> String {
177 let mut can_res: String = String::new();
178 can_res += "/";
179 can_res += account;
180
181 for p in uri.path_segments().into_iter().flatten() {
182 can_res.push('/');
183 can_res.push_str(p);
184 }
185 can_res += "\n";
186
187 let query_pairs = uri.query_pairs();
189 {
190 let mut qps: Vec<String> = Vec::new();
191 for (q, _) in query_pairs {
192 if !(qps.iter().any(|x| x == &*q)) {
193 qps.push(q.into_owned());
194 }
195 }
196
197 qps.sort();
198
199 for qparam in qps {
200 let ret = lexy_sort(query_pairs, &qparam);
202
203 can_res = can_res + &qparam.to_lowercase() + ":";
204
205 for (i, item) in ret.iter().enumerate() {
206 if i > 0 {
207 can_res += ",";
208 }
209 can_res += item;
210 }
211
212 can_res += "\n";
213 }
214 };
215
216 can_res[0..can_res.len() - 1].to_owned()
217}
218
219fn lexy_sort<'a>(
220 vec: impl Iterator<Item = (Cow<'a, str>, Cow<'a, str>)> + 'a,
221 query_param: &str,
222) -> Vec<Cow<'a, str>> {
223 let mut values = vec
224 .filter(|(k, _)| *k == query_param)
225 .map(|(_, v)| v)
226 .collect::<Vec<_>>();
227 values.sort_unstable();
228 values
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use azure_core::{BytesStream, Response};
235
236 #[derive(Debug, Clone)]
237 struct AssertSigHeaderUniqueMockPolicy;
238
239 #[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
240 #[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
241 impl Policy for AssertSigHeaderUniqueMockPolicy {
242 async fn send(
243 &self,
244 _ctx: &Context,
245 request: &mut Request,
246 _next: &[Arc<dyn Policy>],
247 ) -> PolicyResult {
248 let sig_header_count = request
249 .url()
250 .query_pairs()
251 .filter(|param| param.0 == "sig")
252 .count();
253 assert_eq!(sig_header_count, 1);
254
255 Ok(Response::new(
256 azure_core::StatusCode::Accepted,
257 Headers::new(),
258 Box::pin(BytesStream::new(vec![])),
259 ))
260 }
261 }
262
263 const SAMPLE_SAS_TOKEN: &str = "sp=r&st=1970-01-01T00:00:00Z&se=1970-01-01T00:00:00Z&spr=https&sv=1970-01-01&sr=c&sig=AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA";
264
265 #[tokio::test]
266 async fn authorization_policy_applies_sas_token() {
267 let ctx = Context::default();
268 let storage_credentials = StorageCredentials::sas_token(SAMPLE_SAS_TOKEN).unwrap();
269 let auth_policy = AuthorizationPolicy::new(storage_credentials);
270 let mut request = Request::new(Url::parse("https://example.com").unwrap(), Method::Get);
271
272 let assert_sig_header_unique_mock_policy = Arc::new(AssertSigHeaderUniqueMockPolicy);
273
274 auth_policy
275 .send(&ctx, &mut request, &[assert_sig_header_unique_mock_policy])
276 .await
277 .unwrap();
278 }
279
280 #[tokio::test]
281 async fn authorization_policy_with_sas_token_does_not_apply_twice() {
282 let ctx = Context::default();
283 let storage_credentials = StorageCredentials::sas_token(SAMPLE_SAS_TOKEN).unwrap();
284 let auth_policy = AuthorizationPolicy::new(storage_credentials);
285 let mut request = Request::new(Url::parse("https://example.com").unwrap(), Method::Get);
286
287 let assert_sig_header_unique_mock_policy = Arc::new(AssertSigHeaderUniqueMockPolicy);
288
289 auth_policy
291 .send(
292 &ctx,
293 &mut request,
294 &[assert_sig_header_unique_mock_policy.clone()],
295 )
296 .await
297 .unwrap();
298 auth_policy
299 .send(&ctx, &mut request, &[assert_sig_header_unique_mock_policy])
300 .await
301 .unwrap();
302 }
303}