Implement most of certificate auth/sasl external for incoming connections

This commit is contained in:
Travis Burtrum 2022-02-22 02:51:45 -05:00
parent 7df66e2f78
commit f42498970c
6 changed files with 217 additions and 58 deletions

View File

@ -7,6 +7,8 @@ use slicesubsequence::*;
use std::net::SocketAddr;
pub use log::{debug, error, info, log_enabled, trace};
use rustls::{Certificate, ServerConnection};
use tokio_rustls::webpki::DnsNameRef;
pub fn to_str(buf: &[u8]) -> std::borrow::Cow<'_, str> {
String::from_utf8_lossy(buf)
@ -126,3 +128,55 @@ impl<'a> Context<'a> {
}
}
}
#[cfg(feature = "incoming")]
#[derive(Clone)]
pub enum ServerCerts {
Tls(&'static ServerConnection),
#[cfg(feature = "quic")]
Quic(quinn::Connection),
}
impl ServerCerts {
pub fn valid(&self, dns_name: DnsNameRef) -> bool {
use std::convert::TryFrom;
use tokio_rustls::webpki;
self.first_peer_cert()
.and_then(|c| {
if let Ok(cert) = webpki::EndEntityCert::try_from(c.0.as_ref()) {
cert.verify_is_valid_for_dns_name(dns_name).map(|_| true).ok()
} else {
Some(false)
}
})
.unwrap_or(false)
}
pub fn first_peer_cert(&self) -> Option<Certificate> {
match self {
ServerCerts::Tls(c) => c.peer_certificates().map(|c| c[0].clone()),
ServerCerts::Quic(c) => c.peer_identity().and_then(|v| v.downcast::<Vec<Certificate>>().ok()).map(|v| v[0].clone()),
}
}
pub fn sni(&self) -> Option<String> {
match self {
ServerCerts::Tls(c) => c.sni_hostname().map(|s| s.to_string()),
ServerCerts::Quic(c) => c.handshake_data().and_then(|v| v.downcast::<quinn::crypto::rustls::HandshakeData>().ok()).and_then(|h| h.server_name),
}
}
pub fn alpn(&self) -> Option<Vec<u8>> {
match self {
ServerCerts::Tls(c) => c.alpn_protocol().map(|s| s.to_vec()),
ServerCerts::Quic(c) => c.handshake_data().and_then(|v| v.downcast::<quinn::crypto::rustls::HandshakeData>().ok()).and_then(|h| h.protocol),
}
}
pub fn is_tls(&self) -> bool {
match self {
ServerCerts::Tls(_) => true,
ServerCerts::Quic(_) => false,
}
}
}

View File

@ -22,7 +22,10 @@ use rustls::{Certificate, ClientConfig, PrivateKey, ServerConfig};
#[cfg(feature = "rustls-pemfile")]
use rustls_pemfile::{certs, pkcs8_private_keys};
#[cfg(feature = "tokio-rustls")]
use tokio_rustls::{TlsAcceptor, TlsConnector};
use tokio_rustls::{
webpki::{DnsNameRef, TlsServerTrustAnchors, TrustAnchor},
TlsAcceptor, TlsConnector,
};
use anyhow::{bail, Result};
@ -54,6 +57,9 @@ mod websocket;
#[cfg(feature = "websocket")]
use crate::websocket::*;
mod verify;
use crate::verify::*;
mod in_out;
pub use crate::in_out::*;
@ -65,11 +71,29 @@ const ALPN_XMPP_CLIENT: &[u8] = b"xmpp-client";
const ALPN_XMPP_SERVER: &[u8] = b"xmpp-server";
#[cfg(all(feature = "webpki-roots", not(feature = "rustls-native-certs")))]
pub use webpki_roots::TLS_SERVER_ROOTS;
#[cfg(all(feature = "rustls-native-certs", not(feature = "webpki-roots")))]
lazy_static::lazy_static! {
static ref TLS_SERVER_ROOTS: TlsServerTrustAnchors<'static> = {
// we need these to stick around for 'static, this is only called once so no problem
let certs = Box::leak(Box::new(rustls_native_certs::load_native_certs().expect("could not load platform certs")));
let root_cert_store = Box::leak(Box::new(Vec::new()));
for cert in certs {
// some system CAs are invalid, ignore those
if let Ok(ta) = TrustAnchor::try_from_cert_der(&cert.0) {
root_cert_store.push(ta);
}
}
TlsServerTrustAnchors(root_cert_store)
};
}
pub fn root_cert_store() -> rustls::RootCertStore {
use rustls::{OwnedTrustAnchor, RootCertStore};
let mut root_cert_store = RootCertStore::empty();
root_cert_store.add_server_trust_anchors(
webpki_roots::TLS_SERVER_ROOTS
TLS_SERVER_ROOTS
.0
.iter()
.map(|ta| OwnedTrustAnchor::from_subject_spki_name_constraints(ta.subject, ta.spki, ta.name_constraints)),
@ -77,16 +101,6 @@ pub fn root_cert_store() -> rustls::RootCertStore {
root_cert_store
}
#[cfg(all(feature = "rustls-native-certs", not(feature = "webpki-roots")))]
pub fn root_cert_store() -> rustls::RootCertStore {
use rustls::RootCertStore;
let mut root_cert_store = RootCertStore::empty();
for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs") {
root_cert_store.add(&rustls::Certificate(cert.0)).unwrap();
}
root_cert_store
}
#[derive(Deserialize)]
struct Config {
tls_key: String,
@ -193,8 +207,11 @@ impl Config {
fn server_config(&self) -> Result<ServerConfig> {
let (tls_certs, tls_key) = self.certs_key()?;
// todo: request client auth here
let mut config = ServerConfig::builder().with_safe_defaults().with_no_client_auth().with_single_cert(tls_certs, tls_key)?;
let mut config = ServerConfig::builder()
.with_safe_defaults()
.with_client_cert_verifier(Arc::new(AllowAnyAnonymousOrAuthenticatedServer))
.with_single_cert(tls_certs, tls_key)?;
// todo: will connecting without alpn work then?
config.alpn_protocols.push(ALPN_XMPP_CLIENT.to_vec());
config.alpn_protocols.push(ALPN_XMPP_SERVER.to_vec());
@ -248,21 +265,44 @@ impl OutgoingConfig {
}
}
async fn shuffle_rd_wr(in_rd: StanzaRead, in_wr: StanzaWrite, config: CloneableConfig, local_addr: SocketAddr, client_addr: &mut Context<'_>) -> Result<()> {
async fn shuffle_rd_wr(in_rd: StanzaRead, in_wr: StanzaWrite, config: CloneableConfig, server_certs: ServerCerts, local_addr: SocketAddr, client_addr: &mut Context<'_>) -> Result<()> {
let filter = StanzaFilter::new(config.max_stanza_size_bytes);
shuffle_rd_wr_filter(in_rd, in_wr, config, local_addr, client_addr, filter).await
shuffle_rd_wr_filter(in_rd, in_wr, config, server_certs, local_addr, client_addr, filter).await
}
async fn shuffle_rd_wr_filter(
mut in_rd: StanzaRead,
mut in_wr: StanzaWrite,
config: CloneableConfig,
server_certs: ServerCerts,
local_addr: SocketAddr,
client_addr: &mut Context<'_>,
mut in_filter: StanzaFilter,
) -> Result<()> {
// now read to figure out client vs server
let (stream_open, is_c2s) = stream_preamble(&mut in_rd, &mut in_wr, client_addr.log_from(), &mut in_filter).await?;
client_addr.set_c2s_stream_open(is_c2s, &stream_open);
trace!(
"{} connected: sni: {:?}, alpn: {:?}, tls-not-quic: {}",
client_addr.log_from(),
server_certs.sni(),
server_certs.alpn().map(|a| String::from_utf8_lossy(&a).to_string()),
server_certs.is_tls(),
);
if !is_c2s {
// for s2s we need this
let dns_from = stream_open
.extract_between(b" from='", b"'")
.or_else(|_| stream_open.extract_between(b" from=\"", b"\""))
.and_then(|b| Ok(DnsNameRef::try_from_ascii(b)?))?;
if !server_certs.valid(dns_from) {
// todo: send stream error saying cert is invalid
bail!("server certificate invalid for {:?}", dns_from);
}
}
drop(server_certs);
let (out_rd, out_wr) = open_incoming(&config, local_addr, client_addr, &stream_open, is_c2s, &mut in_filter).await?;
drop(stream_open);
@ -332,7 +372,6 @@ async fn open_incoming(
) -> Result<(ReadHalf<tokio::net::TcpStream>, WriteHalf<tokio::net::TcpStream>)> {
let target = if is_c2s { config.c2s_target } else { config.s2s_target };
client_addr.set_to_addr(target);
client_addr.set_c2s_stream_open(is_c2s, stream_open);
let out_stream = tokio::net::TcpStream::connect(target).await?;
let (out_rd, mut out_wr) = tokio::io::split(out_stream);

View File

@ -2,7 +2,6 @@ use crate::*;
use futures::StreamExt;
use quinn::{ServerConfig, TransportConfig};
use std::{net::SocketAddr, sync::Arc};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use anyhow::Result;
@ -22,15 +21,12 @@ pub async fn quic_connect(target: SocketAddr, server_name: &str, is_c2s: bool, c
Ok((StanzaWrite::new(wrt), StanzaRead::new(rd)))
}
impl Config {
#[cfg(feature = "incoming")]
impl Config {
pub fn quic_server_config(&self) -> Result<ServerConfig> {
let transport_config = TransportConfig::default();
// todo: configure transport_config here if needed
let mut server_config = self.server_config()?;
// todo: will connecting without alpn work then?
server_config.alpn_protocols.push(ALPN_XMPP_CLIENT.to_vec());
server_config.alpn_protocols.push(ALPN_XMPP_SERVER.to_vec());
let server_config = self.server_config()?;
let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(server_config));
server_config.transport = Arc::new(transport_config);
@ -38,32 +34,7 @@ impl Config {
}
}
struct NoopIo;
use core::pin::Pin;
use core::task::{Context, Poll};
// todo: could change this to return Error and kill the stream instead, after all, s2s *should* not be receiving any bytes back
impl AsyncWrite for NoopIo {
fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
impl AsyncRead for NoopIo {
fn poll_read(self: Pin<&mut Self>, _cx: &mut Context<'_>, _buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
Poll::Pending
}
}
#[cfg(feature = "incoming")]
pub fn spawn_quic_listener(local_addr: SocketAddr, config: CloneableConfig, server_config: ServerConfig) -> JoinHandle<Result<()>> {
let (_endpoint, mut incoming) = quinn::Endpoint::server(server_config, local_addr).die("cannot listen on port/interface");
tokio::spawn(async move {
@ -73,14 +44,16 @@ pub fn spawn_quic_listener(local_addr: SocketAddr, config: CloneableConfig, serv
tokio::spawn(async move {
if let Ok(mut new_conn) = incoming_conn.await {
let client_addr = crate::Context::new("quic-in", new_conn.connection.remote_address());
let server_certs = ServerCerts::Quic(new_conn.connection);
info!("{} connected new connection", client_addr.log_from());
while let Some(Ok((wrt, rd))) = new_conn.bi_streams.next().await {
let config = config.clone();
let mut client_addr = client_addr.clone();
let server_certs = server_certs.clone();
info!("{} connected new stream", client_addr.log_from());
tokio::spawn(async move {
if let Err(e) = shuffle_rd_wr(StanzaRead::new(rd), StanzaWrite::new(wrt), config, local_addr, &mut client_addr).await {
if let Err(e) = shuffle_rd_wr(StanzaRead::new(rd), StanzaWrite::new(wrt), config, server_certs, local_addr, &mut client_addr).await {
error!("{} {}", client_addr.log_from(), e);
}
});

View File

@ -1,4 +1,5 @@
use crate::*;
use rustls::ServerConnection;
use std::convert::TryFrom;
use tokio::io::{AsyncBufReadExt, BufStream};
@ -172,16 +173,29 @@ async fn handle_tls_connection(mut stream: tokio::net::TcpStream, client_addr: &
}
let stream = acceptor.accept(stream).await?;
let (_, server_connection) = stream.get_ref();
// todo: find better way to do this, might require different tokio_rustls API, the problem is I can't hold this
// past stream.into() below, and I can't get it back out after, now I *could* read sni+alpn+peer_certs
// *here* instead and pass them on, but since I haven't read anything from the stream yet, I'm
// not guaranteed that the handshake is complete and these are available, yes I can call is_handshaking()
// but there is no async API to complete the handshake, so I really need to pass it down to under
// where we read the first stanza, where we are guaranteed the handshake is complete, but I can't
// do that without ignoring the lifetime and just pulling a C programmer and pinky promising to be
// *very careful* that this reference doesn't outlive stream...
let server_connection: &'static ServerConnection = unsafe { std::mem::transmute(server_connection) };
let server_certs = ServerCerts::Tls(server_connection);
#[cfg(not(feature = "websocket"))]
{
let (in_rd, in_wr) = tokio::io::split(stream);
shuffle_rd_wr_filter(StanzaRead::new(in_rd), StanzaWrite::new(in_wr), config, local_addr, client_addr, in_filter).await
shuffle_rd_wr_filter(StanzaRead::new(in_rd), StanzaWrite::new(in_wr), config, server_certs, local_addr, client_addr, in_filter).await
}
#[cfg(feature = "websocket")]
{
let stream: tokio_rustls::TlsStream<tokio::net::TcpStream> = stream.into();
let mut stream = BufStream::with_capacity(crate::IN_BUFFER_SIZE, 0, stream);
let websocket = {
// wait up to 10 seconds until 3 bytes have been read
@ -206,10 +220,10 @@ async fn handle_tls_connection(mut stream: tokio::net::TcpStream, client_addr: &
};
if websocket {
handle_websocket_connection(stream, client_addr, local_addr, config).await
handle_websocket_connection(stream, config, server_certs, local_addr, client_addr, in_filter).await
} else {
let (in_rd, in_wr) = tokio::io::split(stream);
shuffle_rd_wr_filter(StanzaRead::already_buffered(in_rd), StanzaWrite::new(in_wr), config, local_addr, client_addr, in_filter).await
shuffle_rd_wr_filter(StanzaRead::already_buffered(in_rd), StanzaWrite::new(in_wr), config, server_certs, local_addr, client_addr, in_filter).await
}
}
}

69
src/verify.rs Normal file
View File

@ -0,0 +1,69 @@
use rustls::server::{ClientCertVerified, ClientCertVerifier};
use rustls::{Certificate, DistinguishedNames, Error};
use std::convert::TryFrom;
use std::time::SystemTime;
use tokio_rustls::webpki;
type SignatureAlgorithms = &'static [&'static webpki::SignatureAlgorithm];
/// Which signature verification mechanisms we support. No particular
/// order.
static SUPPORTED_SIG_ALGS: SignatureAlgorithms = &[
&webpki::ECDSA_P256_SHA256,
&webpki::ECDSA_P256_SHA384,
&webpki::ECDSA_P384_SHA256,
&webpki::ECDSA_P384_SHA384,
&webpki::ED25519,
&webpki::RSA_PSS_2048_8192_SHA256_LEGACY_KEY,
&webpki::RSA_PSS_2048_8192_SHA384_LEGACY_KEY,
&webpki::RSA_PSS_2048_8192_SHA512_LEGACY_KEY,
&webpki::RSA_PKCS1_2048_8192_SHA256,
&webpki::RSA_PKCS1_2048_8192_SHA384,
&webpki::RSA_PKCS1_2048_8192_SHA512,
&webpki::RSA_PKCS1_3072_8192_SHA384,
];
pub fn pki_error(error: webpki::Error) -> Error {
use webpki::Error::*;
match error {
BadDer | BadDerTime => Error::InvalidCertificateEncoding,
InvalidSignatureForPublicKey => Error::InvalidCertificateSignature,
UnsupportedSignatureAlgorithm | UnsupportedSignatureAlgorithmForPublicKey => Error::InvalidCertificateSignatureType,
e => Error::InvalidCertificateData(format!("invalid peer certificate: {}", e)),
}
}
pub struct AllowAnyAnonymousOrAuthenticatedServer;
impl ClientCertVerifier for AllowAnyAnonymousOrAuthenticatedServer {
fn offer_client_auth(&self) -> bool {
true
}
fn client_auth_mandatory(&self) -> Option<bool> {
Some(false)
}
fn client_auth_root_subjects(&self) -> Option<DistinguishedNames> {
Some(Vec::new())
}
fn verify_client_cert(&self, end_entity: &Certificate, intermediates: &[Certificate], now: SystemTime) -> Result<ClientCertVerified, Error> {
let (cert, chain) = prepare(end_entity, intermediates)?;
let now = webpki::Time::try_from(now).map_err(|_| Error::FailedToGetCurrentTime)?;
cert.verify_is_valid_tls_server_cert(SUPPORTED_SIG_ALGS, &crate::TLS_SERVER_ROOTS, &chain, now)
.map_err(pki_error)
.map(|_| ClientCertVerified::assertion())
}
}
type CertChainAndRoots<'a, 'b> = (webpki::EndEntityCert<'a>, Vec<&'a [u8]>);
fn prepare<'a, 'b>(end_entity: &'a Certificate, intermediates: &'a [Certificate]) -> Result<CertChainAndRoots<'a, 'b>, Error> {
// EE cert must appear first.
let cert = webpki::EndEntityCert::try_from(end_entity.0.as_ref()).map_err(pki_error)?;
let intermediates: Vec<&'a [u8]> = intermediates.iter().map(|cert| cert.0.as_ref()).collect();
Ok((cert, intermediates))
}

View File

@ -17,10 +17,13 @@ fn ws_cfg(max_stanza_size_bytes: usize) -> Option<WebSocketConfig> {
pub async fn handle_websocket_connection(
stream: BufStream<tokio_rustls::TlsStream<tokio::net::TcpStream>>,
client_addr: &mut Context<'_>,
local_addr: SocketAddr,
config: CloneableConfig,
server_certs: ServerCerts,
local_addr: SocketAddr,
client_addr: &mut Context<'_>,
in_filter: StanzaFilter,
) -> Result<()> {
client_addr.set_proto("websocket-in");
info!("{} connected", client_addr.log_from());
// accept the websocket
@ -29,9 +32,16 @@ pub async fn handle_websocket_connection(
let (in_wr, in_rd) = stream.split();
let in_filter = StanzaFilter::new(config.max_stanza_size_bytes);
shuffle_rd_wr_filter(StanzaRead::WebSocketRead(in_rd), StanzaWrite::WebSocketClientWrite(in_wr), config, local_addr, client_addr, in_filter).await
shuffle_rd_wr_filter(
StanzaRead::WebSocketRead(in_rd),
StanzaWrite::WebSocketClientWrite(in_wr),
config,
server_certs,
local_addr,
client_addr,
in_filter,
)
.await
}
pub fn from_ws(stanza: String) -> String {