Implement optional support for systemd socket activation

This commit is contained in:
Travis Burtrum 2022-07-18 01:49:56 -04:00
parent 27887c2e82
commit f179a1c526
10 changed files with 281 additions and 53 deletions

22
Cargo.lock generated
View File

@ -606,6 +606,15 @@ version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d"
[[package]]
name = "memoffset"
version = "0.6.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5aa361d4faea93603064a027415f07bd8e1d5c88c9fbf68bf56a285428fd79ce"
dependencies = [
"autocfg",
]
[[package]] [[package]]
name = "mime" name = "mime"
version = "0.3.16" version = "0.3.16"
@ -633,6 +642,18 @@ dependencies = [
"windows-sys", "windows-sys",
] ]
[[package]]
name = "nix"
version = "0.24.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "195cdbc1741b8134346d515b3a56a1c94b0912758009cfd53f99ea0f57b065fc"
dependencies = [
"bitflags",
"cfg-if",
"libc",
"memoffset",
]
[[package]] [[package]]
name = "num_cpus" name = "num_cpus"
version = "1.13.1" version = "1.13.1"
@ -1607,6 +1628,7 @@ dependencies = [
"futures-util", "futures-util",
"lazy_static", "lazy_static",
"log", "log",
"nix",
"quinn", "quinn",
"rand", "rand",
"reqwest", "reqwest",

View File

@ -60,8 +60,11 @@ rustls-pemfile = { version = "1.0.0", optional = true }
tokio-tungstenite = { version = "0.17", optional = true, default-features = false } tokio-tungstenite = { version = "0.17", optional = true, default-features = false }
futures-util = { version = "0.3", default-features = false, features = ["async-await", "sink", "std"], optional = true } futures-util = { version = "0.3", default-features = false, features = ["async-await", "sink", "std"], optional = true }
# systemd dep
nix = { version = "0.24", optional = true, default-features = false, features = ["socket"]}
[features] [features]
default = ["c2s-incoming", "c2s-outgoing", "s2s-incoming", "s2s-outgoing", "tls", "quic", "websocket", "logging", "tls-ca-roots-native"] default = ["c2s-incoming", "c2s-outgoing", "s2s-incoming", "s2s-outgoing", "tls", "quic", "websocket", "logging", "tls-ca-roots-native", "systemd"]
# you must pick one of these or the other, not both: todo: enable picking both and choosing at runtime # you must pick one of these or the other, not both: todo: enable picking both and choosing at runtime
# don't need either of these if only doing c2s-incoming # don't need either of these if only doing c2s-incoming
@ -88,6 +91,7 @@ quic = ["quinn", "rustls"]
websocket = ["tokio-tungstenite", "futures-util", "tls"] # websocket+incoming also enables incoming TLS support as it's free websocket = ["tokio-tungstenite", "futures-util", "tls"] # websocket+incoming also enables incoming TLS support as it's free
logging = ["rand", "env_logger"] logging = ["rand", "env_logger"]
systemd = ["nix"]
[dev-dependencies] [dev-dependencies]
serde_json = "1.0" serde_json = "1.0"

View File

@ -18,6 +18,8 @@ fn main() {
"websocket", "websocket",
"tls-ca-roots-native", "tls-ca-roots-native",
"tls-ca-roots-bundled", "tls-ca-roots-bundled",
"logging",
"systemd",
]; ];
let optional_deps = [ let optional_deps = [
"rustls", "rustls",
@ -33,6 +35,7 @@ fn main() {
"webpki-roots", "webpki-roots",
"env-logger", "env-logger",
"rand", "rand",
"nix",
]; ];
let mut features = Vec::new(); let mut features = Vec::new();
let mut optional = Vec::new(); let mut optional = Vec::new();

44
check-all-features.sh Normal file → Executable file
View File

@ -1,6 +1,13 @@
#!/bin/bash #!/bin/bash
threads="$1"
set -euo pipefail set -euo pipefail
# if we have access to nproc, divide that by 2, otherwise use 1 thread by default
[ "$threads" == "" ] && threads=$(($(nproc || echo 2) / 2))
echo "threads: $threads"
export RUSTFLAGS=-Awarnings export RUSTFLAGS=-Awarnings
show() { show() {
@ -22,19 +29,15 @@ perm_lines() {
} }
perms() { perms() {
perm_lines "$@" | tr ' ' ',' | sort -u | tr '\n' ' ' perm_lines "$@" | tr ' ' ',' | sort -u
} }
echo_cargo() { perms_optional() {
#echo cargo run "$@" -- -v perm_lines "$@" | tr ' ' ',' | sort -u | sed 's/^/,/'
#cargo run "$@" -- -v
echo cargo check "$@"
cargo check "$@"
} }
echo_cargo all_features() {
for optional in "" $(perms_optional logging systemd)
for optional in "" ",logging"
do do
for proto in $(perms tls quic websocket) for proto in $(perms tls quic websocket)
do do
@ -42,18 +45,35 @@ do
do do
for ca_roots in tls-ca-roots-native tls-ca-roots-bundled for ca_roots in tls-ca-roots-native tls-ca-roots-bundled
do do
echo_cargo --no-default-features --features $direction,$proto,$ca_roots$optional echo $direction,$proto,$ca_roots$optional
done done
done done
done done
done done
for optional in "" ",logging" for optional in "" $(perms_optional logging systemd)
do do
for proto in $(perms tls quic websocket) for proto in $(perms tls quic websocket)
do do
echo_cargo --no-default-features --features c2s-incoming,$proto$optional echo c2s-incoming,$proto$optional
done done
done done
}
echo_cargo() {
set -euo pipefail
#echo cargo run "$@" -- -v
#cargo run "$@" -- -v
echo cargo check "$@"
cargo check "$@"
}
#all_features | sort -u | wc -l; exit 0
export -f echo_cargo
echo_cargo
all_features | sort | xargs -n1 --max-procs=$threads bash -c 'echo_cargo --no-default-features --features "$@" || exit 255' _
echo good! echo good!

View File

@ -20,5 +20,8 @@ pub mod websocket;
#[cfg(any(feature = "s2s-incoming", feature = "outgoing"))] #[cfg(any(feature = "s2s-incoming", feature = "outgoing"))]
pub mod verify; pub mod verify;
mod context; #[cfg(feature = "nix")]
pub mod systemd;
pub mod context;
pub mod in_out; pub mod in_out;

View File

@ -3,10 +3,17 @@ use anyhow::Result;
use die::{die, Die}; use die::{die, Die};
use log::{debug, error, info}; use log::{debug, error, info};
use serde_derive::Deserialize; use serde_derive::Deserialize;
use std::{ffi::OsString, fs::File, io::Read, iter::Iterator, net::SocketAddr, path::Path, sync::Arc}; use std::{
use tokio::task::JoinHandle; ffi::OsString,
fs::File,
io::Read,
iter::Iterator,
net::{SocketAddr, UdpSocket},
path::Path,
sync::Arc,
};
use tokio::{net::TcpListener, task::JoinHandle};
use xmpp_proxy::common::certs_key::CertsKey; use xmpp_proxy::common::certs_key::CertsKey;
#[cfg(feature = "outgoing")] #[cfg(feature = "outgoing")]
use xmpp_proxy::{common::outgoing::OutgoingConfig, outgoing::spawn_outgoing_listener}; use xmpp_proxy::{common::outgoing::OutgoingConfig, outgoing::spawn_outgoing_listener};
@ -128,13 +135,53 @@ async fn main() {
die!("log_level or log_style defined in config but logging disabled at compile-time"); die!("log_level or log_style defined in config but logging disabled at compile-time");
} }
let mut incoming_listen = Vec::new();
for a in main_config.incoming_listen.iter() {
incoming_listen.push(TcpListener::bind(a).await.die("cannot listen on port/interface"));
}
let mut quic_listen = Vec::new();
for a in main_config.quic_listen.iter() {
quic_listen.push(UdpSocket::bind(a).die("cannot listen on port/interface"));
}
let mut outgoing_listen = Vec::new();
for a in main_config.outgoing_listen.iter() {
outgoing_listen.push(TcpListener::bind(a).await.die("cannot listen on port/interface"));
}
#[cfg(feature = "nix")]
if let Ok(fds) = xmpp_proxy::systemd::receive_descriptors_with_names(true) {
use xmpp_proxy::systemd::Listener;
for fd in fds {
match fd.listener() {
Listener::Tcp(tcp_listener) => {
let tcp_listener = TcpListener::from_std(tcp_listener()).die("cannot open systemd TcpListener");
if let Some(name) = fd.name().map(|n| n.to_ascii_lowercase()) {
if name.starts_with("in") {
incoming_listen.push(tcp_listener);
} else if name.starts_with("out") {
outgoing_listen.push(tcp_listener);
} else {
die!("systemd socket name must start with 'in' or 'out' but is '{}'", name);
}
} else {
// what to do here... for now we will require names
// todo: possibly in future if local_addr is localhost or private ranges assume outgoing, otherwise incoming?
die!("systemd TCP socket activation requires name that starts with 'in' or 'out'");
}
}
Listener::Udp(udp_socket) => quic_listen.push(udp_socket()),
_ => continue,
}
}
}
#[cfg(feature = "incoming")] #[cfg(feature = "incoming")]
let config = main_config.get_cloneable_cfg(); let config = main_config.get_cloneable_cfg();
let certs_key = Arc::new(CertsKey::new(main_config.certs_key())); let certs_key = Arc::new(CertsKey::new(main_config.certs_key()));
let mut handles: Vec<JoinHandle<Result<()>>> = Vec::new(); let mut handles: Vec<JoinHandle<Result<()>>> = Vec::new();
if !main_config.incoming_listen.is_empty() { if !incoming_listen.is_empty() {
#[cfg(all(any(feature = "tls", feature = "websocket"), feature = "incoming"))] #[cfg(all(any(feature = "tls", feature = "websocket"), feature = "incoming"))]
{ {
use xmpp_proxy::{ use xmpp_proxy::{
@ -145,14 +192,14 @@ async fn main() {
die!("one of c2s_target/s2s_target must be defined if incoming_listen is non-empty"); die!("one of c2s_target/s2s_target must be defined if incoming_listen is non-empty");
} }
let acceptor = tls_acceptor(server_config(certs_key.clone()).die("invalid cert/key ?")); let acceptor = tls_acceptor(server_config(certs_key.clone()).die("invalid cert/key ?"));
for listener in main_config.incoming_listen.iter() { for listener in incoming_listen {
handles.push(spawn_tls_listener(listener.parse().die("invalid listener address"), config.clone(), acceptor.clone())); handles.push(spawn_tls_listener(listener, config.clone(), acceptor.clone()));
} }
} }
#[cfg(not(all(any(feature = "tls", feature = "websocket"), feature = "incoming")))] #[cfg(not(all(any(feature = "tls", feature = "websocket"), feature = "incoming")))]
die!("incoming_listen non-empty but (tls or websocket) or (s2s-incoming and c2s-incoming) disabled at compile-time"); die!("incoming_listen non-empty but (tls or websocket) or (s2s-incoming and c2s-incoming) disabled at compile-time");
} }
if !main_config.quic_listen.is_empty() { if !quic_listen.is_empty() {
#[cfg(all(feature = "quic", feature = "incoming"))] #[cfg(all(feature = "quic", feature = "incoming"))]
{ {
use xmpp_proxy::{ use xmpp_proxy::{
@ -163,19 +210,19 @@ async fn main() {
die!("one of c2s_target/s2s_target must be defined if quic_listen is non-empty"); die!("one of c2s_target/s2s_target must be defined if quic_listen is non-empty");
} }
let quic_config = quic_server_config(server_config(certs_key.clone()).die("invalid cert/key ?")); let quic_config = quic_server_config(server_config(certs_key.clone()).die("invalid cert/key ?"));
for listener in main_config.quic_listen.iter() { for listener in quic_listen {
handles.push(spawn_quic_listener(listener.parse().die("invalid listener address"), config.clone(), quic_config.clone())); handles.push(spawn_quic_listener(listener, config.clone(), quic_config.clone()));
} }
} }
#[cfg(not(all(feature = "quic", feature = "incoming")))] #[cfg(not(all(feature = "quic", feature = "incoming")))]
die!("quic_listen non-empty but quic or (s2s-incoming and c2s-incoming) disabled at compile-time"); die!("quic_listen non-empty but quic or (s2s-incoming and c2s-incoming) disabled at compile-time");
} }
if !main_config.outgoing_listen.is_empty() { if !outgoing_listen.is_empty() {
#[cfg(feature = "outgoing")] #[cfg(feature = "outgoing")]
{ {
let outgoing_cfg = main_config.get_outgoing_cfg(certs_key.clone()); let outgoing_cfg = main_config.get_outgoing_cfg(certs_key.clone());
for listener in main_config.outgoing_listen.iter() { for listener in outgoing_listen {
handles.push(spawn_outgoing_listener(listener.parse().die("invalid listener address"), outgoing_cfg.clone())); handles.push(spawn_outgoing_listener(listener, outgoing_cfg.clone()));
} }
} }
#[cfg(not(feature = "outgoing"))] #[cfg(not(feature = "outgoing"))]

View File

@ -7,9 +7,7 @@ use crate::{
stanzafilter::StanzaFilter, stanzafilter::StanzaFilter,
}; };
use anyhow::Result; use anyhow::Result;
use die::Die;
use log::{error, info}; use log::{error, info};
use std::net::SocketAddr;
use tokio::{net::TcpListener, task::JoinHandle}; use tokio::{net::TcpListener, task::JoinHandle};
async fn handle_outgoing_connection(stream: tokio::net::TcpStream, client_addr: &mut Context<'_>, config: OutgoingConfig) -> Result<()> { async fn handle_outgoing_connection(stream: tokio::net::TcpStream, client_addr: &mut Context<'_>, config: OutgoingConfig) -> Result<()> {
@ -48,9 +46,8 @@ async fn handle_outgoing_connection(stream: tokio::net::TcpStream, client_addr:
shuffle_rd_wr_filter_only(in_rd, in_wr, out_rd, out_wr, is_c2s, max_stanza_size_bytes, client_addr, in_filter).await shuffle_rd_wr_filter_only(in_rd, in_wr, out_rd, out_wr, is_c2s, max_stanza_size_bytes, client_addr, in_filter).await
} }
pub fn spawn_outgoing_listener(local_addr: SocketAddr, config: OutgoingConfig) -> JoinHandle<Result<()>> { pub fn spawn_outgoing_listener(listener: TcpListener, config: OutgoingConfig) -> JoinHandle<Result<()>> {
tokio::spawn(async move { tokio::spawn(async move {
let listener = TcpListener::bind(&local_addr).await.die("cannot listen on port/interface");
loop { loop {
let (stream, client_addr) = listener.accept().await?; let (stream, client_addr) = listener.accept().await?;
let config = config.clone(); let config = config.clone();

View File

@ -7,12 +7,13 @@ use anyhow::Result;
use die::Die; use die::Die;
use futures::StreamExt; use futures::StreamExt;
use log::{error, info}; use log::{error, info};
use quinn::ServerConfig; use quinn::{Endpoint, EndpointConfig, ServerConfig};
use std::{net::SocketAddr, sync::Arc}; use std::{net::UdpSocket, sync::Arc};
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
pub fn spawn_quic_listener(local_addr: SocketAddr, config: CloneableConfig, server_config: ServerConfig) -> JoinHandle<Result<()>> { pub fn spawn_quic_listener(udp_socket: UdpSocket, config: CloneableConfig, server_config: ServerConfig) -> JoinHandle<Result<()>> {
let (_endpoint, mut incoming) = quinn::Endpoint::server(server_config, local_addr).die("cannot listen on port/interface"); let local_addr = udp_socket.local_addr().die("cannot get local_addr for quic socket");
let (_endpoint, mut incoming) = Endpoint::new(EndpointConfig::default(), Some(server_config), udp_socket).die("cannot listen on port/interface");
tokio::spawn(async move { tokio::spawn(async move {
// when could this return None, do we quit? // when could this return None, do we quit?
while let Some(incoming_conn) = incoming.next().await { while let Some(incoming_conn) = incoming.next().await {

132
src/systemd.rs Normal file
View File

@ -0,0 +1,132 @@
use anyhow::{anyhow, bail, Result};
use nix::sys::socket::{getsockname, getsockopt, AddressFamily, SockType, SockaddrLike, SockaddrStorage};
use std::{
env,
net::{TcpListener, UdpSocket},
os::unix::{
io::{FromRawFd, IntoRawFd, RawFd},
net::{UnixDatagram, UnixListener},
},
process,
};
/// Minimum FD number used by systemd for passing sockets.
const SD_LISTEN_FDS_START: RawFd = 3;
/// File descriptor passed by systemd to socket-activated services.
///
/// See <https://www.freedesktop.org/software/systemd/man/systemd.socket.html>.
#[derive(Debug, Clone)]
pub struct FileDescriptor {
raw_fd: RawFd,
tcp_not_udp: bool,
inet_not_unix: bool,
pub name: Option<String>,
}
pub enum Listener {
Tcp(Box<dyn FnOnce() -> TcpListener>),
Udp(Box<dyn FnOnce() -> UdpSocket>),
UnixListener(Box<dyn FnOnce() -> UnixListener>),
UnixDatagram(Box<dyn FnOnce() -> UnixDatagram>),
}
impl FileDescriptor {
pub fn name(self) -> Option<String> {
self.name
}
pub fn listener(&self) -> Listener {
let raw_fd = self.raw_fd;
match (self.tcp_not_udp, self.inet_not_unix) {
(true, true) => Listener::Tcp(Box::new(move || unsafe { TcpListener::from_raw_fd(raw_fd) })),
(false, true) => Listener::Udp(Box::new(move || unsafe { UdpSocket::from_raw_fd(raw_fd) })),
(true, false) => Listener::UnixListener(Box::new(move || unsafe { UnixListener::from_raw_fd(raw_fd) })),
(false, false) => Listener::UnixDatagram(Box::new(move || unsafe { UnixDatagram::from_raw_fd(raw_fd) })),
}
}
}
/// Check for named file descriptors passed by systemd.
///
/// Like `receive_descriptors`, but this will also return a vector of names
/// associated with each file descriptor.
pub fn receive_descriptors_with_names(unset_env: bool) -> Result<Vec<FileDescriptor>> {
let pid = env::var("LISTEN_PID");
let fds = env::var("LISTEN_FDS");
let fdnames = env::var("LISTEN_FDNAMES");
log::trace!("LISTEN_PID = {:?}; LISTEN_FDS = {:?}; LISTEN_FDNAMES = {:?}", pid, fds, fdnames);
if unset_env {
env::remove_var("LISTEN_PID");
env::remove_var("LISTEN_FDS");
env::remove_var("LISTEN_FDNAMES");
}
let pid = pid
.map_err(|e| anyhow!("failed to get LISTEN_PID: {}", e))?
.parse::<u32>()
.map_err(|e| anyhow!("failed to parse LISTEN_PID: {}", e))?;
let fds = fds
.map_err(|e| anyhow!("failed to get LISTEN_FDS: {}", e))?
.parse::<usize>()
.map_err(|e| anyhow!("failed to parse LISTEN_FDS: {}", e))?;
if process::id() != pid {
bail!("PID mismatch");
}
let names = fdnames.map(|n| n.split(':').map(String::from).collect()).unwrap_or_else(|_| Vec::new());
socks_from_fds(fds, names)
}
fn socks_from_fds(num_fds: usize, names: Vec<String>) -> Result<Vec<FileDescriptor>> {
let mut descriptors = Vec::with_capacity(num_fds);
let mut names = names.into_iter();
for fd_offset in 0..num_fds {
let name = names.next();
let raw_fd: RawFd = SD_LISTEN_FDS_START
.checked_add(fd_offset as i32)
.ok_or_else(|| anyhow!("overlarge file descriptor index: {}", num_fds))?;
if !sock_listening(raw_fd) {
continue;
}
let tcp_not_udp = match sock_type(raw_fd) {
Some(SockType::Stream) => true,
Some(SockType::Datagram) => false,
_ => continue,
};
let inet_not_unix = match address_family(raw_fd) {
Some(AddressFamily::Inet) | Some(AddressFamily::Inet6) => true,
Some(AddressFamily::Unix) => false,
_ => continue,
};
descriptors.push(FileDescriptor {
raw_fd,
tcp_not_udp,
inet_not_unix,
name,
});
}
Ok(descriptors)
}
fn sock_listening(raw_fd: RawFd) -> bool {
getsockopt(raw_fd, nix::sys::socket::sockopt::AcceptConn).unwrap_or(false)
}
fn sock_type(raw_fd: RawFd) -> Option<SockType> {
getsockopt(raw_fd, nix::sys::socket::sockopt::SockType).ok()
}
fn address_family(raw_fd: RawFd) -> Option<AddressFamily> {
getsockname::<SockaddrStorage>(raw_fd).ok().and_then(|addr| addr.family())
}
impl IntoRawFd for FileDescriptor {
fn into_raw_fd(self) -> RawFd {
self.raw_fd
}
}

View File

@ -1,19 +1,18 @@
use crate::common::incoming::{shuffle_rd_wr_filter, CloneableConfig, ServerCerts};
use std::net::SocketAddr;
use crate::{ use crate::{
common::{first_bytes_match, to_str, IN_BUFFER_SIZE}, common::{
first_bytes_match,
incoming::{shuffle_rd_wr_filter, CloneableConfig, ServerCerts},
to_str, IN_BUFFER_SIZE,
},
context::Context, context::Context,
in_out::{StanzaRead, StanzaWrite}, in_out::{StanzaRead, StanzaWrite},
slicesubsequence::SliceSubsequence, slicesubsequence::SliceSubsequence,
stanzafilter::{StanzaFilter, StanzaReader}, stanzafilter::{StanzaFilter, StanzaReader},
}; };
use anyhow::{bail, Result}; use anyhow::{bail, Result};
use die::Die;
use log::{error, info, trace}; use log::{error, info, trace};
use rustls::{ServerConfig, ServerConnection}; use rustls::{ServerConfig, ServerConnection};
use std::{net::SocketAddr, sync::Arc};
use std::sync::Arc;
use tokio::{ use tokio::{
io::{AsyncBufReadExt, AsyncWriteExt, BufStream}, io::{AsyncBufReadExt, AsyncWriteExt, BufStream},
net::TcpListener, net::TcpListener,
@ -25,9 +24,9 @@ pub fn tls_acceptor(server_config: ServerConfig) -> TlsAcceptor {
TlsAcceptor::from(Arc::new(server_config)) TlsAcceptor::from(Arc::new(server_config))
} }
pub fn spawn_tls_listener(local_addr: SocketAddr, config: CloneableConfig, acceptor: TlsAcceptor) -> JoinHandle<Result<()>> { pub fn spawn_tls_listener(listener: TcpListener, config: CloneableConfig, acceptor: TlsAcceptor) -> JoinHandle<Result<()>> {
tokio::spawn(async move { tokio::spawn(async move {
let listener = TcpListener::bind(&local_addr).await.die("cannot listen on port/interface"); let local_addr = listener.local_addr()?;
loop { loop {
let (stream, client_addr) = listener.accept().await?; let (stream, client_addr) = listener.accept().await?;
let config = config.clone(); let config = config.clone();