use crate::{ common::{c2s, certs_key::CertsKey, shuffle_rd_wr_filter_only, stream_preamble, to_str, ALPN_XMPP_CLIENT, ALPN_XMPP_SERVER}, context::Context, in_out::{StanzaRead, StanzaWrite}, slicesubsequence::SliceSubsequence, stanzafilter::StanzaFilter, }; use anyhow::{anyhow, bail, Result}; use log::trace; use rustls::{Certificate, ServerConfig, ServerConnection}; use std::{io::Write, net::SocketAddr, sync::Arc}; use tokio::io::{AsyncWriteExt, ReadHalf, WriteHalf}; #[derive(Clone)] pub struct CloneableConfig { pub max_stanza_size_bytes: usize, #[cfg(feature = "s2s-incoming")] pub s2s_target: Option, #[cfg(feature = "c2s-incoming")] pub c2s_target: Option, pub proxy: bool, } pub fn server_config(certs_key: Arc) -> Result { if let Err(e) = &certs_key.inner { bail!("invalid cert/key: {}", e); } let config = ServerConfig::builder().with_safe_defaults(); #[cfg(feature = "s2s")] let config = config.with_client_cert_verifier(Arc::new(crate::verify::AllowAnonymousOrAnyCert)); #[cfg(not(feature = "s2s"))] let config = config.with_no_client_auth(); let mut config = config.with_cert_resolver(certs_key); // todo: will connecting without alpn work then? config.alpn_protocols.push(ALPN_XMPP_CLIENT.to_vec()); config.alpn_protocols.push(ALPN_XMPP_SERVER.to_vec()); Ok(config) } #[cfg(not(feature = "s2s-incoming"))] pub type ServerCerts = (); #[cfg(feature = "s2s-incoming")] #[derive(Clone)] pub enum ServerCerts { Tls(&'static ServerConnection), #[cfg(feature = "quic")] Quic(Arc), } #[cfg(feature = "s2s-incoming")] impl ServerCerts { pub fn peer_certificates(&self) -> Option> { match self { ServerCerts::Tls(c) => c.peer_certificates().map(|c| c.to_vec()), #[cfg(feature = "quic")] ServerCerts::Quic(c) => c.peer_identity().and_then(|v| v.downcast::>().ok()).map(|v| v.to_vec()), } } pub fn sni(&self) -> Option { match self { ServerCerts::Tls(c) => c.sni_hostname().map(|s| s.to_string()), #[cfg(feature = "quic")] ServerCerts::Quic(c) => c.handshake_data().and_then(|v| v.downcast::().ok()).and_then(|h| h.server_name), } } pub fn alpn(&self) -> Option> { match self { ServerCerts::Tls(c) => c.alpn_protocol().map(|s| s.to_vec()), #[cfg(feature = "quic")] ServerCerts::Quic(c) => c.handshake_data().and_then(|v| v.downcast::().ok()).and_then(|h| h.protocol), } } pub fn is_tls(&self) -> bool { match self { ServerCerts::Tls(_) => true, #[cfg(feature = "quic")] ServerCerts::Quic(_) => false, } } } pub async fn shuffle_rd_wr(in_rd: StanzaRead, in_wr: StanzaWrite, config: CloneableConfig, server_certs: ServerCerts, local_addr: SocketAddr, client_addr: &mut Context<'_>) -> Result<()> { let filter = StanzaFilter::new(config.max_stanza_size_bytes); shuffle_rd_wr_filter(in_rd, in_wr, config, server_certs, local_addr, client_addr, filter).await } pub async fn shuffle_rd_wr_filter( mut in_rd: StanzaRead, mut in_wr: StanzaWrite, config: CloneableConfig, server_certs: ServerCerts, local_addr: SocketAddr, client_addr: &mut Context<'_>, mut in_filter: StanzaFilter, ) -> Result<()> { // now read to figure out client vs server let (stream_open, is_c2s) = stream_preamble(&mut in_rd, &mut in_wr, client_addr.log_from(), &mut in_filter).await?; client_addr.set_c2s_stream_open(is_c2s, &stream_open); #[cfg(feature = "s2s-incoming")] { trace!( "{} connected: sni: {:?}, alpn: {:?}, tls-not-quic: {}", client_addr.log_from(), server_certs.sni(), server_certs.alpn().map(|a| String::from_utf8_lossy(&a).to_string()), server_certs.is_tls(), ); if !is_c2s { // for s2s we need this use std::time::SystemTime; let domain = stream_open .extract_between(b" from='", b"'") .or_else(|_| stream_open.extract_between(b" from=\"", b"\"")) .and_then(|b| Ok(std::str::from_utf8(b)?))?; let (_, cert_verifier) = crate::srv::get_xmpp_connections(domain, is_c2s).await?; let certs = server_certs.peer_certificates().ok_or_else(|| anyhow!("no client cert auth for s2s incoming from {}", domain))?; // todo: send stream error saying cert is invalid cert_verifier.verify_cert(&certs[0], &certs[1..], SystemTime::now())?; } drop(server_certs); } let (out_rd, out_wr) = open_incoming(&config, local_addr, client_addr, &stream_open, is_c2s, &mut in_filter).await?; drop(stream_open); shuffle_rd_wr_filter_only( in_rd, in_wr, StanzaRead::new(out_rd), StanzaWrite::new(out_wr), is_c2s, config.max_stanza_size_bytes, client_addr, in_filter, ) .await } async fn open_incoming( config: &CloneableConfig, local_addr: SocketAddr, client_addr: &mut Context<'_>, stream_open: &[u8], is_c2s: bool, in_filter: &mut StanzaFilter, ) -> Result<(ReadHalf, WriteHalf)> { let target = if is_c2s { #[cfg(not(feature = "c2s-incoming"))] bail!("incoming c2s connection but lacking compile-time support"); #[cfg(feature = "c2s-incoming")] config.c2s_target } else { #[cfg(not(feature = "s2s-incoming"))] bail!("incoming s2s connection but lacking compile-time support"); #[cfg(feature = "s2s-incoming")] config.s2s_target } .ok_or_else(|| anyhow!("incoming connection but `{}_target` not defined", c2s(is_c2s)))?; client_addr.set_to_addr(target); let out_stream = tokio::net::TcpStream::connect(target).await?; let (out_rd, mut out_wr) = tokio::io::split(out_stream); if config.proxy { /* https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt PROXY TCP4 255.255.255.255 255.255.255.255 65535 65535\r\n PROXY TCP6 ffff:f...f:ffff ffff:f...f:ffff 65535 65535\r\n PROXY TCP6 SOURCE_IP DEST_IP SOURCE_PORT DEST_PORT\r\n */ // tokio AsyncWrite doesn't have write_fmt so have to go through this buffer for some crazy reason //write!(out_wr, "PROXY TCP{} {} {} {} {}\r\n", if client_addr.is_ipv4() { '4' } else {'6' }, client_addr.ip(), local_addr.ip(), client_addr.port(), local_addr.port())?; write!( &mut in_filter.buf[0..], "PROXY TCP{} {} {} {} {}\r\n", if client_addr.client_addr().is_ipv4() { '4' } else { '6' }, client_addr.client_addr().ip(), local_addr.ip(), client_addr.client_addr().port(), local_addr.port() )?; let end_idx = &(&in_filter.buf[0..]).first_index_of(b"\n")? + 1; trace!("{} '{}'", client_addr.log_from(), to_str(&in_filter.buf[0..end_idx])); out_wr.write_all(&in_filter.buf[0..end_idx]).await?; } trace!("{} '{}'", client_addr.log_from(), to_str(stream_open)); out_wr.write_all(stream_open).await?; out_wr.flush().await?; Ok((out_rd, out_wr)) }