aws_smithy_http/
query_writer.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use crate::query::fmt_string as percent_encode_query;
7use http_02x::uri::InvalidUri;
8use http_02x::Uri;
9
10/// Utility for updating the query string in a [`Uri`].
11#[allow(missing_debug_implementations)]
12pub struct QueryWriter {
13    base_uri: Uri,
14    new_path_and_query: String,
15    prefix: Option<char>,
16}
17
18impl QueryWriter {
19    /// Creates a new `QueryWriter` from a string
20    pub fn new_from_string(uri: &str) -> Result<Self, InvalidUri> {
21        Ok(Self::new(&Uri::try_from(uri)?))
22    }
23
24    /// Creates a new `QueryWriter` based off the given `uri`.
25    pub fn new(uri: &Uri) -> Self {
26        let new_path_and_query = uri
27            .path_and_query()
28            .map(|pq| pq.to_string())
29            .unwrap_or_default();
30        let prefix = if uri.query().is_none() {
31            Some('?')
32        } else if !uri.query().unwrap_or_default().is_empty() {
33            Some('&')
34        } else {
35            None
36        };
37        QueryWriter {
38            base_uri: uri.clone(),
39            new_path_and_query,
40            prefix,
41        }
42    }
43
44    /// Clears all query parameters.
45    pub fn clear_params(&mut self) {
46        if let Some(index) = self.new_path_and_query.find('?') {
47            self.new_path_and_query.truncate(index);
48            self.prefix = Some('?');
49        }
50    }
51
52    /// Inserts a new query parameter. The key and value are percent encoded
53    /// by `QueryWriter`. Passing in percent encoded values will result in double encoding.
54    pub fn insert(&mut self, k: &str, v: &str) {
55        self.insert_encoded(&percent_encode_query(k), &percent_encode_query(v));
56    }
57
58    /// Inserts a new already encoded query parameter. The key and value will be inserted
59    /// as is.
60    pub fn insert_encoded(&mut self, encoded_k: &str, encoded_v: &str) {
61        if let Some(prefix) = self.prefix {
62            self.new_path_and_query.push(prefix);
63        }
64        self.prefix = Some('&');
65        self.new_path_and_query.push_str(encoded_k);
66        self.new_path_and_query.push('=');
67        self.new_path_and_query.push_str(encoded_v)
68    }
69
70    /// Returns just the built query string.
71    pub fn build_query(self) -> String {
72        self.build_uri().query().unwrap_or_default().to_string()
73    }
74
75    /// Returns a full [`Uri`] with the query string updated.
76    pub fn build_uri(self) -> Uri {
77        let mut parts = self.base_uri.into_parts();
78        parts.path_and_query = Some(
79            self.new_path_and_query
80                .parse()
81                .expect("adding query should not invalidate URI"),
82        );
83        Uri::from_parts(parts).expect("a valid URL in should always produce a valid URL out")
84    }
85}
86
87#[cfg(test)]
88mod test {
89    use super::QueryWriter;
90    use http_02x::Uri;
91
92    #[test]
93    fn empty_uri() {
94        let uri = Uri::from_static("http://www.example.com");
95        let mut query_writer = QueryWriter::new(&uri);
96        query_writer.insert("key", "val%ue");
97        query_writer.insert("another", "value");
98        assert_eq!(
99            query_writer.build_uri(),
100            Uri::from_static("http://www.example.com?key=val%25ue&another=value")
101        );
102    }
103
104    #[test]
105    fn uri_with_path() {
106        let uri = Uri::from_static("http://www.example.com/path");
107        let mut query_writer = QueryWriter::new(&uri);
108        query_writer.insert("key", "val%ue");
109        query_writer.insert("another", "value");
110        assert_eq!(
111            query_writer.build_uri(),
112            Uri::from_static("http://www.example.com/path?key=val%25ue&another=value")
113        );
114    }
115
116    #[test]
117    fn uri_with_path_and_query() {
118        let uri = Uri::from_static("http://www.example.com/path?original=here");
119        let mut query_writer = QueryWriter::new(&uri);
120        query_writer.insert("key", "val%ue");
121        query_writer.insert("another", "value");
122        assert_eq!(
123            query_writer.build_uri(),
124            Uri::from_static(
125                "http://www.example.com/path?original=here&key=val%25ue&another=value"
126            )
127        );
128    }
129
130    #[test]
131    fn build_query() {
132        let uri = Uri::from_static("http://www.example.com");
133        let mut query_writer = QueryWriter::new(&uri);
134        query_writer.insert("key", "val%ue");
135        query_writer.insert("ano%ther", "value");
136        assert_eq!("key=val%25ue&ano%25ther=value", query_writer.build_query());
137    }
138
139    #[test]
140    // This test ensures that the percent encoding applied to queries always produces a valid URI if
141    // the starting URI is valid
142    fn doesnt_panic_when_adding_query_to_valid_uri() {
143        let uri = Uri::from_static("http://www.example.com");
144
145        let mut problematic_chars = Vec::new();
146
147        for byte in u8::MIN..=u8::MAX {
148            match std::str::from_utf8(&[byte]) {
149                // If we can't make a str from the byte then we certainly can't make a URL from it
150                Err(_) => {
151                    continue;
152                }
153                Ok(value) => {
154                    let mut query_writer = QueryWriter::new(&uri);
155                    query_writer.insert("key", value);
156
157                    if std::panic::catch_unwind(|| query_writer.build_uri()).is_err() {
158                        problematic_chars.push(char::from(byte));
159                    };
160                }
161            }
162        }
163
164        if !problematic_chars.is_empty() {
165            panic!("we got some bad bytes here: {problematic_chars:#?}")
166        }
167    }
168
169    #[test]
170    fn clear_params() {
171        let uri = Uri::from_static("http://www.example.com/path?original=here&foo=1");
172        let mut query_writer = QueryWriter::new(&uri);
173        query_writer.clear_params();
174        query_writer.insert("new", "value");
175        assert_eq!("new=value", query_writer.build_query());
176    }
177}