Re-factor ServerCert
All checks were successful
moparisthebest/xmpp-proxy/pipeline/head This commit looks good

This commit is contained in:
Travis Burtrum 2024-01-05 01:59:26 -05:00
parent bf6500538e
commit 37df79a933
Signed by: moparisthebest
GPG Key ID: 88C93BFE27BC8229
5 changed files with 50 additions and 51 deletions

1
.gitignore vendored
View File

@ -8,3 +8,4 @@
fuzz/target/
*.txt
conflict/
*.test.toml

View File

@ -44,55 +44,63 @@ pub type ServerCerts = ();
#[cfg(any(feature = "s2s-incoming", feature = "webtransport"))]
#[derive(Clone)]
pub enum ServerCerts {
Tls(&'static ServerConnection),
#[cfg(feature = "quic")]
Quic(Option<Vec<Certificate>>, Option<String>, Option<Vec<u8>>), // todo: wrap this in arc or something now
pub struct ServerCerts {
inner: Arc<InnerServerCerts>,
is_tls: bool,
}
#[cfg(any(feature = "s2s-incoming", feature = "webtransport"))]
impl ServerCerts {
#[cfg(feature = "quic")]
pub fn quic(conn: &quinn::Connection) -> ServerCerts {
let certs = conn.peer_identity().and_then(|v| v.downcast::<Vec<Certificate>>().ok()).map(|v| v.to_vec());
struct InnerServerCerts {
peer_certificates: Option<Vec<Certificate>>,
sni: Option<String>,
alpn: Option<Vec<u8>>,
}
#[cfg(any(feature = "s2s-incoming", feature = "webtransport"))]
impl From<&ServerConnection> for ServerCerts {
fn from(conn: &ServerConnection) -> Self {
let peer_certificates = conn.peer_certificates().map(|c| c.to_vec());
let sni = conn.server_name().map(|s| s.to_string());
let alpn = conn.alpn_protocol().map(|s| s.to_vec());
Self {
inner: InnerServerCerts { peer_certificates, sni, alpn }.into(),
is_tls: true,
}
}
}
#[cfg(all(feature = "quic", any(feature = "s2s-incoming", feature = "webtransport")))]
impl From<&quinn::Connection> for ServerCerts {
fn from(conn: &quinn::Connection) -> Self {
let peer_certificates = conn.peer_identity().and_then(|v| v.downcast::<Vec<Certificate>>().ok()).map(|v| v.to_vec());
let (sni, alpn) = conn
.handshake_data()
.and_then(|v| v.downcast::<quinn::crypto::rustls::HandshakeData>().ok())
.map(|h| (h.server_name, h.protocol))
.unwrap_or_default();
ServerCerts::Quic(certs, sni, alpn)
}
pub fn peer_certificates(&self) -> Option<Vec<Certificate>> {
match self {
ServerCerts::Tls(c) => c.peer_certificates().map(|c| c.to_vec()),
#[cfg(feature = "quic")]
ServerCerts::Quic(certs, _, _) => certs.clone(),
Self {
inner: InnerServerCerts { peer_certificates, sni, alpn }.into(),
is_tls: false,
}
}
}
pub fn sni(&self) -> Option<String> {
match self {
ServerCerts::Tls(c) => c.server_name().map(|s| s.to_string()),
#[cfg(feature = "quic")]
ServerCerts::Quic(_, sni, _) => sni.clone(),
}
#[cfg(any(feature = "s2s-incoming", feature = "webtransport"))]
impl ServerCerts {
pub fn peer_certificates(&self) -> Option<&Vec<Certificate>> {
self.inner.peer_certificates.as_ref()
}
pub fn alpn(&self) -> Option<Vec<u8>> {
match self {
ServerCerts::Tls(c) => c.alpn_protocol().map(|s| s.to_vec()),
#[cfg(feature = "quic")]
ServerCerts::Quic(_, _, alpn) => alpn.clone(),
}
pub fn sni(&self) -> Option<&str> {
self.inner.sni.as_deref()
}
pub fn alpn(&self) -> Option<&Vec<u8>> {
self.inner.alpn.as_ref()
}
pub fn is_tls(&self) -> bool {
match self {
ServerCerts::Tls(_) => true,
#[cfg(feature = "quic")]
ServerCerts::Quic(_, _, _) => false,
}
self.is_tls
}
}
@ -120,7 +128,7 @@ pub async fn shuffle_rd_wr_filter(
"{} connected: sni: {:?}, alpn: {:?}, tls-not-quic: {}",
client_addr.log_from(),
server_certs.sni(),
server_certs.alpn().map(|a| String::from_utf8_lossy(&a).to_string()),
server_certs.alpn().map(|a| String::from_utf8_lossy(a).to_string()),
server_certs.is_tls(),
);

View File

@ -40,7 +40,7 @@ fn internal_spawn_quic_listener(incoming: Endpoint, local_addr: SocketAddr, conf
#[cfg(any(feature = "s2s-incoming", feature = "webtransport"))]
let server_certs = {
let server_certs = ServerCerts::quic(&new_conn);
let server_certs = ServerCerts::from(&new_conn);
#[cfg(feature = "webtransport")]
if server_certs.alpn().map(|a| a == webtransport_quinn::ALPN).unwrap_or(false) {
return crate::webtransport::incoming::handle_webtransport_session(new_conn, config, server_certs, local_addr, client_addr).await;

View File

@ -12,7 +12,7 @@ use crate::{
};
use anyhow::{bail, Result};
use log::{error, info, trace, warn};
use rustls::{ServerConfig, ServerConnection};
use rustls::ServerConfig;
use std::{net::SocketAddr, sync::Arc};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt, BufStream},
@ -117,7 +117,7 @@ pub async fn handle_tls_connection<S: AsyncReadWritePeekSplit>(mut stream: S, cl
// until we get to the first byte of the TLS handshake...
while stream.first_bytes_match(&mut in_filter.buf[0..1], |p| p[0] != 0x16).await? {
warn!("{} buggy software connecting, sent byte after <starttls: {}", client_addr.log_to(), &in_filter.buf[0]);
stream.read(&mut in_filter.buf[0..1]).await?;
stream.read_exact(&mut in_filter.buf[0..1]).await?;
}
stream
} else {
@ -127,19 +127,9 @@ pub async fn handle_tls_connection<S: AsyncReadWritePeekSplit>(mut stream: S, cl
let stream = acceptor.accept(stream).await?;
let (_, server_connection) = stream.get_ref();
// todo: find better way to do this, might require different tokio_rustls API, the problem is I can't hold this
// past stream.into() below, and I can't get it back out after, now I *could* read sni+alpn+peer_certs
// *here* instead and pass them on, but since I haven't read anything from the stream yet, I'm
// not guaranteed that the handshake is complete and these are available, yes I can call is_handshaking()
// but there is no async API to complete the handshake, so I really need to pass it down to under
// where we read the first stanza, where we are guaranteed the handshake is complete, but I can't
// do that without ignoring the lifetime and just pulling a C programmer and pinky promising to be
// *very careful* that this reference doesn't outlive stream...
#[cfg(any(feature = "s2s-incoming", feature = "webtransport"))]
let server_certs = {
let server_connection: &'static ServerConnection = unsafe { std::mem::transmute(server_connection) };
ServerCerts::Tls(server_connection)
};
let server_certs = ServerCerts::from(server_connection);
#[cfg(not(any(feature = "s2s-incoming", feature = "webtransport")))]
let server_certs = ();

View File

@ -1,8 +1,8 @@
# interfaces to listen for reverse proxy STARTTLS/Direct TLS/TLS WebSocket (wss) XMPP connections on, should be open to the internet
incoming_listen = [ "0.0.0.0:5222", "0.0.0.0:5269", "0.0.0.0:443" ]
incoming_listen = [ "[::]:5222", "[::]:5269", "[::]:443" ]
# interfaces to listen for reverse proxy QUIC/WebTransport XMPP connections on, should be open to the internet
quic_listen = [ "0.0.0.0:443" ]
quic_listen = [ "[::]:443" ]
# interfaces to listen for outgoing proxy TCP or WebSocket XMPP connections on, should be localhost or a path for a unix socket
outgoing_listen = [ "127.0.0.1:15270" ]