Browse Source

Add and define features for conditional compilation

master
Travis Burtrum 2 months ago
parent
commit
e1b11e0537
  1. 38
      Cargo.toml
  2. 41
      README.md
  3. 83
      build.rs
  4. 59
      check-all-features.sh
  5. 4
      src/in_out.rs
  6. 12
      src/lib.rs
  7. 197
      src/main.rs
  8. 11
      src/outgoing.rs
  9. 5
      src/quic.rs
  10. 43
      src/srv.rs
  11. 19
      src/tls.rs
  12. 1
      src/websocket.rs

38
Cargo.toml

@ -42,10 +42,11 @@ tokio-rustls = { version = "0.23", optional = true } @@ -42,10 +42,11 @@ tokio-rustls = { version = "0.23", optional = true }
# outgoing deps
lazy_static = { version = "1.4", optional = true }
trust-dns-resolver = { version = "0.21", optional = true }
# todo: feature+code for dns-over-rustls
#trust-dns-resolver = { version = "0.21", features = ["dns-over-rustls"], optional = true }
# todo: feature to swap between webpki-roots and rustls-native-certs
webpki-roots = { version = "0.22", optional = true }
rustls-native-certs = { version = "0.6", optional = true }
# todo: feed reqwest the roots we already have
reqwest = { version = "0.11", optional = true, default-features = false, features = ["rustls-tls-native-roots", "json", "gzip", "trust-dns"] }
# quic deps
@ -56,20 +57,39 @@ rustls = { version = "0.20.2", optional = true } @@ -56,20 +57,39 @@ rustls = { version = "0.20.2", optional = true }
rustls-pemfile = { version = "1.0.0", optional = true }
# websocket deps
# todo: fix up the situation with these roots
#tokio-tungstenite = { version = "0.17", optional = true, features = ["rustls-tls-webpki-roots"] }
tokio-tungstenite = { version = "0.17", optional = true, features = ["rustls-tls-native-roots"] }
futures-util = { version = "0.3", default-features = false, features = ["async-await", "sink", "std"], optional = true }
[features]
default = ["incoming", "outgoing", "quic", "websocket", "logging"]
incoming = ["tokio-rustls", "rustls-pemfile", "rustls"]
outgoing = ["tokio-rustls", "trust-dns-resolver", "rustls-native-certs", "lazy_static", "rustls", "reqwest", "rustls-pemfile"]
quic = ["quinn", "rustls-pemfile", "rustls"]
websocket = ["tokio-tungstenite", "futures-util", "tokio-rustls", "rustls-pemfile", "rustls"]
logging = ["rand", "env_logger"]
default = ["c2s-incoming", "c2s-outgoing", "s2s-incoming", "s2s-outgoing", "tls", "quic", "websocket", "logging", "tls-ca-roots-native"]
# you must pick one of these or the other, not both: todo: enable picking both and choosing at runtime
# don't need either of these if only doing c2s-incoming
tls-ca-roots-native = ["rustls-native-certs", "lazy_static", "tokio-rustls"] # this loads CA certs from your OS
tls-ca-roots-bundled = ["webpki-roots"] # this bundles CA certs in the binary
# internal use only, ignore
srv = ["tokio-rustls", "trust-dns-resolver", "lazy_static", "reqwest"]
incoming = ["rustls-pemfile"]
outgoing = ["srv"]
c2s = []
s2s = ["srv", "rustls-pemfile"]
# you must pick one or more of these, you may pick them all
c2s-incoming = ["incoming", "c2s",]
c2s-outgoing = ["outgoing", "c2s"]
[package.metadata.cargo-all-features]
skip_optional_dependencies = true
s2s-incoming = ["incoming", "s2s"]
s2s-outgoing = ["outgoing", "s2s"]
# protocols you want to support todo: split out tls vs starttls ?
tls = ["tokio-rustls", "rustls"]
quic = ["quinn", "rustls"]
websocket = ["tokio-tungstenite", "futures-util", "tls"] # websocket+incoming also enables incoming TLS support as it's free
logging = ["rand", "env_logger"]
[dev-dependencies]
serde_json = "1.0"

41
README.md

