Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 123 additions & 6 deletions apps/gateway/src/gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,15 @@ use crate::vault;
#[derive(Clone)]
pub(crate) struct GatewayState {
pub ca: Arc<CertificateAuthority>,
/// Standard upstream client — validates TLS certificates.
pub http_client: reqwest::Client,
/// No-verify upstream client — skips TLS certificate validation.
/// Selected for hosts matched by `skip_verify_hosts`.
pub http_client_no_verify: reqwest::Client,
/// Hostname patterns for which TLS certificate validation is skipped.
/// Supports exact match (`internal.corp`) and wildcard prefix (`*.internal.corp`).
/// Populated from `GATEWAY_SKIP_VERIFY_HOSTS` (comma-separated).
pub skip_verify_hosts: Arc<Vec<String>>,
pub policy_engine: Arc<PolicyEngine>,
pub cache: Arc<dyn CacheStore>,
/// Provider-agnostic vault service for credential fetching.
Expand All @@ -58,15 +66,50 @@ pub struct GatewayServer {
/// Build the HTTP client used for upstream requests.
///
/// - Redirects are disabled so 3xx responses are forwarded to the client as-is.
/// - Invalid certs are optionally accepted via `GATEWAY_DANGER_ACCEPT_INVALID_CERTS`.
fn build_http_client() -> reqwest::Client {
/// - `accept_invalid_certs` skips TLS certificate validation for upstream connections.
fn build_http_client(accept_invalid_certs: bool) -> reqwest::Client {
reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.danger_accept_invalid_certs(std::env::var("GATEWAY_DANGER_ACCEPT_INVALID_CERTS").is_ok())
.danger_accept_invalid_certs(accept_invalid_certs)
.build()
.expect("build HTTP client")
}

/// Parse `GATEWAY_SKIP_VERIFY_HOSTS` into a list of hostname patterns.
///
/// Patterns support:
/// - Exact match: `internal.corp`
/// - Wildcard subdomain prefix: `*.internal.corp`
///
/// Falls back to empty (no hosts skip verification) if the variable is unset.
fn parse_skip_verify_hosts() -> Vec<String> {
std::env::var("GATEWAY_SKIP_VERIFY_HOSTS")
.unwrap_or_default()
.split(',')
.map(|s| s.trim().to_lowercase())
.filter(|s| !s.is_empty())
.collect()
}

/// Returns true if `host` matches any pattern in `patterns`.
///
/// - `*.example.com` matches `sub.example.com` but NOT `example.com` itself.
/// - `example.com` matches only `example.com`.
///
/// Patterns are pre-lowercased by `parse_skip_verify_hosts`.
/// Follows the same wildcard semantics as `connect::host_matches`.
fn host_matches_skip_verify(host: &str, patterns: &[String]) -> bool {
let host = host.to_lowercase();
patterns.iter().any(|pattern| {
if let Some(suffix) = pattern.strip_prefix('*') {
// "*.example.com" → suffix = ".example.com"
host.ends_with(suffix) && host.len() > suffix.len()
} else {
host == *pattern
}
})
}

impl GatewayServer {
pub fn new(
ca: CertificateAuthority,
Expand All @@ -75,9 +118,20 @@ impl GatewayServer {
vault_service: Arc<vault::VaultService>,
cache: Arc<dyn CacheStore>,
) -> Self {
let global_skip = std::env::var("GATEWAY_DANGER_ACCEPT_INVALID_CERTS").is_ok();
let skip_verify_hosts = Arc::new(parse_skip_verify_hosts());

if global_skip {
warn!("GATEWAY_DANGER_ACCEPT_INVALID_CERTS is set: TLS verification disabled for ALL upstream hosts");
} else if !skip_verify_hosts.is_empty() {
info!(hosts = ?skip_verify_hosts.as_ref(), "TLS verification disabled for matched hosts (GATEWAY_SKIP_VERIFY_HOSTS)");
}

let state = GatewayState {
ca: Arc::new(ca),
http_client: build_http_client(),
http_client: build_http_client(global_skip),
http_client_no_verify: build_http_client(true),
skip_verify_hosts,
policy_engine,
cache,
vault_service,
Expand Down Expand Up @@ -306,7 +360,13 @@ async fn handle_connect(
);

let ca = Arc::clone(&state.ca);
let http_client = state.http_client.clone();
let skip_verify = host_matches_skip_verify(&hostname, &state.skip_verify_hosts);
let http_client = if skip_verify {
info!(host = %hostname, "TLS verification skipped (GATEWAY_SKIP_VERIFY_HOSTS)");
state.http_client_no_verify.clone()
} else {
state.http_client.clone()
};
let cache = Arc::clone(&state.cache);
let agent_token_owned = agent_token.clone().unwrap_or_default();

Expand Down Expand Up @@ -722,7 +782,7 @@ mod tests {
});

// Act: use the same client the gateway uses in production.
let client = build_http_client();
let client = build_http_client(false);
let resp = client
.get(format!("http://{addr}/test"))
.send()
Expand Down Expand Up @@ -853,6 +913,63 @@ mod tests {
}
}

// ── host_matches_skip_verify ─────────────────────────────────────────

#[test]
fn skip_verify_exact_match() {
let patterns = vec!["internal.corp".to_string()];
assert!(host_matches_skip_verify("internal.corp", &patterns));
assert!(!host_matches_skip_verify("other.corp", &patterns));
assert!(!host_matches_skip_verify("sub.internal.corp", &patterns));
}

#[test]
fn skip_verify_wildcard_matches_subdomains_only() {
let patterns = vec!["*.internal.corp".to_string()];
assert!(host_matches_skip_verify("foo.internal.corp", &patterns));
assert!(host_matches_skip_verify("a.b.internal.corp", &patterns));
assert!(!host_matches_skip_verify("internal.corp", &patterns));
assert!(!host_matches_skip_verify("notinternal.corp", &patterns));
assert!(!host_matches_skip_verify("evil-internal.corp", &patterns));
}

#[test]
fn skip_verify_case_insensitive_host() {
// Patterns are pre-lowercased by parse_skip_verify_hosts.
// The match function lowercases the host input.
let patterns = vec!["internal.corp".to_string()];
assert!(host_matches_skip_verify("INTERNAL.CORP", &patterns));
assert!(host_matches_skip_verify("Internal.Corp", &patterns));
assert!(host_matches_skip_verify("internal.corp", &patterns));
}

#[test]
fn skip_verify_empty_patterns_never_matches() {
assert!(!host_matches_skip_verify("anything.com", &[]));
}

// ── parse_skip_verify_patterns ─────────────────────────────────────

/// Helper: parse a raw comma-separated string the same way `parse_skip_verify_hosts` does.
fn parse_patterns(input: &str) -> Vec<String> {
input
.split(',')
.map(|s| s.trim().to_lowercase())
.filter(|s| !s.is_empty())
.collect()
}

#[test]
fn parse_skip_verify_splits_and_trims() {
let hosts = parse_patterns(" foo.com , *.bar.com , baz.io ");
assert_eq!(hosts, vec!["foo.com", "*.bar.com", "baz.io"]);
}

#[test]
fn parse_skip_verify_empty_input() {
assert!(parse_patterns("").is_empty());
}

// ── is_http_proxy_request ──────────────────────────────────────────

#[test]
Expand Down
Loading