wasmcloud_core/
http.rs

1//! Common Configuration settings for the http server provider and built in http server. This
2//! requires that the `http` feature be enabled.
3//!
4//! The "values" map in the component link definition may contain one or more of the following keys,
5//! which determine how the configuration is parsed.
6//!
7//! For the key... `config_file`:       load configuration from file name. Interprets file as json
8///   or toml, based on file extension. `config_b64`:        Configuration is a base64-encoded json
9///                      string `config_json`:       Configuration is a raw json string
10///
11/// If no configuration is provided, the default settings below will be used:
12/// - TLS is disabled
13/// - CORS allows all hosts(origins), most methods, and common headers (see constants below).
14/// - Default listener is bound to 127.0.0.1 port 8000.
15///
16use core::fmt;
17use core::ops::Deref;
18use core::str::FromStr;
19
20use std::collections::HashMap;
21use std::io::ErrorKind;
22use std::net::{IpAddr, Ipv4Addr, SocketAddr};
23use std::path::Path;
24
25use base64::engine::Engine as _;
26use base64::prelude::BASE64_STANDARD_NO_PAD;
27use http::Uri;
28use serde::{de, de::Deserializer, de::Visitor, Deserialize, Serialize};
29use tracing::{instrument, trace};
30use unicase::UniCase;
31
32const CORS_ALLOWED_ORIGINS: &[&str] = &[];
33const CORS_ALLOWED_METHODS: &[&str] = &["GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS"];
34const CORS_ALLOWED_HEADERS: &[&str] = &[
35    "accept",
36    "accept-language",
37    "content-type",
38    "content-language",
39];
40const CORS_EXPOSED_HEADERS: &[&str] = &[];
41const CORS_DEFAULT_MAX_AGE_SECS: u64 = 300;
42
43pub fn default_listen_address() -> SocketAddr {
44    (Ipv4Addr::UNSPECIFIED, 8000).into()
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
48pub struct ServiceSettings {
49    /// Bind address
50    #[serde(default = "default_listen_address")]
51    pub address: SocketAddr,
52    /// cache control options
53    #[serde(default)]
54    pub cache_control: Option<String>,
55    /// Flag for read only mode
56    #[serde(default)]
57    pub readonly_mode: Option<bool>,
58    // cors config
59    pub cors_allowed_origins: Option<AllowedOrigins>,
60    pub cors_allowed_headers: Option<AllowedHeaders>,
61    pub cors_allowed_methods: Option<AllowedMethods>,
62    pub cors_exposed_headers: Option<ExposedHeaders>,
63    pub cors_max_age_secs: Option<u64>,
64    // tls config
65    #[serde(default)]
66    /// path to server X.509 cert chain file. Must be PEM-encoded
67    pub tls_cert_file: Option<String>,
68    #[serde(default)]
69    pub tls_priv_key_file: Option<String>,
70    /// Rpc timeout - how long (milliseconds) to wait for component's response
71    /// before returning a status 503 to the http client
72    /// If not set, uses the system-wide rpc timeout
73    #[serde(default)]
74    pub timeout_ms: Option<u64>,
75    // DEPRECATED due to the nested struct being poorly supported by wasmCloud config
76    #[deprecated(since = "0.22.0", note = "Use top-level fields instead")]
77    #[serde(default)]
78    pub tls: Tls,
79    #[deprecated(since = "0.22.0", note = "Use top-level fields instead")]
80    #[serde(default)]
81    pub cors: Cors,
82    #[serde(default)]
83    pub disable_keepalive: Option<bool>,
84}
85
86impl Default for ServiceSettings {
87    fn default() -> ServiceSettings {
88        #[allow(deprecated)]
89        ServiceSettings {
90            address: default_listen_address(),
91            cors_allowed_origins: Some(AllowedOrigins::default()),
92            cors_allowed_headers: Some(AllowedHeaders::default()),
93            cors_allowed_methods: Some(AllowedMethods::default()),
94            cors_exposed_headers: Some(ExposedHeaders::default()),
95            cors_max_age_secs: Some(CORS_DEFAULT_MAX_AGE_SECS),
96            tls_cert_file: None,
97            tls_priv_key_file: None,
98            timeout_ms: None,
99            cache_control: None,
100            readonly_mode: Some(false),
101            tls: Tls::default(),
102            cors: Cors::default(),
103            disable_keepalive: None,
104        }
105    }
106}
107
108impl ServiceSettings {
109    /// load settings from json, flattening nested fields
110    fn from_json(data: &str) -> Result<Self, HttpServerError> {
111        #[allow(deprecated)]
112        serde_json::from_str(data)
113            // For backwards compatibility, we can pull the values from the `tls` and `cors` fields
114            // and merge them into the top-level fields.
115            .map(|s: ServiceSettings| ServiceSettings {
116                address: s.address,
117                cache_control: s.cache_control,
118                readonly_mode: s.readonly_mode,
119                timeout_ms: s.timeout_ms,
120                tls_cert_file: s.tls_cert_file.or(s.tls.cert_file),
121                tls_priv_key_file: s.tls_priv_key_file.or(s.tls.priv_key_file),
122                cors_allowed_origins: s.cors_allowed_origins.or(s.cors.allowed_origins),
123                cors_allowed_headers: s.cors_allowed_headers.or(s.cors.allowed_headers),
124                cors_allowed_methods: s.cors_allowed_methods.or(s.cors.allowed_methods),
125                cors_exposed_headers: s.cors_exposed_headers.or(s.cors.exposed_headers),
126                cors_max_age_secs: s.cors_max_age_secs.or(s.cors.max_age_secs),
127                tls: Tls::default(),
128                cors: Cors::default(),
129                disable_keepalive: s.disable_keepalive,
130            })
131            .map_err(|e| HttpServerError::Settings(format!("invalid json: {e}")))
132    }
133
134    /// perform additional validation checks on settings.
135    /// Several checks have already been done during deserialization.
136    /// All errors found are combined into a single error message
137    fn validate(&self) -> Result<(), HttpServerError> {
138        let mut errors = Vec::new();
139        // 1. make sure tls config is valid
140        match (&self.tls_cert_file, &self.tls_priv_key_file) {
141            (None, None) => {}
142            (Some(_), None) | (None, Some(_)) => {
143                errors.push(
144                    "for tls, both 'tls_cert_file' and 'tls_priv_key_file' must be set".to_string(),
145                );
146            }
147            (Some(cert_file), Some(key_file)) => {
148                for f in &[("cert_file", &cert_file), ("priv_key_file", &key_file)] {
149                    let path: &Path = f.1.as_ref();
150                    if !path.is_file() {
151                        errors.push(format!(
152                            "missing tls_{} '{}'{}",
153                            f.0,
154                            &path.display(),
155                            if path.is_absolute() {
156                                ""
157                            } else {
158                                " : perhaps you should make the path absolute"
159                            }
160                        ));
161                    }
162                }
163            }
164        }
165        if let Some(ref methods) = self.cors_allowed_methods {
166            for m in &methods.0 {
167                if http::Method::try_from(m.as_str()).is_err() {
168                    errors.push(format!("invalid CORS method: '{m}'"));
169                }
170            }
171        }
172        if let Some(cache_control) = self.cache_control.as_ref() {
173            if http::HeaderValue::from_str(cache_control).is_err() {
174                errors.push(format!("Invalid Cache Control header : '{cache_control}'"));
175            }
176        }
177        if !errors.is_empty() {
178            Err(HttpServerError::Settings(format!(
179                "\nInvalid httpserver settings: \n{}\n",
180                errors.join("\n")
181            )))
182        } else {
183            Ok(())
184        }
185    }
186}
187
188/// Errors generated by this HTTP server
189#[derive(Debug, thiserror::Error)]
190pub enum HttpServerError {
191    #[error("invalid parameter: {0}")]
192    InvalidParameter(String),
193
194    #[error("problem reading settings: {0}")]
195    Settings(String),
196}
197
198/// Load settings provides a flexible means for loading configuration.
199/// Return value is any structure with Deserialize, or for example, HashMap<String,String>
200///   config_b64:  base64-encoded json string
201///   config_json: raw json string
202/// Also accept "address" (a string representing SocketAddr) and "port", a localhost port
203/// If more than one key is provided, they are processed in the order above.
204///   (later names override earlier names in the list)
205///
206#[instrument]
207pub fn load_settings(
208    default_address: Option<SocketAddr>,
209    values: &HashMap<String, String>,
210) -> Result<ServiceSettings, HttpServerError> {
211    trace!("load settings");
212    // Allow keys to be case insensitive, as an accommodation
213    // for the lost souls who prefer sPoNgEbOb CaSe variable names.
214    let values: HashMap<UniCase<&str>, &String> = values
215        .iter()
216        .map(|(k, v)| (UniCase::new(k.as_str()), v))
217        .collect();
218
219    if let Some(str) = values.get(&UniCase::new("config_b64")) {
220        let bytes = BASE64_STANDARD_NO_PAD
221            .decode(str)
222            .map_err(|e| HttpServerError::Settings(format!("invalid base64 encoding: {e}")))?;
223        return ServiceSettings::from_json(&String::from_utf8_lossy(&bytes));
224    }
225
226    if let Some(str) = values.get(&UniCase::new("config_json")) {
227        return ServiceSettings::from_json(str);
228    }
229
230    let mut settings = ServiceSettings::default();
231
232    // accept port, for compatibility with previous implementations
233    if let Some(addr) = values.get(&UniCase::new("port")) {
234        let port = addr
235            .parse::<u16>()
236            .map_err(|_| HttpServerError::InvalidParameter(format!("Invalid port: {addr}")))?;
237        settings.address = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), port);
238    }
239    // accept address as value parameter
240    settings.address = values
241        .get(&UniCase::new("address"))
242        .map(|addr| {
243            SocketAddr::from_str(addr)
244                .map_err(|_| HttpServerError::InvalidParameter(format!("invalid address: {addr}")))
245        })
246        .transpose()?
247        .or(default_address)
248        .unwrap_or_else(default_listen_address);
249
250    // accept cache-control header values
251    if let Some(cache_control) = values.get(&UniCase::new("cache_control")) {
252        settings.cache_control = Some(cache_control.to_string());
253    }
254    // accept read only mode flag
255    if let Some(readonly_mode) = values.get(&UniCase::new("readonly_mode")) {
256        settings.readonly_mode = Some(readonly_mode.to_string().parse().unwrap_or(false));
257    }
258    // accept timeout_ms flag
259    if let Some(Ok(timeout_ms)) = values.get(&UniCase::new("timeout_ms")).map(|s| s.parse()) {
260        settings.timeout_ms = Some(timeout_ms)
261    }
262
263    // TLS
264    if let Some(tls_cert_file) = values.get(&UniCase::new("tls_cert_file")) {
265        settings.tls_cert_file = Some(tls_cert_file.to_string());
266    }
267    if let Some(tls_priv_key_file) = values.get(&UniCase::new("tls_priv_key_file")) {
268        settings.tls_priv_key_file = Some(tls_priv_key_file.to_string());
269    }
270
271    // CORS
272    if let Some(cors_allowed_origins) = values.get(&UniCase::new("cors_allowed_origins")) {
273        let origins: Vec<CorsOrigin> = serde_json::from_str(cors_allowed_origins)
274            .map_err(|e| HttpServerError::Settings(format!("invalid cors_allowed_origins: {e}")))?;
275        settings.cors_allowed_origins = Some(AllowedOrigins(origins));
276    }
277    if let Some(cors_allowed_headers) = values.get(&UniCase::new("cors_allowed_headers")) {
278        let headers: Vec<String> = serde_json::from_str(cors_allowed_headers)
279            .map_err(|e| HttpServerError::Settings(format!("invalid cors_allowed_headers: {e}")))?;
280        settings.cors_allowed_headers = Some(AllowedHeaders(headers));
281    }
282    if let Some(cors_allowed_methods) = values.get(&UniCase::new("cors_allowed_methods")) {
283        let methods: Vec<String> = serde_json::from_str(cors_allowed_methods)
284            .map_err(|e| HttpServerError::Settings(format!("invalid cors_allowed_methods: {e}")))?;
285        settings.cors_allowed_methods = Some(AllowedMethods(methods));
286    }
287    if let Some(cors_exposed_headers) = values.get(&UniCase::new("cors_exposed_headers")) {
288        let headers: Vec<String> = serde_json::from_str(cors_exposed_headers)
289            .map_err(|e| HttpServerError::Settings(format!("invalid cors_exposed_headers: {e}")))?;
290        settings.cors_exposed_headers = Some(ExposedHeaders(headers));
291    }
292    if let Some(cors_max_age_secs) = values.get(&UniCase::new("cors_max_age_secs")) {
293        let max_age_secs: u64 = cors_max_age_secs.parse().map_err(|_| {
294            HttpServerError::InvalidParameter("Invalid cors_max_age_secs".to_string())
295        })?;
296        settings.cors_max_age_secs = Some(max_age_secs);
297    }
298    if let Some(disable_keepalive) = values.get(&UniCase::new("disable_keepalive")) {
299        settings.disable_keepalive = Some(disable_keepalive.parse().unwrap_or(false));
300    }
301
302    settings.validate()?;
303    Ok(settings)
304}
305
306#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
307pub struct Tls {
308    /// path to server X.509 cert chain file. Must be PEM-encoded
309    pub cert_file: Option<String>,
310    pub priv_key_file: Option<String>,
311}
312
313#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
314pub struct Cors {
315    pub allowed_origins: Option<AllowedOrigins>,
316    pub allowed_headers: Option<AllowedHeaders>,
317    pub allowed_methods: Option<AllowedMethods>,
318    pub exposed_headers: Option<ExposedHeaders>,
319    pub max_age_secs: Option<u64>,
320}
321
322impl Default for Cors {
323    fn default() -> Self {
324        Cors {
325            allowed_origins: Some(AllowedOrigins::default()),
326            allowed_headers: Some(AllowedHeaders::default()),
327            allowed_methods: Some(AllowedMethods::default()),
328            exposed_headers: Some(ExposedHeaders::default()),
329            max_age_secs: Some(CORS_DEFAULT_MAX_AGE_SECS),
330        }
331    }
332}
333
334#[derive(Debug, Clone, Default, Serialize, PartialEq, Eq)]
335pub struct CorsOrigin(String);
336
337#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
338pub struct AllowedOrigins(Vec<CorsOrigin>);
339
340#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
341pub struct AllowedHeaders(Vec<String>);
342
343#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
344pub struct AllowedMethods(Vec<String>);
345
346#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
347pub struct ExposedHeaders(Vec<String>);
348
349impl<'de> Deserialize<'de> for CorsOrigin {
350    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
351    where
352        D: Deserializer<'de>,
353    {
354        struct CorsOriginVisitor;
355        impl Visitor<'_> for CorsOriginVisitor {
356            type Value = CorsOrigin;
357
358            fn expecting(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
359                write!(fmt, "an origin in format http[s]://example.com[:3000]",)
360            }
361
362            fn visit_str<E>(self, v: &str) -> std::result::Result<Self::Value, E>
363            where
364                E: de::Error,
365            {
366                CorsOrigin::from_str(v).map_err(E::custom)
367            }
368        }
369        deserializer.deserialize_str(CorsOriginVisitor)
370    }
371}
372
373impl FromStr for CorsOrigin {
374    type Err = std::io::Error;
375
376    fn from_str(origin: &str) -> Result<Self, Self::Err> {
377        let uri = Uri::from_str(origin).map_err(|invalid_uri| {
378            std::io::Error::new(
379                ErrorKind::InvalidInput,
380                format!("Invalid uri: {origin}.\n{invalid_uri}"),
381            )
382        })?;
383        if let Some(s) = uri.scheme_str() {
384            if s != "http" && s != "https" {
385                return Err(std::io::Error::new(
386                    ErrorKind::InvalidInput,
387                    format!(
388                        "Cors origin invalid schema {s}, only [http] and [https] are supported: ",
389                    ),
390                ));
391            }
392        } else {
393            return Err(std::io::Error::new(
394                ErrorKind::InvalidInput,
395                "Cors origin missing schema, only [http] or [https] are supported",
396            ));
397        }
398
399        if let Some(p) = uri.path_and_query() {
400            if p.as_str() != "/" {
401                return Err(std::io::Error::new(
402                    ErrorKind::InvalidInput,
403                    format!("Invalid value {} in cors schema.", p.as_str()),
404                ));
405            }
406        }
407        Ok(CorsOrigin(origin.trim_end_matches('/').to_owned()))
408    }
409}
410
411impl AsRef<str> for CorsOrigin {
412    fn as_ref(&self) -> &str {
413        &self.0
414    }
415}
416
417impl Deref for AllowedOrigins {
418    type Target = Vec<CorsOrigin>;
419
420    fn deref(&self) -> &Self::Target {
421        &self.0
422    }
423}
424
425impl Default for AllowedOrigins {
426    fn default() -> Self {
427        AllowedOrigins(
428            CORS_ALLOWED_ORIGINS
429                .iter()
430                .map(|s| CorsOrigin((*s).to_string()))
431                .collect::<Vec<_>>(),
432        )
433    }
434}
435
436impl Deref for AllowedHeaders {
437    type Target = Vec<String>;
438
439    fn deref(&self) -> &Self::Target {
440        &self.0
441    }
442}
443
444impl Default for AllowedHeaders {
445    fn default() -> Self {
446        AllowedHeaders(from_defaults(CORS_ALLOWED_HEADERS))
447    }
448}
449
450impl Default for AllowedMethods {
451    fn default() -> Self {
452        AllowedMethods(from_defaults(CORS_ALLOWED_METHODS))
453    }
454}
455
456impl Deref for AllowedMethods {
457    type Target = Vec<String>;
458
459    fn deref(&self) -> &Self::Target {
460        &self.0
461    }
462}
463
464impl Deref for ExposedHeaders {
465    type Target = Vec<String>;
466
467    fn deref(&self) -> &Self::Target {
468        &self.0
469    }
470}
471
472impl Default for ExposedHeaders {
473    fn default() -> Self {
474        ExposedHeaders(
475            CORS_EXPOSED_HEADERS
476                .iter()
477                .map(|s| (*s).to_string())
478                .collect::<Vec<_>>(),
479        )
480    }
481}
482
483#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
484#[serde(rename_all = "UPPERCASE")]
485pub enum HttpMethod {
486    Get,
487    Post,
488    Put,
489    Delete,
490    Head,
491    Options,
492    Connect,
493    Patch,
494    Trace,
495}
496
497impl FromStr for HttpMethod {
498    type Err = std::io::Error;
499
500    fn from_str(s: &str) -> Result<Self, Self::Err> {
501        match s.to_uppercase().as_str() {
502            "GET" => Ok(Self::Get),
503            "PUT" => Ok(Self::Put),
504            "POST" => Ok(Self::Post),
505            "DELETE" => Ok(Self::Delete),
506            "HEAD" => Ok(Self::Head),
507            "OPTIONS" => Ok(Self::Options),
508            "CONNECT" => Ok(Self::Connect),
509            "PATCH" => Ok(Self::Patch),
510            "TRACE" => Ok(Self::Trace),
511            _ => Err(std::io::Error::new(
512                std::io::ErrorKind::InvalidData,
513                format!("{s} is not a valid http method"),
514            )),
515        }
516    }
517}
518
519/// convert array of &str into array of T if T is From<&str>
520fn from_defaults<'d, T>(d: &[&'d str]) -> Vec<T>
521where
522    T: std::convert::From<&'d str>,
523{
524    // unwrap ok here because this is only used for default values
525    d.iter().map(|s| T::from(*s)).collect::<Vec<_>>()
526}
527
528#[cfg(test)]
529mod test {
530    use std::str::FromStr;
531
532    use super::{CorsOrigin, ServiceSettings};
533
534    const GOOD_ORIGINS: &[&str] = &[
535        // origins that should be parsed correctly
536        "https://www.example.com",
537        "https://www.example.com:1000",
538        "http://localhost",
539        "http://localhost:8080",
540        "http://127.0.0.1",
541        "http://127.0.0.1:8080",
542        "https://:8080",
543    ];
544
545    const BAD_ORIGINS: &[&str] = &[
546        // invalid origin syntax
547        "ftp://www.example.com", // only http,https allowed
548        "localhost",
549        "127.0.0.1",
550        "127.0.0.1:8080",
551        ":8080",
552        "/path/file.txt",
553        "http:",
554        "https://",
555    ];
556
557    #[test]
558    fn settings_init() {
559        let s = ServiceSettings::default();
560        assert!(s.address.is_ipv4());
561        let allowed_origins = s
562            .cors_allowed_origins
563            .expect("allowed_origins should be set");
564        assert!(s.cors_allowed_methods.is_some());
565        assert!(allowed_origins.0.is_empty());
566    }
567
568    #[test]
569    fn settings_json() {
570        let json = r#"{
571        "cors": {
572            "allowed_headers": [ "X-Cookies" ]
573         }
574         }"#;
575
576        let s = ServiceSettings::from_json(json).expect("parse_json");
577        assert_eq!(
578            s.cors_allowed_headers
579                .as_ref()
580                .expect("allowed headers should be set")
581                .0
582                .len(),
583            1
584        );
585        assert_eq!(
586            s.cors_allowed_headers
587                .as_ref()
588                .expect("allowed headers should be set")
589                .0
590                .first(),
591            Some(&"X-Cookies".into())
592        );
593    }
594
595    #[test]
596    fn origins_deserialize() {
597        // test CorsOrigin
598        for valid in GOOD_ORIGINS {
599            let o = serde_json::from_value::<CorsOrigin>(serde_json::Value::String(
600                (*valid).to_string(),
601            ))
602            .expect("deserialize should succeed");
603
604            // test as_ref()
605            assert_eq!(&o.0, valid);
606        }
607    }
608
609    #[test]
610    fn origins_from_str() {
611        // test CorsOrigin
612        for &valid in GOOD_ORIGINS {
613            let o = CorsOrigin::from_str(valid).expect("deserialize should succeed");
614
615            // test as_ref()
616            assert_eq!(&o.0, valid);
617        }
618    }
619
620    #[test]
621    fn origins_negative() {
622        for bad in BAD_ORIGINS {
623            let o =
624                serde_json::from_value::<CorsOrigin>(serde_json::Value::String((*bad).to_string()));
625            assert!(o.is_err(), "from_value '{bad}' (expect err)");
626
627            let o = serde_json::from_str::<CorsOrigin>(bad);
628            assert!(o.is_err(), "from_str '{bad}' (expect err)");
629        }
630    }
631}