@ -131,15 +131,38 @@ s2s_ports = {15268} @@ -131,15 +131,38 @@ s2s_ports = {15268}
If you are a grumpy power user who wants to build xmpp-proxy with exactly the features you want, nothing less, nothing
more, this section is for you!
xmpp-proxy has 5 compile-time features:
1. `incoming` - enables `incoming_listen` config option for reverse proxy STARTTLS/TLS
2. `outgoing` - enables `outgoing_listen` config option for outgoing proxy STARTTLS/TLS
3. `quic` - enables `quic_listen` config option for reverse proxy QUIC, and QUIC support for `outgoing` if it is enabled
4. `websocket` - enables reverse proxy WebSocket on `incoming_listen`, and WebSocket support for `outgoing` if it is enabled
5. `logging` - enables configurable logging
So to build only supporting reverse proxy STARTTLS/TLS, no QUIC, run: `cargo build --release --no-default-features --features incoming`
To build a reverse proxy only, but supporting all of STARTTLS/TLS/QUIC, run: `cargo build --release --no-default-features --features incoming,quic`
xmpp-proxy has multiple compile-time features, some of which are required, they are grouped as such:
choose between 1-4 directions:
1. `c2s-incoming` - enables a server to accept incoming c2s connections
2. `c2s-outgoing` - enables a client to make outgoing c2s connections
3. `s2s-incoming` - enables a server to accept incoming s2s connections
4. `s2s-outgoing` - enables a server to make outgoing s2s connections
choose between 1-3 transport protocols:
1. `tls` - enables STARTTLS/TLS support
2. `quic` - enables QUIC support
3. `websocket` - enables WebSocket support, also enables TLS incoming support if the appropriate directions are enabled
choose exactly 1 of these methods to get trusted CA roots, not needed if only `c2s-incoming` is enabled:
1. `tls-ca-roots-native` - reads CA roots from operating system
2. `tls-ca-roots-bundled` - bundles CA roots into the binary from the `webpki-roots` project
choose any of these optional features:
1. `logging` - enables configurable logging
So to build only supporting reverse proxy STARTTLS/TLS, no QUIC, run: `cargo build --release --no-default-features --features c2s-incoming,s2s-incoming,tls`
To build a reverse proxy only, but supporting all of STARTTLS/TLS/QUIC, run: `cargo build --release --no-default-features --features c2s-incoming,s2s-incoming,tls,quic`
#### Development
1. `check-all-features.sh` is used to check compilation with all supported feature permutations
2. `integration/test.sh` uses [Rootless podman](https://wiki.archlinux.org/title/Podman#Rootless_Podman) to run many tests
through xmpp-proxy on a real network with real dns, web, and xmpp servers, all of these should pass before pushing commits,
and write new tests to cover new functionality.
3. To submit code changes submit a PR on [github](https://github.com/moparisthebest/xmpp-proxy) or
[code.moparisthebest.com](https://code.moparisthebest.com/moparisthebest/xmpp-proxy) or send me a patch via email,
XMPP, fediverse, or carrier pigeon.
#### License
GNU/AGPLv3 - Check LICENSE.md for details

83
build.rs

@ -0,0 +1,83 @@ @@ -0,0 +1,83 @@
use std::{env, fs::File, io::Write, path::Path};
fn main() {
println!("cargo:rerun-if-changed=build.rs");
let out_dir = env::var_os("OUT_DIR").unwrap();
let dest_path = Path::new(&out_dir).join("version.rs");
let mut w = File::create(dest_path).unwrap();
let allowed_features = [
"c2s-incoming",
"c2s-outgoing",
"s2s-incoming",
"s2s-outgoing",
"tls",
"quic",
"websocket",
"tls-ca-roots-native",
"tls-ca-roots-bundled",
];
let optional_deps = [
"rustls",
"tokio-rustls",
"rustls-pemfile",
"quinn",
"tokio-tungstenite",
"futures-util",
"trust-dns-resolver",
"reqwest",
"lazy-static",
"rustls-native-certs",
"webpki-roots",
"env-logger",
"rand",
];
let mut features = Vec::new();
let mut optional = Vec::new();
for (mut key, value) in env::vars() {
//writeln!(&mut w, "{key}: {value}", ).unwrap();
if value == "1" && key.starts_with("CARGO_FEATURE_") {
let mut key = key.split_off(14).replace("_", "-");
key.make_ascii_lowercase();
if allowed_features.contains(&key.as_str()) {
features.push(key);
} else if optional_deps.contains(&key.as_str()) {
optional.push(key);
}
}
}
features.sort_by(|a, b| {
allowed_features
.iter()
.position(|&r| r == a)
.unwrap()
.partial_cmp(&allowed_features.iter().position(|&r| r == b).unwrap())
.unwrap()
});
optional.sort_by(|a, b| {
optional_deps
.iter()
.position(|&r| r == a)
.unwrap()
.partial_cmp(&optional_deps.iter().position(|&r| r == b).unwrap())
.unwrap()
});
let features = features.join(",");
let optional = optional.join(",");
let name = env!("CARGO_PKG_NAME");
let version = env!("CARGO_PKG_VERSION");
let target = env::var("TARGET").unwrap();
writeln!(
&mut w,
"{{println!(
\"{name} {version} ({target})
Features: {features}
Optional crates: {optional}\");}}"
)
.unwrap();
}

59
check-all-features.sh

@ -0,0 +1,59 @@ @@ -0,0 +1,59 @@
#!/bin/bash
set -euo pipefail
export RUSTFLAGS=-Awarnings
show() {
local -a results=()
let idx=$2
for (( j = 0; j < $1; j++ )); do
if (( idx % 2 )); then results=("${results[@]}" "${list[$j]}"); fi
let idx\>\>=1
done
echo "${results[@]}"
}
perm_lines() {
list=($@)
let n=${#list[@]}
for (( i = 1; i < 2**n; i++ )); do
show $n $i
done
}
perms() {
perm_lines "$@" | tr ' ' ',' | sort -u | tr '\n' ' '
}
echo_cargo() {
#echo cargo run "$@" -- -v
#cargo run "$@" -- -v
echo cargo check "$@"
cargo check "$@"
}
echo_cargo
for optional in "" ",logging"
do
for proto in $(perms tls quic websocket)
do
for direction in $(perms c2s-incoming c2s-outgoing s2s-incoming s2s-outgoing)
do
for ca_roots in tls-ca-roots-native tls-ca-roots-bundled
do
echo_cargo --no-default-features --features $direction,$proto,$ca_roots$optional
done
done
done
done
for optional in "" ",logging"
do
for proto in $(perms tls quic websocket)
do
echo_cargo --no-default-features --features c2s-incoming,$proto$optional
done
done
echo good!

4
src/in_out.rs

@ -1,8 +1,8 @@ @@ -1,8 +1,8 @@
// Box<dyn AsyncWrite + Unpin + Send>, Box<dyn AsyncRead + Unpin + Send>
#[cfg(feature = "websocket")]
use crate::{from_ws, to_ws_new};
use crate::{slicesubsequence::SliceSubsequence, trace, AsyncReadAndWrite, StanzaFilter, StanzaRead::*, StanzaReader, StanzaWrite::*};
use crate::{from_ws, to_ws_new, AsyncReadAndWrite};
use crate::{slicesubsequence::SliceSubsequence, trace, StanzaFilter, StanzaRead::*, StanzaReader, StanzaWrite::*};
use anyhow::{bail, Result};
#[cfg(feature = "websocket")]
use futures_util::{

12
src/lib.rs

@ -8,6 +8,8 @@ use anyhow::bail; @@ -8,6 +8,8 @@ use anyhow::bail;
use std::net::SocketAddr;
pub use log::{debug, error, info, log_enabled, trace};
#[cfg(feature = "s2s-incoming")]
use rustls::{Certificate, ServerConnection};
pub fn to_str(buf: &[u8]) -> std::borrow::Cow<'_, str> {
@ -152,7 +154,10 @@ impl<'a> Context<'a> { @@ -152,7 +154,10 @@ impl<'a> Context<'a> {
}
}
#[cfg(feature = "incoming")]
#[cfg(not(feature = "s2s-incoming"))]
pub type ServerCerts = ();
#[cfg(feature = "s2s-incoming")]
#[derive(Clone)]
pub enum ServerCerts {
Tls(&'static ServerConnection),
@ -160,10 +165,12 @@ pub enum ServerCerts { @@ -160,10 +165,12 @@ pub enum ServerCerts {
Quic(quinn::Connection),
}
#[cfg(feature = "s2s-incoming")]
impl ServerCerts {
pub fn peer_certificates(&self) -> Option<Vec<Certificate>> {
match self {
ServerCerts::Tls(c) => c.peer_certificates().map(|c| c.to_vec()),
#[cfg(feature = "quic")]
ServerCerts::Quic(c) => c.peer_identity().and_then(|v| v.downcast::<Vec<Certificate>>().ok()).map(|v| v.to_vec()),
}
}
@ -171,6 +178,7 @@ impl ServerCerts { @@ -171,6 +178,7 @@ impl ServerCerts {
pub fn sni(&self) -> Option<String> {
match self {
ServerCerts::Tls(c) => c.sni_hostname().map(|s| s.to_string()),
#[cfg(feature = "quic")]
ServerCerts::Quic(c) => c.handshake_data().and_then(|v| v.downcast::<quinn::crypto::rustls::HandshakeData>().ok()).and_then(|h| h.server_name),
}
}
@ -178,6 +186,7 @@ impl ServerCerts { @@ -178,6 +186,7 @@ impl ServerCerts {
pub fn alpn(&self) -> Option<Vec<u8>> {
match self {
ServerCerts::Tls(c) => c.alpn_protocol().map(|s| s.to_vec()),
#[cfg(feature = "quic")]
ServerCerts::Quic(c) => c.handshake_data().and_then(|v| v.downcast::<quinn::crypto::rustls::HandshakeData>().ok()).and_then(|h| h.protocol),
}
}
@ -185,6 +194,7 @@ impl ServerCerts { @@ -185,6 +194,7 @@ impl ServerCerts {
pub fn is_tls(&self) -> bool {
match self {
ServerCerts::Tls(_) => true,
#[cfg(feature = "quic")]
ServerCerts::Quic(_) => false,
}
}

197
src/main.rs

@ -8,9 +8,8 @@ use std::iter::Iterator; @@ -8,9 +8,8 @@ use std::iter::Iterator;
use std::net::SocketAddr;
use std::path::Path;
use std::sync::{Arc, RwLock};
use std::time::SystemTime;
use die::Die;
use die::{die, Die};
use serde_derive::Deserialize;
@ -27,7 +26,7 @@ use rustls::{ @@ -27,7 +26,7 @@ use rustls::{
#[cfg(feature = "tokio-rustls")]
use tokio_rustls::{
webpki::{DnsNameRef, TlsServerTrustAnchors, TrustAnchor},
TlsAcceptor, TlsConnector,
TlsConnector,
};
use anyhow::{anyhow, bail, Result};
@ -42,7 +41,9 @@ mod quic; @@ -42,7 +41,9 @@ mod quic;
#[cfg(feature = "quic")]
use crate::quic::*;
#[cfg(feature = "tls")]
mod tls;
#[cfg(feature = "tls")]
use crate::tls::*;
#[cfg(feature = "outgoing")]
@ -50,9 +51,9 @@ mod outgoing; @@ -50,9 +51,9 @@ mod outgoing;
#[cfg(feature = "outgoing")]
use crate::outgoing::*;
#[cfg(feature = "outgoing")]
#[cfg(any(feature = "s2s-incoming", feature = "outgoing"))]
mod srv;
#[cfg(feature = "outgoing")]
#[cfg(any(feature = "s2s-incoming", feature = "outgoing"))]
use crate::srv::*;
#[cfg(feature = "websocket")]
@ -60,7 +61,9 @@ mod websocket; @@ -60,7 +61,9 @@ mod websocket;
#[cfg(feature = "websocket")]
use crate::websocket::*;
#[cfg(any(feature = "s2s-incoming", feature = "outgoing"))]
mod verify;
#[cfg(any(feature = "s2s-incoming", feature = "outgoing"))]
use crate::verify::*;
mod in_out;
@ -92,6 +95,7 @@ lazy_static::lazy_static! { @@ -92,6 +95,7 @@ lazy_static::lazy_static! {
};
}
#[cfg(any(feature = "rustls-native-certs", feature = "webpki-roots"))]
pub fn root_cert_store() -> rustls::RootCertStore {
use rustls::{OwnedTrustAnchor, RootCertStore};
let mut root_cert_store = RootCertStore::empty();
@ -104,43 +108,45 @@ pub fn root_cert_store() -> rustls::RootCertStore { @@ -104,43 +108,45 @@ pub fn root_cert_store() -> rustls::RootCertStore {
root_cert_store
}
#[derive(Deserialize)]
#[derive(Deserialize, Default)]
struct Config {
tls_key: String,
tls_cert: String,
incoming_listen: Option<Vec<String>>,
quic_listen: Option<Vec<String>>,
outgoing_listen: Option<Vec<String>>,
incoming_listen: Vec<String>,
quic_listen: Vec<String>,
outgoing_listen: Vec<String>,
max_stanza_size_bytes: usize,
s2s_target: SocketAddr,
c2s_target: SocketAddr,
s2s_target: Option<SocketAddr>,
c2s_target: Option<SocketAddr>,
proxy: bool,
#[cfg(feature = "logging")]
log_level: Option<String>,
#[cfg(feature = "logging")]
log_style: Option<String>,
}
#[derive(Clone)]
pub struct CloneableConfig {
max_stanza_size_bytes: usize,
s2s_target: SocketAddr,
c2s_target: SocketAddr,
#[cfg(feature = "s2s-incoming")]
s2s_target: Option<SocketAddr>,
#[cfg(feature = "c2s-incoming")]
c2s_target: Option<SocketAddr>,
proxy: bool,
}
struct CertsKey {
#[cfg(feature = "rustls-pemfile")]
inner: Result<RwLock<Arc<rustls::sign::CertifiedKey>>>,
}
impl CertsKey {
fn new(cert_key: Result<rustls::sign::CertifiedKey>) -> Self {
fn new(main_config: &Config) -> Self {
CertsKey {
inner: cert_key.map(|c| RwLock::new(Arc::new(c))),
#[cfg(feature = "rustls-pemfile")]
inner: main_config.certs_key().map(|c| RwLock::new(Arc::new(c))),
}
}
#[cfg(unix)]
#[cfg(all(unix, any(feature = "incoming", feature = "s2s-outgoing")))]
fn spawn_refresh_task(&'static self, cfg_path: OsString) -> Option<JoinHandle<Result<()>>> {
if self.inner.is_err() {
None
@ -169,12 +175,14 @@ impl CertsKey { @@ -169,12 +175,14 @@ impl CertsKey {
}
}
#[cfg(feature = "rustls-pemfile")]
impl rustls::server::ResolvesServerCert for CertsKey {
fn resolve(&self, _: rustls::server::ClientHello) -> Option<Arc<rustls::sign::CertifiedKey>> {
self.inner.as_ref().map(|rwl| rwl.read().expect("CertKey poisoned?").clone()).ok()
}
}
#[cfg(feature = "rustls-pemfile")]
impl rustls::client::ResolvesClientCert for CertsKey {
fn resolve(&self, _: &[&[u8]], _: &[SignatureScheme]) -> Option<Arc<CertifiedKey>> {
self.inner.as_ref().map(|rwl| rwl.read().expect("CertKey poisoned?").clone()).ok()
@ -185,6 +193,17 @@ impl rustls::client::ResolvesClientCert for CertsKey { @@ -185,6 +193,17 @@ impl rustls::client::ResolvesClientCert for CertsKey {
}
}
#[cfg(not(feature = "rustls-pemfile"))]
impl rustls::client::ResolvesClientCert for CertsKey {
fn resolve(&self, _: &[&[u8]], _: &[SignatureScheme]) -> Option<Arc<CertifiedKey>> {
None
}
fn has_certs(&self) -> bool {
false
}
}
impl Config {
fn parse<P: AsRef<Path>>(path: P) -> Result<Config> {
let mut f = File::open(path)?;
@ -196,7 +215,9 @@ impl Config { @@ -196,7 +215,9 @@ impl Config {
fn get_cloneable_cfg(&self) -> CloneableConfig {
CloneableConfig {
max_stanza_size_bytes: self.max_stanza_size_bytes,
#[cfg(feature = "s2s-incoming")]
s2s_target: self.s2s_target,
#[cfg(feature = "c2s-incoming")]
c2s_target: self.c2s_target,
proxy: self.proxy,
}
@ -204,6 +225,7 @@ impl Config { @@ -204,6 +225,7 @@ impl Config {
#[cfg(feature = "outgoing")]
fn get_outgoing_cfg(&self, certs_key: Arc<CertsKey>) -> OutgoingConfig {
#[cfg(feature = "rustls-pemfile")]
if let Err(e) = &certs_key.inner {
debug!("invalid key/cert for s2s client auth: {}", e);
}
@ -243,21 +265,18 @@ impl Config { @@ -243,21 +265,18 @@ impl Config {
bail!("invalid cert/key: {}", e);
}
let mut config = ServerConfig::builder()
.with_safe_defaults()
.with_client_cert_verifier(Arc::new(AllowAnonymousOrAnyCert))
.with_cert_resolver(certs_key);
let config = ServerConfig::builder().with_safe_defaults();
#[cfg(feature = "s2s")]
let config = config.with_client_cert_verifier(Arc::new(AllowAnonymousOrAnyCert));
#[cfg(not(feature = "s2s"))]
let config = config.with_no_client_auth();
let mut config = config.with_cert_resolver(certs_key);
// todo: will connecting without alpn work then?
config.alpn_protocols.push(ALPN_XMPP_CLIENT.to_vec());
config.alpn_protocols.push(ALPN_XMPP_SERVER.to_vec());
Ok(config)
}
#[cfg(feature = "incoming")]
fn tls_acceptor(&self, cert_key: Arc<CertsKey>) -> Result<TlsAcceptor> {
Ok(TlsAcceptor::from(Arc::new(self.server_config(cert_key)?)))
}
}
#[derive(Clone)]
@ -310,11 +329,13 @@ pub struct OutgoingVerifierConfig { @@ -310,11 +329,13 @@ pub struct OutgoingVerifierConfig {
pub connector: TlsConnector,
}
#[cfg(feature = "incoming")]
async fn shuffle_rd_wr(in_rd: StanzaRead, in_wr: StanzaWrite, config: CloneableConfig, server_certs: ServerCerts, local_addr: SocketAddr, client_addr: &mut Context<'_>) -> Result<()> {
let filter = StanzaFilter::new(config.max_stanza_size_bytes);
shuffle_rd_wr_filter(in_rd, in_wr, config, server_certs, local_addr, client_addr, filter).await
}
#[cfg(feature = "incoming")]
async fn shuffle_rd_wr_filter(
mut in_rd: StanzaRead,
mut in_wr: StanzaWrite,
@ -328,26 +349,30 @@ async fn shuffle_rd_wr_filter( @@ -328,26 +349,30 @@ async fn shuffle_rd_wr_filter(
let (stream_open, is_c2s) = stream_preamble(&mut in_rd, &mut in_wr, client_addr.log_from(), &mut in_filter).await?;
client_addr.set_c2s_stream_open(is_c2s, &stream_open);
trace!(
"{} connected: sni: {:?}, alpn: {:?}, tls-not-quic: {}",
client_addr.log_from(),
server_certs.sni(),
server_certs.alpn().map(|a| String::from_utf8_lossy(&a).to_string()),
server_certs.is_tls(),
);
if !is_c2s {
// for s2s we need this
let domain = stream_open
.extract_between(b" from='", b"'")
.or_else(|_| stream_open.extract_between(b" from=\"", b"\""))
.and_then(|b| Ok(std::str::from_utf8(b)?))?;
let (_, cert_verifier) = get_xmpp_connections(domain, is_c2s).await?;
let certs = server_certs.peer_certificates().ok_or_else(|| anyhow!("no client cert auth for s2s incoming from {}", domain))?;
// todo: send stream error saying cert is invalid
cert_verifier.verify_cert(&certs[0], &certs[1..], SystemTime::now())?;
#[cfg(feature = "s2s-incoming")]
{
trace!(
"{} connected: sni: {:?}, alpn: {:?}, tls-not-quic: {}",
client_addr.log_from(),
server_certs.sni(),
server_certs.alpn().map(|a| String::from_utf8_lossy(&a).to_string()),
server_certs.is_tls(),
);
if !is_c2s {
// for s2s we need this
use std::time::SystemTime;
let domain = stream_open
.extract_between(b" from='", b"'")
.or_else(|_| stream_open.extract_between(b" from=\"", b"\""))
.and_then(|b| Ok(std::str::from_utf8(b)?))?;
let (_, cert_verifier) = get_xmpp_connections(domain, is_c2s).await?;
let certs = server_certs.peer_certificates().ok_or_else(|| anyhow!("no client cert auth for s2s incoming from {}", domain))?;
// todo: send stream error saying cert is invalid
cert_verifier.verify_cert(&certs[0], &certs[1..], SystemTime::now())?;
}
drop(server_certs);
}
drop(server_certs);
let (out_rd, out_wr) = open_incoming(&config, local_addr, client_addr, &stream_open, is_c2s, &mut in_filter).await?;
drop(stream_open);
@ -407,6 +432,7 @@ async fn shuffle_rd_wr_filter_only( @@ -407,6 +432,7 @@ async fn shuffle_rd_wr_filter_only(
Ok(())
}
#[cfg(feature = "incoming")]
async fn open_incoming(
config: &CloneableConfig,
local_addr: SocketAddr,
@ -415,7 +441,18 @@ async fn open_incoming( @@ -415,7 +441,18 @@ async fn open_incoming(
is_c2s: bool,
in_filter: &mut StanzaFilter,
) -> Result<(ReadHalf<tokio::net::TcpStream>, WriteHalf<tokio::net::TcpStream>)> {
let target = if is_c2s { config.c2s_target } else { config.s2s_target };
let target = if is_c2s {
#[cfg(not(feature = "c2s-incoming"))]
bail!("incoming c2s connection but lacking compile-time support");
#[cfg(feature = "c2s-incoming")]
config.c2s_target
} else {
#[cfg(not(feature = "s2s-incoming"))]
bail!("incoming s2s connection but lacking compile-time support");
#[cfg(feature = "s2s-incoming")]
config.s2s_target
}
.ok_or_else(|| anyhow!("incoming connection but `{}_target` not defined", c2s(is_c2s)))?;
client_addr.set_to_addr(target);
let out_stream = tokio::net::TcpStream::connect(target).await?;
@ -468,7 +505,12 @@ pub async fn stream_preamble(in_rd: &mut StanzaRead, in_wr: &mut StanzaWrite, cl @@ -468,7 +505,12 @@ pub async fn stream_preamble(in_rd: &mut StanzaRead, in_wr: &mut StanzaWrite, cl
#[tokio::main]
//#[tokio::main(flavor = "multi_thread", worker_threads = 10)]
async fn main() {
let cfg_path = std::env::args_os().nth(1).unwrap_or_else(|| OsString::from("/etc/xmpp-proxy/xmpp-proxy.toml"));
let cfg_path = std::env::args_os().nth(1);
if cfg_path == Some(OsString::from("-v")) {
include!(concat!(env!("OUT_DIR"), "/version.rs"));
die!(0);
}
let cfg_path = cfg_path.unwrap_or_else(|| OsString::from("/etc/xmpp-proxy/xmpp-proxy.toml"));
let main_config = Config::parse(&cfg_path).die("invalid config file");
#[cfg(feature = "logging")]
@ -486,34 +528,59 @@ async fn main() { @@ -486,34 +528,59 @@ async fn main() {
// todo: config for this: builder.format_timestamp(None);
builder.init();
}
#[cfg(not(feature = "logging"))]
if main_config.log_level.is_some() || main_config.log_style.is_some() {
die!("log_level or log_style defined in config but logging disabled at compile-time");
}
let config = main_config.get_cloneable_cfg();
let certs_key = Arc::new(CertsKey::new(main_config.certs_key()));
let certs_key = Arc::new(CertsKey::new(&main_config));
let mut handles: Vec<JoinHandle<Result<()>>> = Vec::new();
#[cfg(feature = "incoming")]
if let Some(ref listeners) = main_config.incoming_listen {
let acceptor = main_config.tls_acceptor(certs_key.clone()).die("invalid cert/key ?");
for listener in listeners {
handles.push(spawn_tls_listener(listener.parse().die("invalid listener address"), config.clone(), acceptor.clone()));
if !main_config.incoming_listen.is_empty() {
#[cfg(all(any(feature = "tls", feature = "websocket"), feature = "incoming"))]
{
if main_config.c2s_target.is_none() && main_config.s2s_target.is_none() {
die!("one of c2s_target/s2s_target must be defined if incoming_listen is non-empty");
}
let acceptor = main_config.tls_acceptor(certs_key.clone()).die("invalid cert/key ?");
for listener in main_config.incoming_listen.iter() {
handles.push(spawn_tls_listener(listener.parse().die("invalid listener address"), config.clone(), acceptor.clone()));
}
}
#[cfg(not(all(any(feature = "tls", feature = "websocket"), feature = "incoming")))]
die!("incoming_listen non-empty but (tls or websocket) or (s2s-incoming and c2s-incoming) disabled at compile-time");
}
#[cfg(all(feature = "quic", feature = "incoming"))]
if let Some(ref listeners) = main_config.quic_listen {
let quic_config = main_config.quic_server_config(certs_key.clone()).die("invalid cert/key ?");
for listener in listeners {
handles.push(spawn_quic_listener(listener.parse().die("invalid listener address"), config.clone(), quic_config.clone()));
if !main_config.quic_listen.is_empty() {
#[cfg(all(feature = "quic", feature = "incoming"))]
{
if main_config.c2s_target.is_none() && main_config.s2s_target.is_none() {
die!("one of c2s_target/s2s_target must be defined if quic_listen is non-empty");
}
let quic_config = main_config.quic_server_config(certs_key.clone()).die("invalid cert/key ?");
for listener in main_config.quic_listen.iter() {
handles.push(spawn_quic_listener(listener.parse().die("invalid listener address"), config.clone(), quic_config.clone()));
}
}
#[cfg(not(all(feature = "quic", feature = "incoming")))]
die!("quic_listen non-empty but quic or (s2s-incoming and c2s-incoming) disabled at compile-time");
}
#[cfg(feature = "outgoing")]
if let Some(ref listeners) = main_config.outgoing_listen {
let outgoing_cfg = main_config.get_outgoing_cfg(certs_key.clone());
for listener in listeners {
handles.push(spawn_outgoing_listener(listener.parse().die("invalid listener address"), outgoing_cfg.clone()));
if !main_config.outgoing_listen.is_empty() {
#[cfg(feature = "outgoing")]
{
let outgoing_cfg = main_config.get_outgoing_cfg(certs_key.clone());
for listener in main_config.outgoing_listen.iter() {
handles.push(spawn_outgoing_listener(listener.parse().die("invalid listener address"), outgoing_cfg.clone()));
}
}
#[cfg(not(feature = "outgoing"))]
die!("outgoing_listen non-empty but c2s-outgoing and s2s-outgoing disabled at compile-time");
}
if handles.is_empty() {
die!("all of incoming_listen, quic_listen, outgoing_listen empty, nothing to do, exiting...");
}
#[cfg(unix)]
#[cfg(all(unix, any(feature = "incoming", feature = "s2s-outgoing")))]
if let Some(refresh_task) = Box::leak(Box::new(certs_key.clone())).spawn_refresh_task(cfg_path) {
handles.push(refresh_task);
}

11
src/outgoing.rs

@ -5,15 +5,20 @@ async fn handle_outgoing_connection(stream: tokio::net::TcpStream, client_addr: @@ -5,15 +5,20 @@ async fn handle_outgoing_connection(stream: tokio::net::TcpStream, client_addr:
let mut in_filter = StanzaFilter::new(config.max_stanza_size_bytes);
let is_ws = first_bytes_match(&stream, &mut in_filter.buf[0..3], |p| p == b"GET").await?;
let (mut in_rd, mut in_wr) = if is_ws {
#[cfg(feature = "websocket")]
let (mut in_rd, mut in_wr) = if first_bytes_match(&stream, &mut in_filter.buf[0..3], |p| p == b"GET").await? {
incoming_websocket_connection(Box::new(stream), config.max_stanza_size_bytes).await?
} else {
let (in_rd, in_wr) = tokio::io::split(stream);
(StanzaRead::new(in_rd), StanzaWrite::new(in_wr))
};
#[cfg(not(feature = "websocket"))]
let (mut in_rd, mut in_wr) = {
let (in_rd, in_wr) = tokio::io::split(stream);
(StanzaRead::new(in_rd), StanzaWrite::new(in_wr))
};
// now read to figure out client vs server
let (stream_open, is_c2s) = stream_preamble(&mut in_rd, &mut in_wr, client_addr.log_to(), &mut in_filter).await?;
client_addr.set_c2s_stream_open(is_c2s, &stream_open);

5
src/quic.rs

@ -44,7 +44,12 @@ pub fn spawn_quic_listener(local_addr: SocketAddr, config: CloneableConfig, serv @@ -44,7 +44,12 @@ pub fn spawn_quic_listener(local_addr: SocketAddr, config: CloneableConfig, serv
tokio::spawn(async move {
if let Ok(mut new_conn) = incoming_conn.await {
let client_addr = crate::Context::new("quic-in", new_conn.connection.remote_address());
#[cfg(feature = "s2s-incoming")]
let server_certs = ServerCerts::Quic(new_conn.connection);
#[cfg(not(feature = "s2s-incoming"))]
let server_certs = ();
info!("{} connected new connection", client_addr.log_from());
while let Some(Ok((wrt, rd))) = new_conn.bi_streams.next().await {

43
src/srv.rs

@ -29,8 +29,10 @@ fn make_resolver() -> TokioAsyncResolver { @@ -29,8 +29,10 @@ fn make_resolver() -> TokioAsyncResolver {
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum XmppConnectionType {
enum XmppConnectionType {
#[cfg(feature = "tls")]
StartTLS,
#[cfg(feature = "tls")]
DirectTLS,
#[cfg(feature = "quic")]
QUIC,
@ -41,9 +43,13 @@ pub enum XmppConnectionType { @@ -41,9 +43,13 @@ pub enum XmppConnectionType {
impl XmppConnectionType {
fn idx(&self) -> u8 {
match self {
#[cfg(feature = "quic")]
XmppConnectionType::QUIC => 0,
#[cfg(feature = "tls")]
XmppConnectionType::DirectTLS => 1,
#[cfg(feature = "tls")]
XmppConnectionType::StartTLS => 2,
#[cfg(feature = "websocket")]
XmppConnectionType::WebSocket(_, _) => 3,
}
}
@ -57,6 +63,7 @@ impl Ord for XmppConnectionType { @@ -57,6 +63,7 @@ impl Ord for XmppConnectionType {
}
// so they are the same type, but WebSocket is a special case...
match (self, other) {
#[cfg(feature = "websocket")]
(XmppConnectionType::WebSocket(self_uri, self_origin), XmppConnectionType::WebSocket(other_uri, other_origin)) => {
let cmp = self_uri.to_string().cmp(&other_uri.to_string());
if cmp != Ordering::Equal {
@ -153,6 +160,7 @@ fn sort_dedup(ret: &mut Vec<XmppConnection>) { @@ -153,6 +160,7 @@ fn sort_dedup(ret: &mut Vec<XmppConnection>) {
}
impl XmppConnection {
#[cfg(feature = "outgoing")]
pub async fn connect(
&self,
domain: &str,
@ -175,10 +183,12 @@ impl XmppConnection { @@ -175,10 +183,12 @@ impl XmppConnection {
let to_addr = SocketAddr::new(*ip, self.port);
debug!("{} trying ip {}", client_addr.log_from(), to_addr);
match self.conn_type {
#[cfg(feature = "tls")]
XmppConnectionType::StartTLS => match crate::starttls_connect(to_addr, domain, stream_open, in_filter, config.clone()).await {
Ok((wr, rd)) => return Ok((wr, rd, to_addr, "starttls-out")),
Err(e) => error!("starttls connection failed to IP {} from SRV {}, error: {}", to_addr, self.target, e),
},
#[cfg(feature = "tls")]
XmppConnectionType::DirectTLS => match crate::tls_connect(to_addr, domain, config.clone()).await {
Ok((wr, rd)) => return Ok((wr, rd, to_addr, "directtls-out")),
Err(e) => error!("direct tls connection failed to IP {} from SRV {}, error: {}", to_addr, self.target, e),
@ -336,7 +346,9 @@ pub async fn get_xmpp_connections(domain: &str, is_c2s: bool) -> Result<(Vec<Xmp @@ -336,7 +346,9 @@ pub async fn get_xmpp_connections(domain: &str, is_c2s: bool) -> Result<(Vec<Xmp
// ignore everything else if new host-meta format
#[cfg(feature = "websocket")]
collect_txts(&mut ret, websocket_txt, is_c2s);
#[cfg(feature = "tls")]
collect_srvs(&mut ret, starttls, XmppConnectionType::StartTLS);
#[cfg(feature = "tls")]
collect_srvs(&mut ret, direct_tls, XmppConnectionType::DirectTLS);
#[cfg(feature = "quic")]
collect_srvs(&mut ret, quic, XmppConnectionType::QUIC);
@ -358,6 +370,7 @@ pub async fn get_xmpp_connections(domain: &str, is_c2s: bool) -> Result<(Vec<Xmp @@ -358,6 +370,7 @@ pub async fn get_xmpp_connections(domain: &str, is_c2s: bool) -> Result<(Vec<Xmp
if ret.is_empty() {
// default starttls ports
#[cfg(feature = "tls")]
ret.push(XmppConnection {
priority: 0,
weight: 0,
@ -369,6 +382,7 @@ pub async fn get_xmpp_connections(domain: &str, is_c2s: bool) -> Result<(Vec<Xmp @@ -369,6 +382,7 @@ pub async fn get_xmpp_connections(domain: &str, is_c2s: bool) -> Result<(Vec<Xmp
ech: None,
});
// by spec there are no default direct/quic ports, but we are going 443
#[cfg(feature = "tls")]
ret.push(XmppConnection {
priority: 0,
weight: 0,
@ -409,6 +423,7 @@ pub async fn get_xmpp_connections(domain: &str, is_c2s: bool) -> Result<(Vec<Xmp @@ -409,6 +423,7 @@ pub async fn get_xmpp_connections(domain: &str, is_c2s: bool) -> Result<(Vec<Xmp
Ok((ret, cert_verifier))
}
#[cfg(feature = "outgoing")]
pub async fn srv_connect(
domain: &str,
is_c2s: bool,
@ -417,6 +432,14 @@ pub async fn srv_connect( @@ -417,6 +432,14 @@ pub async fn srv_connect(
client_addr: &mut Context<'_>,
config: OutgoingConfig,
) -> Result<(StanzaWrite, StanzaRead, Vec<u8>)> {
#[cfg(not(feature = "c2s-outgoing"))]
if is_c2s {
bail!("outgoing c2s connection but c2s-outgoing disabled at compile-time");
}
#[cfg(not(feature = "s2s-outgoing"))]
if !is_c2s {
bail!("outgoing s2s connection but s2s-outgoing disabled at compile-time");
}
let (srvs, cert_verifier) = get_xmpp_connections(domain, is_c2s).await?;
let config = config.with_custom_certificate_verifier(is_c2s, cert_verifier);
for srv in srvs {
@ -448,7 +471,7 @@ pub async fn srv_connect( @@ -448,7 +471,7 @@ pub async fn srv_connect(
#[cfg(not(feature = "websocket"))]
async fn collect_host_meta(ret: &mut Vec<XmppConnection>, sha256_pinnedpubkeys: &mut Vec<String>, domain: &str, is_c2s: bool) -> Result<Option<u16>> {
collect_host_meta_json(ret, sha256_pinnedpubkeys, domain, is_c2s)
collect_host_meta_json(ret, sha256_pinnedpubkeys, domain, is_c2s).await
}
#[cfg(feature = "websocket")]
@ -532,7 +555,7 @@ struct LinkCommon { @@ -532,7 +555,7 @@ struct LinkCommon {
}
impl LinkCommon {
pub fn into_xmpp_connection(self, conn_type: XmppConnectionType, port: u16) -> Option<XmppConnection> {
fn into_xmpp_connection(self, conn_type: XmppConnectionType, port: u16) -> Option<XmppConnection> {
if self.ips.is_empty() {
error!("invalid empty ips");
return None;
@ -551,13 +574,18 @@ impl LinkCommon { @@ -551,13 +574,18 @@ impl LinkCommon {
}
impl Link {
pub fn into_xmpp_connection(self, is_c2s: bool) -> Option<XmppConnection> {
fn into_xmpp_connection(self, is_c2s: bool) -> Option<XmppConnection> {
use XmppConnectionType::*;
let (srv_is_c2s, port, link, conn_type) = match self {
#[cfg(feature = "tls")]
Link::DirectTLS { port, link } => (true, port, link, DirectTLS),
#[cfg(feature = "quic")]
Link::Quic { port, link } => (true, port, link, QUIC),
#[cfg(feature = "tls")]
Link::S2SDirectTLS { port, link } => (false, port, link, DirectTLS),
#[cfg(feature = "quic")]
Link::S2SQuic { port, link } => (false, port, link, QUIC),
#[cfg(feature = "websocket")]
Link::WebSocket { href, link } => {
return if is_c2s {
let srv = wss_to_srv(&href, true)?;
@ -570,6 +598,7 @@ impl Link { @@ -570,6 +598,7 @@ impl Link {
None
};
}
#[cfg(feature = "websocket")]
Link::S2SWebSocket { href, link } => {
return if !is_c2s {
let srv = wss_to_srv(&href, true)?;
@ -579,7 +608,7 @@ impl Link { @@ -579,7 +608,7 @@ impl Link {
};
}
Link::Unknown => return None,
_ => return None,
};
if srv_is_c2s == is_c2s {
@ -591,7 +620,7 @@ impl Link { @@ -591,7 +620,7 @@ impl Link {
}
impl HostMeta {
pub fn collect(self, ret: &mut Vec<XmppConnection>, sha256_pinnedpubkeys: &mut Vec<String>, is_c2s: bool) -> Option<u16> {
fn collect(self, ret: &mut Vec<XmppConnection>, sha256_pinnedpubkeys: &mut Vec<String>, is_c2s: bool) -> Option<u16> {
for link in self.links {
if let Some(srv) = link.into_xmpp_connection(is_c2s) {
ret.push(srv);
@ -666,7 +695,7 @@ async fn collect_host_meta_xml(ret: &mut Vec<XmppConnection>, domain: &str, is_c @@ -666,7 +695,7 @@ async fn collect_host_meta_xml(ret: &mut Vec<XmppConnection>, domain: &str, is_c
}
}
pub async fn https_get<T: reqwest::IntoUrl>(url: T) -> reqwest::Result<reqwest::Response> {
async fn https_get<T: reqwest::IntoUrl>(url: T) -> reqwest::Result<reqwest::Response> {
// todo: resolve URL with our resolver
reqwest::Client::builder().https_only(true).build()?.get(url).send().await
}

19
src/tls.rs

@ -3,8 +3,7 @@ use rustls::ServerConnection; @@ -3,8 +3,7 @@ use rustls::ServerConnection;
use std::convert::TryFrom;
use tokio::io::{AsyncBufReadExt, BufStream};
#[cfg(any(feature = "incoming", feature = "outgoing"))]
use tokio_rustls::rustls::ServerName;
use tokio_rustls::{rustls::ServerName, TlsAcceptor};
#[cfg(feature = "outgoing")]
pub async fn tls_connect(target: SocketAddr, server_name: &str, config: OutgoingVerifierConfig) -> Result<(StanzaWrite, StanzaRead)> {
@ -59,6 +58,13 @@ pub async fn starttls_connect(target: SocketAddr, server_name: &str, stream_open @@ -59,6 +58,13 @@ pub async fn starttls_connect(target: SocketAddr, server_name: &str, stream_open
Ok((StanzaWrite::new(wrt), StanzaRead::new(rd)))
}
#[cfg(feature = "incoming")]
impl Config {
pub fn tls_acceptor(&self, cert_key: Arc<CertsKey>) -> Result<TlsAcceptor> {
Ok(TlsAcceptor::from(Arc::new(self.server_config(cert_key)?)))
}
}
#[cfg(feature = "incoming")]
pub fn spawn_tls_listener(local_addr: SocketAddr, config: CloneableConfig, acceptor: TlsAcceptor) -> JoinHandle<Result<()>> {
tokio::spawn(async move {
@ -159,8 +165,13 @@ async fn handle_tls_connection(mut stream: tokio::net::TcpStream, client_addr: & @@ -159,8 +165,13 @@ async fn handle_tls_connection(mut stream: tokio::net::TcpStream, client_addr: &
// where we read the first stanza, where we are guaranteed the handshake is complete, but I can't
// do that without ignoring the lifetime and just pulling a C programmer and pinky promising to be
// *very careful* that this reference doesn't outlive stream...
let server_connection: &'static ServerConnection = unsafe { std::mem::transmute(server_connection) };
let server_certs = ServerCerts::Tls(server_connection);
#[cfg(feature = "s2s-incoming")]
let server_certs = {
let server_connection: &'static ServerConnection = unsafe { std::mem::transmute(server_connection) };
ServerCerts::Tls(server_connection)
};
#[cfg(not(feature = "s2s-incoming"))]
let server_certs = ();
#[cfg(not(feature = "websocket"))]
{

1
src/websocket.rs

@ -29,6 +29,7 @@ pub async fn incoming_websocket_connection(stream: Box<dyn AsyncReadAndWrite + U @@ -29,6 +29,7 @@ pub async fn incoming_websocket_connection(stream: Box<dyn AsyncReadAndWrite + U
Ok((StanzaRead::WebSocketRead(in_rd), StanzaWrite::WebSocketClientWrite(in_wr)))
}
#[cfg(feature = "incoming")]
pub async fn handle_websocket_connection(
stream: Box<dyn AsyncReadAndWrite + Unpin + Send>,
config: CloneableConfig,

Loading…
Cancel
Save