From f179a1c5264a92f77eba509d1e527be50e4cff41 Mon Sep 17 00:00:00 2001 From: moparisthebest Date: Mon, 18 Jul 2022 01:49:56 -0400 Subject: [PATCH] Implement optional support for systemd socket activation --- Cargo.lock | 22 +++++++ Cargo.toml | 6 +- build.rs | 3 + check-all-features.sh | 64 +++++++++++++------- src/lib.rs | 5 +- src/main.rs | 71 +++++++++++++++++++---- src/outgoing.rs | 5 +- src/quic/incoming.rs | 9 +-- src/systemd.rs | 132 ++++++++++++++++++++++++++++++++++++++++++ src/tls/incoming.rs | 17 +++--- 10 files changed, 281 insertions(+), 53 deletions(-) mode change 100644 => 100755 check-all-features.sh create mode 100644 src/systemd.rs diff --git a/Cargo.lock b/Cargo.lock index e340a51..3a6ecf7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -606,6 +606,15 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +[[package]] +name = "memoffset" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aa361d4faea93603064a027415f07bd8e1d5c88c9fbf68bf56a285428fd79ce" +dependencies = [ + "autocfg", +] + [[package]] name = "mime" version = "0.3.16" @@ -633,6 +642,18 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "nix" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "195cdbc1741b8134346d515b3a56a1c94b0912758009cfd53f99ea0f57b065fc" +dependencies = [ + "bitflags", + "cfg-if", + "libc", + "memoffset", +] + [[package]] name = "num_cpus" version = "1.13.1" @@ -1607,6 +1628,7 @@ dependencies = [ "futures-util", "lazy_static", "log", + "nix", "quinn", "rand", "reqwest", diff --git a/Cargo.toml b/Cargo.toml index 33b4be2..53f622e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,8 +60,11 @@ rustls-pemfile = { version = "1.0.0", optional = true } tokio-tungstenite = { version = "0.17", optional = true, default-features = false } futures-util = { version = "0.3", default-features = false, features = ["async-await", "sink", "std"], optional = true } +# systemd dep +nix = { version = "0.24", optional = true, default-features = false, features = ["socket"]} + [features] -default = ["c2s-incoming", "c2s-outgoing", "s2s-incoming", "s2s-outgoing", "tls", "quic", "websocket", "logging", "tls-ca-roots-native"] +default = ["c2s-incoming", "c2s-outgoing", "s2s-incoming", "s2s-outgoing", "tls", "quic", "websocket", "logging", "tls-ca-roots-native", "systemd"] # you must pick one of these or the other, not both: todo: enable picking both and choosing at runtime # don't need either of these if only doing c2s-incoming @@ -88,6 +91,7 @@ quic = ["quinn", "rustls"] websocket = ["tokio-tungstenite", "futures-util", "tls"] # websocket+incoming also enables incoming TLS support as it's free logging = ["rand", "env_logger"] +systemd = ["nix"] [dev-dependencies] serde_json = "1.0" diff --git a/build.rs b/build.rs index 06ee2d5..99431ba 100644 --- a/build.rs +++ b/build.rs @@ -18,6 +18,8 @@ fn main() { "websocket", "tls-ca-roots-native", "tls-ca-roots-bundled", + "logging", + "systemd", ]; let optional_deps = [ "rustls", @@ -33,6 +35,7 @@ fn main() { "webpki-roots", "env-logger", "rand", + "nix", ]; let mut features = Vec::new(); let mut optional = Vec::new(); diff --git a/check-all-features.sh b/check-all-features.sh old mode 100644 new mode 100755 index 06115c0..9663de4 --- a/check-all-features.sh +++ b/check-all-features.sh @@ -1,6 +1,13 @@ #!/bin/bash +threads="$1" + set -euo pipefail +# if we have access to nproc, divide that by 2, otherwise use 1 thread by default +[ "$threads" == "" ] && threads=$(($(nproc || echo 2) / 2)) + +echo "threads: $threads" + export RUSTFLAGS=-Awarnings show() { @@ -22,38 +29,51 @@ perm_lines() { } perms() { - perm_lines "$@" | tr ' ' ',' | sort -u | tr '\n' ' ' + perm_lines "$@" | tr ' ' ',' | sort -u +} + +perms_optional() { + perm_lines "$@" | tr ' ' ',' | sort -u | sed 's/^/,/' +} + +all_features() { + for optional in "" $(perms_optional logging systemd) + do + for proto in $(perms tls quic websocket) + do + for direction in $(perms c2s-incoming c2s-outgoing s2s-incoming s2s-outgoing) + do + for ca_roots in tls-ca-roots-native tls-ca-roots-bundled + do + echo $direction,$proto,$ca_roots$optional + done + done + done + done + + for optional in "" $(perms_optional logging systemd) + do + for proto in $(perms tls quic websocket) + do + echo c2s-incoming,$proto$optional + done + done } echo_cargo() { + set -euo pipefail #echo cargo run "$@" -- -v #cargo run "$@" -- -v echo cargo check "$@" cargo check "$@" } +#all_features | sort -u | wc -l; exit 0 + +export -f echo_cargo + echo_cargo -for optional in "" ",logging" -do - for proto in $(perms tls quic websocket) - do - for direction in $(perms c2s-incoming c2s-outgoing s2s-incoming s2s-outgoing) - do - for ca_roots in tls-ca-roots-native tls-ca-roots-bundled - do - echo_cargo --no-default-features --features $direction,$proto,$ca_roots$optional - done - done - done -done - -for optional in "" ",logging" -do - for proto in $(perms tls quic websocket) - do - echo_cargo --no-default-features --features c2s-incoming,$proto$optional - done -done +all_features | sort | xargs -n1 --max-procs=$threads bash -c 'echo_cargo --no-default-features --features "$@" || exit 255' _ echo good! diff --git a/src/lib.rs b/src/lib.rs index cb0f98e..b02473b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,5 +20,8 @@ pub mod websocket; #[cfg(any(feature = "s2s-incoming", feature = "outgoing"))] pub mod verify; -mod context; +#[cfg(feature = "nix")] +pub mod systemd; + +pub mod context; pub mod in_out; diff --git a/src/main.rs b/src/main.rs index 333c426..3f24fea 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,10 +3,17 @@ use anyhow::Result; use die::{die, Die}; use log::{debug, error, info}; use serde_derive::Deserialize; -use std::{ffi::OsString, fs::File, io::Read, iter::Iterator, net::SocketAddr, path::Path, sync::Arc}; -use tokio::task::JoinHandle; +use std::{ + ffi::OsString, + fs::File, + io::Read, + iter::Iterator, + net::{SocketAddr, UdpSocket}, + path::Path, + sync::Arc, +}; +use tokio::{net::TcpListener, task::JoinHandle}; use xmpp_proxy::common::certs_key::CertsKey; - #[cfg(feature = "outgoing")] use xmpp_proxy::{common::outgoing::OutgoingConfig, outgoing::spawn_outgoing_listener}; @@ -128,13 +135,53 @@ async fn main() { die!("log_level or log_style defined in config but logging disabled at compile-time"); } + let mut incoming_listen = Vec::new(); + for a in main_config.incoming_listen.iter() { + incoming_listen.push(TcpListener::bind(a).await.die("cannot listen on port/interface")); + } + let mut quic_listen = Vec::new(); + for a in main_config.quic_listen.iter() { + quic_listen.push(UdpSocket::bind(a).die("cannot listen on port/interface")); + } + let mut outgoing_listen = Vec::new(); + for a in main_config.outgoing_listen.iter() { + outgoing_listen.push(TcpListener::bind(a).await.die("cannot listen on port/interface")); + } + + #[cfg(feature = "nix")] + if let Ok(fds) = xmpp_proxy::systemd::receive_descriptors_with_names(true) { + use xmpp_proxy::systemd::Listener; + for fd in fds { + match fd.listener() { + Listener::Tcp(tcp_listener) => { + let tcp_listener = TcpListener::from_std(tcp_listener()).die("cannot open systemd TcpListener"); + if let Some(name) = fd.name().map(|n| n.to_ascii_lowercase()) { + if name.starts_with("in") { + incoming_listen.push(tcp_listener); + } else if name.starts_with("out") { + outgoing_listen.push(tcp_listener); + } else { + die!("systemd socket name must start with 'in' or 'out' but is '{}'", name); + } + } else { + // what to do here... for now we will require names + // todo: possibly in future if local_addr is localhost or private ranges assume outgoing, otherwise incoming? + die!("systemd TCP socket activation requires name that starts with 'in' or 'out'"); + } + } + Listener::Udp(udp_socket) => quic_listen.push(udp_socket()), + _ => continue, + } + } + } + #[cfg(feature = "incoming")] let config = main_config.get_cloneable_cfg(); let certs_key = Arc::new(CertsKey::new(main_config.certs_key())); let mut handles: Vec>> = Vec::new(); - if !main_config.incoming_listen.is_empty() { + if !incoming_listen.is_empty() { #[cfg(all(any(feature = "tls", feature = "websocket"), feature = "incoming"))] { use xmpp_proxy::{ @@ -145,14 +192,14 @@ async fn main() { die!("one of c2s_target/s2s_target must be defined if incoming_listen is non-empty"); } let acceptor = tls_acceptor(server_config(certs_key.clone()).die("invalid cert/key ?")); - for listener in main_config.incoming_listen.iter() { - handles.push(spawn_tls_listener(listener.parse().die("invalid listener address"), config.clone(), acceptor.clone())); + for listener in incoming_listen { + handles.push(spawn_tls_listener(listener, config.clone(), acceptor.clone())); } } #[cfg(not(all(any(feature = "tls", feature = "websocket"), feature = "incoming")))] die!("incoming_listen non-empty but (tls or websocket) or (s2s-incoming and c2s-incoming) disabled at compile-time"); } - if !main_config.quic_listen.is_empty() { + if !quic_listen.is_empty() { #[cfg(all(feature = "quic", feature = "incoming"))] { use xmpp_proxy::{ @@ -163,19 +210,19 @@ async fn main() { die!("one of c2s_target/s2s_target must be defined if quic_listen is non-empty"); } let quic_config = quic_server_config(server_config(certs_key.clone()).die("invalid cert/key ?")); - for listener in main_config.quic_listen.iter() { - handles.push(spawn_quic_listener(listener.parse().die("invalid listener address"), config.clone(), quic_config.clone())); + for listener in quic_listen { + handles.push(spawn_quic_listener(listener, config.clone(), quic_config.clone())); } } #[cfg(not(all(feature = "quic", feature = "incoming")))] die!("quic_listen non-empty but quic or (s2s-incoming and c2s-incoming) disabled at compile-time"); } - if !main_config.outgoing_listen.is_empty() { + if !outgoing_listen.is_empty() { #[cfg(feature = "outgoing")] { let outgoing_cfg = main_config.get_outgoing_cfg(certs_key.clone()); - for listener in main_config.outgoing_listen.iter() { - handles.push(spawn_outgoing_listener(listener.parse().die("invalid listener address"), outgoing_cfg.clone())); + for listener in outgoing_listen { + handles.push(spawn_outgoing_listener(listener, outgoing_cfg.clone())); } } #[cfg(not(feature = "outgoing"))] diff --git a/src/outgoing.rs b/src/outgoing.rs index ab71d46..125669e 100644 --- a/src/outgoing.rs +++ b/src/outgoing.rs @@ -7,9 +7,7 @@ use crate::{ stanzafilter::StanzaFilter, }; use anyhow::Result; -use die::Die; use log::{error, info}; -use std::net::SocketAddr; use tokio::{net::TcpListener, task::JoinHandle}; async fn handle_outgoing_connection(stream: tokio::net::TcpStream, client_addr: &mut Context<'_>, config: OutgoingConfig) -> Result<()> { @@ -48,9 +46,8 @@ async fn handle_outgoing_connection(stream: tokio::net::TcpStream, client_addr: shuffle_rd_wr_filter_only(in_rd, in_wr, out_rd, out_wr, is_c2s, max_stanza_size_bytes, client_addr, in_filter).await } -pub fn spawn_outgoing_listener(local_addr: SocketAddr, config: OutgoingConfig) -> JoinHandle> { +pub fn spawn_outgoing_listener(listener: TcpListener, config: OutgoingConfig) -> JoinHandle> { tokio::spawn(async move { - let listener = TcpListener::bind(&local_addr).await.die("cannot listen on port/interface"); loop { let (stream, client_addr) = listener.accept().await?; let config = config.clone(); diff --git a/src/quic/incoming.rs b/src/quic/incoming.rs index 2529a7e..28aef6d 100644 --- a/src/quic/incoming.rs +++ b/src/quic/incoming.rs @@ -7,12 +7,13 @@ use anyhow::Result; use die::Die; use futures::StreamExt; use log::{error, info}; -use quinn::ServerConfig; -use std::{net::SocketAddr, sync::Arc}; +use quinn::{Endpoint, EndpointConfig, ServerConfig}; +use std::{net::UdpSocket, sync::Arc}; use tokio::task::JoinHandle; -pub fn spawn_quic_listener(local_addr: SocketAddr, config: CloneableConfig, server_config: ServerConfig) -> JoinHandle> { - let (_endpoint, mut incoming) = quinn::Endpoint::server(server_config, local_addr).die("cannot listen on port/interface"); +pub fn spawn_quic_listener(udp_socket: UdpSocket, config: CloneableConfig, server_config: ServerConfig) -> JoinHandle> { + let local_addr = udp_socket.local_addr().die("cannot get local_addr for quic socket"); + let (_endpoint, mut incoming) = Endpoint::new(EndpointConfig::default(), Some(server_config), udp_socket).die("cannot listen on port/interface"); tokio::spawn(async move { // when could this return None, do we quit? while let Some(incoming_conn) = incoming.next().await { diff --git a/src/systemd.rs b/src/systemd.rs new file mode 100644 index 0000000..36b05ab --- /dev/null +++ b/src/systemd.rs @@ -0,0 +1,132 @@ +use anyhow::{anyhow, bail, Result}; +use nix::sys::socket::{getsockname, getsockopt, AddressFamily, SockType, SockaddrLike, SockaddrStorage}; +use std::{ + env, + net::{TcpListener, UdpSocket}, + os::unix::{ + io::{FromRawFd, IntoRawFd, RawFd}, + net::{UnixDatagram, UnixListener}, + }, + process, +}; + +/// Minimum FD number used by systemd for passing sockets. +const SD_LISTEN_FDS_START: RawFd = 3; + +/// File descriptor passed by systemd to socket-activated services. +/// +/// See . +#[derive(Debug, Clone)] +pub struct FileDescriptor { + raw_fd: RawFd, + tcp_not_udp: bool, + inet_not_unix: bool, + pub name: Option, +} + +pub enum Listener { + Tcp(Box TcpListener>), + Udp(Box UdpSocket>), + UnixListener(Box UnixListener>), + UnixDatagram(Box UnixDatagram>), +} + +impl FileDescriptor { + pub fn name(self) -> Option { + self.name + } + + pub fn listener(&self) -> Listener { + let raw_fd = self.raw_fd; + match (self.tcp_not_udp, self.inet_not_unix) { + (true, true) => Listener::Tcp(Box::new(move || unsafe { TcpListener::from_raw_fd(raw_fd) })), + (false, true) => Listener::Udp(Box::new(move || unsafe { UdpSocket::from_raw_fd(raw_fd) })), + (true, false) => Listener::UnixListener(Box::new(move || unsafe { UnixListener::from_raw_fd(raw_fd) })), + (false, false) => Listener::UnixDatagram(Box::new(move || unsafe { UnixDatagram::from_raw_fd(raw_fd) })), + } + } +} + +/// Check for named file descriptors passed by systemd. +/// +/// Like `receive_descriptors`, but this will also return a vector of names +/// associated with each file descriptor. +pub fn receive_descriptors_with_names(unset_env: bool) -> Result> { + let pid = env::var("LISTEN_PID"); + let fds = env::var("LISTEN_FDS"); + let fdnames = env::var("LISTEN_FDNAMES"); + log::trace!("LISTEN_PID = {:?}; LISTEN_FDS = {:?}; LISTEN_FDNAMES = {:?}", pid, fds, fdnames); + + if unset_env { + env::remove_var("LISTEN_PID"); + env::remove_var("LISTEN_FDS"); + env::remove_var("LISTEN_FDNAMES"); + } + + let pid = pid + .map_err(|e| anyhow!("failed to get LISTEN_PID: {}", e))? + .parse::() + .map_err(|e| anyhow!("failed to parse LISTEN_PID: {}", e))?; + let fds = fds + .map_err(|e| anyhow!("failed to get LISTEN_FDS: {}", e))? + .parse::() + .map_err(|e| anyhow!("failed to parse LISTEN_FDS: {}", e))?; + + if process::id() != pid { + bail!("PID mismatch"); + } + + let names = fdnames.map(|n| n.split(':').map(String::from).collect()).unwrap_or_else(|_| Vec::new()); + + socks_from_fds(fds, names) +} + +fn socks_from_fds(num_fds: usize, names: Vec) -> Result> { + let mut descriptors = Vec::with_capacity(num_fds); + let mut names = names.into_iter(); + for fd_offset in 0..num_fds { + let name = names.next(); + let raw_fd: RawFd = SD_LISTEN_FDS_START + .checked_add(fd_offset as i32) + .ok_or_else(|| anyhow!("overlarge file descriptor index: {}", num_fds))?; + if !sock_listening(raw_fd) { + continue; + } + let tcp_not_udp = match sock_type(raw_fd) { + Some(SockType::Stream) => true, + Some(SockType::Datagram) => false, + _ => continue, + }; + let inet_not_unix = match address_family(raw_fd) { + Some(AddressFamily::Inet) | Some(AddressFamily::Inet6) => true, + Some(AddressFamily::Unix) => false, + _ => continue, + }; + descriptors.push(FileDescriptor { + raw_fd, + tcp_not_udp, + inet_not_unix, + name, + }); + } + + Ok(descriptors) +} + +fn sock_listening(raw_fd: RawFd) -> bool { + getsockopt(raw_fd, nix::sys::socket::sockopt::AcceptConn).unwrap_or(false) +} + +fn sock_type(raw_fd: RawFd) -> Option { + getsockopt(raw_fd, nix::sys::socket::sockopt::SockType).ok() +} + +fn address_family(raw_fd: RawFd) -> Option { + getsockname::(raw_fd).ok().and_then(|addr| addr.family()) +} + +impl IntoRawFd for FileDescriptor { + fn into_raw_fd(self) -> RawFd { + self.raw_fd + } +} diff --git a/src/tls/incoming.rs b/src/tls/incoming.rs index dca3fcb..98e5f56 100644 --- a/src/tls/incoming.rs +++ b/src/tls/incoming.rs @@ -1,19 +1,18 @@ -use crate::common::incoming::{shuffle_rd_wr_filter, CloneableConfig, ServerCerts}; -use std::net::SocketAddr; - use crate::{ - common::{first_bytes_match, to_str, IN_BUFFER_SIZE}, + common::{ + first_bytes_match, + incoming::{shuffle_rd_wr_filter, CloneableConfig, ServerCerts}, + to_str, IN_BUFFER_SIZE, + }, context::Context, in_out::{StanzaRead, StanzaWrite}, slicesubsequence::SliceSubsequence, stanzafilter::{StanzaFilter, StanzaReader}, }; use anyhow::{bail, Result}; -use die::Die; use log::{error, info, trace}; use rustls::{ServerConfig, ServerConnection}; - -use std::sync::Arc; +use std::{net::SocketAddr, sync::Arc}; use tokio::{ io::{AsyncBufReadExt, AsyncWriteExt, BufStream}, net::TcpListener, @@ -25,9 +24,9 @@ pub fn tls_acceptor(server_config: ServerConfig) -> TlsAcceptor { TlsAcceptor::from(Arc::new(server_config)) } -pub fn spawn_tls_listener(local_addr: SocketAddr, config: CloneableConfig, acceptor: TlsAcceptor) -> JoinHandle> { +pub fn spawn_tls_listener(listener: TcpListener, config: CloneableConfig, acceptor: TlsAcceptor) -> JoinHandle> { tokio::spawn(async move { - let listener = TcpListener::bind(&local_addr).await.die("cannot listen on port/interface"); + let local_addr = listener.local_addr()?; loop { let (stream, client_addr) = listener.accept().await?; let config = config.clone();