Browse Source

Implement optional support for systemd socket activation

master
Travis Burtrum 5 months ago
parent
commit
63718b3af3
  1. 22
      Cargo.lock
  2. 6
      Cargo.toml
  3. 3
      build.rs
  4. 60
      check-all-features.sh
  5. 5
      src/lib.rs
  6. 71
      src/main.rs
  7. 5
      src/outgoing.rs
  8. 9
      src/quic/incoming.rs
  9. 132
      src/systemd.rs
  10. 17
      src/tls/incoming.rs

22
Cargo.lock generated

@ -606,6 +606,15 @@ version = "2.5.0" @@ -606,6 +606,15 @@ version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d"
[[package]]
name = "memoffset"
version = "0.6.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5aa361d4faea93603064a027415f07bd8e1d5c88c9fbf68bf56a285428fd79ce"
dependencies = [
"autocfg",
]
[[package]]
name = "mime"
version = "0.3.16"
@ -633,6 +642,18 @@ dependencies = [ @@ -633,6 +642,18 @@ dependencies = [
"windows-sys",
]
[[package]]
name = "nix"
version = "0.24.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "195cdbc1741b8134346d515b3a56a1c94b0912758009cfd53f99ea0f57b065fc"
dependencies = [
"bitflags",
"cfg-if",
"libc",
"memoffset",
]
[[package]]
name = "num_cpus"
version = "1.13.1"
@ -1607,6 +1628,7 @@ dependencies = [ @@ -1607,6 +1628,7 @@ dependencies = [
"futures-util",
"lazy_static",
"log",
"nix",
"quinn",
"rand",
"reqwest",

6
Cargo.toml

@ -60,8 +60,11 @@ rustls-pemfile = { version = "1.0.0", optional = true } @@ -60,8 +60,11 @@ rustls-pemfile = { version = "1.0.0", optional = true }
tokio-tungstenite = { version = "0.17", optional = true, default-features = false }
futures-util = { version = "0.3", default-features = false, features = ["async-await", "sink", "std"], optional = true }
# systemd dep
nix = { version = "0.24", optional = true, default-features = false, features = ["socket"]}
[features]
default = ["c2s-incoming", "c2s-outgoing", "s2s-incoming", "s2s-outgoing", "tls", "quic", "websocket", "logging", "tls-ca-roots-native"]
default = ["c2s-incoming", "c2s-outgoing", "s2s-incoming", "s2s-outgoing", "tls", "quic", "websocket", "logging", "tls-ca-roots-native", "systemd"]
# you must pick one of these or the other, not both: todo: enable picking both and choosing at runtime
# don't need either of these if only doing c2s-incoming
@ -88,6 +91,7 @@ quic = ["quinn", "rustls"] @@ -88,6 +91,7 @@ quic = ["quinn", "rustls"]
websocket = ["tokio-tungstenite", "futures-util", "tls"] # websocket+incoming also enables incoming TLS support as it's free
logging = ["rand", "env_logger"]
systemd = ["nix"]
[dev-dependencies]
serde_json = "1.0"

3
build.rs

@ -18,6 +18,8 @@ fn main() { @@ -18,6 +18,8 @@ fn main() {
"websocket",
"tls-ca-roots-native",
"tls-ca-roots-bundled",
"logging",
"systemd",
];
let optional_deps = [
"rustls",
@ -33,6 +35,7 @@ fn main() { @@ -33,6 +35,7 @@ fn main() {
"webpki-roots",
"env-logger",
"rand",
"nix",
];
let mut features = Vec::new();
let mut optional = Vec::new();

60
check-all-features.sh

@ -1,6 +1,13 @@ @@ -1,6 +1,13 @@
#!/bin/bash
threads="$1"
set -euo pipefail
# if we have access to nproc, divide that by 2, otherwise use 1 thread by default
[ "$threads" == "" ] && threads=$(($(nproc || echo 2) / 2))
echo "threads: $threads"
export RUSTFLAGS=-Awarnings
show() {
@ -22,38 +29,51 @@ perm_lines() { @@ -22,38 +29,51 @@ perm_lines() {
}
perms() {
perm_lines "$@" | tr ' ' ',' | sort -u | tr '\n' ' '
perm_lines "$@" | tr ' ' ',' | sort -u
}
echo_cargo() {
#echo cargo run "$@" -- -v
#cargo run "$@" -- -v
echo cargo check "$@"
cargo check "$@"
perms_optional() {
perm_lines "$@" | tr ' ' ',' | sort -u | sed 's/^/,/'
}
echo_cargo
for optional in "" ",logging"
do
for proto in $(perms tls quic websocket)
all_features() {
for optional in "" $(perms_optional logging systemd)
do
for direction in $(perms c2s-incoming c2s-outgoing s2s-incoming s2s-outgoing)
for proto in $(perms tls quic websocket)
do
for ca_roots in tls-ca-roots-native tls-ca-roots-bundled
for direction in $(perms c2s-incoming c2s-outgoing s2s-incoming s2s-outgoing)
do
echo_cargo --no-default-features --features $direction,$proto,$ca_roots$optional
for ca_roots in tls-ca-roots-native tls-ca-roots-bundled
do
echo $direction,$proto,$ca_roots$optional
done
done
done
done
done
for optional in "" ",logging"
do
for proto in $(perms tls quic websocket)
for optional in "" $(perms_optional logging systemd)
do
echo_cargo --no-default-features --features c2s-incoming,$proto$optional
for proto in $(perms tls quic websocket)
do
echo c2s-incoming,$proto$optional
done
done
done
}
echo_cargo() {
set -euo pipefail
#echo cargo run "$@" -- -v
#cargo run "$@" -- -v
echo cargo check "$@"
cargo check "$@"
}
#all_features | sort -u | wc -l; exit 0
export -f echo_cargo
echo_cargo
all_features | sort | xargs -n1 --max-procs=$threads bash -c 'echo_cargo --no-default-features --features "$@" || exit 255' _
echo good!

5
src/lib.rs

@ -20,5 +20,8 @@ pub mod websocket; @@ -20,5 +20,8 @@ pub mod websocket;
#[cfg(any(feature = "s2s-incoming", feature = "outgoing"))]
pub mod verify;
mod context;
#[cfg(feature = "nix")]
pub mod systemd;
pub mod context;
pub mod in_out;

71
src/main.rs

@ -3,10 +3,17 @@ use anyhow::Result; @@ -3,10 +3,17 @@ use anyhow::Result;
use die::{die, Die};
use log::{debug, error, info};
use serde_derive::Deserialize;
use std::{ffi::OsString, fs::File, io::Read, iter::Iterator, net::SocketAddr, path::Path, sync::Arc};
use tokio::task::JoinHandle;
use std::{
ffi::OsString,
fs::File,
io::Read,
iter::Iterator,
net::{SocketAddr, UdpSocket},
path::Path,
sync::Arc,
};
use tokio::{net::TcpListener, task::JoinHandle};
use xmpp_proxy::common::certs_key::CertsKey;
#[cfg(feature = "outgoing")]
use xmpp_proxy::{common::outgoing::OutgoingConfig, outgoing::spawn_outgoing_listener};
@ -128,13 +135,53 @@ async fn main() { @@ -128,13 +135,53 @@ async fn main() {
die!("log_level or log_style defined in config but logging disabled at compile-time");
}
let mut incoming_listen = Vec::new();
for a in main_config.incoming_listen.iter() {
incoming_listen.push(TcpListener::bind(a).await.die("cannot listen on port/interface"));
}
let mut quic_listen = Vec::new();
for a in main_config.quic_listen.iter() {
quic_listen.push(UdpSocket::bind(a).die("cannot listen on port/interface"));
}
let mut outgoing_listen = Vec::new();
for a in main_config.outgoing_listen.iter() {
outgoing_listen.push(TcpListener::bind(a).await.die("cannot listen on port/interface"));
}
#[cfg(feature = "nix")]
if let Ok(fds) = xmpp_proxy::systemd::receive_descriptors_with_names(true) {
use xmpp_proxy::systemd::Listener;
for fd in fds {
match fd.listener() {
Listener::Tcp(tcp_listener) => {
let tcp_listener = TcpListener::from_std(tcp_listener()).die("cannot open systemd TcpListener");
if let Some(name) = fd.name().map(|n| n.to_ascii_lowercase()) {
if name.starts_with("in") {
incoming_listen.push(tcp_listener);
} else if name.starts_with("out") {
outgoing_listen.push(tcp_listener);
} else {
die!("systemd socket name must start with 'in' or 'out' but is '{}'", name);
}
} else {
// what to do here... for now we will require names
// todo: possibly in future if local_addr is localhost or private ranges assume outgoing, otherwise incoming?
die!("systemd TCP socket activation requires name that starts with 'in' or 'out'");
}
}
Listener::Udp(udp_socket) => quic_listen.push(udp_socket()),
_ => continue,
}
}
}
#[cfg(feature = "incoming")]
let config = main_config.get_cloneable_cfg();
let certs_key = Arc::new(CertsKey::new(main_config.certs_key()));
let mut handles: Vec<JoinHandle<Result<()>>> = Vec::new();
if !main_config.incoming_listen.is_empty() {
if !incoming_listen.is_empty() {
#[cfg(all(any(feature = "tls", feature = "websocket"), feature = "incoming"))]
{
use xmpp_proxy::{
@ -145,14 +192,14 @@ async fn main() { @@ -145,14 +192,14 @@ async fn main() {
die!("one of c2s_target/s2s_target must be defined if incoming_listen is non-empty");
}
let acceptor = tls_acceptor(server_config(certs_key.clone()).die("invalid cert/key ?"));
for listener in main_config.incoming_listen.iter() {
handles.push(spawn_tls_listener(listener.parse().die("invalid listener address"), config.clone(), acceptor.clone()));
for listener in incoming_listen {
handles.push(spawn_tls_listener(listener, config.clone(), acceptor.clone()));
}
}
#[cfg(not(all(any(feature = "tls", feature = "websocket"), feature = "incoming")))]
die!("incoming_listen non-empty but (tls or websocket) or (s2s-incoming and c2s-incoming) disabled at compile-time");
}
if !main_config.quic_listen.is_empty() {
if !quic_listen.is_empty() {
#[cfg(all(feature = "quic", feature = "incoming"))]
{
use xmpp_proxy::{
@ -163,19 +210,19 @@ async fn main() { @@ -163,19 +210,19 @@ async fn main() {
die!("one of c2s_target/s2s_target must be defined if quic_listen is non-empty");
}
let quic_config = quic_server_config(server_config(certs_key.clone()).die("invalid cert/key ?"));
for listener in main_config.quic_listen.iter() {
handles.push(spawn_quic_listener(listener.parse().die("invalid listener address"), config.clone(), quic_config.clone()));
for listener in quic_listen {
handles.push(spawn_quic_listener(listener, config.clone(), quic_config.clone()));
}
}
#[cfg(not(all(feature = "quic", feature = "incoming")))]
die!("quic_listen non-empty but quic or (s2s-incoming and c2s-incoming) disabled at compile-time");
}
if !main_config.outgoing_listen.is_empty() {
if !outgoing_listen.is_empty() {
#[cfg(feature = "outgoing")]
{
let outgoing_cfg = main_config.get_outgoing_cfg(certs_key.clone());
for listener in main_config.outgoing_listen.iter() {
handles.push(spawn_outgoing_listener(listener.parse().die("invalid listener address"), outgoing_cfg.clone()));
for listener in outgoing_listen {
handles.push(spawn_outgoing_listener(listener, outgoing_cfg.clone()));
}
}
#[cfg(not(feature = "outgoing"))]

5
src/outgoing.rs

@ -7,9 +7,7 @@ use crate::{ @@ -7,9 +7,7 @@ use crate::{
stanzafilter::StanzaFilter,
};
use anyhow::Result;
use die::Die;
use log::{error, info};
use std::net::SocketAddr;
use tokio::{net::TcpListener, task::JoinHandle};
async fn handle_outgoing_connection(stream: tokio::net::TcpStream, client_addr: &mut Context<'_>, config: OutgoingConfig) -> Result<()> {
@ -48,9 +46,8 @@ async fn handle_outgoing_connection(stream: tokio::net::TcpStream, client_addr: @@ -48,9 +46,8 @@ async fn handle_outgoing_connection(stream: tokio::net::TcpStream, client_addr:
shuffle_rd_wr_filter_only(in_rd, in_wr, out_rd, out_wr, is_c2s, max_stanza_size_bytes, client_addr, in_filter).await
}
pub fn spawn_outgoing_listener(local_addr: SocketAddr, config: OutgoingConfig) -> JoinHandle<Result<()>> {
pub fn spawn_outgoing_listener(listener: TcpListener, config: OutgoingConfig) -> JoinHandle<Result<()>> {
tokio::spawn(async move {
let listener = TcpListener::bind(&local_addr).await.die("cannot listen on port/interface");
loop {
let (stream, client_addr) = listener.accept().await?;
let config = config.clone();

9
src/quic/incoming.rs

@ -7,12 +7,13 @@ use anyhow::Result; @@ -7,12 +7,13 @@ use anyhow::Result;
use die::Die;
use futures::StreamExt;
use log::{error, info};
use quinn::ServerConfig;
use std::{net::SocketAddr, sync::Arc};
use quinn::{Endpoint, EndpointConfig, ServerConfig};
use std::{net::UdpSocket, sync::Arc};
use tokio::task::JoinHandle;
pub fn spawn_quic_listener(local_addr: SocketAddr, config: CloneableConfig, server_config: ServerConfig) -> JoinHandle<Result<()>> {
let (_endpoint, mut incoming) = quinn::Endpoint::server(server_config, local_addr).die("cannot listen on port/interface");
pub fn spawn_quic_listener(udp_socket: UdpSocket, config: CloneableConfig, server_config: ServerConfig) -> JoinHandle<Result<()>> {
let local_addr = udp_socket.local_addr().die("cannot get local_addr for quic socket");
let (_endpoint, mut incoming) = Endpoint::new(EndpointConfig::default(), Some(server_config), udp_socket).die("cannot listen on port/interface");
tokio::spawn(async move {
// when could this return None, do we quit?
while let Some(incoming_conn) = incoming.next().await {

132
src/systemd.rs

@ -0,0 +1,132 @@ @@ -0,0 +1,132 @@
use anyhow::{anyhow, bail, Result};
use nix::sys::socket::{getsockname, getsockopt, AddressFamily, SockType, SockaddrLike, SockaddrStorage};
use std::{
env,
net::{TcpListener, UdpSocket},
os::unix::{
io::{FromRawFd, IntoRawFd, RawFd},
net::{UnixDatagram, UnixListener},
},
process,
};
/// Minimum FD number used by systemd for passing sockets.
const SD_LISTEN_FDS_START: RawFd = 3;
/// File descriptor passed by systemd to socket-activated services.
///
/// See <https://www.freedesktop.org/software/systemd/man/systemd.socket.html>.
#[derive(Debug, Clone)]
pub struct FileDescriptor {
raw_fd: RawFd,
tcp_not_udp: bool,
inet_not_unix: bool,
pub name: Option<String>,
}
pub enum Listener {
Tcp(Box<dyn FnOnce() -> TcpListener>),
Udp(Box<dyn FnOnce() -> UdpSocket>),
UnixListener(Box<dyn FnOnce() -> UnixListener>),
UnixDatagram(Box<dyn FnOnce() -> UnixDatagram>),
}
impl FileDescriptor {
pub fn name(self) -> Option<String> {
self.name
}
pub fn listener(&self) -> Listener {
let raw_fd = self.raw_fd;
match (self.tcp_not_udp, self.inet_not_unix) {
(true, true) => Listener::Tcp(Box::new(move || unsafe { TcpListener::from_raw_fd(raw_fd) })),
(false, true) => Listener::Udp(Box::new(move || unsafe { UdpSocket::from_raw_fd(raw_fd) })),
(true, false) => Listener::UnixListener(Box::new(move || unsafe { UnixListener::from_raw_fd(raw_fd) })),
(false, false) => Listener::UnixDatagram(Box::new(move || unsafe { UnixDatagram::from_raw_fd(raw_fd) })),
}
}
}
/// Check for named file descriptors passed by systemd.
///
/// Like `receive_descriptors`, but this will also return a vector of names
/// associated with each file descriptor.
pub fn receive_descriptors_with_names(unset_env: bool) -> Result<Vec<FileDescriptor>> {
let pid = env::var("LISTEN_PID");
let fds = env::var("LISTEN_FDS");
let fdnames = env::var("LISTEN_FDNAMES");
log::trace!("LISTEN_PID = {:?}; LISTEN_FDS = {:?}; LISTEN_FDNAMES = {:?}", pid, fds, fdnames);
if unset_env {
env::remove_var("LISTEN_PID");
env::remove_var("LISTEN_FDS");
env::remove_var("LISTEN_FDNAMES");
}
let pid = pid
.map_err(|e| anyhow!("failed to get LISTEN_PID: {}", e))?
.parse::<u32>()
.map_err(|e| anyhow!("failed to parse LISTEN_PID: {}", e))?;
let fds = fds
.map_err(|e| anyhow!("failed to get LISTEN_FDS: {}", e))?
.parse::<usize>()
.map_err(|e| anyhow!("failed to parse LISTEN_FDS: {}", e))?;
if process::id() != pid {
bail!("PID mismatch");
}
let names = fdnames.map(|n| n.split(':').map(String::from).collect()).unwrap_or_else(|_| Vec::new());
socks_from_fds(fds, names)
}
fn socks_from_fds(num_fds: usize, names: Vec<String>) -> Result<Vec<FileDescriptor>> {
let mut descriptors = Vec::with_capacity(num_fds);
let mut names = names.into_iter();
for fd_offset in 0..num_fds {
let name = names.next();
let raw_fd: RawFd = SD_LISTEN_FDS_START
.checked_add(fd_offset as i32)
.ok_or_else(|| anyhow!("overlarge file descriptor index: {}", num_fds))?;
if !sock_listening(raw_fd) {
continue;
}
let tcp_not_udp = match sock_type(raw_fd) {
Some(SockType::Stream) => true,
Some(SockType::Datagram) => false,
_ => continue,
};
let inet_not_unix = match address_family(raw_fd) {
Some(AddressFamily::Inet) | Some(AddressFamily::Inet6) => true,
Some(AddressFamily::Unix) => false,
_ => continue,
};
descriptors.push(FileDescriptor {
raw_fd,
tcp_not_udp,
inet_not_unix,
name,
});
}
Ok(descriptors)
}
fn sock_listening(raw_fd: RawFd) -> bool {
getsockopt(raw_fd, nix::sys::socket::sockopt::AcceptConn).unwrap_or(false)
}
fn sock_type(raw_fd: RawFd) -> Option<SockType> {
getsockopt(raw_fd, nix::sys::socket::sockopt::SockType).ok()
}
fn address_family(raw_fd: RawFd) -> Option<AddressFamily> {
getsockname::<SockaddrStorage>(raw_fd).ok().and_then(|addr| addr.family())
}
impl IntoRawFd for FileDescriptor {
fn into_raw_fd(self) -> RawFd {
self.raw_fd
}
}

17
src/tls/incoming.rs

@ -1,19 +1,18 @@ @@ -1,19 +1,18 @@
use crate::common::incoming::{shuffle_rd_wr_filter, CloneableConfig, ServerCerts};
use std::net::SocketAddr;
use crate::{
common::{first_bytes_match, to_str, IN_BUFFER_SIZE},
common::{
first_bytes_match,
incoming::{shuffle_rd_wr_filter, CloneableConfig, ServerCerts},
to_str, IN_BUFFER_SIZE,
},
context::Context,
in_out::{StanzaRead, StanzaWrite},
slicesubsequence::SliceSubsequence,
stanzafilter::{StanzaFilter, StanzaReader},
};
use anyhow::{bail, Result};
use die::Die;
use log::{error, info, trace};
use rustls::{ServerConfig, ServerConnection};
use std::sync::Arc;
use std::{net::SocketAddr, sync::Arc};
use tokio::{
io::{AsyncBufReadExt, AsyncWriteExt, BufStream},
net::TcpListener,
@ -25,9 +24,9 @@ pub fn tls_acceptor(server_config: ServerConfig) -> TlsAcceptor { @@ -25,9 +24,9 @@ pub fn tls_acceptor(server_config: ServerConfig) -> TlsAcceptor {
TlsAcceptor::from(Arc::new(server_config))
}
pub fn spawn_tls_listener(local_addr: SocketAddr, config: CloneableConfig, acceptor: TlsAcceptor) -> JoinHandle<Result<()>> {
pub fn spawn_tls_listener(listener: TcpListener, config: CloneableConfig, acceptor: TlsAcceptor) -> JoinHandle<Result<()>> {
tokio::spawn(async move {
let listener = TcpListener::bind(&local_addr).await.die("cannot listen on port/interface");
let local_addr = listener.local_addr()?;
loop {
let (stream, client_addr) = listener.accept().await?;
let config = config.clone();

Loading…
Cancel
Save