From 37df79a9331575274d698edc344f380e02f933e4 Mon Sep 17 00:00:00 2001 From: moparisthebest Date: Fri, 5 Jan 2024 01:59:26 -0500 Subject: [PATCH] Re-factor ServerCert --- .gitignore | 1 + src/common/incoming.rs | 76 +++++++++++++++++++++++------------------- src/quic/incoming.rs | 2 +- src/tls/incoming.rs | 18 +++------- xmpp-proxy.toml | 4 +-- 5 files changed, 50 insertions(+), 51 deletions(-) diff --git a/.gitignore b/.gitignore index ae6e578..f3ef7a0 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ fuzz/target/ *.txt conflict/ +*.test.toml diff --git a/src/common/incoming.rs b/src/common/incoming.rs index e304ab5..281c5ee 100644 --- a/src/common/incoming.rs +++ b/src/common/incoming.rs @@ -44,55 +44,63 @@ pub type ServerCerts = (); #[cfg(any(feature = "s2s-incoming", feature = "webtransport"))] #[derive(Clone)] -pub enum ServerCerts { - Tls(&'static ServerConnection), - #[cfg(feature = "quic")] - Quic(Option>, Option, Option>), // todo: wrap this in arc or something now +pub struct ServerCerts { + inner: Arc, + is_tls: bool, } #[cfg(any(feature = "s2s-incoming", feature = "webtransport"))] -impl ServerCerts { - #[cfg(feature = "quic")] - pub fn quic(conn: &quinn::Connection) -> ServerCerts { - let certs = conn.peer_identity().and_then(|v| v.downcast::>().ok()).map(|v| v.to_vec()); +struct InnerServerCerts { + peer_certificates: Option>, + sni: Option, + alpn: Option>, +} + +#[cfg(any(feature = "s2s-incoming", feature = "webtransport"))] +impl From<&ServerConnection> for ServerCerts { + fn from(conn: &ServerConnection) -> Self { + let peer_certificates = conn.peer_certificates().map(|c| c.to_vec()); + let sni = conn.server_name().map(|s| s.to_string()); + let alpn = conn.alpn_protocol().map(|s| s.to_vec()); + Self { + inner: InnerServerCerts { peer_certificates, sni, alpn }.into(), + is_tls: true, + } + } +} + +#[cfg(all(feature = "quic", any(feature = "s2s-incoming", feature = "webtransport")))] +impl From<&quinn::Connection> for ServerCerts { + fn from(conn: &quinn::Connection) -> Self { + let peer_certificates = conn.peer_identity().and_then(|v| v.downcast::>().ok()).map(|v| v.to_vec()); let (sni, alpn) = conn .handshake_data() .and_then(|v| v.downcast::().ok()) .map(|h| (h.server_name, h.protocol)) .unwrap_or_default(); - ServerCerts::Quic(certs, sni, alpn) - } - - pub fn peer_certificates(&self) -> Option> { - match self { - ServerCerts::Tls(c) => c.peer_certificates().map(|c| c.to_vec()), - #[cfg(feature = "quic")] - ServerCerts::Quic(certs, _, _) => certs.clone(), + Self { + inner: InnerServerCerts { peer_certificates, sni, alpn }.into(), + is_tls: false, } } +} - pub fn sni(&self) -> Option { - match self { - ServerCerts::Tls(c) => c.server_name().map(|s| s.to_string()), - #[cfg(feature = "quic")] - ServerCerts::Quic(_, sni, _) => sni.clone(), - } +#[cfg(any(feature = "s2s-incoming", feature = "webtransport"))] +impl ServerCerts { + pub fn peer_certificates(&self) -> Option<&Vec> { + self.inner.peer_certificates.as_ref() } - pub fn alpn(&self) -> Option> { - match self { - ServerCerts::Tls(c) => c.alpn_protocol().map(|s| s.to_vec()), - #[cfg(feature = "quic")] - ServerCerts::Quic(_, _, alpn) => alpn.clone(), - } + pub fn sni(&self) -> Option<&str> { + self.inner.sni.as_deref() + } + + pub fn alpn(&self) -> Option<&Vec> { + self.inner.alpn.as_ref() } pub fn is_tls(&self) -> bool { - match self { - ServerCerts::Tls(_) => true, - #[cfg(feature = "quic")] - ServerCerts::Quic(_, _, _) => false, - } + self.is_tls } } @@ -120,7 +128,7 @@ pub async fn shuffle_rd_wr_filter( "{} connected: sni: {:?}, alpn: {:?}, tls-not-quic: {}", client_addr.log_from(), server_certs.sni(), - server_certs.alpn().map(|a| String::from_utf8_lossy(&a).to_string()), + server_certs.alpn().map(|a| String::from_utf8_lossy(a).to_string()), server_certs.is_tls(), ); diff --git a/src/quic/incoming.rs b/src/quic/incoming.rs index 2b7cd6c..2a7510b 100644 --- a/src/quic/incoming.rs +++ b/src/quic/incoming.rs @@ -40,7 +40,7 @@ fn internal_spawn_quic_listener(incoming: Endpoint, local_addr: SocketAddr, conf #[cfg(any(feature = "s2s-incoming", feature = "webtransport"))] let server_certs = { - let server_certs = ServerCerts::quic(&new_conn); + let server_certs = ServerCerts::from(&new_conn); #[cfg(feature = "webtransport")] if server_certs.alpn().map(|a| a == webtransport_quinn::ALPN).unwrap_or(false) { return crate::webtransport::incoming::handle_webtransport_session(new_conn, config, server_certs, local_addr, client_addr).await; diff --git a/src/tls/incoming.rs b/src/tls/incoming.rs index 5eca4fa..b527207 100644 --- a/src/tls/incoming.rs +++ b/src/tls/incoming.rs @@ -12,7 +12,7 @@ use crate::{ }; use anyhow::{bail, Result}; use log::{error, info, trace, warn}; -use rustls::{ServerConfig, ServerConnection}; +use rustls::ServerConfig; use std::{net::SocketAddr, sync::Arc}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt, BufStream}, @@ -117,7 +117,7 @@ pub async fn handle_tls_connection(mut stream: S, cl // until we get to the first byte of the TLS handshake... while stream.first_bytes_match(&mut in_filter.buf[0..1], |p| p[0] != 0x16).await? { warn!("{} buggy software connecting, sent byte after (mut stream: S, cl let stream = acceptor.accept(stream).await?; let (_, server_connection) = stream.get_ref(); - // todo: find better way to do this, might require different tokio_rustls API, the problem is I can't hold this - // past stream.into() below, and I can't get it back out after, now I *could* read sni+alpn+peer_certs - // *here* instead and pass them on, but since I haven't read anything from the stream yet, I'm - // not guaranteed that the handshake is complete and these are available, yes I can call is_handshaking() - // but there is no async API to complete the handshake, so I really need to pass it down to under - // where we read the first stanza, where we are guaranteed the handshake is complete, but I can't - // do that without ignoring the lifetime and just pulling a C programmer and pinky promising to be - // *very careful* that this reference doesn't outlive stream... #[cfg(any(feature = "s2s-incoming", feature = "webtransport"))] - let server_certs = { - let server_connection: &'static ServerConnection = unsafe { std::mem::transmute(server_connection) }; - ServerCerts::Tls(server_connection) - }; + let server_certs = ServerCerts::from(server_connection); + #[cfg(not(any(feature = "s2s-incoming", feature = "webtransport")))] let server_certs = (); diff --git a/xmpp-proxy.toml b/xmpp-proxy.toml index 8884b2b..4ab1e77 100644 --- a/xmpp-proxy.toml +++ b/xmpp-proxy.toml @@ -1,8 +1,8 @@ # interfaces to listen for reverse proxy STARTTLS/Direct TLS/TLS WebSocket (wss) XMPP connections on, should be open to the internet -incoming_listen = [ "0.0.0.0:5222", "0.0.0.0:5269", "0.0.0.0:443" ] +incoming_listen = [ "[::]:5222", "[::]:5269", "[::]:443" ] # interfaces to listen for reverse proxy QUIC/WebTransport XMPP connections on, should be open to the internet -quic_listen = [ "0.0.0.0:443" ] +quic_listen = [ "[::]:443" ] # interfaces to listen for outgoing proxy TCP or WebSocket XMPP connections on, should be localhost or a path for a unix socket outgoing_listen = [ "127.0.0.1:15270" ]