Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,034 changes: 1,957 additions & 77 deletions Cargo.lock

Large diffs are not rendered by default.

11 changes: 8 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@ axum = "0.8"
base64 = "0.22"
chrono = "0.4"
clap = { version = "4.6", features = ["derive", "env"] }
defguard_certs = { git = "https://github.com/DefGuard/defguard.git", rev = "01957186101fc105803d56f1190efbdb5102df2f" }
defguard_version = { git = "https://github.com/DefGuard/defguard.git", rev = "01957186101fc105803d56f1190efbdb5102df2f" }
defguard_certs = { git = "https://github.com/DefGuard/defguard.git", rev = "564dc72c" }
defguard_grpc_tls = { git = "https://github.com/DefGuard/defguard.git", rev = "710b1bfd" }
defguard_version = { git = "https://github.com/DefGuard/defguard.git", rev = "7d28f46e" }
rustls-webpki = { version = "0.103", features = ["ring", "std"] }
rustls-pki-types = "1"
defguard_wireguard_rs = "0.9"
env_logger = "0.11"
ipnetwork = "0.21"
Expand Down Expand Up @@ -46,7 +49,9 @@ mnl = "0.3"
nix = { version = "0.31", default-features = false, features = ["ioctl"] }

[dev-dependencies]
tokio = { version = "1", features = ["io-std", "io-util"] }
defguard_certs = { git = "https://github.com/DefGuard/defguard.git", rev = "564dc72c" }
rustls = { version = "0.23", default-features = false, features = ["ring"] }
tokio = { version = "1", features = ["sync", "time"] }
tonic = { version = "0.14", default-features = false, features = [
"codegen",
"router",
Expand Down
2 changes: 1 addition & 1 deletion proto
95 changes: 53 additions & 42 deletions src/gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,24 @@ use std::{
time::{Duration, SystemTime},
};

use defguard_certs::{CertificateError, CertificateInfo};
use defguard_grpc_tls::{certs::server_tls_config, server::certificate_serial_interceptor};
use defguard_version::{
ComponentInfo, DefguardComponent, Version, get_tracing_variables, server::DefguardVersionLayer,
};
use defguard_wireguard_rs::{WireguardInterfaceApi, net::IpAddrMask};
use tokio::{
fs::remove_file,
sync::{mpsc, oneshot},
time::interval,
};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tonic::{
Request, Response, Status, Streaming,
transport::{Identity, Server, ServerTlsConfig},
};
use tonic::{Request, Response, Status, Streaming, service::InterceptorLayer, transport::Server};
use tower::ServiceBuilder;
use tracing::instrument;

use crate::{
GRPC_CERT_NAME, GRPC_KEY_NAME, VERSION,
CORE_CLIENT_CERT_NAME, GRPC_CA_CERT_NAME, GRPC_CERT_NAME, GRPC_KEY_NAME, VERSION,
config::Config,
enterprise::firewall::{
FirewallConfig, FirewallError, FirewallRule, SnatBinding,
Expand Down Expand Up @@ -144,6 +144,10 @@ type PubKey = String;
pub struct TlsConfig {
pub grpc_cert_pem: String,
pub grpc_key_pem: String,
/// PEM-encoded CA certificate used to verify Core's mTLS client certificate chain.
pub grpc_ca_cert_pem: String,
/// DER-encoded Core client certificate; used to extract and pin the expected serial.
pub core_client_cert_der: Vec<u8>,
}

pub struct Gateway {
Expand Down Expand Up @@ -558,9 +562,11 @@ impl GatewayServer {
}

/// Starts the gateway process.
/// * Retrieves configuration and configuration updates from Defguard gRPC server
/// * Manages the interface according to configuration and updates
/// * Sends interface statistics to Defguard server periodically
/// * Requires a valid mTLS configuration to be set (via `set_tls_config`) before starting;
/// returns an error if TLS configuration is absent - the gRPC server never starts in plain-text mode
/// * Retrieves configuration and configuration updates from Defguard core via a mTLS-secured gRPC server
/// * Manages the WireGuard interface according to configuration and updates
/// * Sends interface statistics to Defguard core periodically
pub async fn start(self, config: Config) -> Result<(), GatewayError> {
info!("Starting Defguard Gateway version {VERSION} with configuration: {config:?}");

Expand Down Expand Up @@ -593,36 +599,42 @@ impl GatewayServer {
execute_command(post_up)?;
}

let grpc_cert = self
let tls_config = self
.gateway
.lock()
.unwrap()
.tls_config
.as_ref()
.map(|c| c.grpc_cert_pem.clone());
let grpc_key = self
.gateway
.lock()
.unwrap()
.expect("gateway mutex poison")
.tls_config
.as_ref()
.map(|c| c.grpc_key_pem.clone());
.clone();

// Build gRPC server.
let addr = config.grpc_socket();
info!("gRPC server is listening on {addr}");
let mut builder = if let (Some(cert), Some(key)) = (grpc_cert, grpc_key) {
let identity = Identity::from_pem(cert, key);
Server::builder().tls_config(ServerTlsConfig::new().identity(identity))?
} else {
Server::builder()
};

let tls = tls_config.ok_or_else(|| {
GatewayError::SetupError(
"TLS configuration is required; gateway gRPC server cannot start without mTLS"
.into(),
)
})?;

let tls_config =
server_tls_config(&tls.grpc_cert_pem, &tls.grpc_key_pem, &tls.grpc_ca_cert_pem)
.map_err(|e| GatewayError::SetupError(e.to_string()))?;
let mut builder = Server::builder().tls_config(tls_config)?;

// Extract Core client cert serial for pinning.
let expected_serial = CertificateInfo::from_der(&tls.core_client_cert_der)
.map_err(|e: CertificateError| GatewayError::SetupError(e.to_string()))?
.serial;

// Start gRPC server. This should run indefinitely.
debug!("Serving gRPC");
builder
.add_service(
ServiceBuilder::new()
.layer(InterceptorLayer::new(certificate_serial_interceptor(
expected_serial,
)))
.layer(DefguardVersionLayer::new(Version::parse(VERSION)?))
.service(gateway_server::GatewayServer::new(self)),
)
Expand Down Expand Up @@ -760,25 +772,24 @@ impl gateway_server::Gateway for GatewayServer {
debug!("Received purge request, removing gRPC certificate files");
let cert_path = self.cert_dir.join(GRPC_CERT_NAME);
let key_path = self.cert_dir.join(GRPC_KEY_NAME);
let ca_cert_path = self.cert_dir.join(GRPC_CA_CERT_NAME);
let core_client_cert_path = self.cert_dir.join(CORE_CLIENT_CERT_NAME);

if let Err(err) = tokio::fs::remove_file(&cert_path).await
&& err.kind() != std::io::ErrorKind::NotFound
{
error!(
"Failed to remove gRPC certificate at {}: {err}",
cert_path.display()
);
return Err(Status::internal("Failed to remove gRPC certificate"));
}
info!("Removed gRPC certificate at {}", cert_path.display());
let remove_cert_file = async |path: &std::path::Path, label: &str| -> Result<(), Status> {
if let Err(err) = remove_file(path).await
&& err.kind() != std::io::ErrorKind::NotFound
{
error!("Failed to remove {label} at {}: {err}", path.display());
return Err(Status::internal(format!("Failed to remove {label}")));
}
info!("Removed {label} at {}", path.display());
Ok(())
};

if let Err(err) = tokio::fs::remove_file(&key_path).await
&& err.kind() != std::io::ErrorKind::NotFound
{
error!("Failed to remove gRPC key at {}: {err}", key_path.display());
return Err(Status::internal("Failed to remove gRPC key"));
}
info!("Removed gRPC key at {}", cert_path.display());
remove_cert_file(&cert_path, "gRPC certificate").await?;
remove_cert_file(&key_path, "gRPC key").await?;
remove_cert_file(&ca_cert_path, "CA certificate").await?;
remove_cert_file(&core_client_cert_path, "Core client certificate").await?;

// Prepare underlying `Gateway` to enter setup mode.
self.gateway
Expand Down
5 changes: 5 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,15 @@ pub mod enterprise;
pub mod logging;
pub mod setup;

#[cfg(test)]
mod tests;

pub const VERSION: &str = concat!(env!("CARGO_PKG_VERSION"), "+", env!("VERGEN_GIT_SHA"));

pub const GRPC_CERT_NAME: &str = "gateway_grpc_cert.pem";
pub const GRPC_KEY_NAME: &str = "gateway_grpc_key.pem";
pub const GRPC_CA_CERT_NAME: &str = "grpc_ca_cert.pem";
pub const CORE_CLIENT_CERT_NAME: &str = "core_client_cert.pem";

/// Masks object's field with "***" string.
/// Used to log sensitive/secret objects.
Expand Down
66 changes: 41 additions & 25 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ use std::{fs::Permissions, os::unix::fs::PermissionsExt};
use std::{
fs::{File, read_to_string},
io::Write,
path::Path,
process,
sync::{Arc, Mutex},
};

use defguard_gateway::{
GRPC_CERT_NAME, GRPC_KEY_NAME, VERSION,
CORE_CLIENT_CERT_NAME, GRPC_CA_CERT_NAME, GRPC_CERT_NAME, GRPC_KEY_NAME, VERSION,
config::get_config,
error::GatewayError,
execute_command,
Expand All @@ -24,6 +25,30 @@ use defguard_wireguard_rs::Kernel;
use defguard_wireguard_rs::{Userspace, WGApi};
use tokio::{sync::mpsc, task::JoinSet};

fn load_tls_config(cert_dir: &Path) -> Result<Option<TlsConfig>, GatewayError> {
let grpc_cert = read_to_string(cert_dir.join(GRPC_CERT_NAME)).ok();
let grpc_key = read_to_string(cert_dir.join(GRPC_KEY_NAME)).ok();
let grpc_ca_cert = read_to_string(cert_dir.join(GRPC_CA_CERT_NAME)).ok();
let core_client_cert_pem = read_to_string(cert_dir.join(CORE_CLIENT_CERT_NAME)).ok();

match (grpc_cert, grpc_key, grpc_ca_cert, core_client_cert_pem) {
(Some(cert), Some(key), Some(ca_cert), Some(client_cert_pem)) => {
let core_client_cert_der = defguard_certs::parse_pem_certificate(&client_cert_pem)
.map_err(|e| {
GatewayError::SetupError(format!("Failed to parse Core client cert: {e}"))
})?
.to_vec();
Ok(Some(TlsConfig {
grpc_cert_pem: cert,
grpc_key_pem: key,
grpc_ca_cert_pem: ca_cert,
core_client_cert_der,
}))
}
_ => Ok(None),
}
}

#[tokio::main]
async fn main() -> Result<(), GatewayError> {
// parse config
Expand All @@ -44,12 +69,7 @@ async fn main() -> Result<(), GatewayError> {
tokio::fs::set_permissions(cert_dir, Permissions::from_mode(0o700)).await?;
}

let (grpc_cert, grpc_key) = (
read_to_string(cert_dir.join(GRPC_CERT_NAME)).ok(),
read_to_string(cert_dir.join(GRPC_KEY_NAME)).ok(),
);

let needs_setup = grpc_cert.is_none() || grpc_key.is_none();
let maybe_tls_config = load_tls_config(cert_dir)?;

// TODO: The channel size may need to be adjusted or some other approach should be used
// to avoid dropping log messages.
Expand Down Expand Up @@ -108,25 +128,21 @@ async fn main() -> Result<(), GatewayError> {
let post_down_clone = config.post_down.clone();

tasks.spawn(async move {
let tls_config = if needs_setup {
log::info!(
"gRPC TLS certificates not found in {}. They will be generated during setup.",
config.cert_dir.display()
);
run_setup(&config, Arc::clone(&logs_rx)).await?
} else if let (Some(cert), Some(key)) = (grpc_cert, grpc_key) {
log::info!(
"Using existing gRPC TLS certificates from {}",
config.cert_dir.display()
);
TlsConfig {
grpc_cert_pem: cert,
grpc_key_pem: key,
let tls_config = match maybe_tls_config {
None => {
log::info!(
"gRPC TLS certificates not found in {}. They will be generated during setup.",
config.cert_dir.display()
);
run_setup(&config, Arc::clone(&logs_rx)).await?
}
Some(tls_config) => {
log::info!(
"Using existing gRPC TLS certificates from {}",
config.cert_dir.display()
);
tls_config
}
} else {
return Err(GatewayError::SetupError(
"gRPC TLS certificates are missing after setup".to_string(),
));
};

// Launch gRPC server (with purge-triggered setup loop).
Expand Down
Loading