Browse Source

Massive refactoring

master
Travis Burtrum 5 months ago
parent
commit
92eaf31edc
  1. 2
      build.rs
  2. 33
      src/common/ca_roots.rs
  3. 47
      src/common/certs_key.rs
  4. 198
      src/common/incoming.rs
  5. 144
      src/common/mod.rs
  6. 54
      src/common/outgoing.rs
  7. 112
      src/context.rs
  8. 12
      src/in_out.rs
  9. 209
      src/lib.rs
  10. 499
      src/main.rs
  11. 16
      src/outgoing.rs
  12. 55
      src/quic/incoming.rs
  13. 5
      src/quic/mod.rs
  14. 23
      src/quic/outgoing.rs
  15. 51
      src/srv.rs
  16. 4
      src/stanzafilter.rs
  17. 96
      src/tls/incoming.rs
  18. 5
      src/tls/mod.rs
  19. 61
      src/tls/outgoing.rs
  20. 22
      src/verify.rs
  21. 25
      src/websocket/incoming.rs
  22. 56
      src/websocket/mod.rs
  23. 37
      src/websocket/outgoing.rs

2
build.rs

@ -39,7 +39,7 @@ fn main() { @@ -39,7 +39,7 @@ fn main() {
for (mut key, value) in env::vars() {
//writeln!(&mut w, "{key}: {value}", ).unwrap();
if value == "1" && key.starts_with("CARGO_FEATURE_") {
let mut key = key.split_off(14).replace("_", "-");
let mut key = key.split_off(14).replace('_', "-");
key.make_ascii_lowercase();
if allowed_features.contains(&key.as_str()) {
features.push(key);

33
src/common/ca_roots.rs

@ -0,0 +1,33 @@ @@ -0,0 +1,33 @@
#[cfg(feature = "tokio-rustls")]
use tokio_rustls::webpki::{TlsServerTrustAnchors, TrustAnchor};
#[cfg(all(feature = "webpki-roots", not(feature = "rustls-native-certs")))]
pub use webpki_roots::TLS_SERVER_ROOTS;
#[cfg(all(feature = "rustls-native-certs", not(feature = "webpki-roots")))]
lazy_static::lazy_static! {
pub static ref TLS_SERVER_ROOTS: TlsServerTrustAnchors<'static> = {
// we need these to stick around for 'static, this is only called once so no problem
let certs = Box::leak(Box::new(rustls_native_certs::load_native_certs().expect("could not load platform certs")));
let root_cert_store = Box::leak(Box::new(Vec::new()));
for cert in certs {
// some system CAs are invalid, ignore those
if let Ok(ta) = TrustAnchor::try_from_cert_der(&cert.0) {
root_cert_store.push(ta);
}
}
TlsServerTrustAnchors(root_cert_store)
};
}
pub fn root_cert_store() -> rustls::RootCertStore {
use rustls::{OwnedTrustAnchor, RootCertStore};
let mut root_cert_store = RootCertStore::empty();
root_cert_store.add_server_trust_anchors(
TLS_SERVER_ROOTS
.0
.iter()
.map(|ta| OwnedTrustAnchor::from_subject_spki_name_constraints(ta.subject, ta.spki, ta.name_constraints)),
);
root_cert_store
}

47
src/common/certs_key.rs

@ -0,0 +1,47 @@ @@ -0,0 +1,47 @@
use std::sync::{Arc, RwLock};
use anyhow::Result;
use rustls::{sign::CertifiedKey, SignatureScheme};
pub struct CertsKey {
#[cfg(feature = "rustls-pemfile")]
pub inner: Result<RwLock<Arc<CertifiedKey>>>,
}
impl CertsKey {
pub fn new(certified_key: Result<CertifiedKey>) -> Self {
CertsKey {
#[cfg(feature = "rustls-pemfile")]
inner: certified_key.map(|c| RwLock::new(Arc::new(c))),
}
}
}
#[cfg(feature = "rustls-pemfile")]
impl rustls::server::ResolvesServerCert for CertsKey {
fn resolve(&self, _: rustls::server::ClientHello) -> Option<Arc<rustls::sign::CertifiedKey>> {
self.inner.as_ref().map(|rwl| rwl.read().expect("CertKey poisoned?").clone()).ok()
}
}
#[cfg(feature = "rustls-pemfile")]
impl rustls::client::ResolvesClientCert for CertsKey {
fn resolve(&self, _: &[&[u8]], _: &[SignatureScheme]) -> Option<Arc<CertifiedKey>> {
self.inner.as_ref().map(|rwl| rwl.read().expect("CertKey poisoned?").clone()).ok()
}
fn has_certs(&self) -> bool {
self.inner.is_ok()
}
}
#[cfg(not(feature = "rustls-pemfile"))]
impl rustls::client::ResolvesClientCert for CertsKey {
fn resolve(&self, _: &[&[u8]], _: &[SignatureScheme]) -> Option<Arc<CertifiedKey>> {
None
}
fn has_certs(&self) -> bool {
false
}
}

198
src/common/incoming.rs

@ -0,0 +1,198 @@ @@ -0,0 +1,198 @@
use crate::{
common::{c2s, certs_key::CertsKey, shuffle_rd_wr_filter_only, stream_preamble, to_str, ALPN_XMPP_CLIENT, ALPN_XMPP_SERVER},
context::Context,
in_out::{StanzaRead, StanzaWrite},
slicesubsequence::SliceSubsequence,
stanzafilter::StanzaFilter,
};
use anyhow::{anyhow, bail, Result};
use log::trace;
use rustls::{Certificate, ServerConfig, ServerConnection};
use std::{io::Write, net::SocketAddr, sync::Arc};
use tokio::io::{AsyncWriteExt, ReadHalf, WriteHalf};
#[derive(Clone)]
pub struct CloneableConfig {
pub max_stanza_size_bytes: usize,
#[cfg(feature = "s2s-incoming")]
pub s2s_target: Option<SocketAddr>,
#[cfg(feature = "c2s-incoming")]
pub c2s_target: Option<SocketAddr>,
pub proxy: bool,
}
pub fn server_config(certs_key: Arc<CertsKey>) -> Result<ServerConfig> {
if let Err(e) = &certs_key.inner {
bail!("invalid cert/key: {}", e);
}
let config = ServerConfig::builder().with_safe_defaults();
#[cfg(feature = "s2s")]
let config = config.with_client_cert_verifier(Arc::new(crate::verify::AllowAnonymousOrAnyCert));
#[cfg(not(feature = "s2s"))]
let config = config.with_no_client_auth();
let mut config = config.with_cert_resolver(certs_key);
// todo: will connecting without alpn work then?
config.alpn_protocols.push(ALPN_XMPP_CLIENT.to_vec());
config.alpn_protocols.push(ALPN_XMPP_SERVER.to_vec());
Ok(config)
}
#[cfg(not(feature = "s2s-incoming"))]
pub type ServerCerts = ();
#[cfg(feature = "s2s-incoming")]
#[derive(Clone)]
pub enum ServerCerts {
Tls(&'static ServerConnection),
#[cfg(feature = "quic")]
Quic(quinn::Connection),
}
#[cfg(feature = "s2s-incoming")]
impl ServerCerts {
pub fn peer_certificates(&self) -> Option<Vec<Certificate>> {
match self {
ServerCerts::Tls(c) => c.peer_certificates().map(|c| c.to_vec()),
#[cfg(feature = "quic")]
ServerCerts::Quic(c) => c.peer_identity().and_then(|v| v.downcast::<Vec<Certificate>>().ok()).map(|v| v.to_vec()),
}
}
pub fn sni(&self) -> Option<String> {
match self {
ServerCerts::Tls(c) => c.sni_hostname().map(|s| s.to_string()),
#[cfg(feature = "quic")]
ServerCerts::Quic(c) => c.handshake_data().and_then(|v| v.downcast::<quinn::crypto::rustls::HandshakeData>().ok()).and_then(|h| h.server_name),
}
}
pub fn alpn(&self) -> Option<Vec<u8>> {
match self {
ServerCerts::Tls(c) => c.alpn_protocol().map(|s| s.to_vec()),
#[cfg(feature = "quic")]
ServerCerts::Quic(c) => c.handshake_data().and_then(|v| v.downcast::<quinn::crypto::rustls::HandshakeData>().ok()).and_then(|h| h.protocol),
}
}
pub fn is_tls(&self) -> bool {
match self {
ServerCerts::Tls(_) => true,
#[cfg(feature = "quic")]
ServerCerts::Quic(_) => false,
}
}
}
pub async fn shuffle_rd_wr(in_rd: StanzaRead, in_wr: StanzaWrite, config: CloneableConfig, server_certs: ServerCerts, local_addr: SocketAddr, client_addr: &mut Context<'_>) -> Result<()> {
let filter = StanzaFilter::new(config.max_stanza_size_bytes);
shuffle_rd_wr_filter(in_rd, in_wr, config, server_certs, local_addr, client_addr, filter).await
}
pub async fn shuffle_rd_wr_filter(
mut in_rd: StanzaRead,
mut in_wr: StanzaWrite,
config: CloneableConfig,
server_certs: ServerCerts,
local_addr: SocketAddr,
client_addr: &mut Context<'_>,
mut in_filter: StanzaFilter,
) -> Result<()> {
// now read to figure out client vs server
let (stream_open, is_c2s) = stream_preamble(&mut in_rd, &mut in_wr, client_addr.log_from(), &mut in_filter).await?;
client_addr.set_c2s_stream_open(is_c2s, &stream_open);
#[cfg(feature = "s2s-incoming")]
{
trace!(
"{} connected: sni: {:?}, alpn: {:?}, tls-not-quic: {}",
client_addr.log_from(),
server_certs.sni(),
server_certs.alpn().map(|a| String::from_utf8_lossy(&a).to_string()),
server_certs.is_tls(),
);
if !is_c2s {
// for s2s we need this
use std::time::SystemTime;
let domain = stream_open
.extract_between(b" from='", b"'")
.or_else(|_| stream_open.extract_between(b" from=\"", b"\""))
.and_then(|b| Ok(std::str::from_utf8(b)?))?;
let (_, cert_verifier) = crate::srv::get_xmpp_connections(domain, is_c2s).await?;
let certs = server_certs.peer_certificates().ok_or_else(|| anyhow!("no client cert auth for s2s incoming from {}", domain))?;
// todo: send stream error saying cert is invalid
cert_verifier.verify_cert(&certs[0], &certs[1..], SystemTime::now())?;
}
drop(server_certs);
}
let (out_rd, out_wr) = open_incoming(&config, local_addr, client_addr, &stream_open, is_c2s, &mut in_filter).await?;
drop(stream_open);
shuffle_rd_wr_filter_only(
in_rd,
in_wr,
StanzaRead::new(out_rd),
StanzaWrite::new(out_wr),
is_c2s,
config.max_stanza_size_bytes,
client_addr,
in_filter,
)
.await
}
async fn open_incoming(
config: &CloneableConfig,
local_addr: SocketAddr,
client_addr: &mut Context<'_>,
stream_open: &[u8],
is_c2s: bool,
in_filter: &mut StanzaFilter,
) -> Result<(ReadHalf<tokio::net::TcpStream>, WriteHalf<tokio::net::TcpStream>)> {
let target = if is_c2s {
#[cfg(not(feature = "c2s-incoming"))]
bail!("incoming c2s connection but lacking compile-time support");
#[cfg(feature = "c2s-incoming")]
config.c2s_target
} else {
#[cfg(not(feature = "s2s-incoming"))]
bail!("incoming s2s connection but lacking compile-time support");
#[cfg(feature = "s2s-incoming")]
config.s2s_target
}
.ok_or_else(|| anyhow!("incoming connection but `{}_target` not defined", c2s(is_c2s)))?;
client_addr.set_to_addr(target);
let out_stream = tokio::net::TcpStream::connect(target).await?;
let (out_rd, mut out_wr) = tokio::io::split(out_stream);
if config.proxy {
/*
https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt
PROXY TCP4 255.255.255.255 255.255.255.255 65535 65535\r\n
PROXY TCP6 ffff:f...f:ffff ffff:f...f:ffff 65535 65535\r\n
PROXY TCP6 SOURCE_IP DEST_IP SOURCE_PORT DEST_PORT\r\n
*/
// tokio AsyncWrite doesn't have write_fmt so have to go through this buffer for some crazy reason
//write!(out_wr, "PROXY TCP{} {} {} {} {}\r\n", if client_addr.is_ipv4() { '4' } else {'6' }, client_addr.ip(), local_addr.ip(), client_addr.port(), local_addr.port())?;
write!(
&mut in_filter.buf[0..],
"PROXY TCP{} {} {} {} {}\r\n",
if client_addr.client_addr().is_ipv4() { '4' } else { '6' },
client_addr.client_addr().ip(),
local_addr.ip(),
client_addr.client_addr().port(),
local_addr.port()
)?;
let end_idx = &(&in_filter.buf[0..]).first_index_of(b"\n")? + 1;
trace!("{} '{}'", client_addr.log_from(), to_str(&in_filter.buf[0..end_idx]));
out_wr.write_all(&in_filter.buf[0..end_idx]).await?;
}
trace!("{} '{}'", client_addr.log_from(), to_str(stream_open));
out_wr.write_all(stream_open).await?;
out_wr.flush().await?;
Ok((out_rd, out_wr))
}

144
src/common/mod.rs

@ -0,0 +1,144 @@ @@ -0,0 +1,144 @@
use crate::{
context::Context,
in_out::{StanzaRead, StanzaWrite},
slicesubsequence::SliceSubsequence,
stanzafilter::StanzaFilter,
};
use anyhow::{bail, Result};
use log::{info, trace};
use rustls::{
sign::{RsaSigningKey, SigningKey},
Certificate, PrivateKey,
};
use std::{fs::File, io, io::BufReader, sync::Arc};
#[cfg(feature = "incoming")]
pub mod incoming;
#[cfg(feature = "outgoing")]
pub mod outgoing;
#[cfg(any(feature = "rustls-native-certs", feature = "webpki-roots"))]
pub mod ca_roots;
pub mod certs_key;
pub const IN_BUFFER_SIZE: usize = 8192;
pub const ALPN_XMPP_CLIENT: &[u8] = b"xmpp-client";
pub const ALPN_XMPP_SERVER: &[u8] = b"xmpp-server";
pub fn to_str(buf: &[u8]) -> std::borrow::Cow<'_, str> {
String::from_utf8_lossy(buf)
}
pub fn c2s(is_c2s: bool) -> &'static str {
if is_c2s {
"c2s"
} else {
"s2s"
}
}
pub async fn first_bytes_match(stream: &tokio::net::TcpStream, p: &mut [u8], matcher: fn(&[u8]) -> bool) -> anyhow::Result<bool> {
// sooo... I don't think peek here can be used for > 1 byte without this timer craziness... can it?
let len = p.len();
// wait up to 10 seconds until len bytes have been read
use std::time::{Duration, Instant};
let duration = Duration::from_secs(10);
let now = Instant::now();
loop {
let n = stream.peek(p).await?;
if n == len {
break; // success
}
if n == 0 {
bail!("not enough bytes");
}
if Instant::now() - now > duration {
bail!("less than {} bytes in 10 seconds, closed connection?", len);
}
}
Ok(matcher(p))
}
pub async fn stream_preamble(in_rd: &mut StanzaRead, in_wr: &mut StanzaWrite, client_addr: &'_ str, in_filter: &mut StanzaFilter) -> Result<(Vec<u8>, bool)> {
let mut stream_open = Vec::new();
while let Ok(Some((buf, _))) = in_rd.next(in_filter, client_addr, in_wr).await {
trace!("{} received pre-<stream:stream> stanza: '{}'", client_addr, to_str(buf));
if buf.starts_with(b"<?xml ") {
stream_open.extend_from_slice(buf);
} else if buf.starts_with(b"<stream:stream ") {
stream_open.extend_from_slice(buf);
return Ok((stream_open, buf.contains_seq(br#" xmlns="jabber:client""#) || buf.contains_seq(br#" xmlns='jabber:client'"#)));
} else {
bail!("bad pre-<stream:stream> stanza: {}", to_str(buf));
}
}
bail!("stream ended before open")
}
#[allow(clippy::too_many_arguments)]
pub async fn shuffle_rd_wr_filter_only(
mut in_rd: StanzaRead,
mut in_wr: StanzaWrite,
mut out_rd: StanzaRead,
mut out_wr: StanzaWrite,
is_c2s: bool,
max_stanza_size_bytes: usize,
client_addr: &mut Context<'_>,
mut in_filter: StanzaFilter,
) -> Result<()> {
let mut out_filter = StanzaFilter::new(max_stanza_size_bytes);
loop {
tokio::select! {
Ok(ret) = in_rd.next(&mut in_filter, client_addr.log_to(), &mut in_wr) => {
match ret {
None => break,
Some((buf, eoft)) => {
trace!("{} '{}'", client_addr.log_from(), to_str(buf));
out_wr.write_all(is_c2s, buf, eoft, client_addr.log_from()).await?;
out_wr.flush().await?;
}
}
},
Ok(ret) = out_rd.next(&mut out_filter, client_addr.log_from(), &mut out_wr) => {
match ret {
None => break,
Some((buf, eoft)) => {
trace!("{} '{}'", client_addr.log_to(), to_str(buf));
in_wr.write_all(is_c2s, buf, eoft, client_addr.log_to()).await?;
in_wr.flush().await?;
}
}
},
}
}
info!("{} disconnected", client_addr.log_from());
Ok(())
}
#[cfg(feature = "rustls-pemfile")]
pub fn read_certified_key(tls_key: &str, tls_cert: &str) -> Result<rustls::sign::CertifiedKey> {
use rustls_pemfile::{certs, read_all, Item};
let tls_key = read_all(&mut BufReader::new(File::open(tls_key)?))
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid key"))?
.into_iter()
.flat_map(|item| match item {
Item::RSAKey(der) => RsaSigningKey::new(&PrivateKey(der)).ok().map(Arc::new).map(|r| r as Arc<dyn SigningKey>),
Item::PKCS8Key(der) => rustls::sign::any_supported_type(&PrivateKey(der)).ok(),
Item::ECKey(der) => rustls::sign::any_supported_type(&PrivateKey(der)).ok(),
_ => None,
})
.next()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "invalid key"))?;
let tls_certs = certs(&mut BufReader::new(File::open(tls_cert)?))
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert"))
.map(|mut certs| certs.drain(..).map(Certificate).collect())?;
Ok(rustls::sign::CertifiedKey::new(tls_certs, tls_key))
}

54
src/common/outgoing.rs

@ -0,0 +1,54 @@ @@ -0,0 +1,54 @@
use crate::{
common::{certs_key::CertsKey, ALPN_XMPP_CLIENT, ALPN_XMPP_SERVER},
verify::XmppServerCertVerifier,
};
use rustls::ClientConfig;
use std::sync::Arc;
use tokio_rustls::TlsConnector;
#[derive(Clone)]
pub struct OutgoingConfig {
pub max_stanza_size_bytes: usize,
pub certs_key: Arc<CertsKey>,
}
impl OutgoingConfig {
pub fn with_custom_certificate_verifier(&self, is_c2s: bool, cert_verifier: XmppServerCertVerifier) -> OutgoingVerifierConfig {
let config = match is_c2s {
false => ClientConfig::builder()
.with_safe_defaults()
.with_custom_certificate_verifier(Arc::new(cert_verifier))
.with_client_cert_resolver(self.certs_key.clone()),
_ => ClientConfig::builder()
.with_safe_defaults()
.with_custom_certificate_verifier(Arc::new(cert_verifier))
.with_no_client_auth(),
};
let mut config_alpn = config.clone();
config_alpn.alpn_protocols.push(if is_c2s { ALPN_XMPP_CLIENT } else { ALPN_XMPP_SERVER }.to_vec());
let config_alpn = Arc::new(config_alpn);
let connector_alpn: TlsConnector = config_alpn.clone().into();
let connector: TlsConnector = Arc::new(config).into();
OutgoingVerifierConfig {
max_stanza_size_bytes: self.max_stanza_size_bytes,
config_alpn,
connector_alpn,
connector,
}
}
}
#[derive(Clone)]
pub struct OutgoingVerifierConfig {
pub max_stanza_size_bytes: usize,
pub config_alpn: Arc<ClientConfig>,
pub connector_alpn: TlsConnector,
pub connector: TlsConnector,
}

112
src/context.rs

@ -0,0 +1,112 @@ @@ -0,0 +1,112 @@
use crate::{
common::{c2s, to_str},
slicesubsequence::SliceSubsequence,
};
use log::{info, log_enabled};
use std::net::SocketAddr;
#[derive(Clone)]
pub struct Context<'a> {
conn_id: String,
log_from: String,
log_to: String,
proto: &'a str,
is_c2s: Option<bool>,
to: Option<String>,
to_addr: Option<SocketAddr>,
from: Option<String>,
client_addr: SocketAddr,
}
impl<'a> Context<'a> {
pub fn new(proto: &'static str, client_addr: SocketAddr) -> Context {
let (log_to, log_from, conn_id) = if log_enabled!(log::Level::Info) {
#[cfg(feature = "logging")]
let conn_id = {
use rand::{distributions::Alphanumeric, thread_rng, Rng};
thread_rng().sample_iter(&Alphanumeric).take(10).map(char::from).collect()
};
#[cfg(not(feature = "logging"))]
let conn_id = "".to_string();
(
format!("{}: ({} <- ({}-unk)):", conn_id, client_addr, proto),
format!("{}: ({} -> ({}-unk)):", conn_id, client_addr, proto),
conn_id,
)
} else {
("".to_string(), "".to_string(), "".to_string())
};
Context {
conn_id,
log_from,
log_to,
proto,
client_addr,
is_c2s: None,
to: None,
to_addr: None,
from: None,
}
}
fn re_calc(&mut self) {
// todo: make this good
self.log_from = format!(
"{}: ({} ({}) -> ({}-{}) -> {} ({})):",
self.conn_id,
self.client_addr,
if self.from.is_some() { self.from.as_ref().unwrap() } else { "unk" },
self.proto,
if self.is_c2s.is_some() { c2s(self.is_c2s.unwrap()) } else { "unk" },
if self.to_addr.is_some() { self.to_addr.as_ref().unwrap().to_string() } else { "unk".to_string() },
if self.to.is_some() { self.to.as_ref().unwrap() } else { "unk" },
);
self.log_to = self.log_from.replace(" -> ", " <- ");
}
pub fn log_from(&self) -> &str {
&self.log_from
}
pub fn log_to(&self) -> &str {
&self.log_to
}
pub fn client_addr(&self) -> &SocketAddr {
&self.client_addr
}
pub fn set_proto(&mut self, proto: &'static str) {
if log_enabled!(log::Level::Info) {
self.proto = proto;
self.to_addr = None;
self.re_calc();
}
}
pub fn set_c2s_stream_open(&mut self, is_c2s: bool, stream_open: &[u8]) {
if log_enabled!(log::Level::Info) {
self.is_c2s = Some(is_c2s);
self.from = stream_open
.extract_between(b" from='", b"'")
.or_else(|_| stream_open.extract_between(b" from=\"", b"\""))
.map(|b| to_str(b).to_string())
.ok();
self.to = stream_open
.extract_between(b" to='", b"'")
.or_else(|_| stream_open.extract_between(b" to=\"", b"\""))
.map(|b| to_str(b).to_string())
.ok();
self.re_calc();
info!("{} stream data set", &self.log_from());
}
}
pub fn set_to_addr(&mut self, to_addr: SocketAddr) {
if log_enabled!(log::Level::Info) {
self.to_addr = Some(to_addr);
self.re_calc();
}
}
}

12
src/in_out.rs

@ -1,14 +1,20 @@ @@ -1,14 +1,20 @@
// Box<dyn AsyncWrite + Unpin + Send>, Box<dyn AsyncRead + Unpin + Send>
#[cfg(feature = "websocket")]
use crate::{from_ws, to_ws_new, AsyncReadAndWrite};
use crate::{slicesubsequence::SliceSubsequence, trace, StanzaFilter, StanzaRead::*, StanzaReader, StanzaWrite::*};
use crate::websocket::{from_ws, to_ws_new, AsyncReadAndWrite};
use crate::{
common::IN_BUFFER_SIZE,
in_out::{StanzaRead::*, StanzaWrite::*},
slicesubsequence::SliceSubsequence,
stanzafilter::{StanzaFilter, StanzaReader},
};
use anyhow::{bail, Result};
#[cfg(feature = "websocket")]
use futures_util::{
stream::{SplitSink, SplitStream},
SinkExt, TryStreamExt,
};
use log::trace;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
#[cfg(feature = "websocket")]
use tokio_tungstenite::{tungstenite::Message::*, WebSocketStream};
@ -75,7 +81,7 @@ impl StanzaRead { @@ -75,7 +81,7 @@ impl StanzaRead {
#[inline(always)]
pub fn new<T: 'static + AsyncRead + Unpin + Send>(rd: T) -> Self {
// we naively read 1 byte at a time, which buffering significantly speeds up
AsyncRead(StanzaReader(Box::new(BufReader::with_capacity(crate::IN_BUFFER_SIZE, rd))))
AsyncRead(StanzaReader(Box::new(BufReader::with_capacity(IN_BUFFER_SIZE, rd))))
}
#[inline(always)]

209
src/lib.rs

@ -1,201 +1,28 @@ @@ -1,201 +1,28 @@
mod stanzafilter;
pub use stanzafilter::*;
mod slicesubsequence;
use slicesubsequence::*;
use anyhow::bail;
use log::info;
use std::net::SocketAddr;
pub use log::{debug, error, info, log_enabled, trace};
#[cfg(feature = "s2s-incoming")]
use rustls::{Certificate, ServerConnection};
pub fn to_str(buf: &[u8]) -> std::borrow::Cow<'_, str> {
String::from_utf8_lossy(buf)
}
pub fn c2s(is_c2s: bool) -> &'static str {
if is_c2s {
"c2s"
} else {
"s2s"
}
}
pub async fn first_bytes_match(stream: &tokio::net::TcpStream, p: &mut [u8], matcher: fn(&[u8]) -> bool) -> anyhow::Result<bool> {
// sooo... I don't think peek here can be used for > 1 byte without this timer craziness... can it?
let len = p.len();
// wait up to 10 seconds until len bytes have been read
use std::time::{Duration, Instant};
let duration = Duration::from_secs(10);
let now = Instant::now();
loop {
let n = stream.peek(p).await?;
if n == len {
break; // success
}
if n == 0 {
bail!("not enough bytes");
}
if Instant::now() - now > duration {
bail!("less than {} bytes in 10 seconds, closed connection?", len);
}
}
Ok(matcher(p))
}
#[derive(Clone)]
pub struct Context<'a> {
conn_id: String,
log_from: String,
log_to: String,
proto: &'a str,
is_c2s: Option<bool>,
to: Option<String>,
to_addr: Option<SocketAddr>,
from: Option<String>,
client_addr: SocketAddr,
}
impl<'a> Context<'a> {
pub fn new(proto: &'static str, client_addr: SocketAddr) -> Context {
let (log_to, log_from, conn_id) = if log_enabled!(log::Level::Info) {
#[cfg(feature = "logging")]
let conn_id = {
use rand::distributions::Alphanumeric;
use rand::{thread_rng, Rng};
thread_rng().sample_iter(&Alphanumeric).take(10).map(char::from).collect()
};
#[cfg(not(feature = "logging"))]
let conn_id = "".to_string();
(
format!("{}: ({} <- ({}-unk)):", conn_id, client_addr, proto),
format!("{}: ({} -> ({}-unk)):", conn_id, client_addr, proto),
conn_id,
)
} else {
("".to_string(), "".to_string(), "".to_string())
};
Context {
conn_id,
log_from,
log_to,
proto,
client_addr,
is_c2s: None,
to: None,
to_addr: None,
from: None,
}
}
fn re_calc(&mut self) {
// todo: make this good
self.log_from = format!(
"{}: ({} ({}) -> ({}-{}) -> {} ({})):",
self.conn_id,
self.client_addr,
if self.from.is_some() { self.from.as_ref().unwrap() } else { "unk" },
self.proto,
if self.is_c2s.is_some() { c2s(self.is_c2s.unwrap()) } else { "unk" },
if self.to_addr.is_some() { self.to_addr.as_ref().unwrap().to_string() } else { "unk".to_string() },
if self.to.is_some() { self.to.as_ref().unwrap() } else { "unk" },
);
self.log_to = self.log_from.replace(" -> ", " <- ");
}
pub fn log_from(&self) -> &str {
&self.log_from
}
pub fn log_to(&self) -> &str {
&self.log_to
}
pub fn client_addr(&self) -> &SocketAddr {
&self.client_addr
}
pub fn set_proto(&mut self, proto: &'static str) {
if log_enabled!(log::Level::Info) {
self.proto = proto;
self.to_addr = None;
self.re_calc();
}
}
pub fn set_c2s_stream_open(&mut self, is_c2s: bool, stream_open: &[u8]) {
if log_enabled!(log::Level::Info) {
self.is_c2s = Some(is_c2s);
self.from = stream_open
.extract_between(b" from='", b"'")
.or_else(|_| stream_open.extract_between(b" from=\"", b"\""))
.map(|b| to_str(b).to_string())
.ok();
self.to = stream_open
.extract_between(b" to='", b"'")
.or_else(|_| stream_open.extract_between(b" to=\"", b"\""))
.map(|b| to_str(b).to_string())
.ok();
self.re_calc();
info!("{} stream data set", &self.log_from());
}
}
pub mod common;
pub mod slicesubsequence;
pub mod stanzafilter;
pub fn set_to_addr(&mut self, to_addr: SocketAddr) {
if log_enabled!(log::Level::Info) {
self.to_addr = Some(to_addr);
self.re_calc();
}
}
}
#[cfg(feature = "quic")]
pub mod quic;
#[cfg(not(feature = "s2s-incoming"))]
pub type ServerCerts = ();
#[cfg(feature = "tls")]
pub mod tls;
#[cfg(feature = "s2s-incoming")]
#[derive(Clone)]
pub enum ServerCerts {
Tls(&'static ServerConnection),
#[cfg(feature = "quic")]
Quic(quinn::Connection),
}
#[cfg(feature = "outgoing")]
pub mod outgoing;
#[cfg(feature = "s2s-incoming")]
impl ServerCerts {
pub fn peer_certificates(&self) -> Option<Vec<Certificate>> {
match self {
ServerCerts::Tls(c) => c.peer_certificates().map(|c| c.to_vec()),
#[cfg(feature = "quic")]
ServerCerts::Quic(c) => c.peer_identity().and_then(|v| v.downcast::<Vec<Certificate>>().ok()).map(|v| v.to_vec()),
}
}
#[cfg(any(feature = "s2s-incoming", feature = "outgoing"))]
pub mod srv;
pub fn sni(&self) -> Option<String> {
match self {
ServerCerts::Tls(c) => c.sni_hostname().map(|s| s.to_string()),
#[cfg(feature = "quic")]
ServerCerts::Quic(c) => c.handshake_data().and_then(|v| v.downcast::<quinn::crypto::rustls::HandshakeData>().ok()).and_then(|h| h.server_name),
}
}
#[cfg(feature = "websocket")]
pub mod websocket;
pub fn alpn(&self) -> Option<Vec<u8>> {
match self {
ServerCerts::Tls(c) => c.alpn_protocol().map(|s| s.to_vec()),
#[cfg(feature = "quic")]
ServerCerts::Quic(c) => c.handshake_data().and_then(|v| v.downcast::<quinn::crypto::rustls::HandshakeData>().ok()).and_then(|h| h.protocol),
}
}
#[cfg(any(feature = "s2s-incoming", feature = "outgoing"))]
pub mod verify;
pub fn is_tls(&self) -> bool {
match self {
ServerCerts::Tls(_) => true,
#[cfg(feature = "quic")]
ServerCerts::Quic(_) => false,
}
}
}
mod context;
pub mod in_out;

499
src/main.rs

@ -1,112 +1,14 @@ @@ -1,112 +1,14 @@
#![deny(clippy::all)]
use std::ffi::OsString;
use std::fs::File;
use std::io;
use std::io::{BufReader, Read, Write};
use std::iter::Iterator;
use std::net::SocketAddr;
use std::path::Path;
use std::sync::{Arc, RwLock};
use anyhow::Result;
use die::{die, Die};
use log::{debug, error, info};
use serde_derive::Deserialize;
use tokio::io::{AsyncWriteExt, ReadHalf, WriteHalf};
use tokio::net::TcpListener;
use std::{ffi::OsString, fs::File, io::Read, iter::Iterator, net::SocketAddr, path::Path, sync::Arc};
use tokio::task::JoinHandle;
use xmpp_proxy::common::certs_key::CertsKey;
#[cfg(feature = "rustls")]
use rustls::{
sign::{CertifiedKey, RsaSigningKey, SigningKey},
Certificate, ClientConfig, PrivateKey, ServerConfig, SignatureScheme,
};
#[cfg(feature = "tokio-rustls")]
use tokio_rustls::{
webpki::{DnsNameRef, TlsServerTrustAnchors, TrustAnchor},
TlsConnector,
};
use anyhow::{anyhow, bail, Result};
mod slicesubsequence;
use slicesubsequence::*;
pub use xmpp_proxy::*;
#[cfg(feature = "quic")]
mod quic;
#[cfg(feature = "quic")]
use crate::quic::*;
#[cfg(feature = "tls")]
mod tls;
#[cfg(feature = "tls")]
use crate::tls::*;
#[cfg(feature = "outgoing")]
mod outgoing;
#[cfg(feature = "outgoing")]
use crate::outgoing::*;
#[cfg(any(feature = "s2s-incoming", feature = "outgoing"))]
mod srv;
#[cfg(any(feature = "s2s-incoming", feature = "outgoing"))]
use crate::srv::*;
#[cfg(feature = "websocket")]
mod websocket;
#[cfg(feature = "websocket")]
use crate::websocket::*;
#[cfg(any(feature = "s2s-incoming", feature = "outgoing"))]
mod verify;
#[cfg(any(feature = "s2s-incoming", feature = "outgoing"))]
use crate::verify::*;
mod in_out;
pub use crate::in_out::*;
const IN_BUFFER_SIZE: usize = 8192;
// todo: split these out to outgoing module
const ALPN_XMPP_CLIENT: &[u8] = b"xmpp-client";
const ALPN_XMPP_SERVER: &[u8] = b"xmpp-server";
#[cfg(all(feature = "webpki-roots", not(feature = "rustls-native-certs")))]
pub use webpki_roots::TLS_SERVER_ROOTS;
#[cfg(all(feature = "rustls-native-certs", not(feature = "webpki-roots")))]
lazy_static::lazy_static! {
static ref TLS_SERVER_ROOTS: TlsServerTrustAnchors<'static> = {
// we need these to stick around for 'static, this is only called once so no problem
let certs = Box::leak(Box::new(rustls_native_certs::load_native_certs().expect("could not load platform certs")));
let root_cert_store = Box::leak(Box::new(Vec::new()));
for cert in certs {
// some system CAs are invalid, ignore those
if let Ok(ta) = TrustAnchor::try_from_cert_der(&cert.0) {
root_cert_store.push(ta);
}
}
TlsServerTrustAnchors(root_cert_store)
};
}
#[cfg(any(feature = "rustls-native-certs", feature = "webpki-roots"))]
pub fn root_cert_store() -> rustls::RootCertStore {
use rustls::{OwnedTrustAnchor, RootCertStore};
let mut root_cert_store = RootCertStore::empty();
root_cert_store.add_server_trust_anchors(
TLS_SERVER_ROOTS
.0
.iter()
.map(|ta| OwnedTrustAnchor::from_subject_spki_name_constraints(ta.subject, ta.spki, ta.name_constraints)),
);
root_cert_store
}
use xmpp_proxy::{common::outgoing::OutgoingConfig, outgoing::spawn_outgoing_listener};
#[derive(Deserialize, Default)]
struct Config {
@ -123,87 +25,6 @@ struct Config { @@ -123,87 +25,6 @@ struct Config {
log_style: Option<String>,
}
#[derive(Clone)]
pub struct CloneableConfig {
max_stanza_size_bytes: usize,
#[cfg(feature = "s2s-incoming")]
s2s_target: Option<SocketAddr>,
#[cfg(feature = "c2s-incoming")]
c2s_target: Option<SocketAddr>,
proxy: bool,
}
struct CertsKey {
#[cfg(feature = "rustls-pemfile")]
inner: Result<RwLock<Arc<rustls::sign::CertifiedKey>>>,
}
impl CertsKey {
fn new(main_config: &Config) -> Self {
CertsKey {
#[cfg(feature = "rustls-pemfile")]
inner: main_config.certs_key().map(|c| RwLock::new(Arc::new(c))),
}
}
#[cfg(all(unix, any(feature = "incoming", feature = "s2s-outgoing")))]
fn spawn_refresh_task(&'static self, cfg_path: OsString) -> Option<JoinHandle<Result<()>>> {
if self.inner.is_err() {
None
} else {
Some(tokio::spawn(async move {
use tokio::signal::unix::{signal, SignalKind};
let mut stream = signal(SignalKind::hangup())?;
loop {
stream.recv().await;
info!("got SIGHUP");
match Config::parse(&cfg_path).and_then(|c| c.certs_key()) {
Ok(cert_key) => {
if let Ok(rwl) = self.inner.as_ref() {
let cert_key = Arc::new(cert_key);
let mut certs_key = rwl.write().expect("CertKey poisoned?");
*certs_key = cert_key;
drop(certs_key);
info!("reloaded cert/key successfully!");
}
}
Err(e) => error!("invalid config/cert/key on SIGHUP: {}", e),
};
}
}))
}
}
}
#[cfg(feature = "rustls-pemfile")]
impl rustls::server::ResolvesServerCert for CertsKey {
fn resolve(&self, _: rustls::server::ClientHello) -> Option<Arc<rustls::sign::CertifiedKey>> {
self.inner.as_ref().map(|rwl| rwl.read().expect("CertKey poisoned?").clone()).ok()
}
}
#[cfg(feature = "rustls-pemfile")]
impl rustls::client::ResolvesClientCert for CertsKey {
fn resolve(&self, _: &[&[u8]], _: &[SignatureScheme]) -> Option<Arc<CertifiedKey>> {
self.inner.as_ref().map(|rwl| rwl.read().expect("CertKey poisoned?").clone()).ok()
}
fn has_certs(&self) -> bool {
self.inner.is_ok()
}
}
#[cfg(not(feature = "rustls-pemfile"))]
impl rustls::client::ResolvesClientCert for CertsKey {
fn resolve(&self, _: &[&[u8]], _: &[SignatureScheme]) -> Option<Arc<CertifiedKey>> {
None
}
fn has_certs(&self) -> bool {
false
}
}
impl Config {
fn parse<P: AsRef<Path>>(path: P) -> Result<Config> {
let mut f = File::open(path)?;
@ -212,8 +33,9 @@ impl Config { @@ -212,8 +33,9 @@ impl Config {
Ok(toml::from_str(&input)?)
}
fn get_cloneable_cfg(&self) -> CloneableConfig {
CloneableConfig {
#[cfg(feature = "incoming")]
fn get_cloneable_cfg(&self) -> xmpp_proxy::common::incoming::CloneableConfig {
xmpp_proxy::common::incoming::CloneableConfig {
max_stanza_size_bytes: self.max_stanza_size_bytes,
#[cfg(feature = "s2s-incoming")]
s2s_target: self.s2s_target,
@ -238,268 +60,41 @@ impl Config { @@ -238,268 +60,41 @@ impl Config {
#[cfg(feature = "rustls-pemfile")]
fn certs_key(&self) -> Result<rustls::sign::CertifiedKey> {
use rustls_pemfile::{certs, read_all, Item};
let tls_key = read_all(&mut BufReader::new(File::open(&self.tls_key)?))
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid key"))?
.into_iter()
.flat_map(|item| match item {
Item::RSAKey(der) => RsaSigningKey::new(&PrivateKey(der)).ok().map(Arc::new).map(|r| r as Arc<dyn SigningKey>),
Item::PKCS8Key(der) => rustls::sign::any_supported_type(&PrivateKey(der)).ok(),
Item::ECKey(der) => rustls::sign::any_supported_type(&PrivateKey(der)).ok(),
_ => None,
})
.next()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "invalid key"))?;
let tls_certs = certs(&mut BufReader::new(File::open(&self.tls_cert)?))
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert"))
.map(|mut certs| certs.drain(..).map(Certificate).collect())?;
Ok(rustls::sign::CertifiedKey::new(tls_certs, tls_key))
}
#[cfg(feature = "incoming")]
fn server_config(&self, certs_key: Arc<CertsKey>) -> Result<ServerConfig> {
if let Err(e) = &certs_key.inner {
bail!("invalid cert/key: {}", e);
}
let config = ServerConfig::builder().with_safe_defaults();
#[cfg(feature = "s2s")]
let config = config.with_client_cert_verifier(Arc::new(AllowAnonymousOrAnyCert));
#[cfg(not(feature = "s2s"))]
let config = config.with_no_client_auth();
let mut config = config.with_cert_resolver(certs_key);
// todo: will connecting without alpn work then?
config.alpn_protocols.push(ALPN_XMPP_CLIENT.to_vec());
config.alpn_protocols.push(ALPN_XMPP_SERVER.to_vec());
Ok(config)
xmpp_proxy::common::read_certified_key(&self.tls_key, &self.tls_cert)
}
}
#[derive(Clone)]
#[cfg(feature = "outgoing")]
pub struct OutgoingConfig {
max_stanza_size_bytes: usize,
certs_key: Arc<CertsKey>,
}
#[cfg(feature = "outgoing")]
impl OutgoingConfig {
pub fn with_custom_certificate_verifier(&self, is_c2s: bool, cert_verifier: XmppServerCertVerifier) -> OutgoingVerifierConfig {
let config = match is_c2s {
false => ClientConfig::builder()
.with_safe_defaults()
.with_custom_certificate_verifier(Arc::new(cert_verifier))
.with_client_cert_resolver(self.certs_key.clone()),
_ => ClientConfig::builder()
.with_safe_defaults()
.with_custom_certificate_verifier(Arc::new(cert_verifier))
.with_no_client_auth(),
};
let mut config_alpn = config.clone();
config_alpn.alpn_protocols.push(if is_c2s { ALPN_XMPP_CLIENT } else { ALPN_XMPP_SERVER }.to_vec());
let config_alpn = Arc::new(config_alpn);
let connector_alpn: TlsConnector = config_alpn.clone().into();
let connector: TlsConnector = Arc::new(config).into();
OutgoingVerifierConfig {
max_stanza_size_bytes: self.max_stanza_size_bytes,
config_alpn,
connector_alpn,
connector,
}
}
}
#[derive(Clone)]
#[cfg(feature = "outgoing")]
pub struct OutgoingVerifierConfig {
pub max_stanza_size_bytes: usize,
pub config_alpn: Arc<ClientConfig>,
pub connector_alpn: TlsConnector,
pub connector: TlsConnector,
}
#[cfg(feature = "incoming")]
async fn shuffle_rd_wr(in_rd: StanzaRead, in_wr: StanzaWrite, config: CloneableConfig, server_certs: ServerCerts, local_addr: SocketAddr, client_addr: &mut Context<'_>) -> Result<()> {
let filter = StanzaFilter::new(config.max_stanza_size_bytes);
shuffle_rd_wr_filter(in_rd, in_wr, config, server_certs, local_addr, client_addr, filter).await
}
#[cfg(feature = "incoming")]
async fn shuffle_rd_wr_filter(
mut in_rd: StanzaRead,
mut in_wr: StanzaWrite,
config: CloneableConfig,
server_certs: ServerCerts,
local_addr: SocketAddr,
client_addr: &mut Context<'_>,
mut in_filter: StanzaFilter,
) -> Result<()> {
// now read to figure out client vs server
let (stream_open, is_c2s) = stream_preamble(&mut in_rd, &mut in_wr, client_addr.log_from(), &mut in_filter).await?;
client_addr.set_c2s_stream_open(is_c2s, &stream_open);
#[cfg(feature = "s2s-incoming")]
{
trace!(
"{} connected: sni: {:?}, alpn: {:?}, tls-not-quic: {}",
client_addr.log_from(),
server_certs.sni(),
server_certs.alpn().map(|a| String::from_utf8_lossy(&a).to_string()),
server_certs.is_tls(),
);
if !is_c2s {
// for s2s we need this
use std::time::SystemTime;
let domain = stream_open
.extract_between(b" from='", b"'")
.or_else(|_| stream_open.extract_between(b" from=\"", b"\""))
.and_then(|b| Ok(std::str::from_utf8(b)?))?;
let (_, cert_verifier) = get_xmpp_connections(domain, is_c2s).await?;
let certs = server_certs.peer_certificates().ok_or_else(|| anyhow!("no client cert auth for s2s incoming from {}", domain))?;
// todo: send stream error saying cert is invalid
cert_verifier.verify_cert(&certs[0], &certs[1..], SystemTime::now())?;
}
drop(server_certs);
}
let (out_rd, out_wr) = open_incoming(&config, local_addr, client_addr, &stream_open, is_c2s, &mut in_filter).await?;
drop(stream_open);
shuffle_rd_wr_filter_only(
in_rd,
in_wr,
StanzaRead::new(out_rd),
StanzaWrite::new(out_wr),
is_c2s,
config.max_stanza_size_bytes,
client_addr,
in_filter,
)
.await
}
#[allow(clippy::too_many_arguments)]
async fn shuffle_rd_wr_filter_only(
mut in_rd: StanzaRead,
mut in_wr: StanzaWrite,
mut out_rd: StanzaRead,
mut out_wr: StanzaWrite,
is_c2s: bool,
max_stanza_size_bytes: usize,
client_addr: &mut Context<'_>,
mut in_filter: StanzaFilter,
) -> Result<()> {
let mut out_filter = StanzaFilter::new(max_stanza_size_bytes);
loop {
tokio::select! {
Ok(ret) = in_rd.next(&mut in_filter, client_addr.log_to(), &mut in_wr) => {
match ret {
None => break,
Some