diff --git a/Cargo.toml b/Cargo.toml index b2fffae..c11e6e0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,10 +42,11 @@ tokio-rustls = { version = "0.23", optional = true } # outgoing deps lazy_static = { version = "1.4", optional = true } trust-dns-resolver = { version = "0.21", optional = true } +# todo: feature+code for dns-over-rustls #trust-dns-resolver = { version = "0.21", features = ["dns-over-rustls"], optional = true } -# todo: feature to swap between webpki-roots and rustls-native-certs webpki-roots = { version = "0.22", optional = true } rustls-native-certs = { version = "0.6", optional = true } +# todo: feed reqwest the roots we already have reqwest = { version = "0.11", optional = true, default-features = false, features = ["rustls-tls-native-roots", "json", "gzip", "trust-dns"] } # quic deps @@ -56,20 +57,39 @@ rustls = { version = "0.20.2", optional = true } rustls-pemfile = { version = "1.0.0", optional = true } # websocket deps +# todo: fix up the situation with these roots #tokio-tungstenite = { version = "0.17", optional = true, features = ["rustls-tls-webpki-roots"] } tokio-tungstenite = { version = "0.17", optional = true, features = ["rustls-tls-native-roots"] } futures-util = { version = "0.3", default-features = false, features = ["async-await", "sink", "std"], optional = true } [features] -default = ["incoming", "outgoing", "quic", "websocket", "logging"] -incoming = ["tokio-rustls", "rustls-pemfile", "rustls"] -outgoing = ["tokio-rustls", "trust-dns-resolver", "rustls-native-certs", "lazy_static", "rustls", "reqwest", "rustls-pemfile"] -quic = ["quinn", "rustls-pemfile", "rustls"] -websocket = ["tokio-tungstenite", "futures-util", "tokio-rustls", "rustls-pemfile", "rustls"] -logging = ["rand", "env_logger"] +default = ["c2s-incoming", "c2s-outgoing", "s2s-incoming", "s2s-outgoing", "tls", "quic", "websocket", "logging", "tls-ca-roots-native"] -[package.metadata.cargo-all-features] -skip_optional_dependencies = true +# you must pick one of these or the other, not both: todo: enable picking both and choosing at runtime +# don't need either of these if only doing c2s-incoming +tls-ca-roots-native = ["rustls-native-certs", "lazy_static", "tokio-rustls"] # this loads CA certs from your OS +tls-ca-roots-bundled = ["webpki-roots"] # this bundles CA certs in the binary + +# internal use only, ignore +srv = ["tokio-rustls", "trust-dns-resolver", "lazy_static", "reqwest"] +incoming = ["rustls-pemfile"] +outgoing = ["srv"] +c2s = [] +s2s = ["srv", "rustls-pemfile"] + +# you must pick one or more of these, you may pick them all +c2s-incoming = ["incoming", "c2s",] +c2s-outgoing = ["outgoing", "c2s"] + +s2s-incoming = ["incoming", "s2s"] +s2s-outgoing = ["outgoing", "s2s"] + +# protocols you want to support todo: split out tls vs starttls ? +tls = ["tokio-rustls", "rustls"] +quic = ["quinn", "rustls"] +websocket = ["tokio-tungstenite", "futures-util", "tls"] # websocket+incoming also enables incoming TLS support as it's free + +logging = ["rand", "env_logger"] [dev-dependencies] serde_json = "1.0" diff --git a/README.md b/README.md index 8b16451..db4be28 100644 --- a/README.md +++ b/README.md @@ -131,15 +131,38 @@ s2s_ports = {15268} If you are a grumpy power user who wants to build xmpp-proxy with exactly the features you want, nothing less, nothing more, this section is for you! -xmpp-proxy has 5 compile-time features: - 1. `incoming` - enables `incoming_listen` config option for reverse proxy STARTTLS/TLS - 2. `outgoing` - enables `outgoing_listen` config option for outgoing proxy STARTTLS/TLS - 3. `quic` - enables `quic_listen` config option for reverse proxy QUIC, and QUIC support for `outgoing` if it is enabled - 4. `websocket` - enables reverse proxy WebSocket on `incoming_listen`, and WebSocket support for `outgoing` if it is enabled - 5. `logging` - enables configurable logging +xmpp-proxy has multiple compile-time features, some of which are required, they are grouped as such: -So to build only supporting reverse proxy STARTTLS/TLS, no QUIC, run: `cargo build --release --no-default-features --features incoming` -To build a reverse proxy only, but supporting all of STARTTLS/TLS/QUIC, run: `cargo build --release --no-default-features --features incoming,quic` +choose between 1-4 directions: + 1. `c2s-incoming` - enables a server to accept incoming c2s connections + 2. `c2s-outgoing` - enables a client to make outgoing c2s connections + 3. `s2s-incoming` - enables a server to accept incoming s2s connections + 4. `s2s-outgoing` - enables a server to make outgoing s2s connections + +choose between 1-3 transport protocols: + 1. `tls` - enables STARTTLS/TLS support + 2. `quic` - enables QUIC support + 3. `websocket` - enables WebSocket support, also enables TLS incoming support if the appropriate directions are enabled + +choose exactly 1 of these methods to get trusted CA roots, not needed if only `c2s-incoming` is enabled: + 1. `tls-ca-roots-native` - reads CA roots from operating system + 2. `tls-ca-roots-bundled` - bundles CA roots into the binary from the `webpki-roots` project + +choose any of these optional features: + 1. `logging` - enables configurable logging + +So to build only supporting reverse proxy STARTTLS/TLS, no QUIC, run: `cargo build --release --no-default-features --features c2s-incoming,s2s-incoming,tls` +To build a reverse proxy only, but supporting all of STARTTLS/TLS/QUIC, run: `cargo build --release --no-default-features --features c2s-incoming,s2s-incoming,tls,quic` + +#### Development + +1. `check-all-features.sh` is used to check compilation with all supported feature permutations +2. `integration/test.sh` uses [Rootless podman](https://wiki.archlinux.org/title/Podman#Rootless_Podman) to run many tests + through xmpp-proxy on a real network with real dns, web, and xmpp servers, all of these should pass before pushing commits, + and write new tests to cover new functionality. +3. To submit code changes submit a PR on [github](https://github.com/moparisthebest/xmpp-proxy) or + [code.moparisthebest.com](https://code.moparisthebest.com/moparisthebest/xmpp-proxy) or send me a patch via email, + XMPP, fediverse, or carrier pigeon. #### License GNU/AGPLv3 - Check LICENSE.md for details diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..e4ff2e7 --- /dev/null +++ b/build.rs @@ -0,0 +1,83 @@ +use std::{env, fs::File, io::Write, path::Path}; + +fn main() { + println!("cargo:rerun-if-changed=build.rs"); + + let out_dir = env::var_os("OUT_DIR").unwrap(); + let dest_path = Path::new(&out_dir).join("version.rs"); + + let mut w = File::create(dest_path).unwrap(); + + let allowed_features = [ + "c2s-incoming", + "c2s-outgoing", + "s2s-incoming", + "s2s-outgoing", + "tls", + "quic", + "websocket", + "tls-ca-roots-native", + "tls-ca-roots-bundled", + ]; + let optional_deps = [ + "rustls", + "tokio-rustls", + "rustls-pemfile", + "quinn", + "tokio-tungstenite", + "futures-util", + "trust-dns-resolver", + "reqwest", + "lazy-static", + "rustls-native-certs", + "webpki-roots", + "env-logger", + "rand", + ]; + let mut features = Vec::new(); + let mut optional = Vec::new(); + for (mut key, value) in env::vars() { + //writeln!(&mut w, "{key}: {value}", ).unwrap(); + if value == "1" && key.starts_with("CARGO_FEATURE_") { + let mut key = key.split_off(14).replace("_", "-"); + key.make_ascii_lowercase(); + if allowed_features.contains(&key.as_str()) { + features.push(key); + } else if optional_deps.contains(&key.as_str()) { + optional.push(key); + } + } + } + features.sort_by(|a, b| { + allowed_features + .iter() + .position(|&r| r == a) + .unwrap() + .partial_cmp(&allowed_features.iter().position(|&r| r == b).unwrap()) + .unwrap() + }); + optional.sort_by(|a, b| { + optional_deps + .iter() + .position(|&r| r == a) + .unwrap() + .partial_cmp(&optional_deps.iter().position(|&r| r == b).unwrap()) + .unwrap() + }); + let features = features.join(","); + let optional = optional.join(","); + + let name = env!("CARGO_PKG_NAME"); + let version = env!("CARGO_PKG_VERSION"); + + let target = env::var("TARGET").unwrap(); + + writeln!( + &mut w, + "{{println!( +\"{name} {version} ({target}) +Features: {features} +Optional crates: {optional}\");}}" + ) + .unwrap(); +} diff --git a/check-all-features.sh b/check-all-features.sh new file mode 100644 index 0000000..06115c0 --- /dev/null +++ b/check-all-features.sh @@ -0,0 +1,59 @@ +#!/bin/bash +set -euo pipefail + +export RUSTFLAGS=-Awarnings + +show() { + local -a results=() + let idx=$2 + for (( j = 0; j < $1; j++ )); do + if (( idx % 2 )); then results=("${results[@]}" "${list[$j]}"); fi + let idx\>\>=1 + done + echo "${results[@]}" +} + +perm_lines() { + list=($@) + let n=${#list[@]} + for (( i = 1; i < 2**n; i++ )); do + show $n $i + done +} + +perms() { + perm_lines "$@" | tr ' ' ',' | sort -u | tr '\n' ' ' +} + +echo_cargo() { + #echo cargo run "$@" -- -v + #cargo run "$@" -- -v + echo cargo check "$@" + cargo check "$@" +} + +echo_cargo + +for optional in "" ",logging" +do + for proto in $(perms tls quic websocket) + do + for direction in $(perms c2s-incoming c2s-outgoing s2s-incoming s2s-outgoing) + do + for ca_roots in tls-ca-roots-native tls-ca-roots-bundled + do + echo_cargo --no-default-features --features $direction,$proto,$ca_roots$optional + done + done + done +done + +for optional in "" ",logging" +do + for proto in $(perms tls quic websocket) + do + echo_cargo --no-default-features --features c2s-incoming,$proto$optional + done +done + +echo good! diff --git a/src/in_out.rs b/src/in_out.rs index 7cde7e4..10ffb76 100644 --- a/src/in_out.rs +++ b/src/in_out.rs @@ -1,8 +1,8 @@ // Box, Box #[cfg(feature = "websocket")] -use crate::{from_ws, to_ws_new}; -use crate::{slicesubsequence::SliceSubsequence, trace, AsyncReadAndWrite, StanzaFilter, StanzaRead::*, StanzaReader, StanzaWrite::*}; +use crate::{from_ws, to_ws_new, AsyncReadAndWrite}; +use crate::{slicesubsequence::SliceSubsequence, trace, StanzaFilter, StanzaRead::*, StanzaReader, StanzaWrite::*}; use anyhow::{bail, Result}; #[cfg(feature = "websocket")] use futures_util::{ diff --git a/src/lib.rs b/src/lib.rs index d9493a8..d4cd79b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,6 +8,8 @@ use anyhow::bail; use std::net::SocketAddr; pub use log::{debug, error, info, log_enabled, trace}; + +#[cfg(feature = "s2s-incoming")] use rustls::{Certificate, ServerConnection}; pub fn to_str(buf: &[u8]) -> std::borrow::Cow<'_, str> { @@ -152,7 +154,10 @@ impl<'a> Context<'a> { } } -#[cfg(feature = "incoming")] +#[cfg(not(feature = "s2s-incoming"))] +pub type ServerCerts = (); + +#[cfg(feature = "s2s-incoming")] #[derive(Clone)] pub enum ServerCerts { Tls(&'static ServerConnection), @@ -160,10 +165,12 @@ pub enum ServerCerts { Quic(quinn::Connection), } +#[cfg(feature = "s2s-incoming")] impl ServerCerts { pub fn peer_certificates(&self) -> Option> { match self { ServerCerts::Tls(c) => c.peer_certificates().map(|c| c.to_vec()), + #[cfg(feature = "quic")] ServerCerts::Quic(c) => c.peer_identity().and_then(|v| v.downcast::>().ok()).map(|v| v.to_vec()), } } @@ -171,6 +178,7 @@ impl ServerCerts { pub fn sni(&self) -> Option { match self { ServerCerts::Tls(c) => c.sni_hostname().map(|s| s.to_string()), + #[cfg(feature = "quic")] ServerCerts::Quic(c) => c.handshake_data().and_then(|v| v.downcast::().ok()).and_then(|h| h.server_name), } } @@ -178,6 +186,7 @@ impl ServerCerts { pub fn alpn(&self) -> Option> { match self { ServerCerts::Tls(c) => c.alpn_protocol().map(|s| s.to_vec()), + #[cfg(feature = "quic")] ServerCerts::Quic(c) => c.handshake_data().and_then(|v| v.downcast::().ok()).and_then(|h| h.protocol), } } @@ -185,6 +194,7 @@ impl ServerCerts { pub fn is_tls(&self) -> bool { match self { ServerCerts::Tls(_) => true, + #[cfg(feature = "quic")] ServerCerts::Quic(_) => false, } } diff --git a/src/main.rs b/src/main.rs index c30f35e..0fa1768 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,9 +8,8 @@ use std::iter::Iterator; use std::net::SocketAddr; use std::path::Path; use std::sync::{Arc, RwLock}; -use std::time::SystemTime; -use die::Die; +use die::{die, Die}; use serde_derive::Deserialize; @@ -27,7 +26,7 @@ use rustls::{ #[cfg(feature = "tokio-rustls")] use tokio_rustls::{ webpki::{DnsNameRef, TlsServerTrustAnchors, TrustAnchor}, - TlsAcceptor, TlsConnector, + TlsConnector, }; use anyhow::{anyhow, bail, Result}; @@ -42,7 +41,9 @@ mod quic; #[cfg(feature = "quic")] use crate::quic::*; +#[cfg(feature = "tls")] mod tls; +#[cfg(feature = "tls")] use crate::tls::*; #[cfg(feature = "outgoing")] @@ -50,9 +51,9 @@ mod outgoing; #[cfg(feature = "outgoing")] use crate::outgoing::*; -#[cfg(feature = "outgoing")] +#[cfg(any(feature = "s2s-incoming", feature = "outgoing"))] mod srv; -#[cfg(feature = "outgoing")] +#[cfg(any(feature = "s2s-incoming", feature = "outgoing"))] use crate::srv::*; #[cfg(feature = "websocket")] @@ -60,7 +61,9 @@ mod websocket; #[cfg(feature = "websocket")] use crate::websocket::*; +#[cfg(any(feature = "s2s-incoming", feature = "outgoing"))] mod verify; +#[cfg(any(feature = "s2s-incoming", feature = "outgoing"))] use crate::verify::*; mod in_out; @@ -92,6 +95,7 @@ lazy_static::lazy_static! { }; } +#[cfg(any(feature = "rustls-native-certs", feature = "webpki-roots"))] pub fn root_cert_store() -> rustls::RootCertStore { use rustls::{OwnedTrustAnchor, RootCertStore}; let mut root_cert_store = RootCertStore::empty(); @@ -104,43 +108,45 @@ pub fn root_cert_store() -> rustls::RootCertStore { root_cert_store } -#[derive(Deserialize)] +#[derive(Deserialize, Default)] struct Config { tls_key: String, tls_cert: String, - incoming_listen: Option>, - quic_listen: Option>, - outgoing_listen: Option>, + incoming_listen: Vec, + quic_listen: Vec, + outgoing_listen: Vec, max_stanza_size_bytes: usize, - s2s_target: SocketAddr, - c2s_target: SocketAddr, + s2s_target: Option, + c2s_target: Option, proxy: bool, - #[cfg(feature = "logging")] log_level: Option, - #[cfg(feature = "logging")] log_style: Option, } #[derive(Clone)] pub struct CloneableConfig { max_stanza_size_bytes: usize, - s2s_target: SocketAddr, - c2s_target: SocketAddr, + #[cfg(feature = "s2s-incoming")] + s2s_target: Option, + #[cfg(feature = "c2s-incoming")] + c2s_target: Option, proxy: bool, } struct CertsKey { + #[cfg(feature = "rustls-pemfile")] inner: Result>>, } impl CertsKey { - fn new(cert_key: Result) -> Self { + fn new(main_config: &Config) -> Self { CertsKey { - inner: cert_key.map(|c| RwLock::new(Arc::new(c))), + #[cfg(feature = "rustls-pemfile")] + inner: main_config.certs_key().map(|c| RwLock::new(Arc::new(c))), } } - #[cfg(unix)] + #[cfg(all(unix, any(feature = "incoming", feature = "s2s-outgoing")))] fn spawn_refresh_task(&'static self, cfg_path: OsString) -> Option>> { if self.inner.is_err() { None @@ -169,12 +175,14 @@ impl CertsKey { } } +#[cfg(feature = "rustls-pemfile")] impl rustls::server::ResolvesServerCert for CertsKey { fn resolve(&self, _: rustls::server::ClientHello) -> Option> { self.inner.as_ref().map(|rwl| rwl.read().expect("CertKey poisoned?").clone()).ok() } } +#[cfg(feature = "rustls-pemfile")] impl rustls::client::ResolvesClientCert for CertsKey { fn resolve(&self, _: &[&[u8]], _: &[SignatureScheme]) -> Option> { self.inner.as_ref().map(|rwl| rwl.read().expect("CertKey poisoned?").clone()).ok() @@ -185,6 +193,17 @@ impl rustls::client::ResolvesClientCert for CertsKey { } } +#[cfg(not(feature = "rustls-pemfile"))] +impl rustls::client::ResolvesClientCert for CertsKey { + fn resolve(&self, _: &[&[u8]], _: &[SignatureScheme]) -> Option> { + None + } + + fn has_certs(&self) -> bool { + false + } +} + impl Config { fn parse>(path: P) -> Result { let mut f = File::open(path)?; @@ -196,7 +215,9 @@ impl Config { fn get_cloneable_cfg(&self) -> CloneableConfig { CloneableConfig { max_stanza_size_bytes: self.max_stanza_size_bytes, + #[cfg(feature = "s2s-incoming")] s2s_target: self.s2s_target, + #[cfg(feature = "c2s-incoming")] c2s_target: self.c2s_target, proxy: self.proxy, } @@ -204,6 +225,7 @@ impl Config { #[cfg(feature = "outgoing")] fn get_outgoing_cfg(&self, certs_key: Arc) -> OutgoingConfig { + #[cfg(feature = "rustls-pemfile")] if let Err(e) = &certs_key.inner { debug!("invalid key/cert for s2s client auth: {}", e); } @@ -243,21 +265,18 @@ impl Config { bail!("invalid cert/key: {}", e); } - let mut config = ServerConfig::builder() - .with_safe_defaults() - .with_client_cert_verifier(Arc::new(AllowAnonymousOrAnyCert)) - .with_cert_resolver(certs_key); + let config = ServerConfig::builder().with_safe_defaults(); + #[cfg(feature = "s2s")] + let config = config.with_client_cert_verifier(Arc::new(AllowAnonymousOrAnyCert)); + #[cfg(not(feature = "s2s"))] + let config = config.with_no_client_auth(); + let mut config = config.with_cert_resolver(certs_key); // todo: will connecting without alpn work then? config.alpn_protocols.push(ALPN_XMPP_CLIENT.to_vec()); config.alpn_protocols.push(ALPN_XMPP_SERVER.to_vec()); Ok(config) } - - #[cfg(feature = "incoming")] - fn tls_acceptor(&self, cert_key: Arc) -> Result { - Ok(TlsAcceptor::from(Arc::new(self.server_config(cert_key)?))) - } } #[derive(Clone)] @@ -310,11 +329,13 @@ pub struct OutgoingVerifierConfig { pub connector: TlsConnector, } +#[cfg(feature = "incoming")] async fn shuffle_rd_wr(in_rd: StanzaRead, in_wr: StanzaWrite, config: CloneableConfig, server_certs: ServerCerts, local_addr: SocketAddr, client_addr: &mut Context<'_>) -> Result<()> { let filter = StanzaFilter::new(config.max_stanza_size_bytes); shuffle_rd_wr_filter(in_rd, in_wr, config, server_certs, local_addr, client_addr, filter).await } +#[cfg(feature = "incoming")] async fn shuffle_rd_wr_filter( mut in_rd: StanzaRead, mut in_wr: StanzaWrite, @@ -328,26 +349,30 @@ async fn shuffle_rd_wr_filter( let (stream_open, is_c2s) = stream_preamble(&mut in_rd, &mut in_wr, client_addr.log_from(), &mut in_filter).await?; client_addr.set_c2s_stream_open(is_c2s, &stream_open); - trace!( - "{} connected: sni: {:?}, alpn: {:?}, tls-not-quic: {}", - client_addr.log_from(), - server_certs.sni(), - server_certs.alpn().map(|a| String::from_utf8_lossy(&a).to_string()), - server_certs.is_tls(), - ); + #[cfg(feature = "s2s-incoming")] + { + trace!( + "{} connected: sni: {:?}, alpn: {:?}, tls-not-quic: {}", + client_addr.log_from(), + server_certs.sni(), + server_certs.alpn().map(|a| String::from_utf8_lossy(&a).to_string()), + server_certs.is_tls(), + ); - if !is_c2s { - // for s2s we need this - let domain = stream_open - .extract_between(b" from='", b"'") - .or_else(|_| stream_open.extract_between(b" from=\"", b"\"")) - .and_then(|b| Ok(std::str::from_utf8(b)?))?; - let (_, cert_verifier) = get_xmpp_connections(domain, is_c2s).await?; - let certs = server_certs.peer_certificates().ok_or_else(|| anyhow!("no client cert auth for s2s incoming from {}", domain))?; - // todo: send stream error saying cert is invalid - cert_verifier.verify_cert(&certs[0], &certs[1..], SystemTime::now())?; + if !is_c2s { + // for s2s we need this + use std::time::SystemTime; + let domain = stream_open + .extract_between(b" from='", b"'") + .or_else(|_| stream_open.extract_between(b" from=\"", b"\"")) + .and_then(|b| Ok(std::str::from_utf8(b)?))?; + let (_, cert_verifier) = get_xmpp_connections(domain, is_c2s).await?; + let certs = server_certs.peer_certificates().ok_or_else(|| anyhow!("no client cert auth for s2s incoming from {}", domain))?; + // todo: send stream error saying cert is invalid + cert_verifier.verify_cert(&certs[0], &certs[1..], SystemTime::now())?; + } + drop(server_certs); } - drop(server_certs); let (out_rd, out_wr) = open_incoming(&config, local_addr, client_addr, &stream_open, is_c2s, &mut in_filter).await?; drop(stream_open); @@ -407,6 +432,7 @@ async fn shuffle_rd_wr_filter_only( Ok(()) } +#[cfg(feature = "incoming")] async fn open_incoming( config: &CloneableConfig, local_addr: SocketAddr, @@ -415,7 +441,18 @@ async fn open_incoming( is_c2s: bool, in_filter: &mut StanzaFilter, ) -> Result<(ReadHalf, WriteHalf)> { - let target = if is_c2s { config.c2s_target } else { config.s2s_target }; + let target = if is_c2s { + #[cfg(not(feature = "c2s-incoming"))] + bail!("incoming c2s connection but lacking compile-time support"); + #[cfg(feature = "c2s-incoming")] + config.c2s_target + } else { + #[cfg(not(feature = "s2s-incoming"))] + bail!("incoming s2s connection but lacking compile-time support"); + #[cfg(feature = "s2s-incoming")] + config.s2s_target + } + .ok_or_else(|| anyhow!("incoming connection but `{}_target` not defined", c2s(is_c2s)))?; client_addr.set_to_addr(target); let out_stream = tokio::net::TcpStream::connect(target).await?; @@ -468,7 +505,12 @@ pub async fn stream_preamble(in_rd: &mut StanzaRead, in_wr: &mut StanzaWrite, cl #[tokio::main] //#[tokio::main(flavor = "multi_thread", worker_threads = 10)] async fn main() { - let cfg_path = std::env::args_os().nth(1).unwrap_or_else(|| OsString::from("/etc/xmpp-proxy/xmpp-proxy.toml")); + let cfg_path = std::env::args_os().nth(1); + if cfg_path == Some(OsString::from("-v")) { + include!(concat!(env!("OUT_DIR"), "/version.rs")); + die!(0); + } + let cfg_path = cfg_path.unwrap_or_else(|| OsString::from("/etc/xmpp-proxy/xmpp-proxy.toml")); let main_config = Config::parse(&cfg_path).die("invalid config file"); #[cfg(feature = "logging")] @@ -486,34 +528,59 @@ async fn main() { // todo: config for this: builder.format_timestamp(None); builder.init(); } + #[cfg(not(feature = "logging"))] + if main_config.log_level.is_some() || main_config.log_style.is_some() { + die!("log_level or log_style defined in config but logging disabled at compile-time"); + } let config = main_config.get_cloneable_cfg(); - let certs_key = Arc::new(CertsKey::new(main_config.certs_key())); + let certs_key = Arc::new(CertsKey::new(&main_config)); let mut handles: Vec>> = Vec::new(); - #[cfg(feature = "incoming")] - if let Some(ref listeners) = main_config.incoming_listen { - let acceptor = main_config.tls_acceptor(certs_key.clone()).die("invalid cert/key ?"); - for listener in listeners { - handles.push(spawn_tls_listener(listener.parse().die("invalid listener address"), config.clone(), acceptor.clone())); + if !main_config.incoming_listen.is_empty() { + #[cfg(all(any(feature = "tls", feature = "websocket"), feature = "incoming"))] + { + if main_config.c2s_target.is_none() && main_config.s2s_target.is_none() { + die!("one of c2s_target/s2s_target must be defined if incoming_listen is non-empty"); + } + let acceptor = main_config.tls_acceptor(certs_key.clone()).die("invalid cert/key ?"); + for listener in main_config.incoming_listen.iter() { + handles.push(spawn_tls_listener(listener.parse().die("invalid listener address"), config.clone(), acceptor.clone())); + } } + #[cfg(not(all(any(feature = "tls", feature = "websocket"), feature = "incoming")))] + die!("incoming_listen non-empty but (tls or websocket) or (s2s-incoming and c2s-incoming) disabled at compile-time"); } - #[cfg(all(feature = "quic", feature = "incoming"))] - if let Some(ref listeners) = main_config.quic_listen { - let quic_config = main_config.quic_server_config(certs_key.clone()).die("invalid cert/key ?"); - for listener in listeners { - handles.push(spawn_quic_listener(listener.parse().die("invalid listener address"), config.clone(), quic_config.clone())); + if !main_config.quic_listen.is_empty() { + #[cfg(all(feature = "quic", feature = "incoming"))] + { + if main_config.c2s_target.is_none() && main_config.s2s_target.is_none() { + die!("one of c2s_target/s2s_target must be defined if quic_listen is non-empty"); + } + let quic_config = main_config.quic_server_config(certs_key.clone()).die("invalid cert/key ?"); + for listener in main_config.quic_listen.iter() { + handles.push(spawn_quic_listener(listener.parse().die("invalid listener address"), config.clone(), quic_config.clone())); + } } + #[cfg(not(all(feature = "quic", feature = "incoming")))] + die!("quic_listen non-empty but quic or (s2s-incoming and c2s-incoming) disabled at compile-time"); } - #[cfg(feature = "outgoing")] - if let Some(ref listeners) = main_config.outgoing_listen { - let outgoing_cfg = main_config.get_outgoing_cfg(certs_key.clone()); - for listener in listeners { - handles.push(spawn_outgoing_listener(listener.parse().die("invalid listener address"), outgoing_cfg.clone())); + if !main_config.outgoing_listen.is_empty() { + #[cfg(feature = "outgoing")] + { + let outgoing_cfg = main_config.get_outgoing_cfg(certs_key.clone()); + for listener in main_config.outgoing_listen.iter() { + handles.push(spawn_outgoing_listener(listener.parse().die("invalid listener address"), outgoing_cfg.clone())); + } } + #[cfg(not(feature = "outgoing"))] + die!("outgoing_listen non-empty but c2s-outgoing and s2s-outgoing disabled at compile-time"); } - #[cfg(unix)] + if handles.is_empty() { + die!("all of incoming_listen, quic_listen, outgoing_listen empty, nothing to do, exiting..."); + } + #[cfg(all(unix, any(feature = "incoming", feature = "s2s-outgoing")))] if let Some(refresh_task) = Box::leak(Box::new(certs_key.clone())).spawn_refresh_task(cfg_path) { handles.push(refresh_task); } diff --git a/src/outgoing.rs b/src/outgoing.rs index d7089cc..4c91216 100644 --- a/src/outgoing.rs +++ b/src/outgoing.rs @@ -5,15 +5,20 @@ async fn handle_outgoing_connection(stream: tokio::net::TcpStream, client_addr: let mut in_filter = StanzaFilter::new(config.max_stanza_size_bytes); - let is_ws = first_bytes_match(&stream, &mut in_filter.buf[0..3], |p| p == b"GET").await?; - - let (mut in_rd, mut in_wr) = if is_ws { + #[cfg(feature = "websocket")] + let (mut in_rd, mut in_wr) = if first_bytes_match(&stream, &mut in_filter.buf[0..3], |p| p == b"GET").await? { incoming_websocket_connection(Box::new(stream), config.max_stanza_size_bytes).await? } else { let (in_rd, in_wr) = tokio::io::split(stream); (StanzaRead::new(in_rd), StanzaWrite::new(in_wr)) }; + #[cfg(not(feature = "websocket"))] + let (mut in_rd, mut in_wr) = { + let (in_rd, in_wr) = tokio::io::split(stream); + (StanzaRead::new(in_rd), StanzaWrite::new(in_wr)) + }; + // now read to figure out client vs server let (stream_open, is_c2s) = stream_preamble(&mut in_rd, &mut in_wr, client_addr.log_to(), &mut in_filter).await?; client_addr.set_c2s_stream_open(is_c2s, &stream_open); diff --git a/src/quic.rs b/src/quic.rs index 74cd5f6..43b6852 100644 --- a/src/quic.rs +++ b/src/quic.rs @@ -44,7 +44,12 @@ pub fn spawn_quic_listener(local_addr: SocketAddr, config: CloneableConfig, serv tokio::spawn(async move { if let Ok(mut new_conn) = incoming_conn.await { let client_addr = crate::Context::new("quic-in", new_conn.connection.remote_address()); + + #[cfg(feature = "s2s-incoming")] let server_certs = ServerCerts::Quic(new_conn.connection); + #[cfg(not(feature = "s2s-incoming"))] + let server_certs = (); + info!("{} connected new connection", client_addr.log_from()); while let Some(Ok((wrt, rd))) = new_conn.bi_streams.next().await { diff --git a/src/srv.rs b/src/srv.rs index 31a1598..35301e8 100644 --- a/src/srv.rs +++ b/src/srv.rs @@ -29,8 +29,10 @@ fn make_resolver() -> TokioAsyncResolver { } #[derive(Clone, Debug, PartialEq, Eq)] -pub enum XmppConnectionType { +enum XmppConnectionType { + #[cfg(feature = "tls")] StartTLS, + #[cfg(feature = "tls")] DirectTLS, #[cfg(feature = "quic")] QUIC, @@ -41,9 +43,13 @@ pub enum XmppConnectionType { impl XmppConnectionType { fn idx(&self) -> u8 { match self { + #[cfg(feature = "quic")] XmppConnectionType::QUIC => 0, + #[cfg(feature = "tls")] XmppConnectionType::DirectTLS => 1, + #[cfg(feature = "tls")] XmppConnectionType::StartTLS => 2, + #[cfg(feature = "websocket")] XmppConnectionType::WebSocket(_, _) => 3, } } @@ -57,6 +63,7 @@ impl Ord for XmppConnectionType { } // so they are the same type, but WebSocket is a special case... match (self, other) { + #[cfg(feature = "websocket")] (XmppConnectionType::WebSocket(self_uri, self_origin), XmppConnectionType::WebSocket(other_uri, other_origin)) => { let cmp = self_uri.to_string().cmp(&other_uri.to_string()); if cmp != Ordering::Equal { @@ -153,6 +160,7 @@ fn sort_dedup(ret: &mut Vec) { } impl XmppConnection { + #[cfg(feature = "outgoing")] pub async fn connect( &self, domain: &str, @@ -175,10 +183,12 @@ impl XmppConnection { let to_addr = SocketAddr::new(*ip, self.port); debug!("{} trying ip {}", client_addr.log_from(), to_addr); match self.conn_type { + #[cfg(feature = "tls")] XmppConnectionType::StartTLS => match crate::starttls_connect(to_addr, domain, stream_open, in_filter, config.clone()).await { Ok((wr, rd)) => return Ok((wr, rd, to_addr, "starttls-out")), Err(e) => error!("starttls connection failed to IP {} from SRV {}, error: {}", to_addr, self.target, e), }, + #[cfg(feature = "tls")] XmppConnectionType::DirectTLS => match crate::tls_connect(to_addr, domain, config.clone()).await { Ok((wr, rd)) => return Ok((wr, rd, to_addr, "directtls-out")), Err(e) => error!("direct tls connection failed to IP {} from SRV {}, error: {}", to_addr, self.target, e), @@ -336,7 +346,9 @@ pub async fn get_xmpp_connections(domain: &str, is_c2s: bool) -> Result<(Vec Result<(Vec<Xmp if ret.is_empty() { // default starttls ports + #[cfg(feature = "tls")] ret.push(XmppConnection { priority: 0, weight: 0, @@ -369,6 +382,7 @@ pub async fn get_xmpp_connections(domain: &str, is_c2s: bool) -> Result<(Vec<Xmp ech: None, }); // by spec there are no default direct/quic ports, but we are going 443 + #[cfg(feature = "tls")] ret.push(XmppConnection { priority: 0, weight: 0, @@ -409,6 +423,7 @@ pub async fn get_xmpp_connections(domain: &str, is_c2s: bool) -> Result<(Vec<Xmp Ok((ret, cert_verifier)) } +#[cfg(feature = "outgoing")] pub async fn srv_connect( domain: &str, is_c2s: bool, @@ -417,6 +432,14 @@ pub async fn srv_connect( client_addr: &mut Context<'_>, config: OutgoingConfig, ) -> Result<(StanzaWrite, StanzaRead, Vec<u8>)> { + #[cfg(not(feature = "c2s-outgoing"))] + if is_c2s { + bail!("outgoing c2s connection but c2s-outgoing disabled at compile-time"); + } + #[cfg(not(feature = "s2s-outgoing"))] + if !is_c2s { + bail!("outgoing s2s connection but s2s-outgoing disabled at compile-time"); + } let (srvs, cert_verifier) = get_xmpp_connections(domain, is_c2s).await?; let config = config.with_custom_certificate_verifier(is_c2s, cert_verifier); for srv in srvs { @@ -448,7 +471,7 @@ pub async fn srv_connect( #[cfg(not(feature = "websocket"))] async fn collect_host_meta(ret: &mut Vec<XmppConnection>, sha256_pinnedpubkeys: &mut Vec<String>, domain: &str, is_c2s: bool) -> Result<Option<u16>> { - collect_host_meta_json(ret, sha256_pinnedpubkeys, domain, is_c2s) + collect_host_meta_json(ret, sha256_pinnedpubkeys, domain, is_c2s).await } #[cfg(feature = "websocket")] @@ -532,7 +555,7 @@ struct LinkCommon { } impl LinkCommon { - pub fn into_xmpp_connection(self, conn_type: XmppConnectionType, port: u16) -> Option<XmppConnection> { + fn into_xmpp_connection(self, conn_type: XmppConnectionType, port: u16) -> Option<XmppConnection> { if self.ips.is_empty() { error!("invalid empty ips"); return None; @@ -551,13 +574,18 @@ impl LinkCommon { } impl Link { - pub fn into_xmpp_connection(self, is_c2s: bool) -> Option<XmppConnection> { + fn into_xmpp_connection(self, is_c2s: bool) -> Option<XmppConnection> { use XmppConnectionType::*; let (srv_is_c2s, port, link, conn_type) = match self { + #[cfg(feature = "tls")] Link::DirectTLS { port, link } => (true, port, link, DirectTLS), + #[cfg(feature = "quic")] Link::Quic { port, link } => (true, port, link, QUIC), + #[cfg(feature = "tls")] Link::S2SDirectTLS { port, link } => (false, port, link, DirectTLS), + #[cfg(feature = "quic")] Link::S2SQuic { port, link } => (false, port, link, QUIC), + #[cfg(feature = "websocket")] Link::WebSocket { href, link } => { return if is_c2s { let srv = wss_to_srv(&href, true)?; @@ -570,6 +598,7 @@ impl Link { None }; } + #[cfg(feature = "websocket")] Link::S2SWebSocket { href, link } => { return if !is_c2s { let srv = wss_to_srv(&href, true)?; @@ -579,7 +608,7 @@ impl Link { }; } - Link::Unknown => return None, + _ => return None, }; if srv_is_c2s == is_c2s { @@ -591,7 +620,7 @@ impl Link { } impl HostMeta { - pub fn collect(self, ret: &mut Vec<XmppConnection>, sha256_pinnedpubkeys: &mut Vec<String>, is_c2s: bool) -> Option<u16> { + fn collect(self, ret: &mut Vec<XmppConnection>, sha256_pinnedpubkeys: &mut Vec<String>, is_c2s: bool) -> Option<u16> { for link in self.links { if let Some(srv) = link.into_xmpp_connection(is_c2s) { ret.push(srv); @@ -666,7 +695,7 @@ async fn collect_host_meta_xml(ret: &mut Vec<XmppConnection>, domain: &str, is_c } } -pub async fn https_get<T: reqwest::IntoUrl>(url: T) -> reqwest::Result<reqwest::Response> { +async fn https_get<T: reqwest::IntoUrl>(url: T) -> reqwest::Result<reqwest::Response> { // todo: resolve URL with our resolver reqwest::Client::builder().https_only(true).build()?.get(url).send().await } diff --git a/src/tls.rs b/src/tls.rs index b235340..6e40ffe 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -3,8 +3,7 @@ use rustls::ServerConnection; use std::convert::TryFrom; use tokio::io::{AsyncBufReadExt, BufStream}; -#[cfg(any(feature = "incoming", feature = "outgoing"))] -use tokio_rustls::rustls::ServerName; +use tokio_rustls::{rustls::ServerName, TlsAcceptor}; #[cfg(feature = "outgoing")] pub async fn tls_connect(target: SocketAddr, server_name: &str, config: OutgoingVerifierConfig) -> Result<(StanzaWrite, StanzaRead)> { @@ -59,6 +58,13 @@ pub async fn starttls_connect(target: SocketAddr, server_name: &str, stream_open Ok((StanzaWrite::new(wrt), StanzaRead::new(rd))) } +#[cfg(feature = "incoming")] +impl Config { + pub fn tls_acceptor(&self, cert_key: Arc<CertsKey>) -> Result<TlsAcceptor> { + Ok(TlsAcceptor::from(Arc::new(self.server_config(cert_key)?))) + } +} + #[cfg(feature = "incoming")] pub fn spawn_tls_listener(local_addr: SocketAddr, config: CloneableConfig, acceptor: TlsAcceptor) -> JoinHandle<Result<()>> { tokio::spawn(async move { @@ -159,8 +165,13 @@ async fn handle_tls_connection(mut stream: tokio::net::TcpStream, client_addr: & // where we read the first stanza, where we are guaranteed the handshake is complete, but I can't // do that without ignoring the lifetime and just pulling a C programmer and pinky promising to be // *very careful* that this reference doesn't outlive stream... - let server_connection: &'static ServerConnection = unsafe { std::mem::transmute(server_connection) }; - let server_certs = ServerCerts::Tls(server_connection); + #[cfg(feature = "s2s-incoming")] + let server_certs = { + let server_connection: &'static ServerConnection = unsafe { std::mem::transmute(server_connection) }; + ServerCerts::Tls(server_connection) + }; + #[cfg(not(feature = "s2s-incoming"))] + let server_certs = (); #[cfg(not(feature = "websocket"))] { diff --git a/src/websocket.rs b/src/websocket.rs index fafea4a..09ca3f5 100644 --- a/src/websocket.rs +++ b/src/websocket.rs @@ -29,6 +29,7 @@ pub async fn incoming_websocket_connection(stream: Box<dyn AsyncReadAndWrite + U Ok((StanzaRead::WebSocketRead(in_rd), StanzaWrite::WebSocketClientWrite(in_wr))) } +#[cfg(feature = "incoming")] pub async fn handle_websocket_connection( stream: Box<dyn AsyncReadAndWrite + Unpin + Send>, config: CloneableConfig,