1use core::fmt;
17use core::ops::Deref;
18use core::str::FromStr;
19
20use std::collections::HashMap;
21use std::io::ErrorKind;
22use std::net::{IpAddr, Ipv4Addr, SocketAddr};
23use std::path::Path;
24
25use base64::engine::Engine as _;
26use base64::prelude::BASE64_STANDARD_NO_PAD;
27use http::Uri;
28use serde::{de, de::Deserializer, de::Visitor, Deserialize, Serialize};
29use tracing::{instrument, trace};
30use unicase::UniCase;
31
32const CORS_ALLOWED_ORIGINS: &[&str] = &[];
33const CORS_ALLOWED_METHODS: &[&str] = &["GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS"];
34const CORS_ALLOWED_HEADERS: &[&str] = &[
35 "accept",
36 "accept-language",
37 "content-type",
38 "content-language",
39];
40const CORS_EXPOSED_HEADERS: &[&str] = &[];
41const CORS_DEFAULT_MAX_AGE_SECS: u64 = 300;
42
43pub fn default_listen_address() -> SocketAddr {
44 (Ipv4Addr::UNSPECIFIED, 8000).into()
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
48pub struct ServiceSettings {
49 #[serde(default = "default_listen_address")]
51 pub address: SocketAddr,
52 #[serde(default)]
54 pub cache_control: Option<String>,
55 #[serde(default)]
57 pub readonly_mode: Option<bool>,
58 pub cors_allowed_origins: Option<AllowedOrigins>,
60 pub cors_allowed_headers: Option<AllowedHeaders>,
61 pub cors_allowed_methods: Option<AllowedMethods>,
62 pub cors_exposed_headers: Option<ExposedHeaders>,
63 pub cors_max_age_secs: Option<u64>,
64 #[serde(default)]
66 pub tls_cert_file: Option<String>,
68 #[serde(default)]
69 pub tls_priv_key_file: Option<String>,
70 #[serde(default)]
74 pub timeout_ms: Option<u64>,
75 #[deprecated(since = "0.22.0", note = "Use top-level fields instead")]
77 #[serde(default)]
78 pub tls: Tls,
79 #[deprecated(since = "0.22.0", note = "Use top-level fields instead")]
80 #[serde(default)]
81 pub cors: Cors,
82 #[serde(default)]
83 pub disable_keepalive: Option<bool>,
84}
85
86impl Default for ServiceSettings {
87 fn default() -> ServiceSettings {
88 #[allow(deprecated)]
89 ServiceSettings {
90 address: default_listen_address(),
91 cors_allowed_origins: Some(AllowedOrigins::default()),
92 cors_allowed_headers: Some(AllowedHeaders::default()),
93 cors_allowed_methods: Some(AllowedMethods::default()),
94 cors_exposed_headers: Some(ExposedHeaders::default()),
95 cors_max_age_secs: Some(CORS_DEFAULT_MAX_AGE_SECS),
96 tls_cert_file: None,
97 tls_priv_key_file: None,
98 timeout_ms: None,
99 cache_control: None,
100 readonly_mode: Some(false),
101 tls: Tls::default(),
102 cors: Cors::default(),
103 disable_keepalive: None,
104 }
105 }
106}
107
108impl ServiceSettings {
109 fn from_json(data: &str) -> Result<Self, HttpServerError> {
111 #[allow(deprecated)]
112 serde_json::from_str(data)
113 .map(|s: ServiceSettings| ServiceSettings {
116 address: s.address,
117 cache_control: s.cache_control,
118 readonly_mode: s.readonly_mode,
119 timeout_ms: s.timeout_ms,
120 tls_cert_file: s.tls_cert_file.or(s.tls.cert_file),
121 tls_priv_key_file: s.tls_priv_key_file.or(s.tls.priv_key_file),
122 cors_allowed_origins: s.cors_allowed_origins.or(s.cors.allowed_origins),
123 cors_allowed_headers: s.cors_allowed_headers.or(s.cors.allowed_headers),
124 cors_allowed_methods: s.cors_allowed_methods.or(s.cors.allowed_methods),
125 cors_exposed_headers: s.cors_exposed_headers.or(s.cors.exposed_headers),
126 cors_max_age_secs: s.cors_max_age_secs.or(s.cors.max_age_secs),
127 tls: Tls::default(),
128 cors: Cors::default(),
129 disable_keepalive: s.disable_keepalive,
130 })
131 .map_err(|e| HttpServerError::Settings(format!("invalid json: {e}")))
132 }
133
134 fn validate(&self) -> Result<(), HttpServerError> {
138 let mut errors = Vec::new();
139 match (&self.tls_cert_file, &self.tls_priv_key_file) {
141 (None, None) => {}
142 (Some(_), None) | (None, Some(_)) => {
143 errors.push(
144 "for tls, both 'tls_cert_file' and 'tls_priv_key_file' must be set".to_string(),
145 );
146 }
147 (Some(cert_file), Some(key_file)) => {
148 for f in &[("cert_file", &cert_file), ("priv_key_file", &key_file)] {
149 let path: &Path = f.1.as_ref();
150 if !path.is_file() {
151 errors.push(format!(
152 "missing tls_{} '{}'{}",
153 f.0,
154 &path.display(),
155 if path.is_absolute() {
156 ""
157 } else {
158 " : perhaps you should make the path absolute"
159 }
160 ));
161 }
162 }
163 }
164 }
165 if let Some(ref methods) = self.cors_allowed_methods {
166 for m in &methods.0 {
167 if http::Method::try_from(m.as_str()).is_err() {
168 errors.push(format!("invalid CORS method: '{m}'"));
169 }
170 }
171 }
172 if let Some(cache_control) = self.cache_control.as_ref() {
173 if http::HeaderValue::from_str(cache_control).is_err() {
174 errors.push(format!("Invalid Cache Control header : '{cache_control}'"));
175 }
176 }
177 if !errors.is_empty() {
178 Err(HttpServerError::Settings(format!(
179 "\nInvalid httpserver settings: \n{}\n",
180 errors.join("\n")
181 )))
182 } else {
183 Ok(())
184 }
185 }
186}
187
188#[derive(Debug, thiserror::Error)]
190pub enum HttpServerError {
191 #[error("invalid parameter: {0}")]
192 InvalidParameter(String),
193
194 #[error("problem reading settings: {0}")]
195 Settings(String),
196}
197
198#[instrument]
207pub fn load_settings(
208 default_address: Option<SocketAddr>,
209 values: &HashMap<String, String>,
210) -> Result<ServiceSettings, HttpServerError> {
211 trace!("load settings");
212 let values: HashMap<UniCase<&str>, &String> = values
215 .iter()
216 .map(|(k, v)| (UniCase::new(k.as_str()), v))
217 .collect();
218
219 if let Some(str) = values.get(&UniCase::new("config_b64")) {
220 let bytes = BASE64_STANDARD_NO_PAD
221 .decode(str)
222 .map_err(|e| HttpServerError::Settings(format!("invalid base64 encoding: {e}")))?;
223 return ServiceSettings::from_json(&String::from_utf8_lossy(&bytes));
224 }
225
226 if let Some(str) = values.get(&UniCase::new("config_json")) {
227 return ServiceSettings::from_json(str);
228 }
229
230 let mut settings = ServiceSettings::default();
231
232 if let Some(addr) = values.get(&UniCase::new("port")) {
234 let port = addr
235 .parse::<u16>()
236 .map_err(|_| HttpServerError::InvalidParameter(format!("Invalid port: {addr}")))?;
237 settings.address = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), port);
238 }
239 settings.address = values
241 .get(&UniCase::new("address"))
242 .map(|addr| {
243 SocketAddr::from_str(addr)
244 .map_err(|_| HttpServerError::InvalidParameter(format!("invalid address: {addr}")))
245 })
246 .transpose()?
247 .or(default_address)
248 .unwrap_or_else(default_listen_address);
249
250 if let Some(cache_control) = values.get(&UniCase::new("cache_control")) {
252 settings.cache_control = Some(cache_control.to_string());
253 }
254 if let Some(readonly_mode) = values.get(&UniCase::new("readonly_mode")) {
256 settings.readonly_mode = Some(readonly_mode.to_string().parse().unwrap_or(false));
257 }
258 if let Some(Ok(timeout_ms)) = values.get(&UniCase::new("timeout_ms")).map(|s| s.parse()) {
260 settings.timeout_ms = Some(timeout_ms)
261 }
262
263 if let Some(tls_cert_file) = values.get(&UniCase::new("tls_cert_file")) {
265 settings.tls_cert_file = Some(tls_cert_file.to_string());
266 }
267 if let Some(tls_priv_key_file) = values.get(&UniCase::new("tls_priv_key_file")) {
268 settings.tls_priv_key_file = Some(tls_priv_key_file.to_string());
269 }
270
271 if let Some(cors_allowed_origins) = values.get(&UniCase::new("cors_allowed_origins")) {
273 let origins: Vec<CorsOrigin> = serde_json::from_str(cors_allowed_origins)
274 .map_err(|e| HttpServerError::Settings(format!("invalid cors_allowed_origins: {e}")))?;
275 settings.cors_allowed_origins = Some(AllowedOrigins(origins));
276 }
277 if let Some(cors_allowed_headers) = values.get(&UniCase::new("cors_allowed_headers")) {
278 let headers: Vec<String> = serde_json::from_str(cors_allowed_headers)
279 .map_err(|e| HttpServerError::Settings(format!("invalid cors_allowed_headers: {e}")))?;
280 settings.cors_allowed_headers = Some(AllowedHeaders(headers));
281 }
282 if let Some(cors_allowed_methods) = values.get(&UniCase::new("cors_allowed_methods")) {
283 let methods: Vec<String> = serde_json::from_str(cors_allowed_methods)
284 .map_err(|e| HttpServerError::Settings(format!("invalid cors_allowed_methods: {e}")))?;
285 settings.cors_allowed_methods = Some(AllowedMethods(methods));
286 }
287 if let Some(cors_exposed_headers) = values.get(&UniCase::new("cors_exposed_headers")) {
288 let headers: Vec<String> = serde_json::from_str(cors_exposed_headers)
289 .map_err(|e| HttpServerError::Settings(format!("invalid cors_exposed_headers: {e}")))?;
290 settings.cors_exposed_headers = Some(ExposedHeaders(headers));
291 }
292 if let Some(cors_max_age_secs) = values.get(&UniCase::new("cors_max_age_secs")) {
293 let max_age_secs: u64 = cors_max_age_secs.parse().map_err(|_| {
294 HttpServerError::InvalidParameter("Invalid cors_max_age_secs".to_string())
295 })?;
296 settings.cors_max_age_secs = Some(max_age_secs);
297 }
298 if let Some(disable_keepalive) = values.get(&UniCase::new("disable_keepalive")) {
299 settings.disable_keepalive = Some(disable_keepalive.parse().unwrap_or(false));
300 }
301
302 settings.validate()?;
303 Ok(settings)
304}
305
306#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
307pub struct Tls {
308 pub cert_file: Option<String>,
310 pub priv_key_file: Option<String>,
311}
312
313#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
314pub struct Cors {
315 pub allowed_origins: Option<AllowedOrigins>,
316 pub allowed_headers: Option<AllowedHeaders>,
317 pub allowed_methods: Option<AllowedMethods>,
318 pub exposed_headers: Option<ExposedHeaders>,
319 pub max_age_secs: Option<u64>,
320}
321
322impl Default for Cors {
323 fn default() -> Self {
324 Cors {
325 allowed_origins: Some(AllowedOrigins::default()),
326 allowed_headers: Some(AllowedHeaders::default()),
327 allowed_methods: Some(AllowedMethods::default()),
328 exposed_headers: Some(ExposedHeaders::default()),
329 max_age_secs: Some(CORS_DEFAULT_MAX_AGE_SECS),
330 }
331 }
332}
333
334#[derive(Debug, Clone, Default, Serialize, PartialEq, Eq)]
335pub struct CorsOrigin(String);
336
337#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
338pub struct AllowedOrigins(Vec<CorsOrigin>);
339
340#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
341pub struct AllowedHeaders(Vec<String>);
342
343#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
344pub struct AllowedMethods(Vec<String>);
345
346#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
347pub struct ExposedHeaders(Vec<String>);
348
349impl<'de> Deserialize<'de> for CorsOrigin {
350 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
351 where
352 D: Deserializer<'de>,
353 {
354 struct CorsOriginVisitor;
355 impl Visitor<'_> for CorsOriginVisitor {
356 type Value = CorsOrigin;
357
358 fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
359 write!(fmt, "an origin in format http[s]://example.com[:3000]",)
360 }
361
362 fn visit_str<E>(self, v: &str) -> std::result::Result<Self::Value, E>
363 where
364 E: de::Error,
365 {
366 CorsOrigin::from_str(v).map_err(E::custom)
367 }
368 }
369 deserializer.deserialize_str(CorsOriginVisitor)
370 }
371}
372
373impl FromStr for CorsOrigin {
374 type Err = std::io::Error;
375
376 fn from_str(origin: &str) -> Result<Self, Self::Err> {
377 let uri = Uri::from_str(origin).map_err(|invalid_uri| {
378 std::io::Error::new(
379 ErrorKind::InvalidInput,
380 format!("Invalid uri: {origin}.\n{invalid_uri}"),
381 )
382 })?;
383 if let Some(s) = uri.scheme_str() {
384 if s != "http" && s != "https" {
385 return Err(std::io::Error::new(
386 ErrorKind::InvalidInput,
387 format!(
388 "Cors origin invalid schema {s}, only [http] and [https] are supported: ",
389 ),
390 ));
391 }
392 } else {
393 return Err(std::io::Error::new(
394 ErrorKind::InvalidInput,
395 "Cors origin missing schema, only [http] or [https] are supported",
396 ));
397 }
398
399 if let Some(p) = uri.path_and_query() {
400 if p.as_str() != "/" {
401 return Err(std::io::Error::new(
402 ErrorKind::InvalidInput,
403 format!("Invalid value {} in cors schema.", p.as_str()),
404 ));
405 }
406 }
407 Ok(CorsOrigin(origin.trim_end_matches('/').to_owned()))
408 }
409}
410
411impl AsRef<str> for CorsOrigin {
412 fn as_ref(&self) -> &str {
413 &self.0
414 }
415}
416
417impl Deref for AllowedOrigins {
418 type Target = Vec<CorsOrigin>;
419
420 fn deref(&self) -> &Self::Target {
421 &self.0
422 }
423}
424
425impl Default for AllowedOrigins {
426 fn default() -> Self {
427 AllowedOrigins(
428 CORS_ALLOWED_ORIGINS
429 .iter()
430 .map(|s| CorsOrigin((*s).to_string()))
431 .collect::<Vec<_>>(),
432 )
433 }
434}
435
436impl Deref for AllowedHeaders {
437 type Target = Vec<String>;
438
439 fn deref(&self) -> &Self::Target {
440 &self.0
441 }
442}
443
444impl Default for AllowedHeaders {
445 fn default() -> Self {
446 AllowedHeaders(from_defaults(CORS_ALLOWED_HEADERS))
447 }
448}
449
450impl Default for AllowedMethods {
451 fn default() -> Self {
452 AllowedMethods(from_defaults(CORS_ALLOWED_METHODS))
453 }
454}
455
456impl Deref for AllowedMethods {
457 type Target = Vec<String>;
458
459 fn deref(&self) -> &Self::Target {
460 &self.0
461 }
462}
463
464impl Deref for ExposedHeaders {
465 type Target = Vec<String>;
466
467 fn deref(&self) -> &Self::Target {
468 &self.0
469 }
470}
471
472impl Default for ExposedHeaders {
473 fn default() -> Self {
474 ExposedHeaders(
475 CORS_EXPOSED_HEADERS
476 .iter()
477 .map(|s| (*s).to_string())
478 .collect::<Vec<_>>(),
479 )
480 }
481}
482
483#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
484#[serde(rename_all = "UPPERCASE")]
485pub enum HttpMethod {
486 Get,
487 Post,
488 Put,
489 Delete,
490 Head,
491 Options,
492 Connect,
493 Patch,
494 Trace,
495}
496
497impl FromStr for HttpMethod {
498 type Err = std::io::Error;
499
500 fn from_str(s: &str) -> Result<Self, Self::Err> {
501 match s.to_uppercase().as_str() {
502 "GET" => Ok(Self::Get),
503 "PUT" => Ok(Self::Put),
504 "POST" => Ok(Self::Post),
505 "DELETE" => Ok(Self::Delete),
506 "HEAD" => Ok(Self::Head),
507 "OPTIONS" => Ok(Self::Options),
508 "CONNECT" => Ok(Self::Connect),
509 "PATCH" => Ok(Self::Patch),
510 "TRACE" => Ok(Self::Trace),
511 _ => Err(std::io::Error::new(
512 std::io::ErrorKind::InvalidData,
513 format!("{s} is not a valid http method"),
514 )),
515 }
516 }
517}
518
519fn from_defaults<'d, T>(d: &[&'d str]) -> Vec<T>
521where
522 T: std::convert::From<&'d str>,
523{
524 d.iter().map(|s| T::from(*s)).collect::<Vec<_>>()
526}
527
528#[cfg(test)]
529mod test {
530 use std::str::FromStr;
531
532 use super::{CorsOrigin, ServiceSettings};
533
534 const GOOD_ORIGINS: &[&str] = &[
535 "https://www.example.com",
537 "https://www.example.com:1000",
538 "http://localhost",
539 "http://localhost:8080",
540 "http://127.0.0.1",
541 "http://127.0.0.1:8080",
542 "https://:8080",
543 ];
544
545 const BAD_ORIGINS: &[&str] = &[
546 "ftp://www.example.com", "localhost",
549 "127.0.0.1",
550 "127.0.0.1:8080",
551 ":8080",
552 "/path/file.txt",
553 "http:",
554 "https://",
555 ];
556
557 #[test]
558 fn settings_init() {
559 let s = ServiceSettings::default();
560 assert!(s.address.is_ipv4());
561 let allowed_origins = s
562 .cors_allowed_origins
563 .expect("allowed_origins should be set");
564 assert!(s.cors_allowed_methods.is_some());
565 assert!(allowed_origins.0.is_empty());
566 }
567
568 #[test]
569 fn settings_json() {
570 let json = r#"{
571 "cors": {
572 "allowed_headers": [ "X-Cookies" ]
573 }
574 }"#;
575
576 let s = ServiceSettings::from_json(json).expect("parse_json");
577 assert_eq!(
578 s.cors_allowed_headers
579 .as_ref()
580 .expect("allowed headers should be set")
581 .0
582 .len(),
583 1
584 );
585 assert_eq!(
586 s.cors_allowed_headers
587 .as_ref()
588 .expect("allowed headers should be set")
589 .0
590 .first(),
591 Some(&"X-Cookies".into())
592 );
593 }
594
595 #[test]
596 fn origins_deserialize() {
597 for valid in GOOD_ORIGINS {
599 let o = serde_json::from_value::<CorsOrigin>(serde_json::Value::String(
600 (*valid).to_string(),
601 ))
602 .expect("deserialize should succeed");
603
604 assert_eq!(&o.0, valid);
606 }
607 }
608
609 #[test]
610 fn origins_from_str() {
611 for &valid in GOOD_ORIGINS {
613 let o = CorsOrigin::from_str(valid).expect("deserialize should succeed");
614
615 assert_eq!(&o.0, valid);
617 }
618 }
619
620 #[test]
621 fn origins_negative() {
622 for bad in BAD_ORIGINS {
623 let o =
624 serde_json::from_value::<CorsOrigin>(serde_json::Value::String((*bad).to_string()));
625 assert!(o.is_err(), "from_value '{bad}' (expect err)");
626
627 let o = serde_json::from_str::<CorsOrigin>(bad);
628 assert!(o.is_err(), "from_str '{bad}' (expect err)");
629 }
630 }
631}