Implement most of certificate auth/sasl external for incoming connections
This commit is contained in:
parent
7df66e2f78
commit
f42498970c
54
src/lib.rs
54
src/lib.rs
@ -7,6 +7,8 @@ use slicesubsequence::*;
|
|||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
|
|
||||||
pub use log::{debug, error, info, log_enabled, trace};
|
pub use log::{debug, error, info, log_enabled, trace};
|
||||||
|
use rustls::{Certificate, ServerConnection};
|
||||||
|
use tokio_rustls::webpki::DnsNameRef;
|
||||||
|
|
||||||
pub fn to_str(buf: &[u8]) -> std::borrow::Cow<'_, str> {
|
pub fn to_str(buf: &[u8]) -> std::borrow::Cow<'_, str> {
|
||||||
String::from_utf8_lossy(buf)
|
String::from_utf8_lossy(buf)
|
||||||
@ -126,3 +128,55 @@ impl<'a> Context<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "incoming")]
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub enum ServerCerts {
|
||||||
|
Tls(&'static ServerConnection),
|
||||||
|
#[cfg(feature = "quic")]
|
||||||
|
Quic(quinn::Connection),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ServerCerts {
|
||||||
|
pub fn valid(&self, dns_name: DnsNameRef) -> bool {
|
||||||
|
use std::convert::TryFrom;
|
||||||
|
use tokio_rustls::webpki;
|
||||||
|
self.first_peer_cert()
|
||||||
|
.and_then(|c| {
|
||||||
|
if let Ok(cert) = webpki::EndEntityCert::try_from(c.0.as_ref()) {
|
||||||
|
cert.verify_is_valid_for_dns_name(dns_name).map(|_| true).ok()
|
||||||
|
} else {
|
||||||
|
Some(false)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.unwrap_or(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn first_peer_cert(&self) -> Option<Certificate> {
|
||||||
|
match self {
|
||||||
|
ServerCerts::Tls(c) => c.peer_certificates().map(|c| c[0].clone()),
|
||||||
|
ServerCerts::Quic(c) => c.peer_identity().and_then(|v| v.downcast::<Vec<Certificate>>().ok()).map(|v| v[0].clone()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sni(&self) -> Option<String> {
|
||||||
|
match self {
|
||||||
|
ServerCerts::Tls(c) => c.sni_hostname().map(|s| s.to_string()),
|
||||||
|
ServerCerts::Quic(c) => c.handshake_data().and_then(|v| v.downcast::<quinn::crypto::rustls::HandshakeData>().ok()).and_then(|h| h.server_name),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn alpn(&self) -> Option<Vec<u8>> {
|
||||||
|
match self {
|
||||||
|
ServerCerts::Tls(c) => c.alpn_protocol().map(|s| s.to_vec()),
|
||||||
|
ServerCerts::Quic(c) => c.handshake_data().and_then(|v| v.downcast::<quinn::crypto::rustls::HandshakeData>().ok()).and_then(|h| h.protocol),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_tls(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
ServerCerts::Tls(_) => true,
|
||||||
|
ServerCerts::Quic(_) => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
73
src/main.rs
73
src/main.rs
@ -22,7 +22,10 @@ use rustls::{Certificate, ClientConfig, PrivateKey, ServerConfig};
|
|||||||
#[cfg(feature = "rustls-pemfile")]
|
#[cfg(feature = "rustls-pemfile")]
|
||||||
use rustls_pemfile::{certs, pkcs8_private_keys};
|
use rustls_pemfile::{certs, pkcs8_private_keys};
|
||||||
#[cfg(feature = "tokio-rustls")]
|
#[cfg(feature = "tokio-rustls")]
|
||||||
use tokio_rustls::{TlsAcceptor, TlsConnector};
|
use tokio_rustls::{
|
||||||
|
webpki::{DnsNameRef, TlsServerTrustAnchors, TrustAnchor},
|
||||||
|
TlsAcceptor, TlsConnector,
|
||||||
|
};
|
||||||
|
|
||||||
use anyhow::{bail, Result};
|
use anyhow::{bail, Result};
|
||||||
|
|
||||||
@ -54,6 +57,9 @@ mod websocket;
|
|||||||
#[cfg(feature = "websocket")]
|
#[cfg(feature = "websocket")]
|
||||||
use crate::websocket::*;
|
use crate::websocket::*;
|
||||||
|
|
||||||
|
mod verify;
|
||||||
|
use crate::verify::*;
|
||||||
|
|
||||||
mod in_out;
|
mod in_out;
|
||||||
pub use crate::in_out::*;
|
pub use crate::in_out::*;
|
||||||
|
|
||||||
@ -65,11 +71,29 @@ const ALPN_XMPP_CLIENT: &[u8] = b"xmpp-client";
|
|||||||
const ALPN_XMPP_SERVER: &[u8] = b"xmpp-server";
|
const ALPN_XMPP_SERVER: &[u8] = b"xmpp-server";
|
||||||
|
|
||||||
#[cfg(all(feature = "webpki-roots", not(feature = "rustls-native-certs")))]
|
#[cfg(all(feature = "webpki-roots", not(feature = "rustls-native-certs")))]
|
||||||
|
pub use webpki_roots::TLS_SERVER_ROOTS;
|
||||||
|
|
||||||
|
#[cfg(all(feature = "rustls-native-certs", not(feature = "webpki-roots")))]
|
||||||
|
lazy_static::lazy_static! {
|
||||||
|
static ref TLS_SERVER_ROOTS: TlsServerTrustAnchors<'static> = {
|
||||||
|
// we need these to stick around for 'static, this is only called once so no problem
|
||||||
|
let certs = Box::leak(Box::new(rustls_native_certs::load_native_certs().expect("could not load platform certs")));
|
||||||
|
let root_cert_store = Box::leak(Box::new(Vec::new()));
|
||||||
|
for cert in certs {
|
||||||
|
// some system CAs are invalid, ignore those
|
||||||
|
if let Ok(ta) = TrustAnchor::try_from_cert_der(&cert.0) {
|
||||||
|
root_cert_store.push(ta);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TlsServerTrustAnchors(root_cert_store)
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
pub fn root_cert_store() -> rustls::RootCertStore {
|
pub fn root_cert_store() -> rustls::RootCertStore {
|
||||||
use rustls::{OwnedTrustAnchor, RootCertStore};
|
use rustls::{OwnedTrustAnchor, RootCertStore};
|
||||||
let mut root_cert_store = RootCertStore::empty();
|
let mut root_cert_store = RootCertStore::empty();
|
||||||
root_cert_store.add_server_trust_anchors(
|
root_cert_store.add_server_trust_anchors(
|
||||||
webpki_roots::TLS_SERVER_ROOTS
|
TLS_SERVER_ROOTS
|
||||||
.0
|
.0
|
||||||
.iter()
|
.iter()
|
||||||
.map(|ta| OwnedTrustAnchor::from_subject_spki_name_constraints(ta.subject, ta.spki, ta.name_constraints)),
|
.map(|ta| OwnedTrustAnchor::from_subject_spki_name_constraints(ta.subject, ta.spki, ta.name_constraints)),
|
||||||
@ -77,16 +101,6 @@ pub fn root_cert_store() -> rustls::RootCertStore {
|
|||||||
root_cert_store
|
root_cert_store
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(all(feature = "rustls-native-certs", not(feature = "webpki-roots")))]
|
|
||||||
pub fn root_cert_store() -> rustls::RootCertStore {
|
|
||||||
use rustls::RootCertStore;
|
|
||||||
let mut root_cert_store = RootCertStore::empty();
|
|
||||||
for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs") {
|
|
||||||
root_cert_store.add(&rustls::Certificate(cert.0)).unwrap();
|
|
||||||
}
|
|
||||||
root_cert_store
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct Config {
|
struct Config {
|
||||||
tls_key: String,
|
tls_key: String,
|
||||||
@ -193,8 +207,11 @@ impl Config {
|
|||||||
fn server_config(&self) -> Result<ServerConfig> {
|
fn server_config(&self) -> Result<ServerConfig> {
|
||||||
let (tls_certs, tls_key) = self.certs_key()?;
|
let (tls_certs, tls_key) = self.certs_key()?;
|
||||||
|
|
||||||
// todo: request client auth here
|
let mut config = ServerConfig::builder()
|
||||||
let mut config = ServerConfig::builder().with_safe_defaults().with_no_client_auth().with_single_cert(tls_certs, tls_key)?;
|
.with_safe_defaults()
|
||||||
|
.with_client_cert_verifier(Arc::new(AllowAnyAnonymousOrAuthenticatedServer))
|
||||||
|
.with_single_cert(tls_certs, tls_key)?;
|
||||||
|
// todo: will connecting without alpn work then?
|
||||||
config.alpn_protocols.push(ALPN_XMPP_CLIENT.to_vec());
|
config.alpn_protocols.push(ALPN_XMPP_CLIENT.to_vec());
|
||||||
config.alpn_protocols.push(ALPN_XMPP_SERVER.to_vec());
|
config.alpn_protocols.push(ALPN_XMPP_SERVER.to_vec());
|
||||||
|
|
||||||
@ -248,21 +265,44 @@ impl OutgoingConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn shuffle_rd_wr(in_rd: StanzaRead, in_wr: StanzaWrite, config: CloneableConfig, local_addr: SocketAddr, client_addr: &mut Context<'_>) -> Result<()> {
|
async fn shuffle_rd_wr(in_rd: StanzaRead, in_wr: StanzaWrite, config: CloneableConfig, server_certs: ServerCerts, local_addr: SocketAddr, client_addr: &mut Context<'_>) -> Result<()> {
|
||||||
let filter = StanzaFilter::new(config.max_stanza_size_bytes);
|
let filter = StanzaFilter::new(config.max_stanza_size_bytes);
|
||||||
shuffle_rd_wr_filter(in_rd, in_wr, config, local_addr, client_addr, filter).await
|
shuffle_rd_wr_filter(in_rd, in_wr, config, server_certs, local_addr, client_addr, filter).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn shuffle_rd_wr_filter(
|
async fn shuffle_rd_wr_filter(
|
||||||
mut in_rd: StanzaRead,
|
mut in_rd: StanzaRead,
|
||||||
mut in_wr: StanzaWrite,
|
mut in_wr: StanzaWrite,
|
||||||
config: CloneableConfig,
|
config: CloneableConfig,
|
||||||
|
server_certs: ServerCerts,
|
||||||
local_addr: SocketAddr,
|
local_addr: SocketAddr,
|
||||||
client_addr: &mut Context<'_>,
|
client_addr: &mut Context<'_>,
|
||||||
mut in_filter: StanzaFilter,
|
mut in_filter: StanzaFilter,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
// now read to figure out client vs server
|
// now read to figure out client vs server
|
||||||
let (stream_open, is_c2s) = stream_preamble(&mut in_rd, &mut in_wr, client_addr.log_from(), &mut in_filter).await?;
|
let (stream_open, is_c2s) = stream_preamble(&mut in_rd, &mut in_wr, client_addr.log_from(), &mut in_filter).await?;
|
||||||
|
client_addr.set_c2s_stream_open(is_c2s, &stream_open);
|
||||||
|
|
||||||
|
trace!(
|
||||||
|
"{} connected: sni: {:?}, alpn: {:?}, tls-not-quic: {}",
|
||||||
|
client_addr.log_from(),
|
||||||
|
server_certs.sni(),
|
||||||
|
server_certs.alpn().map(|a| String::from_utf8_lossy(&a).to_string()),
|
||||||
|
server_certs.is_tls(),
|
||||||
|
);
|
||||||
|
|
||||||
|
if !is_c2s {
|
||||||
|
// for s2s we need this
|
||||||
|
let dns_from = stream_open
|
||||||
|
.extract_between(b" from='", b"'")
|
||||||
|
.or_else(|_| stream_open.extract_between(b" from=\"", b"\""))
|
||||||
|
.and_then(|b| Ok(DnsNameRef::try_from_ascii(b)?))?;
|
||||||
|
if !server_certs.valid(dns_from) {
|
||||||
|
// todo: send stream error saying cert is invalid
|
||||||
|
bail!("server certificate invalid for {:?}", dns_from);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
drop(server_certs);
|
||||||
|
|
||||||
let (out_rd, out_wr) = open_incoming(&config, local_addr, client_addr, &stream_open, is_c2s, &mut in_filter).await?;
|
let (out_rd, out_wr) = open_incoming(&config, local_addr, client_addr, &stream_open, is_c2s, &mut in_filter).await?;
|
||||||
drop(stream_open);
|
drop(stream_open);
|
||||||
@ -332,7 +372,6 @@ async fn open_incoming(
|
|||||||
) -> Result<(ReadHalf<tokio::net::TcpStream>, WriteHalf<tokio::net::TcpStream>)> {
|
) -> Result<(ReadHalf<tokio::net::TcpStream>, WriteHalf<tokio::net::TcpStream>)> {
|
||||||
let target = if is_c2s { config.c2s_target } else { config.s2s_target };
|
let target = if is_c2s { config.c2s_target } else { config.s2s_target };
|
||||||
client_addr.set_to_addr(target);
|
client_addr.set_to_addr(target);
|
||||||
client_addr.set_c2s_stream_open(is_c2s, stream_open);
|
|
||||||
|
|
||||||
let out_stream = tokio::net::TcpStream::connect(target).await?;
|
let out_stream = tokio::net::TcpStream::connect(target).await?;
|
||||||
let (out_rd, mut out_wr) = tokio::io::split(out_stream);
|
let (out_rd, mut out_wr) = tokio::io::split(out_stream);
|
||||||
|
39
src/quic.rs
39
src/quic.rs
@ -2,7 +2,6 @@ use crate::*;
|
|||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
use quinn::{ServerConfig, TransportConfig};
|
use quinn::{ServerConfig, TransportConfig};
|
||||||
use std::{net::SocketAddr, sync::Arc};
|
use std::{net::SocketAddr, sync::Arc};
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
|
||||||
@ -22,15 +21,12 @@ pub async fn quic_connect(target: SocketAddr, server_name: &str, is_c2s: bool, c
|
|||||||
Ok((StanzaWrite::new(wrt), StanzaRead::new(rd)))
|
Ok((StanzaWrite::new(wrt), StanzaRead::new(rd)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "incoming")]
|
||||||
impl Config {
|
impl Config {
|
||||||
#[cfg(feature = "incoming")]
|
|
||||||
pub fn quic_server_config(&self) -> Result<ServerConfig> {
|
pub fn quic_server_config(&self) -> Result<ServerConfig> {
|
||||||
let transport_config = TransportConfig::default();
|
let transport_config = TransportConfig::default();
|
||||||
// todo: configure transport_config here if needed
|
// todo: configure transport_config here if needed
|
||||||
let mut server_config = self.server_config()?;
|
let server_config = self.server_config()?;
|
||||||
// todo: will connecting without alpn work then?
|
|
||||||
server_config.alpn_protocols.push(ALPN_XMPP_CLIENT.to_vec());
|
|
||||||
server_config.alpn_protocols.push(ALPN_XMPP_SERVER.to_vec());
|
|
||||||
let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(server_config));
|
let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(server_config));
|
||||||
server_config.transport = Arc::new(transport_config);
|
server_config.transport = Arc::new(transport_config);
|
||||||
|
|
||||||
@ -38,32 +34,7 @@ impl Config {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct NoopIo;
|
#[cfg(feature = "incoming")]
|
||||||
|
|
||||||
use core::pin::Pin;
|
|
||||||
use core::task::{Context, Poll};
|
|
||||||
|
|
||||||
// todo: could change this to return Error and kill the stream instead, after all, s2s *should* not be receiving any bytes back
|
|
||||||
impl AsyncWrite for NoopIo {
|
|
||||||
fn poll_write(self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
|
|
||||||
Poll::Ready(Ok(buf.len()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
|
||||||
Poll::Ready(Ok(()))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
|
||||||
Poll::Ready(Ok(()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AsyncRead for NoopIo {
|
|
||||||
fn poll_read(self: Pin<&mut Self>, _cx: &mut Context<'_>, _buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
|
|
||||||
Poll::Pending
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn spawn_quic_listener(local_addr: SocketAddr, config: CloneableConfig, server_config: ServerConfig) -> JoinHandle<Result<()>> {
|
pub fn spawn_quic_listener(local_addr: SocketAddr, config: CloneableConfig, server_config: ServerConfig) -> JoinHandle<Result<()>> {
|
||||||
let (_endpoint, mut incoming) = quinn::Endpoint::server(server_config, local_addr).die("cannot listen on port/interface");
|
let (_endpoint, mut incoming) = quinn::Endpoint::server(server_config, local_addr).die("cannot listen on port/interface");
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
@ -73,14 +44,16 @@ pub fn spawn_quic_listener(local_addr: SocketAddr, config: CloneableConfig, serv
|
|||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Ok(mut new_conn) = incoming_conn.await {
|
if let Ok(mut new_conn) = incoming_conn.await {
|
||||||
let client_addr = crate::Context::new("quic-in", new_conn.connection.remote_address());
|
let client_addr = crate::Context::new("quic-in", new_conn.connection.remote_address());
|
||||||
|
let server_certs = ServerCerts::Quic(new_conn.connection);
|
||||||
info!("{} connected new connection", client_addr.log_from());
|
info!("{} connected new connection", client_addr.log_from());
|
||||||
|
|
||||||
while let Some(Ok((wrt, rd))) = new_conn.bi_streams.next().await {
|
while let Some(Ok((wrt, rd))) = new_conn.bi_streams.next().await {
|
||||||
let config = config.clone();
|
let config = config.clone();
|
||||||
let mut client_addr = client_addr.clone();
|
let mut client_addr = client_addr.clone();
|
||||||
|
let server_certs = server_certs.clone();
|
||||||
info!("{} connected new stream", client_addr.log_from());
|
info!("{} connected new stream", client_addr.log_from());
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Err(e) = shuffle_rd_wr(StanzaRead::new(rd), StanzaWrite::new(wrt), config, local_addr, &mut client_addr).await {
|
if let Err(e) = shuffle_rd_wr(StanzaRead::new(rd), StanzaWrite::new(wrt), config, server_certs, local_addr, &mut client_addr).await {
|
||||||
error!("{} {}", client_addr.log_from(), e);
|
error!("{} {}", client_addr.log_from(), e);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
20
src/tls.rs
20
src/tls.rs
@ -1,4 +1,5 @@
|
|||||||
use crate::*;
|
use crate::*;
|
||||||
|
use rustls::ServerConnection;
|
||||||
use std::convert::TryFrom;
|
use std::convert::TryFrom;
|
||||||
use tokio::io::{AsyncBufReadExt, BufStream};
|
use tokio::io::{AsyncBufReadExt, BufStream};
|
||||||
|
|
||||||
@ -172,16 +173,29 @@ async fn handle_tls_connection(mut stream: tokio::net::TcpStream, client_addr: &
|
|||||||
}
|
}
|
||||||
|
|
||||||
let stream = acceptor.accept(stream).await?;
|
let stream = acceptor.accept(stream).await?;
|
||||||
|
let (_, server_connection) = stream.get_ref();
|
||||||
|
|
||||||
|
// todo: find better way to do this, might require different tokio_rustls API, the problem is I can't hold this
|
||||||
|
// past stream.into() below, and I can't get it back out after, now I *could* read sni+alpn+peer_certs
|
||||||
|
// *here* instead and pass them on, but since I haven't read anything from the stream yet, I'm
|
||||||
|
// not guaranteed that the handshake is complete and these are available, yes I can call is_handshaking()
|
||||||
|
// but there is no async API to complete the handshake, so I really need to pass it down to under
|
||||||
|
// where we read the first stanza, where we are guaranteed the handshake is complete, but I can't
|
||||||
|
// do that without ignoring the lifetime and just pulling a C programmer and pinky promising to be
|
||||||
|
// *very careful* that this reference doesn't outlive stream...
|
||||||
|
let server_connection: &'static ServerConnection = unsafe { std::mem::transmute(server_connection) };
|
||||||
|
let server_certs = ServerCerts::Tls(server_connection);
|
||||||
|
|
||||||
#[cfg(not(feature = "websocket"))]
|
#[cfg(not(feature = "websocket"))]
|
||||||
{
|
{
|
||||||
let (in_rd, in_wr) = tokio::io::split(stream);
|
let (in_rd, in_wr) = tokio::io::split(stream);
|
||||||
shuffle_rd_wr_filter(StanzaRead::new(in_rd), StanzaWrite::new(in_wr), config, local_addr, client_addr, in_filter).await
|
shuffle_rd_wr_filter(StanzaRead::new(in_rd), StanzaWrite::new(in_wr), config, server_certs, local_addr, client_addr, in_filter).await
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "websocket")]
|
#[cfg(feature = "websocket")]
|
||||||
{
|
{
|
||||||
let stream: tokio_rustls::TlsStream<tokio::net::TcpStream> = stream.into();
|
let stream: tokio_rustls::TlsStream<tokio::net::TcpStream> = stream.into();
|
||||||
|
|
||||||
let mut stream = BufStream::with_capacity(crate::IN_BUFFER_SIZE, 0, stream);
|
let mut stream = BufStream::with_capacity(crate::IN_BUFFER_SIZE, 0, stream);
|
||||||
let websocket = {
|
let websocket = {
|
||||||
// wait up to 10 seconds until 3 bytes have been read
|
// wait up to 10 seconds until 3 bytes have been read
|
||||||
@ -206,10 +220,10 @@ async fn handle_tls_connection(mut stream: tokio::net::TcpStream, client_addr: &
|
|||||||
};
|
};
|
||||||
|
|
||||||
if websocket {
|
if websocket {
|
||||||
handle_websocket_connection(stream, client_addr, local_addr, config).await
|
handle_websocket_connection(stream, config, server_certs, local_addr, client_addr, in_filter).await
|
||||||
} else {
|
} else {
|
||||||
let (in_rd, in_wr) = tokio::io::split(stream);
|
let (in_rd, in_wr) = tokio::io::split(stream);
|
||||||
shuffle_rd_wr_filter(StanzaRead::already_buffered(in_rd), StanzaWrite::new(in_wr), config, local_addr, client_addr, in_filter).await
|
shuffle_rd_wr_filter(StanzaRead::already_buffered(in_rd), StanzaWrite::new(in_wr), config, server_certs, local_addr, client_addr, in_filter).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
69
src/verify.rs
Normal file
69
src/verify.rs
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
use rustls::server::{ClientCertVerified, ClientCertVerifier};
|
||||||
|
use rustls::{Certificate, DistinguishedNames, Error};
|
||||||
|
use std::convert::TryFrom;
|
||||||
|
use std::time::SystemTime;
|
||||||
|
use tokio_rustls::webpki;
|
||||||
|
|
||||||
|
type SignatureAlgorithms = &'static [&'static webpki::SignatureAlgorithm];
|
||||||
|
|
||||||
|
/// Which signature verification mechanisms we support. No particular
|
||||||
|
/// order.
|
||||||
|
static SUPPORTED_SIG_ALGS: SignatureAlgorithms = &[
|
||||||
|
&webpki::ECDSA_P256_SHA256,
|
||||||
|
&webpki::ECDSA_P256_SHA384,
|
||||||
|
&webpki::ECDSA_P384_SHA256,
|
||||||
|
&webpki::ECDSA_P384_SHA384,
|
||||||
|
&webpki::ED25519,
|
||||||
|
&webpki::RSA_PSS_2048_8192_SHA256_LEGACY_KEY,
|
||||||
|
&webpki::RSA_PSS_2048_8192_SHA384_LEGACY_KEY,
|
||||||
|
&webpki::RSA_PSS_2048_8192_SHA512_LEGACY_KEY,
|
||||||
|
&webpki::RSA_PKCS1_2048_8192_SHA256,
|
||||||
|
&webpki::RSA_PKCS1_2048_8192_SHA384,
|
||||||
|
&webpki::RSA_PKCS1_2048_8192_SHA512,
|
||||||
|
&webpki::RSA_PKCS1_3072_8192_SHA384,
|
||||||
|
];
|
||||||
|
|
||||||
|
pub fn pki_error(error: webpki::Error) -> Error {
|
||||||
|
use webpki::Error::*;
|
||||||
|
match error {
|
||||||
|
BadDer | BadDerTime => Error::InvalidCertificateEncoding,
|
||||||
|
InvalidSignatureForPublicKey => Error::InvalidCertificateSignature,
|
||||||
|
UnsupportedSignatureAlgorithm | UnsupportedSignatureAlgorithmForPublicKey => Error::InvalidCertificateSignatureType,
|
||||||
|
e => Error::InvalidCertificateData(format!("invalid peer certificate: {}", e)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct AllowAnyAnonymousOrAuthenticatedServer;
|
||||||
|
|
||||||
|
impl ClientCertVerifier for AllowAnyAnonymousOrAuthenticatedServer {
|
||||||
|
fn offer_client_auth(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
fn client_auth_mandatory(&self) -> Option<bool> {
|
||||||
|
Some(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn client_auth_root_subjects(&self) -> Option<DistinguishedNames> {
|
||||||
|
Some(Vec::new())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn verify_client_cert(&self, end_entity: &Certificate, intermediates: &[Certificate], now: SystemTime) -> Result<ClientCertVerified, Error> {
|
||||||
|
let (cert, chain) = prepare(end_entity, intermediates)?;
|
||||||
|
let now = webpki::Time::try_from(now).map_err(|_| Error::FailedToGetCurrentTime)?;
|
||||||
|
cert.verify_is_valid_tls_server_cert(SUPPORTED_SIG_ALGS, &crate::TLS_SERVER_ROOTS, &chain, now)
|
||||||
|
.map_err(pki_error)
|
||||||
|
.map(|_| ClientCertVerified::assertion())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type CertChainAndRoots<'a, 'b> = (webpki::EndEntityCert<'a>, Vec<&'a [u8]>);
|
||||||
|
|
||||||
|
fn prepare<'a, 'b>(end_entity: &'a Certificate, intermediates: &'a [Certificate]) -> Result<CertChainAndRoots<'a, 'b>, Error> {
|
||||||
|
// EE cert must appear first.
|
||||||
|
let cert = webpki::EndEntityCert::try_from(end_entity.0.as_ref()).map_err(pki_error)?;
|
||||||
|
|
||||||
|
let intermediates: Vec<&'a [u8]> = intermediates.iter().map(|cert| cert.0.as_ref()).collect();
|
||||||
|
|
||||||
|
Ok((cert, intermediates))
|
||||||
|
}
|
@ -17,10 +17,13 @@ fn ws_cfg(max_stanza_size_bytes: usize) -> Option<WebSocketConfig> {
|
|||||||
|
|
||||||
pub async fn handle_websocket_connection(
|
pub async fn handle_websocket_connection(
|
||||||
stream: BufStream<tokio_rustls::TlsStream<tokio::net::TcpStream>>,
|
stream: BufStream<tokio_rustls::TlsStream<tokio::net::TcpStream>>,
|
||||||
client_addr: &mut Context<'_>,
|
|
||||||
local_addr: SocketAddr,
|
|
||||||
config: CloneableConfig,
|
config: CloneableConfig,
|
||||||
|
server_certs: ServerCerts,
|
||||||
|
local_addr: SocketAddr,
|
||||||
|
client_addr: &mut Context<'_>,
|
||||||
|
in_filter: StanzaFilter,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
|
client_addr.set_proto("websocket-in");
|
||||||
info!("{} connected", client_addr.log_from());
|
info!("{} connected", client_addr.log_from());
|
||||||
|
|
||||||
// accept the websocket
|
// accept the websocket
|
||||||
@ -29,9 +32,16 @@ pub async fn handle_websocket_connection(
|
|||||||
|
|
||||||
let (in_wr, in_rd) = stream.split();
|
let (in_wr, in_rd) = stream.split();
|
||||||
|
|
||||||
let in_filter = StanzaFilter::new(config.max_stanza_size_bytes);
|
shuffle_rd_wr_filter(
|
||||||
|
StanzaRead::WebSocketRead(in_rd),
|
||||||
shuffle_rd_wr_filter(StanzaRead::WebSocketRead(in_rd), StanzaWrite::WebSocketClientWrite(in_wr), config, local_addr, client_addr, in_filter).await
|
StanzaWrite::WebSocketClientWrite(in_wr),
|
||||||
|
config,
|
||||||
|
server_certs,
|
||||||
|
local_addr,
|
||||||
|
client_addr,
|
||||||
|
in_filter,
|
||||||
|
)
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_ws(stanza: String) -> String {
|
pub fn from_ws(stanza: String) -> String {
|
||||||
|
Loading…
Reference in New Issue
Block a user