Add and define features for conditional compilation

This commit is contained in:
Travis Burtrum 2022-07-16 16:27:41 -04:00
parent ca4dce14fd
commit 455f833879
12 changed files with 411 additions and 98 deletions

View File

@ -42,10 +42,11 @@ tokio-rustls = { version = "0.23", optional = true }
# outgoing deps # outgoing deps
lazy_static = { version = "1.4", optional = true } lazy_static = { version = "1.4", optional = true }
trust-dns-resolver = { version = "0.21", optional = true } trust-dns-resolver = { version = "0.21", optional = true }
# todo: feature+code for dns-over-rustls
#trust-dns-resolver = { version = "0.21", features = ["dns-over-rustls"], optional = true } #trust-dns-resolver = { version = "0.21", features = ["dns-over-rustls"], optional = true }
# todo: feature to swap between webpki-roots and rustls-native-certs
webpki-roots = { version = "0.22", optional = true } webpki-roots = { version = "0.22", optional = true }
rustls-native-certs = { version = "0.6", optional = true } rustls-native-certs = { version = "0.6", optional = true }
# todo: feed reqwest the roots we already have
reqwest = { version = "0.11", optional = true, default-features = false, features = ["rustls-tls-native-roots", "json", "gzip", "trust-dns"] } reqwest = { version = "0.11", optional = true, default-features = false, features = ["rustls-tls-native-roots", "json", "gzip", "trust-dns"] }
# quic deps # quic deps
@ -56,20 +57,39 @@ rustls = { version = "0.20.2", optional = true }
rustls-pemfile = { version = "1.0.0", optional = true } rustls-pemfile = { version = "1.0.0", optional = true }
# websocket deps # websocket deps
# todo: fix up the situation with these roots
#tokio-tungstenite = { version = "0.17", optional = true, features = ["rustls-tls-webpki-roots"] } #tokio-tungstenite = { version = "0.17", optional = true, features = ["rustls-tls-webpki-roots"] }
tokio-tungstenite = { version = "0.17", optional = true, features = ["rustls-tls-native-roots"] } tokio-tungstenite = { version = "0.17", optional = true, features = ["rustls-tls-native-roots"] }
futures-util = { version = "0.3", default-features = false, features = ["async-await", "sink", "std"], optional = true } futures-util = { version = "0.3", default-features = false, features = ["async-await", "sink", "std"], optional = true }
[features] [features]
default = ["incoming", "outgoing", "quic", "websocket", "logging"] default = ["c2s-incoming", "c2s-outgoing", "s2s-incoming", "s2s-outgoing", "tls", "quic", "websocket", "logging", "tls-ca-roots-native"]
incoming = ["tokio-rustls", "rustls-pemfile", "rustls"]
outgoing = ["tokio-rustls", "trust-dns-resolver", "rustls-native-certs", "lazy_static", "rustls", "reqwest", "rustls-pemfile"]
quic = ["quinn", "rustls-pemfile", "rustls"]
websocket = ["tokio-tungstenite", "futures-util", "tokio-rustls", "rustls-pemfile", "rustls"]
logging = ["rand", "env_logger"]
[package.metadata.cargo-all-features] # you must pick one of these or the other, not both: todo: enable picking both and choosing at runtime
skip_optional_dependencies = true # don't need either of these if only doing c2s-incoming
tls-ca-roots-native = ["rustls-native-certs", "lazy_static", "tokio-rustls"] # this loads CA certs from your OS
tls-ca-roots-bundled = ["webpki-roots"] # this bundles CA certs in the binary
# internal use only, ignore
srv = ["tokio-rustls", "trust-dns-resolver", "lazy_static", "reqwest"]
incoming = ["rustls-pemfile"]
outgoing = ["srv"]
c2s = []
s2s = ["srv", "rustls-pemfile"]
# you must pick one or more of these, you may pick them all
c2s-incoming = ["incoming", "c2s",]
c2s-outgoing = ["outgoing", "c2s"]
s2s-incoming = ["incoming", "s2s"]
s2s-outgoing = ["outgoing", "s2s"]
# protocols you want to support todo: split out tls vs starttls ?
tls = ["tokio-rustls", "rustls"]
quic = ["quinn", "rustls"]
websocket = ["tokio-tungstenite", "futures-util", "tls"] # websocket+incoming also enables incoming TLS support as it's free
logging = ["rand", "env_logger"]
[dev-dependencies] [dev-dependencies]
serde_json = "1.0" serde_json = "1.0"

View File

