1use std::env;
4use std::net::IpAddr;
5use std::str::FromStr;
6
7use crate::constants::SPIFFE_SOCKET_ENV;
8use crate::error::SocketPathError;
9use url::Url;
10
11const TCP_SCHEME: &str = "tcp";
12const UNIX_SCHEME: &str = "unix";
13
14pub fn get_default_socket_path() -> Option<String> {
17 env::var(SPIFFE_SOCKET_ENV).ok()
18}
19
20pub fn validate_socket_path(socket_path: &str) -> Result<(), SocketPathError> {
22 let url = Url::parse(socket_path)?;
23
24 if !url.username().is_empty() {
25 return Err(SocketPathError::HasUserInfo);
26 }
27
28 if url.query().is_some() {
29 return Err(SocketPathError::HasQueryValues);
30 }
31
32 if url.fragment().is_some() {
33 return Err(SocketPathError::HasFragment);
34 }
35
36 match url.scheme() {
37 UNIX_SCHEME => {
38 if url.path().is_empty() || url.path() == "/" {
39 return Err(SocketPathError::UnixAddressEmptyPath);
40 }
41 }
42 TCP_SCHEME => {
43 let host = url.host_str().ok_or(SocketPathError::TcpEmptyHost)?;
44
45 IpAddr::from_str(host).map_err(|_| SocketPathError::TcpAddressNoIpPort)?;
46
47 if !url.path().is_empty() && url.path() != "/" {
48 return Err(SocketPathError::TcpAddressNonEmptyPath);
49 }
50 if url.port().is_none() {
51 return Err(SocketPathError::TcpAddressNoIpPort);
52 }
53 }
54 _ => return Err(SocketPathError::InvalidScheme),
55 }
56
57 Ok(())
58}
59
60#[cfg(test)]
61mod tests {
62 use super::*;
63 use url::ParseError;
64
65 #[test]
66 fn test_validate_correct_unix_address() {
67 let socket_path = "unix:///foo";
68 validate_socket_path(socket_path).unwrap();
69 }
70
71 #[test]
72 fn test_validate_other_correct_unix_address() {
73 let socket_path = "unix:/tmp/spire-agent/public/api.sock";
74 validate_socket_path(socket_path).unwrap();
75 }
76
77 #[test]
78 fn test_validate_correct_tcp_address() {
79 let socket_path = "tcp://1.2.3.4:80";
80 validate_socket_path(socket_path).unwrap();
81 }
82
83 macro_rules! validate_socket_path_error_tests {
84 ($($name:ident: $value:expr,)*) => {
85 $(
86 #[test]
87 fn $name() {
88 let (input, expected_error, expected_message) = $value;
89 let result = validate_socket_path(input);
90 let error = result.unwrap_err();
91
92 assert_eq!(error, expected_error);
93 assert_eq!(error.to_string(), expected_message);
94 }
95 )*
96 }
97 }
98
99 validate_socket_path_error_tests! {
100 test_validate_empty_str: (" ", SocketPathError::Parse(ParseError::RelativeUrlWithoutBase), "workload endpoint socket is not a valid URI"),
101 test_validate_str_missing_scheme: ("foo", SocketPathError::Parse(ParseError::RelativeUrlWithoutBase), "workload endpoint socket is not a valid URI"),
102 test_validate_uri_invalid_scheme: ("other:///path", SocketPathError::InvalidScheme, "workload endpoint socket URI must have a tcp:// or unix:// scheme"),
103
104 test_validate_unix_uri_empty_path: ("unix://", SocketPathError::UnixAddressEmptyPath, "workload endpoint unix socket URI must include a path"),
105 test_validate_unix_uri_empty_path_slash: ("unix:///", SocketPathError::UnixAddressEmptyPath, "workload endpoint unix socket URI must include a path"),
106 test_validate_unix_uri_with_query_values: ("unix:///foo?whatever", SocketPathError::HasQueryValues, "workload endpoint socket URI must not include query values"),
107 test_validate_unix_uri_with_fragment: ("unix:///foo#whatever", SocketPathError::HasFragment, "workload endpoint socket URI must not include a fragment"),
108 test_validate_unix_uri_with_user_info: ("unix://john:doe@foo/path", SocketPathError::HasUserInfo, "workload endpoint socket URI must not include user info"),
109
110 test_validate_tcp_uri_non_empty_path: ("tcp://1.2.3.4:80/path", SocketPathError::TcpAddressNonEmptyPath, "workload endpoint tcp socket URI must not include a path"),
111 test_validate_tcp_uri_with_query_values: ("tcp://1.2.3.4:80?whatever", SocketPathError::HasQueryValues, "workload endpoint socket URI must not include query values"),
112 test_validate_tcp_uri_with_fragment: ("tcp://1.2.3.4:80#whatever", SocketPathError::HasFragment, "workload endpoint socket URI must not include a fragment"),
113 test_validate_tcp_uri_with_user_info: ("tcp://john:doe@1.2.3.4:80", SocketPathError::HasUserInfo, "workload endpoint socket URI must not include user info"),
114 test_validate_tcp_uri_no_ip: ("tcp://foo:80", SocketPathError::TcpAddressNoIpPort, "workload endpoint tcp socket URI host component must be an IP:port"),
115 test_validate_tcp_uri_no_ip_and_port: ("tcp://foo", SocketPathError::TcpAddressNoIpPort, "workload endpoint tcp socket URI host component must be an IP:port"),
116 test_validate_tcp_uri_no_port: ("tcp://1.2.3.4", SocketPathError::TcpAddressNoIpPort, "workload endpoint tcp socket URI host component must be an IP:port"),
117 }
118}