@ -8,9 +8,8 @@ use std::iter::Iterator;
use std ::net ::SocketAddr ;
use std ::path ::Path ;
use std ::sync ::{ Arc , RwLock } ;
use std ::time ::SystemTime ;
use die ::Die ;
use die ::{ die , Die } ;
use serde_derive ::Deserialize ;
@ -27,7 +26,7 @@ use rustls::{
#[ cfg(feature = " tokio-rustls " ) ]
use tokio_rustls ::{
webpki ::{ DnsNameRef , TlsServerTrustAnchors , TrustAnchor } ,
Tls Acceptor, Tls Connector,
Tls Connector,
} ;
use anyhow ::{ anyhow , bail , Result } ;
@ -42,7 +41,9 @@ mod quic;
#[ cfg(feature = " quic " ) ]
use crate ::quic ::* ;
#[ cfg(feature = " tls " ) ]
mod tls ;
#[ cfg(feature = " tls " ) ]
use crate ::tls ::* ;
#[ cfg(feature = " outgoing " ) ]
@ -50,9 +51,9 @@ mod outgoing;
#[ cfg(feature = " outgoing " ) ]
use crate ::outgoing ::* ;
#[ cfg( feature = " outgoing " )]
#[ cfg( any(feature = " s2s-incoming " , feature = " outgoing " ) )]
mod srv ;
#[ cfg( feature = " outgoing " )]
#[ cfg( any(feature = " s2s-incoming " , feature = " outgoing " ) )]
use crate ::srv ::* ;
#[ cfg(feature = " websocket " ) ]
@ -60,7 +61,9 @@ mod websocket;
#[ cfg(feature = " websocket " ) ]
use crate ::websocket ::* ;
#[ cfg(any(feature = " s2s-incoming " , feature = " outgoing " )) ]
mod verify ;
#[ cfg(any(feature = " s2s-incoming " , feature = " outgoing " )) ]
use crate ::verify ::* ;
mod in_out ;
@ -92,6 +95,7 @@ lazy_static::lazy_static! {
} ;
}
#[ cfg(any(feature = " rustls-native-certs " , feature = " webpki-roots " )) ]
pub fn root_cert_store ( ) -> rustls ::RootCertStore {
use rustls ::{ OwnedTrustAnchor , RootCertStore } ;
let mut root_cert_store = RootCertStore ::empty ( ) ;
@ -104,43 +108,45 @@ pub fn root_cert_store() -> rustls::RootCertStore {
root_cert_store
}
#[ derive(Deserialize )]
#[ derive(Deserialize , Default )]
struct Config {
tls_key : String ,
tls_cert : String ,
incoming_listen : Option< Vec< String > > ,
quic_listen : Option< Vec< String > > ,
outgoing_listen : Option< Vec< String > > ,
incoming_listen : Vec< String > ,
quic_listen : Vec< String > ,
outgoing_listen : Vec< String > ,
max_stanza_size_bytes : usize ,
s2s_target : SocketAddr ,
c2s_target : SocketAddr ,
s2s_target : Option < SocketAddr > ,
c2s_target : Option < SocketAddr > ,
proxy : bool ,
#[ cfg(feature = " logging " ) ]
log_level : Option < String > ,
#[ cfg(feature = " logging " ) ]
log_style : Option < String > ,
}
#[ derive(Clone) ]
pub struct CloneableConfig {
max_stanza_size_bytes : usize ,
s2s_target : SocketAddr ,
c2s_target : SocketAddr ,
#[ cfg(feature = " s2s-incoming " ) ]
s2s_target : Option < SocketAddr > ,
#[ cfg(feature = " c2s-incoming " ) ]
c2s_target : Option < SocketAddr > ,
proxy : bool ,
}
struct CertsKey {
#[ cfg(feature = " rustls-pemfile " ) ]
inner : Result < RwLock < Arc < rustls ::sign ::CertifiedKey > > > ,
}
impl CertsKey {
fn new ( cert_key: Result < rustls ::sign ::CertifiedKey > ) -> Self {
fn new ( main_config: & Config ) -> Self {
CertsKey {
inner : cert_key . map ( | c | RwLock ::new ( Arc ::new ( c ) ) ) ,
#[ cfg(feature = " rustls-pemfile " ) ]
inner : main_config . certs_key ( ) . map ( | c | RwLock ::new ( Arc ::new ( c ) ) ) ,
}
}
#[ cfg( unix)]
#[ cfg( all( unix, any(feature = " incoming " , feature = " s2s-outgoing " )) )]
fn spawn_refresh_task ( & ' static self , cfg_path : OsString ) -> Option < JoinHandle < Result < ( ) > > > {
if self . inner . is_err ( ) {
None
@ -169,12 +175,14 @@ impl CertsKey {
}
}
#[ cfg(feature = " rustls-pemfile " ) ]
impl rustls ::server ::ResolvesServerCert for CertsKey {
fn resolve ( & self , _ : rustls ::server ::ClientHello ) -> Option < Arc < rustls ::sign ::CertifiedKey > > {
self . inner . as_ref ( ) . map ( | rwl | rwl . read ( ) . expect ( "CertKey poisoned?" ) . clone ( ) ) . ok ( )
}
}
#[ cfg(feature = " rustls-pemfile " ) ]
impl rustls ::client ::ResolvesClientCert for CertsKey {
fn resolve ( & self , _ : & [ & [ u8 ] ] , _ : & [ SignatureScheme ] ) -> Option < Arc < CertifiedKey > > {
self . inner . as_ref ( ) . map ( | rwl | rwl . read ( ) . expect ( "CertKey poisoned?" ) . clone ( ) ) . ok ( )
@ -185,6 +193,17 @@ impl rustls::client::ResolvesClientCert for CertsKey {
}
}
#[ cfg(not(feature = " rustls-pemfile " )) ]
impl rustls ::client ::ResolvesClientCert for CertsKey {
fn resolve ( & self , _ : & [ & [ u8 ] ] , _ : & [ SignatureScheme ] ) -> Option < Arc < CertifiedKey > > {
None
}
fn has_certs ( & self ) -> bool {
false
}
}
impl Config {
fn parse < P : AsRef < Path > > ( path : P ) -> Result < Config > {
let mut f = File ::open ( path ) ? ;
@ -196,7 +215,9 @@ impl Config {
fn get_cloneable_cfg ( & self ) -> CloneableConfig {
CloneableConfig {
max_stanza_size_bytes : self . max_stanza_size_bytes ,
#[ cfg(feature = " s2s-incoming " ) ]
s2s_target : self . s2s_target ,
#[ cfg(feature = " c2s-incoming " ) ]
c2s_target : self . c2s_target ,
proxy : self . proxy ,
}
@ -204,6 +225,7 @@ impl Config {
#[ cfg(feature = " outgoing " ) ]
fn get_outgoing_cfg ( & self , certs_key : Arc < CertsKey > ) -> OutgoingConfig {
#[ cfg(feature = " rustls-pemfile " ) ]
if let Err ( e ) = & certs_key . inner {
debug ! ( "invalid key/cert for s2s client auth: {}" , e ) ;
}
@ -243,21 +265,18 @@ impl Config {
bail ! ( "invalid cert/key: {}" , e ) ;
}
let mut config = ServerConfig ::builder ( )
. with_safe_defaults ( )
. with_client_cert_verifier ( Arc ::new ( AllowAnonymousOrAnyCert ) )
. with_cert_resolver ( certs_key ) ;
let config = ServerConfig ::builder ( ) . with_safe_defaults ( ) ;
#[ cfg(feature = " s2s " ) ]
let config = config . with_client_cert_verifier ( Arc ::new ( AllowAnonymousOrAnyCert ) ) ;
#[ cfg(not(feature = " s2s " )) ]
let config = config . with_no_client_auth ( ) ;
let mut config = config . with_cert_resolver ( certs_key ) ;
// todo: will connecting without alpn work then?
config . alpn_protocols . push ( ALPN_XMPP_CLIENT . to_vec ( ) ) ;
config . alpn_protocols . push ( ALPN_XMPP_SERVER . to_vec ( ) ) ;
Ok ( config )
}
#[ cfg(feature = " incoming " ) ]
fn tls_acceptor ( & self , cert_key : Arc < CertsKey > ) -> Result < TlsAcceptor > {
Ok ( TlsAcceptor ::from ( Arc ::new ( self . server_config ( cert_key ) ? ) ) )
}
}
#[ derive(Clone) ]
@ -310,11 +329,13 @@ pub struct OutgoingVerifierConfig {
pub connector : TlsConnector ,
}
#[ cfg(feature = " incoming " ) ]
async fn shuffle_rd_wr ( in_rd : StanzaRead , in_wr : StanzaWrite , config : CloneableConfig , server_certs : ServerCerts , local_addr : SocketAddr , client_addr : & mut Context < ' _ > ) -> Result < ( ) > {
let filter = StanzaFilter ::new ( config . max_stanza_size_bytes ) ;
shuffle_rd_wr_filter ( in_rd , in_wr , config , server_certs , local_addr , client_addr , filter ) . await
}
#[ cfg(feature = " incoming " ) ]
async fn shuffle_rd_wr_filter (
mut in_rd : StanzaRead ,
mut in_wr : StanzaWrite ,
@ -328,26 +349,30 @@ async fn shuffle_rd_wr_filter(
let ( stream_open , is_c2s ) = stream_preamble ( & mut in_rd , & mut in_wr , client_addr . log_from ( ) , & mut in_filter ) . await ? ;
client_addr . set_c2s_stream_open ( is_c2s , & stream_open ) ;
trace ! (
"{} connected: sni: {:?}, alpn: {:?}, tls-not-quic: {}" ,
client_addr . log_from ( ) ,
server_certs . sni ( ) ,
server_certs . alpn ( ) . map ( | a | String ::from_utf8_lossy ( & a ) . to_string ( ) ) ,
server_certs . is_tls ( ) ,
) ;
if ! is_c2s {
// for s2s we need this
let domain = stream_open
. extract_between ( b" from=' " , b" ' " )
. or_else ( | _ | stream_open . extract_between ( b" from= \" " , b" \" " ) )
. and_then ( | b | Ok ( std ::str ::from_utf8 ( b ) ? ) ) ? ;
let ( _ , cert_verifier ) = get_xmpp_connections ( domain , is_c2s ) . await ? ;
let certs = server_certs . peer_certificates ( ) . ok_or_else ( | | anyhow ! ( "no client cert auth for s2s incoming from {}" , domain ) ) ? ;
// todo: send stream error saying cert is invalid
cert_verifier . verify_cert ( & certs [ 0 ] , & certs [ 1 .. ] , SystemTime ::now ( ) ) ? ;
#[ cfg(feature = " s2s-incoming " ) ]
{
trace ! (
"{} connected: sni: {:?}, alpn: {:?}, tls-not-quic: {}" ,
client_addr . log_from ( ) ,
server_certs . sni ( ) ,
server_certs . alpn ( ) . map ( | a | String ::from_utf8_lossy ( & a ) . to_string ( ) ) ,
server_certs . is_tls ( ) ,
) ;
if ! is_c2s {
// for s2s we need this
use std ::time ::SystemTime ;
let domain = stream_open
. extract_between ( b" from=' " , b" ' " )
. or_else ( | _ | stream_open . extract_between ( b" from= \" " , b" \" " ) )
. and_then ( | b | Ok ( std ::str ::from_utf8 ( b ) ? ) ) ? ;
let ( _ , cert_verifier ) = get_xmpp_connections ( domain , is_c2s ) . await ? ;
let certs = server_certs . peer_certificates ( ) . ok_or_else ( | | anyhow ! ( "no client cert auth for s2s incoming from {}" , domain ) ) ? ;
// todo: send stream error saying cert is invalid
cert_verifier . verify_cert ( & certs [ 0 ] , & certs [ 1 .. ] , SystemTime ::now ( ) ) ? ;
}
drop ( server_certs ) ;
}
drop ( server_certs ) ;
let ( out_rd , out_wr ) = open_incoming ( & config , local_addr , client_addr , & stream_open , is_c2s , & mut in_filter ) . await ? ;
drop ( stream_open ) ;
@ -407,6 +432,7 @@ async fn shuffle_rd_wr_filter_only(
Ok ( ( ) )
}
#[ cfg(feature = " incoming " ) ]
async fn open_incoming (
config : & CloneableConfig ,
local_addr : SocketAddr ,
@ -415,7 +441,18 @@ async fn open_incoming(
is_c2s : bool ,
in_filter : & mut StanzaFilter ,
) -> Result < ( ReadHalf < tokio ::net ::TcpStream > , WriteHalf < tokio ::net ::TcpStream > ) > {
let target = if is_c2s { config . c2s_target } else { config . s2s_target } ;
let target = if is_c2s {
#[ cfg(not(feature = " c2s-incoming " )) ]
bail ! ( "incoming c2s connection but lacking compile-time support" ) ;
#[ cfg(feature = " c2s-incoming " ) ]
config . c2s_target
} else {
#[ cfg(not(feature = " s2s-incoming " )) ]
bail ! ( "incoming s2s connection but lacking compile-time support" ) ;
#[ cfg(feature = " s2s-incoming " ) ]
config . s2s_target
}
. ok_or_else ( | | anyhow ! ( "incoming connection but `{}_target` not defined" , c2s ( is_c2s ) ) ) ? ;
client_addr . set_to_addr ( target ) ;
let out_stream = tokio ::net ::TcpStream ::connect ( target ) . await ? ;
@ -468,7 +505,12 @@ pub async fn stream_preamble(in_rd: &mut StanzaRead, in_wr: &mut StanzaWrite, cl
#[ tokio::main ]
//#[tokio::main(flavor = "multi_thread", worker_threads = 10)]
async fn main ( ) {
let cfg_path = std ::env ::args_os ( ) . nth ( 1 ) . unwrap_or_else ( | | OsString ::from ( "/etc/xmpp-proxy/xmpp-proxy.toml" ) ) ;
let cfg_path = std ::env ::args_os ( ) . nth ( 1 ) ;
if cfg_path = = Some ( OsString ::from ( "-v" ) ) {
include! ( concat! ( env! ( "OUT_DIR" ) , "/version.rs" ) ) ;
die ! ( 0 ) ;
}
let cfg_path = cfg_path . unwrap_or_else ( | | OsString ::from ( "/etc/xmpp-proxy/xmpp-proxy.toml" ) ) ;
let main_config = Config ::parse ( & cfg_path ) . die ( "invalid config file" ) ;
#[ cfg(feature = " logging " ) ]
@ -486,34 +528,59 @@ async fn main() {
// todo: config for this: builder.format_timestamp(None);
builder . init ( ) ;
}
#[ cfg(not(feature = " logging " )) ]
if main_config . log_level . is_some ( ) | | main_config . log_style . is_some ( ) {
die ! ( "log_level or log_style defined in config but logging disabled at compile-time" ) ;
}
let config = main_config . get_cloneable_cfg ( ) ;
let certs_key = Arc ::new ( CertsKey ::new ( main_config . certs_key ( ) ) ) ;
let certs_key = Arc ::new ( CertsKey ::new ( & main_config ) ) ;
let mut handles : Vec < JoinHandle < Result < ( ) > > > = Vec ::new ( ) ;
#[ cfg(feature = " incoming " ) ]
if let Some ( ref listeners ) = main_config . incoming_listen {
let acceptor = main_config . tls_acceptor ( certs_key . clone ( ) ) . die ( "invalid cert/key ?" ) ;
for listener in listeners {
handles . push ( spawn_tls_listener ( listener . parse ( ) . die ( "invalid listener address" ) , config . clone ( ) , acceptor . clone ( ) ) ) ;
if ! main_config . incoming_listen . is_empty ( ) {
#[ cfg(all(any(feature = " tls " , feature = " websocket " ), feature = " incoming " )) ]
{
if main_config . c2s_target . is_none ( ) & & main_config . s2s_target . is_none ( ) {
die ! ( "one of c2s_target/s2s_target must be defined if incoming_listen is non-empty" ) ;
}
let acceptor = main_config . tls_acceptor ( certs_key . clone ( ) ) . die ( "invalid cert/key ?" ) ;
for listener in main_config . incoming_listen . iter ( ) {
handles . push ( spawn_tls_listener ( listener . parse ( ) . die ( "invalid listener address" ) , config . clone ( ) , acceptor . clone ( ) ) ) ;
}
}
#[ cfg(not(all(any(feature = " tls " , feature = " websocket " ), feature = " incoming " ))) ]
die ! ( "incoming_listen non-empty but (tls or websocket) or (s2s-incoming and c2s-incoming) disabled at compile-time" ) ;
}
#[ cfg(all(feature = " quic " , feature = " incoming " )) ]
if let Some ( ref listeners ) = main_config . quic_listen {
let quic_config = main_config . quic_server_config ( certs_key . clone ( ) ) . die ( "invalid cert/key ?" ) ;
for listener in listeners {
handles . push ( spawn_quic_listener ( listener . parse ( ) . die ( "invalid listener address" ) , config . clone ( ) , quic_config . clone ( ) ) ) ;
if ! main_config . quic_listen . is_empty ( ) {
#[ cfg(all(feature = " quic " , feature = " incoming " )) ]
{
if main_config . c2s_target . is_none ( ) & & main_config . s2s_target . is_none ( ) {
die ! ( "one of c2s_target/s2s_target must be defined if quic_listen is non-empty" ) ;
}
let quic_config = main_config . quic_server_config ( certs_key . clone ( ) ) . die ( "invalid cert/key ?" ) ;
for listener in main_config . quic_listen . iter ( ) {
handles . push ( spawn_quic_listener ( listener . parse ( ) . die ( "invalid listener address" ) , config . clone ( ) , quic_config . clone ( ) ) ) ;
}
}
#[ cfg(not(all(feature = " quic " , feature = " incoming " ))) ]
die ! ( "quic_listen non-empty but quic or (s2s-incoming and c2s-incoming) disabled at compile-time" ) ;
}
#[ cfg(feature = " outgoing " ) ]
if let Some ( ref listeners ) = main_config . outgoing_listen {
let outgoing_cfg = main_config . get_outgoing_cfg ( certs_key . clone ( ) ) ;
for listener in listeners {
handles . push ( spawn_outgoing_listener ( listener . parse ( ) . die ( "invalid listener address" ) , outgoing_cfg . clone ( ) ) ) ;
if ! main_config . outgoing_listen . is_empty ( ) {
#[ cfg(feature = " outgoing " ) ]
{
let outgoing_cfg = main_config . get_outgoing_cfg ( certs_key . clone ( ) ) ;
for listener in main_config . outgoing_listen . iter ( ) {
handles . push ( spawn_outgoing_listener ( listener . parse ( ) . die ( "invalid listener address" ) , outgoing_cfg . clone ( ) ) ) ;
}
}
#[ cfg(not(feature = " outgoing " )) ]
die ! ( "outgoing_listen non-empty but c2s-outgoing and s2s-outgoing disabled at compile-time" ) ;
}
if handles . is_empty ( ) {
die ! ( "all of incoming_listen, quic_listen, outgoing_listen empty, nothing to do, exiting..." ) ;
}
#[ cfg(unix) ]
#[ cfg( all( unix, any(feature = " incoming " , feature = " s2s-outgoing " )) )]
if let Some ( refresh_task ) = Box ::leak ( Box ::new ( certs_key . clone ( ) ) ) . spawn_refresh_task ( cfg_path ) {
handles . push ( refresh_task ) ;
}