@ -131,15 +131,38 @@ s2s_ports = {15268}
If you are a grumpy power user who wants to build xmpp-proxy with exactly the features you want, nothing less, nothing If you are a grumpy power user who wants to build xmpp-proxy with exactly the features you want, nothing less, nothing
more, this section is for you! more, this section is for you!
xmpp-proxy has 5 compile-time features: xmpp-proxy has multiple compile-time features, some of which are required, they are grouped as such:
1. `incoming` - enables `incoming_listen` config option for reverse proxy STARTTLS/TLS
2. `outgoing` - enables `outgoing_listen` config option for outgoing proxy STARTTLS/TLS
3. `quic` - enables `quic_listen` config option for reverse proxy QUIC, and QUIC support for `outgoing` if it is enabled
4. `websocket` - enables reverse proxy WebSocket on `incoming_listen`, and WebSocket support for `outgoing` if it is enabled
5. `logging` - enables configurable logging
So to build only supporting reverse proxy STARTTLS/TLS, no QUIC, run: `cargo build --release --no-default-features --features incoming` choose between 1-4 directions:
To build a reverse proxy only, but supporting all of STARTTLS/TLS/QUIC, run: `cargo build --release --no-default-features --features incoming,quic` 1. `c2s-incoming` - enables a server to accept incoming c2s connections
2. `c2s-outgoing` - enables a client to make outgoing c2s connections
3. `s2s-incoming` - enables a server to accept incoming s2s connections
4. `s2s-outgoing` - enables a server to make outgoing s2s connections
choose between 1-3 transport protocols:
1. `tls` - enables STARTTLS/TLS support
2. `quic` - enables QUIC support
3. `websocket` - enables WebSocket support, also enables TLS incoming support if the appropriate directions are enabled
choose exactly 1 of these methods to get trusted CA roots, not needed if only `c2s-incoming` is enabled:
1. `tls-ca-roots-native` - reads CA roots from operating system
2. `tls-ca-roots-bundled` - bundles CA roots into the binary from the `webpki-roots` project
choose any of these optional features:
1. `logging` - enables configurable logging
So to build only supporting reverse proxy STARTTLS/TLS, no QUIC, run: `cargo build --release --no-default-features --features c2s-incoming,s2s-incoming,tls`
To build a reverse proxy only, but supporting all of STARTTLS/TLS/QUIC, run: `cargo build --release --no-default-features --features c2s-incoming,s2s-incoming,tls,quic`
#### Development
1. `check-all-features.sh` is used to check compilation with all supported feature permutations
2. `integration/test.sh` uses [Rootless podman](https://wiki.archlinux.org/title/Podman#Rootless_Podman) to run many tests
through xmpp-proxy on a real network with real dns, web, and xmpp servers, all of these should pass before pushing commits,
and write new tests to cover new functionality.
3. To submit code changes submit a PR on [github](https://github.com/moparisthebest/xmpp-proxy) or
[code.moparisthebest.com](https://code.moparisthebest.com/moparisthebest/xmpp-proxy) or send me a patch via email,
XMPP, fediverse, or carrier pigeon.
#### License #### License
GNU/AGPLv3 - Check LICENSE.md for details GNU/AGPLv3 - Check LICENSE.md for details

83
build.rs Normal file
View File

@ -0,0 +1,83 @@
use std::{env, fs::File, io::Write, path::Path};
fn main() {
println!("cargo:rerun-if-changed=build.rs");
let out_dir = env::var_os("OUT_DIR").unwrap();
let dest_path = Path::new(&out_dir).join("version.rs");
let mut w = File::create(dest_path).unwrap();
let allowed_features = [
"c2s-incoming",
"c2s-outgoing",
"s2s-incoming",
"s2s-outgoing",
"tls",
"quic",
"websocket",
"tls-ca-roots-native",
"tls-ca-roots-bundled",
];
let optional_deps = [
"rustls",
"tokio-rustls",
"rustls-pemfile",
"quinn",
"tokio-tungstenite",
"futures-util",
"trust-dns-resolver",
"reqwest",
"lazy-static",
"rustls-native-certs",
"webpki-roots",
"env-logger",
"rand",
];
let mut features = Vec::new();
let mut optional = Vec::new();
for (mut key, value) in env::vars() {
//writeln!(&mut w, "{key}: {value}", ).unwrap();
if value == "1" && key.starts_with("CARGO_FEATURE_") {
let mut key = key.split_off(14).replace("_", "-");
key.make_ascii_lowercase();
if allowed_features.contains(&key.as_str()) {
features.push(key);
} else if optional_deps.contains(&key.as_str()) {
optional.push(key);
}
}
}
features.sort_by(|a, b| {
allowed_features
.iter()
.position(|&r| r == a)
.unwrap()
.partial_cmp(&allowed_features.iter().position(|&r| r == b).unwrap())
.unwrap()
});
optional.sort_by(|a, b| {
optional_deps
.iter()
.position(|&r| r == a)
.unwrap()
.partial_cmp(&optional_deps.iter().position(|&r| r == b).unwrap())
.unwrap()
});
let features = features.join(",");
let optional = optional.join(",");
let name = env!("CARGO_PKG_NAME");
let version = env!("CARGO_PKG_VERSION");
let target = env::var("TARGET").unwrap();
writeln!(
&mut w,
"{{println!(
\"{name} {version} ({target})
Features: {features}
Optional crates: {optional}\");}}"
)
.unwrap();
}

59
check-all-features.sh Normal file
View File

@ -0,0 +1,59 @@
#!/bin/bash
set -euo pipefail
export RUSTFLAGS=-Awarnings
show() {
local -a results=()
let idx=$2
for (( j = 0; j < $1; j++ )); do
if (( idx % 2 )); then results=("${results[@]}" "${list[$j]}"); fi
let idx\>\>=1
done
echo "${results[@]}"
}
perm_lines() {
list=($@)
let n=${#list[@]}
for (( i = 1; i < 2**n; i++ )); do
show $n $i
done
}
perms() {
perm_lines "$@" | tr ' ' ',' | sort -u | tr '\n' ' '
}
echo_cargo() {
#echo cargo run "$@" -- -v
#cargo run "$@" -- -v
echo cargo check "$@"
cargo check "$@"
}
echo_cargo
for optional in "" ",logging"
do
for proto in $(perms tls quic websocket)
do
for direction in $(perms c2s-incoming c2s-outgoing s2s-incoming s2s-outgoing)
do
for ca_roots in tls-ca-roots-native tls-ca-roots-bundled
do
echo_cargo --no-default-features --features $direction,$proto,$ca_roots$optional
done
done
done
done
for optional in "" ",logging"
do
for proto in $(perms tls quic websocket)
do
echo_cargo --no-default-features --features c2s-incoming,$proto$optional
done
done
echo good!

View File

@ -1,8 +1,8 @@
// Box<dyn AsyncWrite + Unpin + Send>, Box<dyn AsyncRead + Unpin + Send> // Box<dyn AsyncWrite + Unpin + Send>, Box<dyn AsyncRead + Unpin + Send>
#[cfg(feature = "websocket")] #[cfg(feature = "websocket")]
use crate::{from_ws, to_ws_new}; use crate::{from_ws, to_ws_new, AsyncReadAndWrite};
use crate::{slicesubsequence::SliceSubsequence, trace, AsyncReadAndWrite, StanzaFilter, StanzaRead::*, StanzaReader, StanzaWrite::*}; use crate::{slicesubsequence::SliceSubsequence, trace, StanzaFilter, StanzaRead::*, StanzaReader, StanzaWrite::*};
use anyhow::{bail, Result}; use anyhow::{bail, Result};
#[cfg(feature = "websocket")] #[cfg(feature = "websocket")]
use futures_util::{ use futures_util::{

View File

@ -8,6 +8,8 @@ use anyhow::bail;
use std::net::SocketAddr; use std::net::SocketAddr;
pub use log::{debug, error, info, log_enabled, trace}; pub use log::{debug, error, info, log_enabled, trace};
#[cfg(feature = "s2s-incoming")]
use rustls::{Certificate, ServerConnection}; use rustls::{Certificate, ServerConnection};
pub fn to_str(buf: &[u8]) -> std::borrow::Cow<'_, str> { pub fn to_str(buf: &[u8]) -> std::borrow::Cow<'_, str> {
@ -152,7 +154,10 @@ impl<'a> Context<'a> {
} }
} }
#[cfg(feature = "incoming")] #[cfg(not(feature = "s2s-incoming"))]
pub type ServerCerts = ();
#[cfg(feature = "s2s-incoming")]
#[derive(Clone)] #[derive(Clone)]
pub enum ServerCerts { pub enum ServerCerts {
Tls(&'static ServerConnection), Tls(&'static ServerConnection),
@ -160,10 +165,12 @@ pub enum ServerCerts {
Quic(quinn::Connection), Quic(quinn::Connection),
} }
#[cfg(feature = "s2s-incoming")]
impl ServerCerts { impl ServerCerts {
pub fn peer_certificates(&self) -> Option<Vec<Certificate>> { pub fn peer_certificates(&self) -> Option<Vec<Certificate>> {
match self { match self {
ServerCerts::Tls(c) => c.peer_certificates().map(|c| c.to_vec()), ServerCerts::Tls(c) => c.peer_certificates().map(|c| c.to_vec()),
#[cfg(feature = "quic")]
ServerCerts::Quic(c) => c.peer_identity().and_then(|v| v.downcast::<Vec<Certificate>>().ok()).map(|v| v.to_vec()), ServerCerts::Quic(c) => c.peer_identity().and_then(|v| v.downcast::<Vec<Certificate>>().ok()).map(|v| v.to_vec()),
} }
} }
@ -171,6 +178,7 @@ impl ServerCerts {
pub fn sni(&self) -> Option<String> { pub fn sni(&self) -> Option<String> {
match self { match self {
ServerCerts::Tls(c) => c.sni_hostname().map(|s| s.to_string()), ServerCerts::Tls(c) => c.sni_hostname().map(|s| s.to_string()),
#[cfg(feature = "quic")]
ServerCerts::Quic(c) => c.handshake_data().and_then(|v| v.downcast::<quinn::crypto::rustls::HandshakeData>().ok()).and_then(|h| h.server_name), ServerCerts::Quic(c) => c.handshake_data().and_then(|v| v.downcast::<quinn::crypto::rustls::HandshakeData>().ok()).and_then(|h| h.server_name),
} }
} }
@ -178,6 +186,7 @@ impl ServerCerts {
pub fn alpn(&self) -> Option<Vec<u8>> { pub fn alpn(&self) -> Option<Vec<u8>> {
match self { match self {
ServerCerts::Tls(c) => c.alpn_protocol().map(|s| s.to_vec()), ServerCerts::Tls(c) => c.alpn_protocol().map(|s| s.to_vec()),
#[cfg(feature = "quic")]
ServerCerts::Quic(c) => c.handshake_data().and_then(|v| v.downcast::<quinn::crypto::rustls::HandshakeData>().ok()).and_then(|h| h.protocol), ServerCerts::Quic(c) => c.handshake_data().and_then(|v| v.downcast::<quinn::crypto::rustls::HandshakeData>().ok()).and_then(|h| h.protocol),
} }
} }
@ -185,6 +194,7 @@ impl ServerCerts {
pub fn is_tls(&self) -> bool { pub fn is_tls(&self) -> bool {
match self { match self {
ServerCerts::Tls(_) => true, ServerCerts::Tls(_) => true,
#[cfg(feature = "quic")]
ServerCerts::Quic(_) => false, ServerCerts::Quic(_) => false,
} }
} }

View File

@ -8,9 +8,8 @@ use std::iter::Iterator;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::path::Path; use std::path::Path;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use std::time::SystemTime;
use die::Die; use die::{die, Die};
use serde_derive::Deserialize; use serde_derive::Deserialize;
@ -27,7 +26,7 @@ use rustls::{
#[cfg(feature = "tokio-rustls")] #[cfg(feature = "tokio-rustls")]
use tokio_rustls::{ use tokio_rustls::{
webpki::{DnsNameRef, TlsServerTrustAnchors, TrustAnchor}, webpki::{DnsNameRef, TlsServerTrustAnchors, TrustAnchor},
TlsAcceptor, TlsConnector, TlsConnector,
}; };
use anyhow::{anyhow, bail, Result}; use anyhow::{anyhow, bail, Result};
@ -42,7 +41,9 @@ mod quic;
#[cfg(feature = "quic")] #[cfg(feature = "quic")]
use crate::quic::*; use crate::quic::*;
#[cfg(feature = "tls")]
mod tls; mod tls;
#[cfg(feature = "tls")]
use crate::tls::*; use crate::tls::*;
#[cfg(feature = "outgoing")] #[cfg(feature = "outgoing")]
@ -50,9 +51,9 @@ mod outgoing;
#[cfg(feature = "outgoing")] #[cfg(feature = "outgoing")]
use crate::outgoing::*; use crate::outgoing::*;
#[cfg(feature = "outgoing")] #[cfg(any(feature = "s2s-incoming", feature = "outgoing"))]
mod srv; mod srv;
#[cfg(feature = "outgoing")] #[cfg(any(feature = "s2s-incoming", feature = "outgoing"))]
use crate::srv::*; use crate::srv::*;
#[cfg(feature = "websocket")] #[cfg(feature = "websocket")]
@ -60,7 +61,9 @@ mod websocket;
#[cfg(feature = "websocket")] #[cfg(feature = "websocket")]
use crate::websocket::*; use crate::websocket::*;
#[cfg(any(feature = "s2s-incoming", feature = "outgoing"))]
mod verify; mod verify;
#[cfg(any(feature = "s2s-incoming", feature = "outgoing"))]
use crate::verify::*; use crate::verify::*;
mod in_out; mod in_out;
@ -92,6 +95,7 @@ lazy_static::lazy_static! {
}; };
} }
#[cfg(any(feature = "rustls-native-certs", feature = "webpki-roots"))]
pub fn root_cert_store() -> rustls::RootCertStore { pub fn root_cert_store() -> rustls::RootCertStore {
use rustls::{OwnedTrustAnchor, RootCertStore}; use rustls::{OwnedTrustAnchor, RootCertStore};
let mut root_cert_store = RootCertStore::empty(); let mut root_cert_store = RootCertStore::empty();
@ -104,43 +108,45 @@ pub fn root_cert_store() -> rustls::RootCertStore {
root_cert_store root_cert_store
} }
#[derive(Deserialize)] #[derive(Deserialize, Default)]
struct Config { struct Config {
tls_key: String, tls_key: String,
tls_cert: String, tls_cert: String,
incoming_listen: Option<Vec<String>>, incoming_listen: Vec<String>,
quic_listen: Option<Vec<String>>, quic_listen: Vec<String>,
outgoing_listen: Option<Vec<String>>, outgoing_listen: Vec<String>,
max_stanza_size_bytes: usize, max_stanza_size_bytes: usize,
s2s_target: SocketAddr, s2s_target: Option<SocketAddr>,
c2s_target: SocketAddr, c2s_target: Option<SocketAddr>,
proxy: bool, proxy: bool,
#[cfg(feature = "logging")]
log_level: Option<String>, log_level: Option<String>,
#[cfg(feature = "logging")]
log_style: Option<String>, log_style: Option<String>,
} }
#[derive(Clone)] #[derive(Clone)]
pub struct CloneableConfig { pub struct CloneableConfig {
max_stanza_size_bytes: usize, max_stanza_size_bytes: usize,
s2s_target: SocketAddr, #[cfg(feature = "s2s-incoming")]
c2s_target: SocketAddr, s2s_target: Option<SocketAddr>,
#[cfg(feature = "c2s-incoming")]
c2s_target: Option<SocketAddr>,
proxy: bool, proxy: bool,
} }
struct CertsKey { struct CertsKey {
#[cfg(feature = "rustls-pemfile")]
inner: Result<RwLock<Arc<rustls::sign::CertifiedKey>>>, inner: Result<RwLock<Arc<rustls::sign::CertifiedKey>>>,
} }
impl CertsKey { impl CertsKey {
fn new(cert_key: Result<rustls::sign::CertifiedKey>) -> Self { fn new(main_config: &Config) -> Self {
CertsKey { CertsKey {
inner: cert_key.map(|c| RwLock::new(Arc::new(c))), #[cfg(feature = "rustls-pemfile")]
inner: main_config.certs_key().map(|c| RwLock::new(Arc::new(c))),
} }
} }
#[cfg(unix)] #[cfg(all(unix, any(feature = "incoming", feature = "s2s-outgoing")))]
fn spawn_refresh_task(&'static self, cfg_path: OsString) -> Option<JoinHandle<Result<()>>> { fn spawn_refresh_task(&'static self, cfg_path: OsString) -> Option<JoinHandle<Result<()>>> {
if self.inner.is_err() { if self.inner.is_err() {
None None
@ -169,12 +175,14 @@ impl CertsKey {
} }
} }
#[cfg(feature = "rustls-pemfile")]
impl rustls::server::ResolvesServerCert for CertsKey { impl rustls::server::ResolvesServerCert for CertsKey {
fn resolve(&self, _: rustls::server::ClientHello) -> Option<Arc<rustls::sign::CertifiedKey>> { fn resolve(&self, _: rustls::server::ClientHello) -> Option<Arc<rustls::sign::CertifiedKey>> {
self.inner.as_ref().map(|rwl| rwl.read().expect("CertKey poisoned?").clone()).ok() self.inner.as_ref().map(|rwl| rwl.read().expect("CertKey poisoned?").clone()).ok()
} }
} }
#[cfg(feature = "rustls-pemfile")]
impl rustls::client::ResolvesClientCert for CertsKey { impl rustls::client::ResolvesClientCert for CertsKey {
fn resolve(&self, _: &[&[u8]], _: &[SignatureScheme]) -> Option<Arc<CertifiedKey>> { fn resolve(&self, _: &[&[u8]], _: &[SignatureScheme]) -> Option<Arc<CertifiedKey>> {
self.inner.as_ref().map(|rwl| rwl.read().expect("CertKey poisoned?").clone()).ok() self.inner.as_ref().map(|rwl| rwl.read().expect("CertKey poisoned?").clone()).ok()
@ -185,6 +193,17 @@ impl rustls::client::ResolvesClientCert for CertsKey {
} }
} }
#[cfg(not(feature = "rustls-pemfile"))]
impl rustls::client::ResolvesClientCert for CertsKey {
fn resolve(&self, _: &[&[u8]], _: &[SignatureScheme]) -> Option<Arc<CertifiedKey>> {
None
}
fn has_certs(&self) -> bool {
false
}
}
impl Config { impl Config {
fn parse<P: AsRef<Path>>(path: P) -> Result<Config> { fn parse<P: AsRef<Path>>(path: P) -> Result<Config> {
let mut f = File::open(path)?; let mut f = File::open(path)?;
@ -196,7 +215,9 @@ impl Config {
fn get_cloneable_cfg(&self) -> CloneableConfig { fn get_cloneable_cfg(&self) -> CloneableConfig {
CloneableConfig { CloneableConfig {
max_stanza_size_bytes: self.max_stanza_size_bytes, max_stanza_size_bytes: self.max_stanza_size_bytes,
#[cfg(feature = "s2s-incoming")]
s2s_target: self.s2s_target, s2s_target: self.s2s_target,
#[cfg(feature = "c2s-incoming")]
c2s_target: self.c2s_target, c2s_target: self.c2s_target,
proxy: self.proxy, proxy: self.proxy,
} }
@ -204,6 +225,7 @@ impl Config {
#[cfg(feature = "outgoing")] #[cfg(feature = "outgoing")]
fn get_outgoing_cfg(&self, certs_key: Arc<CertsKey>) -> OutgoingConfig { fn get_outgoing_cfg(&self, certs_key: Arc<CertsKey>) -> OutgoingConfig {
#[cfg(feature = "rustls-pemfile")]
if let Err(e) = &certs_key.inner { if let Err(e) = &certs_key.inner {
debug!("invalid key/cert for s2s client auth: {}", e); debug!("invalid key/cert for s2s client auth: {}", e);
} }
@ -243,21 +265,18 @@ impl Config {
bail!("invalid cert/key: {}", e); bail!("invalid cert/key: {}", e);
} }
let mut config = ServerConfig::builder() let config = ServerConfig::builder().with_safe_defaults();
.with_safe_defaults() #[cfg(feature = "s2s")]
.with_client_cert_verifier(Arc::new(AllowAnonymousOrAnyCert)) let config = config.with_client_cert_verifier(Arc::new(AllowAnonymousOrAnyCert));
.with_cert_resolver(certs_key); #[cfg(not(feature = "s2s"))]
let config = config.with_no_client_auth();
let mut config = config.with_cert_resolver(certs_key);
// todo: will connecting without alpn work then? // todo: will connecting without alpn work then?
config.alpn_protocols.push(ALPN_XMPP_CLIENT.to_vec()); config.alpn_protocols.push(ALPN_XMPP_CLIENT.to_vec());
config.alpn_protocols.push(ALPN_XMPP_SERVER.to_vec()); config.alpn_protocols.push(ALPN_XMPP_SERVER.to_vec());
Ok(config) Ok(config)
} }
#[cfg(feature = "incoming")]
fn tls_acceptor(&self, cert_key: Arc<CertsKey>) -> Result<TlsAcceptor> {
Ok(TlsAcceptor::from(Arc::new(self.server_config(cert_key)?)))
}
} }
#[derive(Clone)] #[derive(Clone)]
@ -310,11 +329,13 @@ pub struct OutgoingVerifierConfig {
pub connector: TlsConnector, pub connector: TlsConnector,
} }
#[cfg(feature = "incoming")]
async fn shuffle_rd_wr(in_rd: StanzaRead, in_wr: StanzaWrite, config: CloneableConfig, server_certs: ServerCerts, local_addr: SocketAddr, client_addr: &mut Context<'_>) -> Result<()> { async fn shuffle_rd_wr(in_rd: StanzaRead, in_wr: StanzaWrite, config: CloneableConfig, server_certs: ServerCerts, local_addr: SocketAddr, client_addr: &mut Context<'_>) -> Result<()> {
let filter = StanzaFilter::new(config.max_stanza_size_bytes); let filter = StanzaFilter::new(config.max_stanza_size_bytes);
shuffle_rd_wr_filter(in_rd, in_wr, config, server_certs, local_addr, client_addr, filter).await shuffle_rd_wr_filter(in_rd, in_wr, config, server_certs, local_addr, client_addr, filter).await
} }
#[cfg(feature = "incoming")]
async fn shuffle_rd_wr_filter( async fn shuffle_rd_wr_filter(
mut in_rd: StanzaRead, mut in_rd: StanzaRead,
mut in_wr: StanzaWrite, mut in_wr: StanzaWrite,
@ -328,26 +349,30 @@ async fn shuffle_rd_wr_filter(
let (stream_open, is_c2s) = stream_preamble(&mut in_rd, &mut in_wr, client_addr.log_from(), &mut in_filter).await?; let (stream_open, is_c2s) = stream_preamble(&mut in_rd, &mut in_wr, client_addr.log_from(), &mut in_filter).await?;
client_addr.set_c2s_stream_open(is_c2s, &stream_open); client_addr.set_c2s_stream_open(is_c2s, &stream_open);
trace!( #[cfg(feature = "s2s-incoming")]
"{} connected: sni: {:?}, alpn: {:?}, tls-not-quic: {}", {
client_addr.log_from(), trace!(
server_certs.sni(), "{} connected: sni: {:?}, alpn: {:?}, tls-not-quic: {}",
server_certs.alpn().map(|a| String::from_utf8_lossy(&a).to_string()), client_addr.log_from(),
server_certs.is_tls(), server_certs.sni(),
); server_certs.alpn().map(|a| String::from_utf8_lossy(&a).to_string()),
server_certs.is_tls(),
);
if !is_c2s { if !is_c2s {
// for s2s we need this // for s2s we need this
let domain = stream_open use std::time::SystemTime;
.extract_between(b" from='", b"'") let domain = stream_open
.or_else(|_| stream_open.extract_between(b" from=\"", b"\"")) .extract_between(b" from='", b"'")
.and_then(|b| Ok(std::str::from_utf8(b)?))?; .or_else(|_| stream_open.extract_between(b" from=\"", b"\""))
let (_, cert_verifier) = get_xmpp_connections(domain, is_c2s).await?; .and_then(|b| Ok(std::str::from_utf8(b)?))?;
let certs = server_certs.peer_certificates().ok_or_else(|| anyhow!("no client cert auth for s2s incoming from {}", domain))?; let (_, cert_verifier) = get_xmpp_connections(domain, is_c2s).await?;
// todo: send stream error saying cert is invalid let certs = server_certs.peer_certificates().ok_or_else(|| anyhow!("no client cert auth for s2s incoming from {}", domain))?;
cert_verifier.verify_cert(&certs[0], &certs[1..], SystemTime::now())?; // todo: send stream error saying cert is invalid
cert_verifier.verify_cert(&certs[0], &certs[1..], SystemTime::now())?;
}
drop(server_certs);
} }
drop(server_certs);
let (out_rd, out_wr) = open_incoming(&config, local_addr, client_addr, &stream_open, is_c2s, &mut in_filter).await?; let (out_rd, out_wr) = open_incoming(&config, local_addr, client_addr, &stream_open, is_c2s, &mut in_filter).await?;
drop(stream_open); drop(stream_open);
@ -407,6 +432,7 @@ async fn shuffle_rd_wr_filter_only(
Ok(()) Ok(())
} }
#[cfg(feature = "incoming")]
async fn open_incoming( async fn open_incoming(
config: &CloneableConfig, config: &CloneableConfig,
local_addr: SocketAddr, local_addr: SocketAddr,
@ -415,7 +441,18 @@ async fn open_incoming(
is_c2s: bool, is_c2s: bool,
in_filter: &mut StanzaFilter, in_filter: &mut StanzaFilter,
) -> Result<(ReadHalf<tokio::net::TcpStream>, WriteHalf<tokio::net::TcpStream>)> { ) -> Result<(ReadHalf<tokio::net::TcpStream>, WriteHalf<tokio::net::TcpStream>)> {
let target = if is_c2s { config.c2s_target } else { config.s2s_target }; let target = if is_c2s {
#[cfg(not(feature = "c2s-incoming"))]
bail!("incoming c2s connection but lacking compile-time support");
#[cfg(feature = "c2s-incoming")]
config.c2s_target
} else {
#[cfg(not(feature = "s2s-incoming"))]
bail!("incoming s2s connection but lacking compile-time support");
#[cfg(feature = "s2s-incoming")]
config.s2s_target
}
.ok_or_else(|| anyhow!("incoming connection but `{}_target` not defined", c2s(is_c2s)))?;
client_addr.set_to_addr(target); client_addr.set_to_addr(target);
let out_stream = tokio::net::TcpStream::connect(target).await?; let out_stream = tokio::net::TcpStream::connect(target).await?;
@ -468,7 +505,12 @@ pub async fn stream_preamble(in_rd: &mut StanzaRead, in_wr: &mut StanzaWrite, cl
#[tokio::main] #[tokio::main]
//#[tokio::main(flavor = "multi_thread", worker_threads = 10)] //#[tokio::main(flavor = "multi_thread", worker_threads = 10)]
async fn main() { async fn main() {
let cfg_path = std::env::args_os().nth(1).unwrap_or_else(|| OsString::from("/etc/xmpp-proxy/xmpp-proxy.toml")); let cfg_path = std::env::args_os().nth(1);
if cfg_path == Some(OsString::from("-v")) {
include!(concat!(env!("OUT_DIR"), "/version.rs"));
die!(0);
}
let cfg_path = cfg_path.unwrap_or_else(|| OsString::from("/etc/xmpp-proxy/xmpp-proxy.toml"));
let main_config = Config::parse(&cfg_path).die("invalid config file"); let main_config = Config::parse(&cfg_path).die("invalid config file");
#[cfg(feature = "logging")] #[cfg(feature = "logging")]
@ -486,34 +528,59 @@ async fn main() {
// todo: config for this: builder.format_timestamp(None); // todo: config for this: builder.format_timestamp(None);
builder.init(); builder.init();
} }
#[cfg(not(feature = "logging"))]
if main_config.log_level.is_some() || main_config.log_style.is_some() {
die!("log_level or log_style defined in config but logging disabled at compile-time");
}
let config = main_config.get_cloneable_cfg(); let config = main_config.get_cloneable_cfg();
let certs_key = Arc::new(CertsKey::new(main_config.certs_key())); let certs_key = Arc::new(CertsKey::new(&main_config));
let mut handles: Vec<JoinHandle<Result<()>>> = Vec::new(); let mut handles: Vec<JoinHandle<Result<()>>> = Vec::new();
#[cfg(feature = "incoming")] if !main_config.incoming_listen.is_empty() {
if let Some(ref listeners) = main_config.incoming_listen { #[cfg(all(any(feature = "tls", feature = "websocket"), feature = "incoming"))]
let acceptor = main_config.tls_acceptor(certs_key.clone()).die("invalid cert/key ?"); {
for listener in listeners { if main_config.c2s_target.is_none() && main_config.s2s_target.is_none() {
handles.push(spawn_tls_listener(listener.parse().die("invalid listener address"), config.clone(), acceptor.clone())); die!("one of c2s_target/s2s_target must be defined if incoming_listen is non-empty");
}
let acceptor = main_config.tls_acceptor(certs_key.clone()).die("invalid cert/key ?");
for listener in main_config.incoming_listen.iter() {
handles.push(spawn_tls_listener(listener.parse().die("invalid listener address"), config.clone(), acceptor.clone()));
}
} }
#[cfg(not(all(any(feature = "tls", feature = "websocket"), feature = "incoming")))]
die!("incoming_listen non-empty but (tls or websocket) or (s2s-incoming and c2s-incoming) disabled at compile-time");
} }
#[cfg(all(feature = "quic", feature = "incoming"))] if !main_config.quic_listen.is_empty() {
if let Some(ref listeners) = main_config.quic_listen { #[cfg(all(feature = "quic", feature = "incoming"))]
let quic_config = main_config.quic_server_config(certs_key.clone()).die("invalid cert/key ?"); {
for listener in listeners { if main_config.c2s_target.is_none() && main_config.s2s_target.is_none() {
handles.push(spawn_quic_listener(listener.parse().die("invalid listener address"), config.clone(), quic_config.clone())); die!("one of c2s_target/s2s_target must be defined if quic_listen is non-empty");
}
let quic_config = main_config.quic_server_config(certs_key.clone()).die("invalid cert/key ?");
for listener in main_config.quic_listen.iter() {
handles.push(spawn_quic_listener(listener.parse().die("invalid listener address"), config.clone(), quic_config.clone()));
}
} }
#[cfg(not(all(feature = "quic", feature = "incoming")))]
die!("quic_listen non-empty but quic or (s2s-incoming and c2s-incoming) disabled at compile-time");
} }
#[cfg(feature = "outgoing")] if !main_config.outgoing_listen.is_empty() {
if let Some(ref listeners) = main_config.outgoing_listen { #[cfg(feature = "outgoing")]
let outgoing_cfg = main_config.get_outgoing_cfg(certs_key.clone()); {
for listener in listeners { let outgoing_cfg = main_config.get_outgoing_cfg(certs_key.clone());
handles.push(spawn_outgoing_listener(listener.parse().die("invalid listener address"), outgoing_cfg.clone())); for listener in main_config.outgoing_listen.iter() {
handles.push(spawn_outgoing_listener(listener.parse().die("invalid listener address"), outgoing_cfg.clone()));
}
} }
#[cfg(not(feature = "outgoing"))]
die!("outgoing_listen non-empty but c2s-outgoing and s2s-outgoing disabled at compile-time");
} }
#[cfg(unix)] if handles.is_empty() {
die!("all of incoming_listen, quic_listen, outgoing_listen empty, nothing to do, exiting...");
}
#[cfg(all(unix, any(feature = "incoming", feature = "s2s-outgoing")))]
if let Some(refresh_task) = Box::leak(Box::new(certs_key.clone())).spawn_refresh_task(cfg_path) { if let Some(refresh_task) = Box::leak(Box::new(certs_key.clone())).spawn_refresh_task(cfg_path) {
handles.push(refresh_task); handles.push(refresh_task);
} }

View File

@ -5,15 +5,20 @@ async fn handle_outgoing_connection(stream: tokio::net::TcpStream, client_addr:
let mut in_filter = StanzaFilter::new(config.max_stanza_size_bytes); let mut in_filter = StanzaFilter::new(config.max_stanza_size_bytes);
let is_ws = first_bytes_match(&stream, &mut in_filter.buf[0..3], |p| p == b"GET").await?; #[cfg(feature = "websocket")]
let (mut in_rd, mut in_wr) = if first_bytes_match(&stream, &mut in_filter.buf[0..3], |p| p == b"GET").await? {
let (mut in_rd, mut in_wr) = if is_ws {
incoming_websocket_connection(Box::new(stream), config.max_stanza_size_bytes).await? incoming_websocket_connection(Box::new(stream), config.max_stanza_size_bytes).await?
} else { } else {
let (in_rd, in_wr) = tokio::io::split(stream); let (in_rd, in_wr) = tokio::io::split(stream);
(StanzaRead::new(in_rd), StanzaWrite::new(in_wr)) (StanzaRead::new(in_rd), StanzaWrite::new(in_wr))
}; };
#[cfg(not(feature = "websocket"))]
let (mut in_rd, mut in_wr) = {
let (in_rd, in_wr) = tokio::io::split(stream);
(StanzaRead::new(in_rd), StanzaWrite::new(in_wr))
};
// now read to figure out client vs server // now read to figure out client vs server
let (stream_open, is_c2s) = stream_preamble(&mut in_rd, &mut in_wr, client_addr.log_to(), &mut in_filter).await?; let (stream_open, is_c2s) = stream_preamble(&mut in_rd, &mut in_wr, client_addr.log_to(), &mut in_filter).await?;
client_addr.set_c2s_stream_open(is_c2s, &stream_open); client_addr.set_c2s_stream_open(is_c2s, &stream_open);

View File

@ -44,7 +44,12 @@ pub fn spawn_quic_listener(local_addr: SocketAddr, config: CloneableConfig, serv
tokio::spawn(async move { tokio::spawn(async move {
if let Ok(mut new_conn) = incoming_conn.await { if let Ok(mut new_conn) = incoming_conn.await {
let client_addr = crate::Context::new("quic-in", new_conn.connection.remote_address()); let client_addr = crate::Context::new("quic-in", new_conn.connection.remote_address());
#[cfg(feature = "s2s-incoming")]
let server_certs = ServerCerts::Quic(new_conn.connection); let server_certs = ServerCerts::Quic(new_conn.connection);
#[cfg(not(feature = "s2s-incoming"))]
let server_certs = ();
info!("{} connected new connection", client_addr.log_from()); info!("{} connected new connection", client_addr.log_from());
while let Some(Ok((wrt, rd))) = new_conn.bi_streams.next().await { while let Some(Ok((wrt, rd))) = new_conn.bi_streams.next().await {

View File

@ -29,8 +29,10 @@ fn make_resolver() -> TokioAsyncResolver {
} }
#[derive(Clone, Debug, PartialEq, Eq)] #[derive(Clone, Debug, PartialEq, Eq)]
pub enum XmppConnectionType { enum XmppConnectionType {
#[cfg(feature = "tls")]
StartTLS, StartTLS,
#[cfg(feature = "tls")]
DirectTLS, DirectTLS,
#[cfg(feature = "quic")] #[cfg(feature = "quic")]
QUIC, QUIC,
@ -41,9 +43,13 @@ pub enum XmppConnectionType {
impl XmppConnectionType { impl XmppConnectionType {
fn idx(&self) -> u8 { fn idx(&self) -> u8 {
match self { match self {
#[cfg(feature = "quic")]
XmppConnectionType::QUIC => 0, XmppConnectionType::QUIC => 0,
#[cfg(feature = "tls")]
XmppConnectionType::DirectTLS => 1, XmppConnectionType::DirectTLS => 1,
#[cfg(feature = "tls")]
XmppConnectionType::StartTLS => 2, XmppConnectionType::StartTLS => 2,
#[cfg(feature = "websocket")]
XmppConnectionType::WebSocket(_, _) => 3, XmppConnectionType::WebSocket(_, _) => 3,
} }
} }
@ -57,6 +63,7 @@ impl Ord for XmppConnectionType {
} }
// so they are the same type, but WebSocket is a special case... // so they are the same type, but WebSocket is a special case...
match (self, other) { match (self, other) {
#[cfg(feature = "websocket")]
(XmppConnectionType::WebSocket(self_uri, self_origin), XmppConnectionType::WebSocket(other_uri, other_origin)) => { (XmppConnectionType::WebSocket(self_uri, self_origin), XmppConnectionType::WebSocket(other_uri, other_origin)) => {
let cmp = self_uri.to_string().cmp(&other_uri.to_string()); let cmp = self_uri.to_string().cmp(&other_uri.to_string());
if cmp != Ordering::Equal { if cmp != Ordering::Equal {
@ -153,6 +160,7 @@ fn sort_dedup(ret: &mut Vec<XmppConnection>) {
} }
impl XmppConnection { impl XmppConnection {
#[cfg(feature = "outgoing")]
pub async fn connect( pub async fn connect(
&self, &self,
domain: &str, domain: &str,
@ -175,10 +183,12 @@ impl XmppConnection {
let to_addr = SocketAddr::new(*ip, self.port); let to_addr = SocketAddr::new(*ip, self.port);
debug!("{} trying ip {}", client_addr.log_from(), to_addr); debug!("{} trying ip {}", client_addr.log_from(), to_addr);
match self.conn_type { match self.conn_type {
#[cfg(feature = "tls")]
XmppConnectionType::StartTLS => match crate::starttls_connect(to_addr, domain, stream_open, in_filter, config.clone()).await { XmppConnectionType::StartTLS => match crate::starttls_connect(to_addr, domain, stream_open, in_filter, config.clone()).await {
Ok((wr, rd)) => return Ok((wr, rd, to_addr, "starttls-out")), Ok((wr, rd)) => return Ok((wr, rd, to_addr, "starttls-out")),
Err(e) => error!("starttls connection failed to IP {} from SRV {}, error: {}", to_addr, self.target, e), Err(e) => error!("starttls connection failed to IP {} from SRV {}, error: {}", to_addr, self.target, e),
}, },
#[cfg(feature = "tls")]
XmppConnectionType::DirectTLS => match crate::tls_connect(to_addr, domain, config.clone()).await { XmppConnectionType::DirectTLS => match crate::tls_connect(to_addr, domain, config.clone()).await {
Ok((wr, rd)) => return Ok((wr, rd, to_addr, "directtls-out")), Ok((wr, rd)) => return Ok((wr, rd, to_addr, "directtls-out")),
Err(e) => error!("direct tls connection failed to IP {} from SRV {}, error: {}", to_addr, self.target, e), Err(e) => error!("direct tls connection failed to IP {} from SRV {}, error: {}", to_addr, self.target, e),
@ -336,7 +346,9 @@ pub async fn get_xmpp_connections(domain: &str, is_c2s: bool) -> Result<(Vec<Xmp
// ignore everything else if new host-meta format // ignore everything else if new host-meta format
#[cfg(feature = "websocket")] #[cfg(feature = "websocket")]
collect_txts(&mut ret, websocket_txt, is_c2s); collect_txts(&mut ret, websocket_txt, is_c2s);
#[cfg(feature = "tls")]
collect_srvs(&mut ret, starttls, XmppConnectionType::StartTLS); collect_srvs(&mut ret, starttls, XmppConnectionType::StartTLS);
#[cfg(feature = "tls")]
collect_srvs(&mut ret, direct_tls, XmppConnectionType::DirectTLS); collect_srvs(&mut ret, direct_tls, XmppConnectionType::DirectTLS);
#[cfg(feature = "quic")] #[cfg(feature = "quic")]
collect_srvs(&mut ret, quic, XmppConnectionType::QUIC); collect_srvs(&mut ret, quic, XmppConnectionType::QUIC);
@ -358,6 +370,7 @@ pub async fn get_xmpp_connections(domain: &str, is_c2s: bool) -> Result<(Vec<Xmp
if ret.is_empty() { if ret.is_empty() {
// default starttls ports // default starttls ports
#[cfg(feature = "tls")]
ret.push(XmppConnection { ret.push(XmppConnection {
priority: 0, priority: 0,
weight: 0, weight: 0,
@ -369,6 +382,7 @@ pub async fn get_xmpp_connections(domain: &str, is_c2s: bool) -> Result<(Vec<Xmp
ech: None, ech: None,
}); });
// by spec there are no default direct/quic ports, but we are going 443 // by spec there are no default direct/quic ports, but we are going 443
#[cfg(feature = "tls")]
ret.push(XmppConnection { ret.push(XmppConnection {
priority: 0, priority: 0,
weight: 0, weight: 0,
@ -409,6 +423,7 @@ pub async fn get_xmpp_connections(domain: &str, is_c2s: bool) -> Result<(Vec<Xmp
Ok((ret, cert_verifier)) Ok((ret, cert_verifier))
} }
#[cfg(feature = "outgoing")]
pub async fn srv_connect( pub async fn srv_connect(
domain: &str, domain: &str,
is_c2s: bool, is_c2s: bool,
@ -417,6 +432,14 @@ pub async fn srv_connect(
client_addr: &mut Context<'_>, client_addr: &mut Context<'_>,
config: OutgoingConfig, config: OutgoingConfig,
) -> Result<(StanzaWrite, StanzaRead, Vec<u8>)> { ) -> Result<(StanzaWrite, StanzaRead, Vec<u8>)> {
#[cfg(not(feature = "c2s-outgoing"))]
if is_c2s {
bail!("outgoing c2s connection but c2s-outgoing disabled at compile-time");
}
#[cfg(not(feature = "s2s-outgoing"))]
if !is_c2s {
bail!("outgoing s2s connection but s2s-outgoing disabled at compile-time");
}
let (srvs, cert_verifier) = get_xmpp_connections(domain, is_c2s).await?; let (srvs, cert_verifier) = get_xmpp_connections(domain, is_c2s).await?;
let config = config.with_custom_certificate_verifier(is_c2s, cert_verifier); let config = config.with_custom_certificate_verifier(is_c2s, cert_verifier);
for srv in srvs { for srv in srvs {
@ -448,7 +471,7 @@ pub async fn srv_connect(
#[cfg(not(feature = "websocket"))] #[cfg(not(feature = "websocket"))]
async fn collect_host_meta(ret: &mut Vec<XmppConnection>, sha256_pinnedpubkeys: &mut Vec<String>, domain: &str, is_c2s: bool) -> Result<Option<u16>> { async fn collect_host_meta(ret: &mut Vec<XmppConnection>, sha256_pinnedpubkeys: &mut Vec<String>, domain: &str, is_c2s: bool) -> Result<Option<u16>> {
collect_host_meta_json(ret, sha256_pinnedpubkeys, domain, is_c2s) collect_host_meta_json(ret, sha256_pinnedpubkeys, domain, is_c2s).await
} }
#[cfg(feature = "websocket")] #[cfg(feature = "websocket")]
@ -532,7 +555,7 @@ struct LinkCommon {
} }
impl LinkCommon { impl LinkCommon {
pub fn into_xmpp_connection(self, conn_type: XmppConnectionType, port: u16) -> Option<XmppConnection> { fn into_xmpp_connection(self, conn_type: XmppConnectionType, port: u16) -> Option<XmppConnection> {
if self.ips.is_empty() { if self.ips.is_empty() {
error!("invalid empty ips"); error!("invalid empty ips");
return None; return None;
@ -551,13 +574,18 @@ impl LinkCommon {
} }
impl Link { impl Link {
pub fn into_xmpp_connection(self, is_c2s: bool) -> Option<XmppConnection> { fn into_xmpp_connection(self, is_c2s: bool) -> Option<XmppConnection> {
use XmppConnectionType::*; use XmppConnectionType::*;
let (srv_is_c2s, port, link, conn_type) = match self { let (srv_is_c2s, port, link, conn_type) = match self {
#[cfg(feature = "tls")]
Link::DirectTLS { port, link } => (true, port, link, DirectTLS), Link::DirectTLS { port, link } => (true, port, link, DirectTLS),
#[cfg(feature = "quic")]
Link::Quic { port, link } => (true, port, link, QUIC), Link::Quic { port, link } => (true, port, link, QUIC),
#[cfg(feature = "tls")]
Link::S2SDirectTLS { port, link } => (false, port, link, DirectTLS), Link::S2SDirectTLS { port, link } => (false, port, link, DirectTLS),
#[cfg(feature = "quic")]
Link::S2SQuic { port, link } => (false, port, link, QUIC), Link::S2SQuic { port, link } => (false, port, link, QUIC),
#[cfg(feature = "websocket")]
Link::WebSocket { href, link } => { Link::WebSocket { href, link } => {
return if is_c2s { return if is_c2s {
let srv = wss_to_srv(&href, true)?; let srv = wss_to_srv(&href, true)?;
@ -570,6 +598,7 @@ impl Link {
None None
}; };
} }
#[cfg(feature = "websocket")]
Link::S2SWebSocket { href, link } => { Link::S2SWebSocket { href, link } => {
return if !is_c2s { return if !is_c2s {
let srv = wss_to_srv(&href, true)?; let srv = wss_to_srv(&href, true)?;
@ -579,7 +608,7 @@ impl Link {
}; };
} }
Link::Unknown => return None, _ => return None,
}; };
if srv_is_c2s == is_c2s { if srv_is_c2s == is_c2s {
@ -591,7 +620,7 @@ impl Link {
} }
impl HostMeta { impl HostMeta {
pub fn collect(self, ret: &mut Vec<XmppConnection>, sha256_pinnedpubkeys: &mut Vec<String>, is_c2s: bool) -> Option<u16> { fn collect(self, ret: &mut Vec<XmppConnection>, sha256_pinnedpubkeys: &mut Vec<String>, is_c2s: bool) -> Option<u16> {
for link in self.links { for link in self.links {
if let Some(srv) = link.into_xmpp_connection(is_c2s) { if let Some(srv) = link.into_xmpp_connection(is_c2s) {
ret.push(srv); ret.push(srv);
@ -666,7 +695,7 @@ async fn collect_host_meta_xml(ret: &mut Vec<XmppConnection>, domain: &str, is_c
} }
} }
pub async fn https_get<T: reqwest::IntoUrl>(url: T) -> reqwest::Result<reqwest::Response> { async fn https_get<T: reqwest::IntoUrl>(url: T) -> reqwest::Result<reqwest::Response> {
// todo: resolve URL with our resolver // todo: resolve URL with our resolver
reqwest::Client::builder().https_only(true).build()?.get(url).send().await reqwest::Client::builder().https_only(true).build()?.get(url).send().await
} }

View File

@ -3,8 +3,7 @@ use rustls::ServerConnection;
use std::convert::TryFrom; use std::convert::TryFrom;
use tokio::io::{AsyncBufReadExt, BufStream}; use tokio::io::{AsyncBufReadExt, BufStream};
#[cfg(any(feature = "incoming", feature = "outgoing"))] use tokio_rustls::{rustls::ServerName, TlsAcceptor};
use tokio_rustls::rustls::ServerName;
#[cfg(feature = "outgoing")] #[cfg(feature = "outgoing")]
pub async fn tls_connect(target: SocketAddr, server_name: &str, config: OutgoingVerifierConfig) -> Result<(StanzaWrite, StanzaRead)> { pub async fn tls_connect(target: SocketAddr, server_name: &str, config: OutgoingVerifierConfig) -> Result<(StanzaWrite, StanzaRead)> {
@ -59,6 +58,13 @@ pub async fn starttls_connect(target: SocketAddr, server_name: &str, stream_open
Ok((StanzaWrite::new(wrt), StanzaRead::new(rd))) Ok((StanzaWrite::new(wrt), StanzaRead::new(rd)))
} }
#[cfg(feature = "incoming")]
impl Config {
pub fn tls_acceptor(&self, cert_key: Arc<CertsKey>) -> Result<TlsAcceptor> {
Ok(TlsAcceptor::from(Arc::new(self.server_config(cert_key)?)))
}
}
#[cfg(feature = "incoming")] #[cfg(feature = "incoming")]
pub fn spawn_tls_listener(local_addr: SocketAddr, config: CloneableConfig, acceptor: TlsAcceptor) -> JoinHandle<Result<()>> { pub fn spawn_tls_listener(local_addr: SocketAddr, config: CloneableConfig, acceptor: TlsAcceptor) -> JoinHandle<Result<()>> {
tokio::spawn(async move { tokio::spawn(async move {
@ -159,8 +165,13 @@ async fn handle_tls_connection(mut stream: tokio::net::TcpStream, client_addr: &
// where we read the first stanza, where we are guaranteed the handshake is complete, but I can't // where we read the first stanza, where we are guaranteed the handshake is complete, but I can't
// do that without ignoring the lifetime and just pulling a C programmer and pinky promising to be // do that without ignoring the lifetime and just pulling a C programmer and pinky promising to be
// *very careful* that this reference doesn't outlive stream... // *very careful* that this reference doesn't outlive stream...
let server_connection: &'static ServerConnection = unsafe { std::mem::transmute(server_connection) }; #[cfg(feature = "s2s-incoming")]
let server_certs = ServerCerts::Tls(server_connection); let server_certs = {
let server_connection: &'static ServerConnection = unsafe { std::mem::transmute(server_connection) };
ServerCerts::Tls(server_connection)
};
#[cfg(not(feature = "s2s-incoming"))]
let server_certs = ();
#[cfg(not(feature = "websocket"))] #[cfg(not(feature = "websocket"))]
{ {

View File

@ -29,6 +29,7 @@ pub async fn incoming_websocket_connection(stream: Box<dyn AsyncReadAndWrite + U
Ok((StanzaRead::WebSocketRead(in_rd), StanzaWrite::WebSocketClientWrite(in_wr))) Ok((StanzaRead::WebSocketRead(in_rd), StanzaWrite::WebSocketClientWrite(in_wr)))
} }
#[cfg(feature = "incoming")]
pub async fn handle_websocket_connection( pub async fn handle_websocket_connection(
stream: Box<dyn AsyncReadAndWrite + Unpin + Send>, stream: Box<dyn AsyncReadAndWrite + Unpin + Send>,
config: CloneableConfig, config: CloneableConfig,