Massive refactoring
This commit is contained in:
parent
4498559c08
commit
e553b4da14
2
build.rs
2
build.rs
@ -39,7 +39,7 @@ fn main() {
|
||||
for (mut key, value) in env::vars() {
|
||||
//writeln!(&mut w, "{key}: {value}", ).unwrap();
|
||||
if value == "1" && key.starts_with("CARGO_FEATURE_") {
|
||||
let mut key = key.split_off(14).replace("_", "-");
|
||||
let mut key = key.split_off(14).replace('_', "-");
|
||||
key.make_ascii_lowercase();
|
||||
if allowed_features.contains(&key.as_str()) {
|
||||
features.push(key);
|
||||
|
33
src/common/ca_roots.rs
Normal file
33
src/common/ca_roots.rs
Normal file
@ -0,0 +1,33 @@
|
||||
#[cfg(feature = "tokio-rustls")]
|
||||
use tokio_rustls::webpki::{TlsServerTrustAnchors, TrustAnchor};
|
||||
|
||||
#[cfg(all(feature = "webpki-roots", not(feature = "rustls-native-certs")))]
|
||||
pub use webpki_roots::TLS_SERVER_ROOTS;
|
||||
|
||||
#[cfg(all(feature = "rustls-native-certs", not(feature = "webpki-roots")))]
|
||||
lazy_static::lazy_static! {
|
||||
pub static ref TLS_SERVER_ROOTS: TlsServerTrustAnchors<'static> = {
|
||||
// we need these to stick around for 'static, this is only called once so no problem
|
||||
let certs = Box::leak(Box::new(rustls_native_certs::load_native_certs().expect("could not load platform certs")));
|
||||
let root_cert_store = Box::leak(Box::new(Vec::new()));
|
||||
for cert in certs {
|
||||
// some system CAs are invalid, ignore those
|
||||
if let Ok(ta) = TrustAnchor::try_from_cert_der(&cert.0) {
|
||||
root_cert_store.push(ta);
|
||||
}
|
||||
}
|
||||
TlsServerTrustAnchors(root_cert_store)
|
||||
};
|
||||
}
|
||||
|
||||
pub fn root_cert_store() -> rustls::RootCertStore {
|
||||
use rustls::{OwnedTrustAnchor, RootCertStore};
|
||||
let mut root_cert_store = RootCertStore::empty();
|
||||
root_cert_store.add_server_trust_anchors(
|
||||
TLS_SERVER_ROOTS
|
||||
.0
|
||||
.iter()
|
||||
.map(|ta| OwnedTrustAnchor::from_subject_spki_name_constraints(ta.subject, ta.spki, ta.name_constraints)),
|
||||
);
|
||||
root_cert_store
|
||||
}
|
47
src/common/certs_key.rs
Normal file
47
src/common/certs_key.rs
Normal file
@ -0,0 +1,47 @@
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use anyhow::Result;
|
||||
use rustls::{sign::CertifiedKey, SignatureScheme};
|
||||
|
||||
pub struct CertsKey {
|
||||
#[cfg(feature = "rustls-pemfile")]
|
||||
pub inner: Result<RwLock<Arc<CertifiedKey>>>,
|
||||
}
|
||||
|
||||
impl CertsKey {
|
||||
pub fn new(certified_key: Result<CertifiedKey>) -> Self {
|
||||
CertsKey {
|
||||
#[cfg(feature = "rustls-pemfile")]
|
||||
inner: certified_key.map(|c| RwLock::new(Arc::new(c))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "rustls-pemfile")]
|
||||
impl rustls::server::ResolvesServerCert for CertsKey {
|
||||
fn resolve(&self, _: rustls::server::ClientHello) -> Option<Arc<rustls::sign::CertifiedKey>> {
|
||||
self.inner.as_ref().map(|rwl| rwl.read().expect("CertKey poisoned?").clone()).ok()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "rustls-pemfile")]
|
||||
impl rustls::client::ResolvesClientCert for CertsKey {
|
||||
fn resolve(&self, _: &[&[u8]], _: &[SignatureScheme]) -> Option<Arc<CertifiedKey>> {
|
||||
self.inner.as_ref().map(|rwl| rwl.read().expect("CertKey poisoned?").clone()).ok()
|
||||
}
|
||||
|
||||
fn has_certs(&self) -> bool {
|
||||
self.inner.is_ok()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "rustls-pemfile"))]
|
||||
impl rustls::client::ResolvesClientCert for CertsKey {
|
||||
fn resolve(&self, _: &[&[u8]], _: &[SignatureScheme]) -> Option<Arc<CertifiedKey>> {
|
||||
None
|
||||
}
|
||||
|
||||
fn has_certs(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
198
src/common/incoming.rs
Normal file
198
src/common/incoming.rs
Normal file
@ -0,0 +1,198 @@
|
||||
use crate::{
|
||||
common::{c2s, certs_key::CertsKey, shuffle_rd_wr_filter_only, stream_preamble, to_str, ALPN_XMPP_CLIENT, ALPN_XMPP_SERVER},
|
||||
context::Context,
|
||||
in_out::{StanzaRead, StanzaWrite},
|
||||
slicesubsequence::SliceSubsequence,
|
||||
stanzafilter::StanzaFilter,
|
||||
};
|
||||
use anyhow::{anyhow, bail, Result};
|
||||
use log::trace;
|
||||
use rustls::{Certificate, ServerConfig, ServerConnection};
|
||||
use std::{io::Write, net::SocketAddr, sync::Arc};
|
||||
use tokio::io::{AsyncWriteExt, ReadHalf, WriteHalf};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CloneableConfig {
|
||||
pub max_stanza_size_bytes: usize,
|
||||
#[cfg(feature = "s2s-incoming")]
|
||||
pub s2s_target: Option<SocketAddr>,
|
||||
#[cfg(feature = "c2s-incoming")]
|
||||
pub c2s_target: Option<SocketAddr>,
|
||||
pub proxy: bool,
|
||||
}
|
||||
|
||||
pub fn server_config(certs_key: Arc<CertsKey>) -> Result<ServerConfig> {
|
||||
if let Err(e) = &certs_key.inner {
|
||||
bail!("invalid cert/key: {}", e);
|
||||
}
|
||||
|
||||
let config = ServerConfig::builder().with_safe_defaults();
|
||||
#[cfg(feature = "s2s")]
|
||||
let config = config.with_client_cert_verifier(Arc::new(crate::verify::AllowAnonymousOrAnyCert));
|
||||
#[cfg(not(feature = "s2s"))]
|
||||
let config = config.with_no_client_auth();
|
||||
let mut config = config.with_cert_resolver(certs_key);
|
||||
// todo: will connecting without alpn work then?
|
||||
config.alpn_protocols.push(ALPN_XMPP_CLIENT.to_vec());
|
||||
config.alpn_protocols.push(ALPN_XMPP_SERVER.to_vec());
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "s2s-incoming"))]
|
||||
pub type ServerCerts = ();
|
||||
|
||||
#[cfg(feature = "s2s-incoming")]
|
||||
#[derive(Clone)]
|
||||
pub enum ServerCerts {
|
||||
Tls(&'static ServerConnection),
|
||||
#[cfg(feature = "quic")]
|
||||
Quic(quinn::Connection),
|
||||
}
|
||||
|
||||
#[cfg(feature = "s2s-incoming")]
|
||||
impl ServerCerts {
|
||||
pub fn peer_certificates(&self) -> Option<Vec<Certificate>> {
|
||||
match self {
|
||||
ServerCerts::Tls(c) => c.peer_certificates().map(|c| c.to_vec()),
|
||||
#[cfg(feature = "quic")]
|
||||
ServerCerts::Quic(c) => c.peer_identity().and_then(|v| v.downcast::<Vec<Certificate>>().ok()).map(|v| v.to_vec()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sni(&self) -> Option<String> {
|
||||
match self {
|
||||
ServerCerts::Tls(c) => c.sni_hostname().map(|s| s.to_string()),
|
||||
#[cfg(feature = "quic")]
|
||||
ServerCerts::Quic(c) => c.handshake_data().and_then(|v| v.downcast::<quinn::crypto::rustls::HandshakeData>().ok()).and_then(|h| h.server_name),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn alpn(&self) -> Option<Vec<u8>> {
|
||||
match self {
|
||||
ServerCerts::Tls(c) => c.alpn_protocol().map(|s| s.to_vec()),
|
||||
#[cfg(feature = "quic")]
|
||||
ServerCerts::Quic(c) => c.handshake_data().and_then(|v| v.downcast::<quinn::crypto::rustls::HandshakeData>().ok()).and_then(|h| h.protocol),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_tls(&self) -> bool {
|
||||
match self {
|
||||
ServerCerts::Tls(_) => true,
|
||||
#[cfg(feature = "quic")]
|
||||
ServerCerts::Quic(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn shuffle_rd_wr(in_rd: StanzaRead, in_wr: StanzaWrite, config: CloneableConfig, server_certs: ServerCerts, local_addr: SocketAddr, client_addr: &mut Context<'_>) -> Result<()> {
|
||||
let filter = StanzaFilter::new(config.max_stanza_size_bytes);
|
||||
shuffle_rd_wr_filter(in_rd, in_wr, config, server_certs, local_addr, client_addr, filter).await
|
||||
}
|
||||
|
||||
pub async fn shuffle_rd_wr_filter(
|
||||
mut in_rd: StanzaRead,
|
||||
mut in_wr: StanzaWrite,
|
||||
config: CloneableConfig,
|
||||
server_certs: ServerCerts,
|
||||
local_addr: SocketAddr,
|
||||
client_addr: &mut Context<'_>,
|
||||
mut in_filter: StanzaFilter,
|
||||
) -> Result<()> {
|
||||
// now read to figure out client vs server
|
||||
let (stream_open, is_c2s) = stream_preamble(&mut in_rd, &mut in_wr, client_addr.log_from(), &mut in_filter).await?;
|
||||
client_addr.set_c2s_stream_open(is_c2s, &stream_open);
|
||||
|
||||
#[cfg(feature = "s2s-incoming")]
|
||||
{
|
||||
trace!(
|
||||
"{} connected: sni: {:?}, alpn: {:?}, tls-not-quic: {}",
|
||||
client_addr.log_from(),
|
||||
server_certs.sni(),
|
||||
server_certs.alpn().map(|a| String::from_utf8_lossy(&a).to_string()),
|
||||
server_certs.is_tls(),
|
||||
);
|
||||
|
||||
if !is_c2s {
|
||||
// for s2s we need this
|
||||
use std::time::SystemTime;
|
||||
let domain = stream_open
|
||||
.extract_between(b" from='", b"'")
|
||||
.or_else(|_| stream_open.extract_between(b" from=\"", b"\""))
|
||||
.and_then(|b| Ok(std::str::from_utf8(b)?))?;
|
||||
let (_, cert_verifier) = crate::srv::get_xmpp_connections(domain, is_c2s).await?;
|
||||
let certs = server_certs.peer_certificates().ok_or_else(|| anyhow!("no client cert auth for s2s incoming from {}", domain))?;
|
||||
// todo: send stream error saying cert is invalid
|
||||
cert_verifier.verify_cert(&certs[0], &certs[1..], SystemTime::now())?;
|
||||
}
|
||||
drop(server_certs);
|
||||
}
|
||||
|
||||
let (out_rd, out_wr) = open_incoming(&config, local_addr, client_addr, &stream_open, is_c2s, &mut in_filter).await?;
|
||||
drop(stream_open);
|
||||
|
||||
shuffle_rd_wr_filter_only(
|
||||
in_rd,
|
||||
in_wr,
|
||||
StanzaRead::new(out_rd),
|
||||
StanzaWrite::new(out_wr),
|
||||
is_c2s,
|
||||
config.max_stanza_size_bytes,
|
||||
client_addr,
|
||||
in_filter,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn open_incoming(
|
||||
config: &CloneableConfig,
|
||||
local_addr: SocketAddr,
|
||||
client_addr: &mut Context<'_>,
|
||||
stream_open: &[u8],
|
||||
is_c2s: bool,
|
||||
in_filter: &mut StanzaFilter,
|
||||
) -> Result<(ReadHalf<tokio::net::TcpStream>, WriteHalf<tokio::net::TcpStream>)> {
|
||||
let target = if is_c2s {
|
||||
#[cfg(not(feature = "c2s-incoming"))]
|
||||
bail!("incoming c2s connection but lacking compile-time support");
|
||||
#[cfg(feature = "c2s-incoming")]
|
||||
config.c2s_target
|
||||
} else {
|
||||
#[cfg(not(feature = "s2s-incoming"))]
|
||||
bail!("incoming s2s connection but lacking compile-time support");
|
||||
#[cfg(feature = "s2s-incoming")]
|
||||
config.s2s_target
|
||||
}
|
||||
.ok_or_else(|| anyhow!("incoming connection but `{}_target` not defined", c2s(is_c2s)))?;
|
||||
client_addr.set_to_addr(target);
|
||||
|
||||
let out_stream = tokio::net::TcpStream::connect(target).await?;
|
||||
let (out_rd, mut out_wr) = tokio::io::split(out_stream);
|
||||
|
||||
if config.proxy {
|
||||
/*
|
||||
https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt
|
||||
PROXY TCP4 255.255.255.255 255.255.255.255 65535 65535\r\n
|
||||
PROXY TCP6 ffff:f...f:ffff ffff:f...f:ffff 65535 65535\r\n
|
||||
PROXY TCP6 SOURCE_IP DEST_IP SOURCE_PORT DEST_PORT\r\n
|
||||
*/
|
||||
// tokio AsyncWrite doesn't have write_fmt so have to go through this buffer for some crazy reason
|
||||
//write!(out_wr, "PROXY TCP{} {} {} {} {}\r\n", if client_addr.is_ipv4() { '4' } else {'6' }, client_addr.ip(), local_addr.ip(), client_addr.port(), local_addr.port())?;
|
||||
write!(
|
||||
&mut in_filter.buf[0..],
|
||||
"PROXY TCP{} {} {} {} {}\r\n",
|
||||
if client_addr.client_addr().is_ipv4() { '4' } else { '6' },
|
||||
client_addr.client_addr().ip(),
|
||||
local_addr.ip(),
|
||||
client_addr.client_addr().port(),
|
||||
local_addr.port()
|
||||
)?;
|
||||
let end_idx = &(&in_filter.buf[0..]).first_index_of(b"\n")? + 1;
|
||||
trace!("{} '{}'", client_addr.log_from(), to_str(&in_filter.buf[0..end_idx]));
|
||||
out_wr.write_all(&in_filter.buf[0..end_idx]).await?;
|
||||
}
|
||||
trace!("{} '{}'", client_addr.log_from(), to_str(stream_open));
|
||||
out_wr.write_all(stream_open).await?;
|
||||
out_wr.flush().await?;
|
||||
Ok((out_rd, out_wr))
|
||||
}
|
144
src/common/mod.rs
Normal file
144
src/common/mod.rs
Normal file
@ -0,0 +1,144 @@
|
||||
use crate::{
|
||||
context::Context,
|
||||
in_out::{StanzaRead, StanzaWrite},
|
||||
slicesubsequence::SliceSubsequence,
|
||||
stanzafilter::StanzaFilter,
|
||||
};
|
||||
use anyhow::{bail, Result};
|
||||
use log::{info, trace};
|
||||
use rustls::{
|
||||
sign::{RsaSigningKey, SigningKey},
|
||||
Certificate, PrivateKey,
|
||||
};
|
||||
use std::{fs::File, io, io::BufReader, sync::Arc};
|
||||
|
||||
#[cfg(feature = "incoming")]
|
||||
pub mod incoming;
|
||||
|
||||
#[cfg(feature = "outgoing")]
|
||||
pub mod outgoing;
|
||||
|
||||
#[cfg(any(feature = "rustls-native-certs", feature = "webpki-roots"))]
|
||||
pub mod ca_roots;
|
||||
|
||||
pub mod certs_key;
|
||||
|
||||
pub const IN_BUFFER_SIZE: usize = 8192;
|
||||
pub const ALPN_XMPP_CLIENT: &[u8] = b"xmpp-client";
|
||||
pub const ALPN_XMPP_SERVER: &[u8] = b"xmpp-server";
|
||||
|
||||
pub fn to_str(buf: &[u8]) -> std::borrow::Cow<'_, str> {
|
||||
String::from_utf8_lossy(buf)
|
||||
}
|
||||
|
||||
pub fn c2s(is_c2s: bool) -> &'static str {
|
||||
if is_c2s {
|
||||
"c2s"
|
||||
} else {
|
||||
"s2s"
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn first_bytes_match(stream: &tokio::net::TcpStream, p: &mut [u8], matcher: fn(&[u8]) -> bool) -> anyhow::Result<bool> {
|
||||
// sooo... I don't think peek here can be used for > 1 byte without this timer craziness... can it?
|
||||
let len = p.len();
|
||||
// wait up to 10 seconds until len bytes have been read
|
||||
use std::time::{Duration, Instant};
|
||||
let duration = Duration::from_secs(10);
|
||||
let now = Instant::now();
|
||||
loop {
|
||||
let n = stream.peek(p).await?;
|
||||
if n == len {
|
||||
break; // success
|
||||
}
|
||||
if n == 0 {
|
||||
bail!("not enough bytes");
|
||||
}
|
||||
if Instant::now() - now > duration {
|
||||
bail!("less than {} bytes in 10 seconds, closed connection?", len);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(matcher(p))
|
||||
}
|
||||
|
||||
pub async fn stream_preamble(in_rd: &mut StanzaRead, in_wr: &mut StanzaWrite, client_addr: &'_ str, in_filter: &mut StanzaFilter) -> Result<(Vec<u8>, bool)> {
|
||||
let mut stream_open = Vec::new();
|
||||
while let Ok(Some((buf, _))) = in_rd.next(in_filter, client_addr, in_wr).await {
|
||||
trace!("{} received pre-<stream:stream> stanza: '{}'", client_addr, to_str(buf));
|
||||
if buf.starts_with(b"<?xml ") {
|
||||
stream_open.extend_from_slice(buf);
|
||||
} else if buf.starts_with(b"<stream:stream ") {
|
||||
stream_open.extend_from_slice(buf);
|
||||
return Ok((stream_open, buf.contains_seq(br#" xmlns="jabber:client""#) || buf.contains_seq(br#" xmlns='jabber:client'"#)));
|
||||
} else {
|
||||
bail!("bad pre-<stream:stream> stanza: {}", to_str(buf));
|
||||
}
|
||||
}
|
||||
bail!("stream ended before open")
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn shuffle_rd_wr_filter_only(
|
||||
mut in_rd: StanzaRead,
|
||||
mut in_wr: StanzaWrite,
|
||||
mut out_rd: StanzaRead,
|
||||
mut out_wr: StanzaWrite,
|
||||
is_c2s: bool,
|
||||
max_stanza_size_bytes: usize,
|
||||
client_addr: &mut Context<'_>,
|
||||
mut in_filter: StanzaFilter,
|
||||
) -> Result<()> {
|
||||
let mut out_filter = StanzaFilter::new(max_stanza_size_bytes);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
Ok(ret) = in_rd.next(&mut in_filter, client_addr.log_to(), &mut in_wr) => {
|
||||
match ret {
|
||||
None => break,
|
||||
Some((buf, eoft)) => {
|
||||
trace!("{} '{}'", client_addr.log_from(), to_str(buf));
|
||||
out_wr.write_all(is_c2s, buf, eoft, client_addr.log_from()).await?;
|
||||
out_wr.flush().await?;
|
||||
}
|
||||
}
|
||||
},
|
||||
Ok(ret) = out_rd.next(&mut out_filter, client_addr.log_from(), &mut out_wr) => {
|
||||
match ret {
|
||||
None => break,
|
||||
Some((buf, eoft)) => {
|
||||
trace!("{} '{}'", client_addr.log_to(), to_str(buf));
|
||||
in_wr.write_all(is_c2s, buf, eoft, client_addr.log_to()).await?;
|
||||
in_wr.flush().await?;
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
info!("{} disconnected", client_addr.log_from());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "rustls-pemfile")]
|
||||
pub fn read_certified_key(tls_key: &str, tls_cert: &str) -> Result<rustls::sign::CertifiedKey> {
|
||||
use rustls_pemfile::{certs, read_all, Item};
|
||||
|
||||
let tls_key = read_all(&mut BufReader::new(File::open(tls_key)?))
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid key"))?
|
||||
.into_iter()
|
||||
.flat_map(|item| match item {
|
||||
Item::RSAKey(der) => RsaSigningKey::new(&PrivateKey(der)).ok().map(Arc::new).map(|r| r as Arc<dyn SigningKey>),
|
||||
Item::PKCS8Key(der) => rustls::sign::any_supported_type(&PrivateKey(der)).ok(),
|
||||
Item::ECKey(der) => rustls::sign::any_supported_type(&PrivateKey(der)).ok(),
|
||||
_ => None,
|
||||
})
|
||||
.next()
|
||||
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "invalid key"))?;
|
||||
|
||||
let tls_certs = certs(&mut BufReader::new(File::open(tls_cert)?))
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert"))
|
||||
.map(|mut certs| certs.drain(..).map(Certificate).collect())?;
|
||||
|
||||
Ok(rustls::sign::CertifiedKey::new(tls_certs, tls_key))
|
||||
}
|
54
src/common/outgoing.rs
Normal file
54
src/common/outgoing.rs
Normal file
@ -0,0 +1,54 @@
|
||||
use crate::{
|
||||
common::{certs_key::CertsKey, ALPN_XMPP_CLIENT, ALPN_XMPP_SERVER},
|
||||
verify::XmppServerCertVerifier,
|
||||
};
|
||||
use rustls::ClientConfig;
|
||||
use std::sync::Arc;
|
||||
use tokio_rustls::TlsConnector;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OutgoingConfig {
|
||||
pub max_stanza_size_bytes: usize,
|
||||
pub certs_key: Arc<CertsKey>,
|
||||
}
|
||||
|
||||
impl OutgoingConfig {
|
||||
pub fn with_custom_certificate_verifier(&self, is_c2s: bool, cert_verifier: XmppServerCertVerifier) -> OutgoingVerifierConfig {
|
||||
let config = match is_c2s {
|
||||
false => ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_custom_certificate_verifier(Arc::new(cert_verifier))
|
||||
.with_client_cert_resolver(self.certs_key.clone()),
|
||||
_ => ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_custom_certificate_verifier(Arc::new(cert_verifier))
|
||||
.with_no_client_auth(),
|
||||
};
|
||||
|
||||
let mut config_alpn = config.clone();
|
||||
config_alpn.alpn_protocols.push(if is_c2s { ALPN_XMPP_CLIENT } else { ALPN_XMPP_SERVER }.to_vec());
|
||||
|
||||
let config_alpn = Arc::new(config_alpn);
|
||||
|
||||
let connector_alpn: TlsConnector = config_alpn.clone().into();
|
||||
|
||||
let connector: TlsConnector = Arc::new(config).into();
|
||||
|
||||
OutgoingVerifierConfig {
|
||||
max_stanza_size_bytes: self.max_stanza_size_bytes,
|
||||
config_alpn,
|
||||
connector_alpn,
|
||||
connector,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct OutgoingVerifierConfig {
|
||||
pub max_stanza_size_bytes: usize,
|
||||
|
||||
pub config_alpn: Arc<ClientConfig>,
|
||||
pub connector_alpn: TlsConnector,
|
||||
|
||||
pub connector: TlsConnector,
|
||||
}
|
112
src/context.rs
Normal file
112
src/context.rs
Normal file
@ -0,0 +1,112 @@
|
||||
use crate::{
|
||||
common::{c2s, to_str},
|
||||
slicesubsequence::SliceSubsequence,
|
||||
};
|
||||
use log::{info, log_enabled};
|
||||
use std::net::SocketAddr;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Context<'a> {
|
||||
conn_id: String,
|
||||
log_from: String,
|
||||
log_to: String,
|
||||
proto: &'a str,
|
||||
is_c2s: Option<bool>,
|
||||
to: Option<String>,
|
||||
to_addr: Option<SocketAddr>,
|
||||
from: Option<String>,
|
||||
client_addr: SocketAddr,
|
||||
}
|
||||
|
||||
impl<'a> Context<'a> {
|
||||
pub fn new(proto: &'static str, client_addr: SocketAddr) -> Context {
|
||||
let (log_to, log_from, conn_id) = if log_enabled!(log::Level::Info) {
|
||||
#[cfg(feature = "logging")]
|
||||
let conn_id = {
|
||||
use rand::{distributions::Alphanumeric, thread_rng, Rng};
|
||||
thread_rng().sample_iter(&Alphanumeric).take(10).map(char::from).collect()
|
||||
};
|
||||
#[cfg(not(feature = "logging"))]
|
||||
let conn_id = "".to_string();
|
||||
(
|
||||
format!("{}: ({} <- ({}-unk)):", conn_id, client_addr, proto),
|
||||
format!("{}: ({} -> ({}-unk)):", conn_id, client_addr, proto),
|
||||
conn_id,
|
||||
)
|
||||
} else {
|
||||
("".to_string(), "".to_string(), "".to_string())
|
||||
};
|
||||
|
||||
Context {
|
||||
conn_id,
|
||||
log_from,
|
||||
log_to,
|
||||
proto,
|
||||
client_addr,
|
||||
is_c2s: None,
|
||||
to: None,
|
||||
to_addr: None,
|
||||
from: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn re_calc(&mut self) {
|
||||
// todo: make this good
|
||||
self.log_from = format!(
|
||||
"{}: ({} ({}) -> ({}-{}) -> {} ({})):",
|
||||
self.conn_id,
|
||||
self.client_addr,
|
||||
if self.from.is_some() { self.from.as_ref().unwrap() } else { "unk" },
|
||||
self.proto,
|
||||
if self.is_c2s.is_some() { c2s(self.is_c2s.unwrap()) } else { "unk" },
|
||||
if self.to_addr.is_some() { self.to_addr.as_ref().unwrap().to_string() } else { "unk".to_string() },
|
||||
if self.to.is_some() { self.to.as_ref().unwrap() } else { "unk" },
|
||||
);
|
||||
self.log_to = self.log_from.replace(" -> ", " <- ");
|
||||
}
|
||||
|
||||
pub fn log_from(&self) -> &str {
|
||||
&self.log_from
|
||||
}
|
||||
|
||||
pub fn log_to(&self) -> &str {
|
||||
&self.log_to
|
||||
}
|
||||
|
||||
pub fn client_addr(&self) -> &SocketAddr {
|
||||
&self.client_addr
|
||||
}
|
||||
|
||||
pub fn set_proto(&mut self, proto: &'static str) {
|
||||
if log_enabled!(log::Level::Info) {
|
||||
self.proto = proto;
|
||||
self.to_addr = None;
|
||||
self.re_calc();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_c2s_stream_open(&mut self, is_c2s: bool, stream_open: &[u8]) {
|
||||
if log_enabled!(log::Level::Info) {
|
||||
self.is_c2s = Some(is_c2s);
|
||||
self.from = stream_open
|
||||
.extract_between(b" from='", b"'")
|
||||
.or_else(|_| stream_open.extract_between(b" from=\"", b"\""))
|
||||
.map(|b| to_str(b).to_string())
|
||||
.ok();
|
||||
self.to = stream_open
|
||||
.extract_between(b" to='", b"'")
|
||||
.or_else(|_| stream_open.extract_between(b" to=\"", b"\""))
|
||||
.map(|b| to_str(b).to_string())
|
||||
.ok();
|
||||
self.re_calc();
|
||||
info!("{} stream data set", &self.log_from());
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_to_addr(&mut self, to_addr: SocketAddr) {
|
||||
if log_enabled!(log::Level::Info) {
|
||||
self.to_addr = Some(to_addr);
|
||||
self.re_calc();
|
||||
}
|
||||
}
|
||||
}
|
@ -1,14 +1,20 @@
|
||||
// Box<dyn AsyncWrite + Unpin + Send>, Box<dyn AsyncRead + Unpin + Send>
|
||||
|
||||
#[cfg(feature = "websocket")]
|
||||
use crate::{from_ws, to_ws_new, AsyncReadAndWrite};
|
||||
use crate::{slicesubsequence::SliceSubsequence, trace, StanzaFilter, StanzaRead::*, StanzaReader, StanzaWrite::*};
|
||||
use crate::websocket::{from_ws, to_ws_new, AsyncReadAndWrite};
|
||||
use crate::{
|
||||
common::IN_BUFFER_SIZE,
|
||||
in_out::{StanzaRead::*, StanzaWrite::*},
|
||||
slicesubsequence::SliceSubsequence,
|
||||
stanzafilter::{StanzaFilter, StanzaReader},
|
||||
};
|
||||
use anyhow::{bail, Result};
|
||||
#[cfg(feature = "websocket")]
|
||||
use futures_util::{
|
||||
stream::{SplitSink, SplitStream},
|
||||
SinkExt, TryStreamExt,
|
||||
};
|
||||
use log::trace;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader};
|
||||
#[cfg(feature = "websocket")]
|
||||
use tokio_tungstenite::{tungstenite::Message::*, WebSocketStream};
|
||||
@ -75,7 +81,7 @@ impl StanzaRead {
|
||||
#[inline(always)]
|
||||
pub fn new<T: 'static + AsyncRead + Unpin + Send>(rd: T) -> Self {
|
||||
// we naively read 1 byte at a time, which buffering significantly speeds up
|
||||
AsyncRead(StanzaReader(Box::new(BufReader::with_capacity(crate::IN_BUFFER_SIZE, rd))))
|
||||
AsyncRead(StanzaReader(Box::new(BufReader::with_capacity(IN_BUFFER_SIZE, rd))))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
|
209
src/lib.rs
209
src/lib.rs
@ -1,201 +1,28 @@
|
||||
mod stanzafilter;
|
||||
pub use stanzafilter::*;
|
||||
|
||||
mod slicesubsequence;
|
||||
use slicesubsequence::*;
|
||||
|
||||
use anyhow::bail;
|
||||
use log::info;
|
||||
use std::net::SocketAddr;
|
||||
|
||||
pub use log::{debug, error, info, log_enabled, trace};
|
||||
pub mod common;
|
||||
pub mod slicesubsequence;
|
||||
pub mod stanzafilter;
|
||||
|
||||
#[cfg(feature = "s2s-incoming")]
|
||||
use rustls::{Certificate, ServerConnection};
|
||||
#[cfg(feature = "quic")]
|
||||
pub mod quic;
|
||||
|
||||
pub fn to_str(buf: &[u8]) -> std::borrow::Cow<'_, str> {
|
||||
String::from_utf8_lossy(buf)
|
||||
}
|
||||
#[cfg(feature = "tls")]
|
||||
pub mod tls;
|
||||
|
||||
pub fn c2s(is_c2s: bool) -> &'static str {
|
||||
if is_c2s {
|
||||
"c2s"
|
||||
} else {
|
||||
"s2s"
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "outgoing")]
|
||||
pub mod outgoing;
|
||||
|
||||
pub async fn first_bytes_match(stream: &tokio::net::TcpStream, p: &mut [u8], matcher: fn(&[u8]) -> bool) -> anyhow::Result<bool> {
|
||||
// sooo... I don't think peek here can be used for > 1 byte without this timer craziness... can it?
|
||||
let len = p.len();
|
||||
// wait up to 10 seconds until len bytes have been read
|
||||
use std::time::{Duration, Instant};
|
||||
let duration = Duration::from_secs(10);
|
||||
let now = Instant::now();
|
||||
loop {
|
||||
let n = stream.peek(p).await?;
|
||||
if n == len {
|
||||
break; // success
|
||||
}
|
||||
if n == 0 {
|
||||
bail!("not enough bytes");
|
||||
}
|
||||
if Instant::now() - now > duration {
|
||||
bail!("less than {} bytes in 10 seconds, closed connection?", len);
|
||||
}
|
||||
}
|
||||
#[cfg(any(feature = "s2s-incoming", feature = "outgoing"))]
|
||||
pub mod srv;
|
||||
|
||||
Ok(matcher(p))
|
||||
}
|
||||
#[cfg(feature = "websocket")]
|
||||
pub mod websocket;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Context<'a> {
|
||||
conn_id: String,
|
||||
log_from: String,
|
||||
log_to: String,
|
||||
proto: &'a str,
|
||||
is_c2s: Option<bool>,
|
||||
to: Option<String>,
|
||||
to_addr: Option<SocketAddr>,
|
||||
from: Option<String>,
|
||||
client_addr: SocketAddr,
|
||||
}
|
||||
#[cfg(any(feature = "s2s-incoming", feature = "outgoing"))]
|
||||
pub mod verify;
|
||||
|
||||
impl<'a> Context<'a> {
|
||||
pub fn new(proto: &'static str, client_addr: SocketAddr) -> Context {
|
||||
let (log_to, log_from, conn_id) = if log_enabled!(log::Level::Info) {
|
||||
#[cfg(feature = "logging")]
|
||||
let conn_id = {
|
||||
use rand::distributions::Alphanumeric;
|
||||
use rand::{thread_rng, Rng};
|
||||
thread_rng().sample_iter(&Alphanumeric).take(10).map(char::from).collect()
|
||||
};
|
||||
#[cfg(not(feature = "logging"))]
|
||||
let conn_id = "".to_string();
|
||||
(
|
||||
format!("{}: ({} <- ({}-unk)):", conn_id, client_addr, proto),
|
||||
format!("{}: ({} -> ({}-unk)):", conn_id, client_addr, proto),
|
||||
conn_id,
|
||||
)
|
||||
} else {
|
||||
("".to_string(), "".to_string(), "".to_string())
|
||||
};
|
||||
|
||||
Context {
|
||||
conn_id,
|
||||
log_from,
|
||||
log_to,
|
||||
proto,
|
||||
client_addr,
|
||||
is_c2s: None,
|
||||
to: None,
|
||||
to_addr: None,
|
||||
from: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn re_calc(&mut self) {
|
||||
// todo: make this good
|
||||
self.log_from = format!(
|
||||
"{}: ({} ({}) -> ({}-{}) -> {} ({})):",
|
||||
self.conn_id,
|
||||
self.client_addr,
|
||||
if self.from.is_some() { self.from.as_ref().unwrap() } else { "unk" },
|
||||
self.proto,
|
||||
if self.is_c2s.is_some() { c2s(self.is_c2s.unwrap()) } else { "unk" },
|
||||
if self.to_addr.is_some() { self.to_addr.as_ref().unwrap().to_string() } else { "unk".to_string() },
|
||||
if self.to.is_some() { self.to.as_ref().unwrap() } else { "unk" },
|
||||
);
|
||||
self.log_to = self.log_from.replace(" -> ", " <- ");
|
||||
}
|
||||
|
||||
pub fn log_from(&self) -> &str {
|
||||
&self.log_from
|
||||
}
|
||||
|
||||
pub fn log_to(&self) -> &str {
|
||||
&self.log_to
|
||||
}
|
||||
|
||||
pub fn client_addr(&self) -> &SocketAddr {
|
||||
&self.client_addr
|
||||
}
|
||||
|
||||
pub fn set_proto(&mut self, proto: &'static str) {
|
||||
if log_enabled!(log::Level::Info) {
|
||||
self.proto = proto;
|
||||
self.to_addr = None;
|
||||
self.re_calc();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_c2s_stream_open(&mut self, is_c2s: bool, stream_open: &[u8]) {
|
||||
if log_enabled!(log::Level::Info) {
|
||||
self.is_c2s = Some(is_c2s);
|
||||
self.from = stream_open
|
||||
.extract_between(b" from='", b"'")
|
||||
.or_else(|_| stream_open.extract_between(b" from=\"", b"\""))
|
||||
.map(|b| to_str(b).to_string())
|
||||
.ok();
|
||||
self.to = stream_open
|
||||
.extract_between(b" to='", b"'")
|
||||
.or_else(|_| stream_open.extract_between(b" to=\"", b"\""))
|
||||
.map(|b| to_str(b).to_string())
|
||||
.ok();
|
||||
self.re_calc();
|
||||
info!("{} stream data set", &self.log_from());
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_to_addr(&mut self, to_addr: SocketAddr) {
|
||||
if log_enabled!(log::Level::Info) {
|
||||
self.to_addr = Some(to_addr);
|
||||
self.re_calc();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "s2s-incoming"))]
|
||||
pub type ServerCerts = ();
|
||||
|
||||
#[cfg(feature = "s2s-incoming")]
|
||||
#[derive(Clone)]
|
||||
pub enum ServerCerts {
|
||||
Tls(&'static ServerConnection),
|
||||
#[cfg(feature = "quic")]
|
||||
Quic(quinn::Connection),
|
||||
}
|
||||
|
||||
#[cfg(feature = "s2s-incoming")]
|
||||
impl ServerCerts {
|
||||
pub fn peer_certificates(&self) -> Option<Vec<Certificate>> {
|
||||
match self {
|
||||
ServerCerts::Tls(c) => c.peer_certificates().map(|c| c.to_vec()),
|
||||
#[cfg(feature = "quic")]
|
||||
ServerCerts::Quic(c) => c.peer_identity().and_then(|v| v.downcast::<Vec<Certificate>>().ok()).map(|v| v.to_vec()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sni(&self) -> Option<String> {
|
||||
match self {
|
||||
ServerCerts::Tls(c) => c.sni_hostname().map(|s| s.to_string()),
|
||||
#[cfg(feature = "quic")]
|
||||
ServerCerts::Quic(c) => c.handshake_data().and_then(|v| v.downcast::<quinn::crypto::rustls::HandshakeData>().ok()).and_then(|h| h.server_name),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn alpn(&self) -> Option<Vec<u8>> {
|
||||
match self {
|
||||
ServerCerts::Tls(c) => c.alpn_protocol().map(|s| s.to_vec()),
|
||||
#[cfg(feature = "quic")]
|
||||
ServerCerts::Quic(c) => c.handshake_data().and_then(|v| v.downcast::<quinn::crypto::rustls::HandshakeData>().ok()).and_then(|h| h.protocol),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_tls(&self) -> bool {
|
||||
match self {
|
||||
ServerCerts::Tls(_) => true,
|
||||
#[cfg(feature = "quic")]
|
||||
ServerCerts::Quic(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
mod context;
|
||||
pub mod in_out;
|
||||
|
491
src/main.rs
491
src/main.rs
@ -1,112 +1,14 @@
|
||||
#![deny(clippy::all)]
|
||||
|
||||
use std::ffi::OsString;
|
||||
use std::fs::File;
|
||||
use std::io;
|
||||
use std::io::{BufReader, Read, Write};
|
||||
use std::iter::Iterator;
|
||||
use std::net::SocketAddr;
|
||||
use std::path::Path;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use anyhow::Result;
|
||||
use die::{die, Die};
|
||||
|
||||
use log::{debug, error, info};
|
||||
use serde_derive::Deserialize;
|
||||
|
||||
use tokio::io::{AsyncWriteExt, ReadHalf, WriteHalf};
|
||||
use tokio::net::TcpListener;
|
||||
use std::{ffi::OsString, fs::File, io::Read, iter::Iterator, net::SocketAddr, path::Path, sync::Arc};
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
#[cfg(feature = "rustls")]
|
||||
use rustls::{
|
||||
sign::{CertifiedKey, RsaSigningKey, SigningKey},
|
||||
Certificate, ClientConfig, PrivateKey, ServerConfig, SignatureScheme,
|
||||
};
|
||||
|
||||
#[cfg(feature = "tokio-rustls")]
|
||||
use tokio_rustls::{
|
||||
webpki::{DnsNameRef, TlsServerTrustAnchors, TrustAnchor},
|
||||
TlsConnector,
|
||||
};
|
||||
|
||||
use anyhow::{anyhow, bail, Result};
|
||||
|
||||
mod slicesubsequence;
|
||||
use slicesubsequence::*;
|
||||
|
||||
pub use xmpp_proxy::*;
|
||||
|
||||
#[cfg(feature = "quic")]
|
||||
mod quic;
|
||||
#[cfg(feature = "quic")]
|
||||
use crate::quic::*;
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
mod tls;
|
||||
#[cfg(feature = "tls")]
|
||||
use crate::tls::*;
|
||||
use xmpp_proxy::common::certs_key::CertsKey;
|
||||
|
||||
#[cfg(feature = "outgoing")]
|
||||
mod outgoing;
|
||||
#[cfg(feature = "outgoing")]
|
||||
use crate::outgoing::*;
|
||||
|
||||
#[cfg(any(feature = "s2s-incoming", feature = "outgoing"))]
|
||||
mod srv;
|
||||
#[cfg(any(feature = "s2s-incoming", feature = "outgoing"))]
|
||||
use crate::srv::*;
|
||||
|
||||
#[cfg(feature = "websocket")]
|
||||
mod websocket;
|
||||
#[cfg(feature = "websocket")]
|
||||
use crate::websocket::*;
|
||||
|
||||
#[cfg(any(feature = "s2s-incoming", feature = "outgoing"))]
|
||||
mod verify;
|
||||
#[cfg(any(feature = "s2s-incoming", feature = "outgoing"))]
|
||||
use crate::verify::*;
|
||||
|
||||
mod in_out;
|
||||
pub use crate::in_out::*;
|
||||
|
||||
const IN_BUFFER_SIZE: usize = 8192;
|
||||
|
||||
// todo: split these out to outgoing module
|
||||
|
||||
const ALPN_XMPP_CLIENT: &[u8] = b"xmpp-client";
|
||||
const ALPN_XMPP_SERVER: &[u8] = b"xmpp-server";
|
||||
|
||||
#[cfg(all(feature = "webpki-roots", not(feature = "rustls-native-certs")))]
|
||||
pub use webpki_roots::TLS_SERVER_ROOTS;
|
||||
|
||||
#[cfg(all(feature = "rustls-native-certs", not(feature = "webpki-roots")))]
|
||||
lazy_static::lazy_static! {
|
||||
static ref TLS_SERVER_ROOTS: TlsServerTrustAnchors<'static> = {
|
||||
// we need these to stick around for 'static, this is only called once so no problem
|
||||
let certs = Box::leak(Box::new(rustls_native_certs::load_native_certs().expect("could not load platform certs")));
|
||||
let root_cert_store = Box::leak(Box::new(Vec::new()));
|
||||
for cert in certs {
|
||||
// some system CAs are invalid, ignore those
|
||||
if let Ok(ta) = TrustAnchor::try_from_cert_der(&cert.0) {
|
||||
root_cert_store.push(ta);
|
||||
}
|
||||
}
|
||||
TlsServerTrustAnchors(root_cert_store)
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "rustls-native-certs", feature = "webpki-roots"))]
|
||||
pub fn root_cert_store() -> rustls::RootCertStore {
|
||||
use rustls::{OwnedTrustAnchor, RootCertStore};
|
||||
let mut root_cert_store = RootCertStore::empty();
|
||||
root_cert_store.add_server_trust_anchors(
|
||||
TLS_SERVER_ROOTS
|
||||
.0
|
||||
.iter()
|
||||
.map(|ta| OwnedTrustAnchor::from_subject_spki_name_constraints(ta.subject, ta.spki, ta.name_constraints)),
|
||||
);
|
||||
root_cert_store
|
||||
}
|
||||
use xmpp_proxy::{common::outgoing::OutgoingConfig, outgoing::spawn_outgoing_listener};
|
||||
|
||||
#[derive(Deserialize, Default)]
|
||||
struct Config {
|
||||
@ -123,87 +25,6 @@ struct Config {
|
||||
log_style: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CloneableConfig {
|
||||
max_stanza_size_bytes: usize,
|
||||
#[cfg(feature = "s2s-incoming")]
|
||||
s2s_target: Option<SocketAddr>,
|
||||
#[cfg(feature = "c2s-incoming")]
|
||||
c2s_target: Option<SocketAddr>,
|
||||
proxy: bool,
|
||||
}
|
||||
|
||||
struct CertsKey {
|
||||
#[cfg(feature = "rustls-pemfile")]
|
||||
inner: Result<RwLock<Arc<rustls::sign::CertifiedKey>>>,
|
||||
}
|
||||
|
||||
impl CertsKey {
|
||||
fn new(main_config: &Config) -> Self {
|
||||
CertsKey {
|
||||
#[cfg(feature = "rustls-pemfile")]
|
||||
inner: main_config.certs_key().map(|c| RwLock::new(Arc::new(c))),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(unix, any(feature = "incoming", feature = "s2s-outgoing")))]
|
||||
fn spawn_refresh_task(&'static self, cfg_path: OsString) -> Option<JoinHandle<Result<()>>> {
|
||||
if self.inner.is_err() {
|
||||
None
|
||||
} else {
|
||||
Some(tokio::spawn(async move {
|
||||
use tokio::signal::unix::{signal, SignalKind};
|
||||
let mut stream = signal(SignalKind::hangup())?;
|
||||
loop {
|
||||
stream.recv().await;
|
||||
info!("got SIGHUP");
|
||||
match Config::parse(&cfg_path).and_then(|c| c.certs_key()) {
|
||||
Ok(cert_key) => {
|
||||
if let Ok(rwl) = self.inner.as_ref() {
|
||||
let cert_key = Arc::new(cert_key);
|
||||
let mut certs_key = rwl.write().expect("CertKey poisoned?");
|
||||
*certs_key = cert_key;
|
||||
drop(certs_key);
|
||||
info!("reloaded cert/key successfully!");
|
||||
}
|
||||
}
|
||||
Err(e) => error!("invalid config/cert/key on SIGHUP: {}", e),
|
||||
};
|
||||
}
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "rustls-pemfile")]
|
||||
impl rustls::server::ResolvesServerCert for CertsKey {
|
||||
fn resolve(&self, _: rustls::server::ClientHello) -> Option<Arc<rustls::sign::CertifiedKey>> {
|
||||
self.inner.as_ref().map(|rwl| rwl.read().expect("CertKey poisoned?").clone()).ok()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "rustls-pemfile")]
|
||||
impl rustls::client::ResolvesClientCert for CertsKey {
|
||||
fn resolve(&self, _: &[&[u8]], _: &[SignatureScheme]) -> Option<Arc<CertifiedKey>> {
|
||||
self.inner.as_ref().map(|rwl| rwl.read().expect("CertKey poisoned?").clone()).ok()
|
||||
}
|
||||
|
||||
fn has_certs(&self) -> bool {
|
||||
self.inner.is_ok()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "rustls-pemfile"))]
|
||||
impl rustls::client::ResolvesClientCert for CertsKey {
|
||||
fn resolve(&self, _: &[&[u8]], _: &[SignatureScheme]) -> Option<Arc<CertifiedKey>> {
|
||||
None
|
||||
}
|
||||
|
||||
fn has_certs(&self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
fn parse<P: AsRef<Path>>(path: P) -> Result<Config> {
|
||||
let mut f = File::open(path)?;
|
||||
@ -212,8 +33,9 @@ impl Config {
|
||||
Ok(toml::from_str(&input)?)
|
||||
}
|
||||
|
||||
fn get_cloneable_cfg(&self) -> CloneableConfig {
|
||||
CloneableConfig {
|
||||
#[cfg(feature = "incoming")]
|
||||
fn get_cloneable_cfg(&self) -> xmpp_proxy::common::incoming::CloneableConfig {
|
||||
xmpp_proxy::common::incoming::CloneableConfig {
|
||||
max_stanza_size_bytes: self.max_stanza_size_bytes,
|
||||
#[cfg(feature = "s2s-incoming")]
|
||||
s2s_target: self.s2s_target,
|
||||
@ -238,268 +60,41 @@ impl Config {
|
||||
|
||||
#[cfg(feature = "rustls-pemfile")]
|
||||
fn certs_key(&self) -> Result<rustls::sign::CertifiedKey> {
|
||||
use rustls_pemfile::{certs, read_all, Item};
|
||||
|
||||
let tls_key = read_all(&mut BufReader::new(File::open(&self.tls_key)?))
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid key"))?
|
||||
.into_iter()
|
||||
.flat_map(|item| match item {
|
||||
Item::RSAKey(der) => RsaSigningKey::new(&PrivateKey(der)).ok().map(Arc::new).map(|r| r as Arc<dyn SigningKey>),
|
||||
Item::PKCS8Key(der) => rustls::sign::any_supported_type(&PrivateKey(der)).ok(),
|
||||
Item::ECKey(der) => rustls::sign::any_supported_type(&PrivateKey(der)).ok(),
|
||||
_ => None,
|
||||
})
|
||||
.next()
|
||||
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "invalid key"))?;
|
||||
|
||||
let tls_certs = certs(&mut BufReader::new(File::open(&self.tls_cert)?))
|
||||
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert"))
|
||||
.map(|mut certs| certs.drain(..).map(Certificate).collect())?;
|
||||
|
||||
Ok(rustls::sign::CertifiedKey::new(tls_certs, tls_key))
|
||||
xmpp_proxy::common::read_certified_key(&self.tls_key, &self.tls_cert)
|
||||
}
|
||||
|
||||
#[cfg(feature = "incoming")]
|
||||
fn server_config(&self, certs_key: Arc<CertsKey>) -> Result<ServerConfig> {
|
||||
if let Err(e) = &certs_key.inner {
|
||||
bail!("invalid cert/key: {}", e);
|
||||
}
|
||||
|
||||
let config = ServerConfig::builder().with_safe_defaults();
|
||||
#[cfg(feature = "s2s")]
|
||||
let config = config.with_client_cert_verifier(Arc::new(AllowAnonymousOrAnyCert));
|
||||
#[cfg(not(feature = "s2s"))]
|
||||
let config = config.with_no_client_auth();
|
||||
let mut config = config.with_cert_resolver(certs_key);
|
||||
// todo: will connecting without alpn work then?
|
||||
config.alpn_protocols.push(ALPN_XMPP_CLIENT.to_vec());
|
||||
config.alpn_protocols.push(ALPN_XMPP_SERVER.to_vec());
|
||||
|
||||
Ok(config)
|
||||
#[cfg(not(feature = "rustls-pemfile"))]
|
||||
fn certs_key(&self) -> Result<rustls::sign::CertifiedKey> {
|
||||
anyhow::bail!("rustls-pemfile disabled at compile time")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[cfg(feature = "outgoing")]
|
||||
pub struct OutgoingConfig {
|
||||
max_stanza_size_bytes: usize,
|
||||
certs_key: Arc<CertsKey>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "outgoing")]
|
||||
impl OutgoingConfig {
|
||||
pub fn with_custom_certificate_verifier(&self, is_c2s: bool, cert_verifier: XmppServerCertVerifier) -> OutgoingVerifierConfig {
|
||||
let config = match is_c2s {
|
||||
false => ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_custom_certificate_verifier(Arc::new(cert_verifier))
|
||||
.with_client_cert_resolver(self.certs_key.clone()),
|
||||
_ => ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_custom_certificate_verifier(Arc::new(cert_verifier))
|
||||
.with_no_client_auth(),
|
||||
};
|
||||
|
||||
let mut config_alpn = config.clone();
|
||||
config_alpn.alpn_protocols.push(if is_c2s { ALPN_XMPP_CLIENT } else { ALPN_XMPP_SERVER }.to_vec());
|
||||
|
||||
let config_alpn = Arc::new(config_alpn);
|
||||
|
||||
let connector_alpn: TlsConnector = config_alpn.clone().into();
|
||||
|
||||
let connector: TlsConnector = Arc::new(config).into();
|
||||
|
||||
OutgoingVerifierConfig {
|
||||
max_stanza_size_bytes: self.max_stanza_size_bytes,
|
||||
config_alpn,
|
||||
connector_alpn,
|
||||
connector,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[cfg(feature = "outgoing")]
|
||||
pub struct OutgoingVerifierConfig {
|
||||
pub max_stanza_size_bytes: usize,
|
||||
|
||||
pub config_alpn: Arc<ClientConfig>,
|
||||
pub connector_alpn: TlsConnector,
|
||||
|
||||
pub connector: TlsConnector,
|
||||
}
|
||||
|
||||
#[cfg(feature = "incoming")]
|
||||
async fn shuffle_rd_wr(in_rd: StanzaRead, in_wr: StanzaWrite, config: CloneableConfig, server_certs: ServerCerts, local_addr: SocketAddr, client_addr: &mut Context<'_>) -> Result<()> {
|
||||
let filter = StanzaFilter::new(config.max_stanza_size_bytes);
|
||||
shuffle_rd_wr_filter(in_rd, in_wr, config, server_certs, local_addr, client_addr, filter).await
|
||||
}
|
||||
|
||||
#[cfg(feature = "incoming")]
|
||||
async fn shuffle_rd_wr_filter(
|
||||
mut in_rd: StanzaRead,
|
||||
mut in_wr: StanzaWrite,
|
||||
config: CloneableConfig,
|
||||
server_certs: ServerCerts,
|
||||
local_addr: SocketAddr,
|
||||
client_addr: &mut Context<'_>,
|
||||
mut in_filter: StanzaFilter,
|
||||
) -> Result<()> {
|
||||
// now read to figure out client vs server
|
||||
let (stream_open, is_c2s) = stream_preamble(&mut in_rd, &mut in_wr, client_addr.log_from(), &mut in_filter).await?;
|
||||
client_addr.set_c2s_stream_open(is_c2s, &stream_open);
|
||||
|
||||
#[cfg(feature = "s2s-incoming")]
|
||||
{
|
||||
trace!(
|
||||
"{} connected: sni: {:?}, alpn: {:?}, tls-not-quic: {}",
|
||||
client_addr.log_from(),
|
||||
server_certs.sni(),
|
||||
server_certs.alpn().map(|a| String::from_utf8_lossy(&a).to_string()),
|
||||
server_certs.is_tls(),
|
||||
);
|
||||
|
||||
if !is_c2s {
|
||||
// for s2s we need this
|
||||
use std::time::SystemTime;
|
||||
let domain = stream_open
|
||||
.extract_between(b" from='", b"'")
|
||||
.or_else(|_| stream_open.extract_between(b" from=\"", b"\""))
|
||||
.and_then(|b| Ok(std::str::from_utf8(b)?))?;
|
||||
let (_, cert_verifier) = get_xmpp_connections(domain, is_c2s).await?;
|
||||
let certs = server_certs.peer_certificates().ok_or_else(|| anyhow!("no client cert auth for s2s incoming from {}", domain))?;
|
||||
// todo: send stream error saying cert is invalid
|
||||
cert_verifier.verify_cert(&certs[0], &certs[1..], SystemTime::now())?;
|
||||
}
|
||||
drop(server_certs);
|
||||
}
|
||||
|
||||
let (out_rd, out_wr) = open_incoming(&config, local_addr, client_addr, &stream_open, is_c2s, &mut in_filter).await?;
|
||||
drop(stream_open);
|
||||
|
||||
shuffle_rd_wr_filter_only(
|
||||
in_rd,
|
||||
in_wr,
|
||||
StanzaRead::new(out_rd),
|
||||
StanzaWrite::new(out_wr),
|
||||
is_c2s,
|
||||
config.max_stanza_size_bytes,
|
||||
client_addr,
|
||||
in_filter,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn shuffle_rd_wr_filter_only(
|
||||
mut in_rd: StanzaRead,
|
||||
mut in_wr: StanzaWrite,
|
||||
mut out_rd: StanzaRead,
|
||||
mut out_wr: StanzaWrite,
|
||||
is_c2s: bool,
|
||||
max_stanza_size_bytes: usize,
|
||||
client_addr: &mut Context<'_>,
|
||||
mut in_filter: StanzaFilter,
|
||||
) -> Result<()> {
|
||||
let mut out_filter = StanzaFilter::new(max_stanza_size_bytes);
|
||||
|
||||
#[cfg(all(unix, any(feature = "incoming", feature = "s2s-outgoing")))]
|
||||
fn spawn_refresh_task(certs_key: &'static CertsKey, cfg_path: OsString) -> Option<JoinHandle<Result<()>>> {
|
||||
if certs_key.inner.is_err() {
|
||||
None
|
||||
} else {
|
||||
Some(tokio::spawn(async move {
|
||||
use tokio::signal::unix::{signal, SignalKind};
|
||||
let mut stream = signal(SignalKind::hangup())?;
|
||||
loop {
|
||||
tokio::select! {
|
||||
Ok(ret) = in_rd.next(&mut in_filter, client_addr.log_to(), &mut in_wr) => {
|
||||
match ret {
|
||||
None => break,
|
||||
Some((buf, eoft)) => {
|
||||
trace!("{} '{}'", client_addr.log_from(), to_str(buf));
|
||||
out_wr.write_all(is_c2s, buf, eoft, client_addr.log_from()).await?;
|
||||
out_wr.flush().await?;
|
||||
stream.recv().await;
|
||||
info!("got SIGHUP");
|
||||
match Config::parse(&cfg_path).and_then(|c| c.certs_key()) {
|
||||
Ok(cert_key) => {
|
||||
if let Ok(rwl) = certs_key.inner.as_ref() {
|
||||
let cert_key = Arc::new(cert_key);
|
||||
let mut certs_key = rwl.write().expect("CertKey poisoned?");
|
||||
*certs_key = cert_key;
|
||||
drop(certs_key);
|
||||
info!("reloaded cert/key successfully!");
|
||||
}
|
||||
}
|
||||
},
|
||||
Ok(ret) = out_rd.next(&mut out_filter, client_addr.log_from(), &mut out_wr) => {
|
||||
match ret {
|
||||
None => break,
|
||||
Some((buf, eoft)) => {
|
||||
trace!("{} '{}'", client_addr.log_to(), to_str(buf));
|
||||
in_wr.write_all(is_c2s, buf, eoft, client_addr.log_to()).await?;
|
||||
in_wr.flush().await?;
|
||||
Err(e) => error!("invalid config/cert/key on SIGHUP: {}", e),
|
||||
};
|
||||
}
|
||||
}))
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
info!("{} disconnected", client_addr.log_from());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "incoming")]
|
||||
async fn open_incoming(
|
||||
config: &CloneableConfig,
|
||||
local_addr: SocketAddr,
|
||||
client_addr: &mut Context<'_>,
|
||||
stream_open: &[u8],
|
||||
is_c2s: bool,
|
||||
in_filter: &mut StanzaFilter,
|
||||
) -> Result<(ReadHalf<tokio::net::TcpStream>, WriteHalf<tokio::net::TcpStream>)> {
|
||||
let target = if is_c2s {
|
||||
#[cfg(not(feature = "c2s-incoming"))]
|
||||
bail!("incoming c2s connection but lacking compile-time support");
|
||||
#[cfg(feature = "c2s-incoming")]
|
||||
config.c2s_target
|
||||
} else {
|
||||
#[cfg(not(feature = "s2s-incoming"))]
|
||||
bail!("incoming s2s connection but lacking compile-time support");
|
||||
#[cfg(feature = "s2s-incoming")]
|
||||
config.s2s_target
|
||||
}
|
||||
.ok_or_else(|| anyhow!("incoming connection but `{}_target` not defined", c2s(is_c2s)))?;
|
||||
client_addr.set_to_addr(target);
|
||||
|
||||
let out_stream = tokio::net::TcpStream::connect(target).await?;
|
||||
let (out_rd, mut out_wr) = tokio::io::split(out_stream);
|
||||
|
||||
if config.proxy {
|
||||
/*
|
||||
https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt
|
||||
PROXY TCP4 255.255.255.255 255.255.255.255 65535 65535\r\n
|
||||
PROXY TCP6 ffff:f...f:ffff ffff:f...f:ffff 65535 65535\r\n
|
||||
PROXY TCP6 SOURCE_IP DEST_IP SOURCE_PORT DEST_PORT\r\n
|
||||
*/
|
||||
// tokio AsyncWrite doesn't have write_fmt so have to go through this buffer for some crazy reason
|
||||
//write!(out_wr, "PROXY TCP{} {} {} {} {}\r\n", if client_addr.is_ipv4() { '4' } else {'6' }, client_addr.ip(), local_addr.ip(), client_addr.port(), local_addr.port())?;
|
||||
write!(
|
||||
&mut in_filter.buf[0..],
|
||||
"PROXY TCP{} {} {} {} {}\r\n",
|
||||
if client_addr.client_addr().is_ipv4() { '4' } else { '6' },
|
||||
client_addr.client_addr().ip(),
|
||||
local_addr.ip(),
|
||||
client_addr.client_addr().port(),
|
||||
local_addr.port()
|
||||
)?;
|
||||
let end_idx = &(&in_filter.buf[0..]).first_index_of(b"\n")? + 1;
|
||||
trace!("{} '{}'", client_addr.log_from(), to_str(&in_filter.buf[0..end_idx]));
|
||||
out_wr.write_all(&in_filter.buf[0..end_idx]).await?;
|
||||
}
|
||||
trace!("{} '{}'", client_addr.log_from(), to_str(stream_open));
|
||||
out_wr.write_all(stream_open).await?;
|
||||
out_wr.flush().await?;
|
||||
Ok((out_rd, out_wr))
|
||||
}
|
||||
|
||||
pub async fn stream_preamble(in_rd: &mut StanzaRead, in_wr: &mut StanzaWrite, client_addr: &'_ str, in_filter: &mut StanzaFilter) -> Result<(Vec<u8>, bool)> {
|
||||
let mut stream_open = Vec::new();
|
||||
while let Ok(Some((buf, _))) = in_rd.next(in_filter, client_addr, in_wr).await {
|
||||
trace!("{} received pre-<stream:stream> stanza: '{}'", client_addr, to_str(buf));
|
||||
if buf.starts_with(b"<?xml ") {
|
||||
stream_open.extend_from_slice(buf);
|
||||
} else if buf.starts_with(b"<stream:stream ") {
|
||||
stream_open.extend_from_slice(buf);
|
||||
return Ok((stream_open, buf.contains_seq(br#" xmlns="jabber:client""#) || buf.contains_seq(br#" xmlns='jabber:client'"#)));
|
||||
} else {
|
||||
bail!("bad pre-<stream:stream> stanza: {}", to_str(buf));
|
||||
}
|
||||
}
|
||||
bail!("stream ended before open")
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
@ -533,18 +128,23 @@ async fn main() {
|
||||
die!("log_level or log_style defined in config but logging disabled at compile-time");
|
||||
}
|
||||
|
||||
#[cfg(feature = "incoming")]
|
||||
let config = main_config.get_cloneable_cfg();
|
||||
|
||||
let certs_key = Arc::new(CertsKey::new(&main_config));
|
||||
let certs_key = Arc::new(CertsKey::new(main_config.certs_key()));
|
||||
|
||||
let mut handles: Vec<JoinHandle<Result<()>>> = Vec::new();
|
||||
if !main_config.incoming_listen.is_empty() {
|
||||
#[cfg(all(any(feature = "tls", feature = "websocket"), feature = "incoming"))]
|
||||
{
|
||||
use xmpp_proxy::{
|
||||
common::incoming::server_config,
|
||||
tls::incoming::{spawn_tls_listener, tls_acceptor},
|
||||
};
|
||||
if main_config.c2s_target.is_none() && main_config.s2s_target.is_none() {
|
||||
die!("one of c2s_target/s2s_target must be defined if incoming_listen is non-empty");
|
||||
}
|
||||
let acceptor = main_config.tls_acceptor(certs_key.clone()).die("invalid cert/key ?");
|
||||
let acceptor = tls_acceptor(server_config(certs_key.clone()).die("invalid cert/key ?"));
|
||||
for listener in main_config.incoming_listen.iter() {
|
||||
handles.push(spawn_tls_listener(listener.parse().die("invalid listener address"), config.clone(), acceptor.clone()));
|
||||
}
|
||||
@ -555,10 +155,14 @@ async fn main() {
|
||||
if !main_config.quic_listen.is_empty() {
|
||||
#[cfg(all(feature = "quic", feature = "incoming"))]
|
||||
{
|
||||
use xmpp_proxy::{
|
||||
common::incoming::server_config,
|
||||
quic::incoming::{quic_server_config, spawn_quic_listener},
|
||||
};
|
||||
if main_config.c2s_target.is_none() && main_config.s2s_target.is_none() {
|
||||
die!("one of c2s_target/s2s_target must be defined if quic_listen is non-empty");
|
||||
}
|
||||
let quic_config = main_config.quic_server_config(certs_key.clone()).die("invalid cert/key ?");
|
||||
let quic_config = quic_server_config(server_config(certs_key.clone()).die("invalid cert/key ?"));
|
||||
for listener in main_config.quic_listen.iter() {
|
||||
handles.push(spawn_quic_listener(listener.parse().die("invalid listener address"), config.clone(), quic_config.clone()));
|
||||
}
|
||||
@ -581,9 +185,12 @@ async fn main() {
|
||||
die!("all of incoming_listen, quic_listen, outgoing_listen empty, nothing to do, exiting...");
|
||||
}
|
||||
#[cfg(all(unix, any(feature = "incoming", feature = "s2s-outgoing")))]
|
||||
if let Some(refresh_task) = Box::leak(Box::new(certs_key.clone())).spawn_refresh_task(cfg_path) {
|
||||
{
|
||||
let certs_key = Box::leak(Box::new(certs_key.clone()));
|
||||
if let Some(refresh_task) = spawn_refresh_task(certs_key, cfg_path) {
|
||||
handles.push(refresh_task);
|
||||
}
|
||||
}
|
||||
|
||||
info!("xmpp-proxy started");
|
||||
futures::future::join_all(handles).await;
|
||||
|
@ -1,4 +1,16 @@
|
||||
use crate::*;
|
||||
use crate::{
|
||||
common::{first_bytes_match, outgoing::OutgoingConfig, shuffle_rd_wr_filter_only, stream_preamble},
|
||||
context::Context,
|
||||
in_out::{StanzaRead, StanzaWrite},
|
||||
slicesubsequence::SliceSubsequence,
|
||||
srv::srv_connect,
|
||||
stanzafilter::StanzaFilter,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use die::Die;
|
||||
use log::{error, info};
|
||||
use std::net::SocketAddr;
|
||||
use tokio::{net::TcpListener, task::JoinHandle};
|
||||
|
||||
async fn handle_outgoing_connection(stream: tokio::net::TcpStream, client_addr: &mut Context<'_>, config: OutgoingConfig) -> Result<()> {
|
||||
info!("{} connected", client_addr.log_from());
|
||||
@ -7,7 +19,7 @@ async fn handle_outgoing_connection(stream: tokio::net::TcpStream, client_addr:
|
||||
|
||||
#[cfg(feature = "websocket")]
|
||||
let (mut in_rd, mut in_wr) = if first_bytes_match(&stream, &mut in_filter.buf[0..3], |p| p == b"GET").await? {
|
||||
incoming_websocket_connection(Box::new(stream), config.max_stanza_size_bytes).await?
|
||||
crate::websocket::incoming_websocket_connection(Box::new(stream), config.max_stanza_size_bytes).await?
|
||||
} else {
|
||||
let (in_rd, in_wr) = tokio::io::split(stream);
|
||||
(StanzaRead::new(in_rd), StanzaWrite::new(in_wr))
|
||||
|
@ -1,40 +1,16 @@
|
||||
use crate::*;
|
||||
use futures::StreamExt;
|
||||
use quinn::{ServerConfig, TransportConfig};
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
|
||||
use crate::{
|
||||
common::incoming::{shuffle_rd_wr, CloneableConfig, ServerCerts},
|
||||
context::Context,
|
||||
in_out::{StanzaRead, StanzaWrite},
|
||||
};
|
||||
use anyhow::Result;
|
||||
use die::Die;
|
||||
use futures::StreamExt;
|
||||
use log::{error, info};
|
||||
use quinn::ServerConfig;
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
#[cfg(feature = "outgoing")]
|
||||
pub async fn quic_connect(target: SocketAddr, server_name: &str, config: OutgoingVerifierConfig) -> Result<(StanzaWrite, StanzaRead)> {
|
||||
let bind_addr = "0.0.0.0:0".parse().unwrap();
|
||||
let client_cfg = config.config_alpn;
|
||||
|
||||
let mut endpoint = quinn::Endpoint::client(bind_addr)?;
|
||||
endpoint.set_default_client_config(quinn::ClientConfig::new(client_cfg));
|
||||
|
||||
// connect to server
|
||||
let quinn::NewConnection { connection, .. } = endpoint.connect(target, server_name)?.await?;
|
||||
trace!("quic connected: addr={}", connection.remote_address());
|
||||
|
||||
let (wrt, rd) = connection.open_bi().await?;
|
||||
Ok((StanzaWrite::new(wrt), StanzaRead::new(rd)))
|
||||
}
|
||||
|
||||
#[cfg(feature = "incoming")]
|
||||
impl Config {
|
||||
pub fn quic_server_config(&self, cert_key: Arc<CertsKey>) -> Result<ServerConfig> {
|
||||
let transport_config = TransportConfig::default();
|
||||
// todo: configure transport_config here if needed
|
||||
let server_config = self.server_config(cert_key)?;
|
||||
let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(server_config));
|
||||
server_config.transport = Arc::new(transport_config);
|
||||
|
||||
Ok(server_config)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "incoming")]
|
||||
pub fn spawn_quic_listener(local_addr: SocketAddr, config: CloneableConfig, server_config: ServerConfig) -> JoinHandle<Result<()>> {
|
||||
let (_endpoint, mut incoming) = quinn::Endpoint::server(server_config, local_addr).die("cannot listen on port/interface");
|
||||
tokio::spawn(async move {
|
||||
@ -43,7 +19,7 @@ pub fn spawn_quic_listener(local_addr: SocketAddr, config: CloneableConfig, serv
|
||||
let config = config.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Ok(mut new_conn) = incoming_conn.await {
|
||||
let client_addr = crate::Context::new("quic-in", new_conn.connection.remote_address());
|
||||
let client_addr = Context::new("quic-in", new_conn.connection.remote_address());
|
||||
|
||||
#[cfg(feature = "s2s-incoming")]
|
||||
let server_certs = ServerCerts::Quic(new_conn.connection);
|
||||
@ -70,3 +46,12 @@ pub fn spawn_quic_listener(local_addr: SocketAddr, config: CloneableConfig, serv
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn quic_server_config(server_config: rustls::ServerConfig) -> ServerConfig {
|
||||
let transport_config = quinn::TransportConfig::default();
|
||||
// todo: configure transport_config here if needed
|
||||
let mut server_config = ServerConfig::with_crypto(Arc::new(server_config));
|
||||
server_config.transport = Arc::new(transport_config);
|
||||
|
||||
server_config
|
||||
}
|
5
src/quic/mod.rs
Normal file
5
src/quic/mod.rs
Normal file
@ -0,0 +1,5 @@
|
||||
#[cfg(feature = "incoming")]
|
||||
pub mod incoming;
|
||||
|
||||
#[cfg(feature = "outgoing")]
|
||||
pub mod outgoing;
|
23
src/quic/outgoing.rs
Normal file
23
src/quic/outgoing.rs
Normal file
@ -0,0 +1,23 @@
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use crate::{
|
||||
common::outgoing::OutgoingVerifierConfig,
|
||||
in_out::{StanzaRead, StanzaWrite},
|
||||
};
|
||||
use anyhow::Result;
|
||||
use log::trace;
|
||||
|
||||
pub async fn quic_connect(target: SocketAddr, server_name: &str, config: OutgoingVerifierConfig) -> Result<(StanzaWrite, StanzaRead)> {
|
||||
let bind_addr = "0.0.0.0:0".parse().unwrap();
|
||||
let client_cfg = config.config_alpn;
|
||||
|
||||
let mut endpoint = quinn::Endpoint::client(bind_addr)?;
|
||||
endpoint.set_default_client_config(quinn::ClientConfig::new(client_cfg));
|
||||
|
||||
// connect to server
|
||||
let quinn::NewConnection { connection, .. } = endpoint.connect(target, server_name)?.await?;
|
||||
trace!("quic connected: addr={}", connection.remote_address());
|
||||
|
||||
let (wrt, rd) = connection.open_bi().await?;
|
||||
Ok((StanzaWrite::new(wrt), StanzaRead::new(rd)))
|
||||
}
|
53
src/srv.rs
53
src/srv.rs
@ -1,22 +1,33 @@
|
||||
#![allow(clippy::upper_case_acronyms)]
|
||||
|
||||
use std::cmp::Ordering;
|
||||
use std::convert::TryFrom;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
|
||||
use data_encoding::BASE64;
|
||||
use ring::digest::{Algorithm, Context as DigestContext, SHA256, SHA512};
|
||||
|
||||
use trust_dns_resolver::error::ResolveError;
|
||||
use trust_dns_resolver::lookup::{SrvLookup, TxtLookup};
|
||||
use trust_dns_resolver::{IntoName, TokioAsyncResolver};
|
||||
|
||||
#[cfg(feature = "outgoing")]
|
||||
use crate::common::outgoing::{OutgoingConfig, OutgoingVerifierConfig};
|
||||
use crate::{
|
||||
common::{stream_preamble, to_str},
|
||||
context::Context,
|
||||
in_out::{StanzaRead, StanzaWrite},
|
||||
slicesubsequence::SliceSubsequence,
|
||||
stanzafilter::{StanzaFilter, StanzaReader},
|
||||
verify::XmppServerCertVerifier,
|
||||
};
|
||||
use anyhow::{bail, Result};
|
||||
use tokio_rustls::webpki::DnsName;
|
||||
use data_encoding::BASE64;
|
||||
use log::{debug, error, trace};
|
||||
use ring::digest::{Algorithm, Context as DigestContext, SHA256, SHA512};
|
||||
use serde::Deserialize;
|
||||
use std::{
|
||||
cmp::Ordering,
|
||||
convert::TryFrom,
|
||||
net::{IpAddr, SocketAddr},
|
||||
};
|
||||
use tokio_rustls::webpki::{DnsName, DnsNameRef};
|
||||
#[cfg(feature = "websocket")]
|
||||
use tokio_tungstenite::tungstenite::http::Uri;
|
||||
|
||||
use crate::*;
|
||||
use trust_dns_resolver::{
|
||||
error::ResolveError,
|
||||
lookup::{SrvLookup, TxtLookup},
|
||||
IntoName, TokioAsyncResolver,
|
||||
};
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
static ref RESOLVER: TokioAsyncResolver = make_resolver();
|
||||
@ -165,7 +176,7 @@ impl XmppConnection {
|
||||
&self,
|
||||
domain: &str,
|
||||
stream_open: &[u8],
|
||||
in_filter: &mut crate::StanzaFilter,
|
||||
in_filter: &mut StanzaFilter,
|
||||
client_addr: &mut Context<'_>,
|
||||
config: OutgoingVerifierConfig,
|
||||
) -> Result<(StanzaWrite, StanzaRead, SocketAddr, &'static str)> {
|
||||
@ -184,28 +195,28 @@ impl XmppConnection {
|
||||
debug!("{} trying ip {}", client_addr.log_from(), to_addr);
|
||||
match self.conn_type {
|
||||
#[cfg(feature = "tls")]
|
||||
XmppConnectionType::StartTLS => match crate::starttls_connect(to_addr, domain, stream_open, in_filter, config.clone()).await {
|
||||
XmppConnectionType::StartTLS => match crate::tls::outgoing::starttls_connect(to_addr, domain, stream_open, in_filter, config.clone()).await {
|
||||
Ok((wr, rd)) => return Ok((wr, rd, to_addr, "starttls-out")),
|
||||
Err(e) => error!("starttls connection failed to IP {} from SRV {}, error: {}", to_addr, self.target, e),
|
||||
},
|
||||
#[cfg(feature = "tls")]
|
||||
XmppConnectionType::DirectTLS => match crate::tls_connect(to_addr, domain, config.clone()).await {
|
||||
XmppConnectionType::DirectTLS => match crate::tls::outgoing::tls_connect(to_addr, domain, config.clone()).await {
|
||||
Ok((wr, rd)) => return Ok((wr, rd, to_addr, "directtls-out")),
|
||||
Err(e) => error!("direct tls connection failed to IP {} from SRV {}, error: {}", to_addr, self.target, e),
|
||||
},
|
||||
#[cfg(feature = "quic")]
|
||||
XmppConnectionType::QUIC => match crate::quic_connect(to_addr, domain, config.clone()).await {
|
||||
XmppConnectionType::QUIC => match crate::quic::outgoing::quic_connect(to_addr, domain, config.clone()).await {
|
||||
Ok((wr, rd)) => return Ok((wr, rd, to_addr, "quic-out")),
|
||||
Err(e) => error!("quic connection failed to IP {} from SRV {}, error: {}", to_addr, self.target, e),
|
||||
},
|
||||
#[cfg(feature = "websocket")]
|
||||
// todo: when websocket is found via DNS, we need to validate cert against domain, *not* target, this is a security problem with XEP-0156, we are doing it the secure but likely unexpected way here for now
|
||||
XmppConnectionType::WebSocket(ref url, ref origin) => match crate::websocket_connect(to_addr, domain, url, origin, config.clone()).await {
|
||||
XmppConnectionType::WebSocket(ref url, ref origin) => match crate::websocket::outgoing::websocket_connect(to_addr, domain, url, origin, config.clone()).await {
|
||||
Ok((wr, rd)) => return Ok((wr, rd, to_addr, "websocket-out")),
|
||||
Err(e) => {
|
||||
if self.secure && self.target != orig_domain {
|
||||
// https is a special case, as target is sent in the Host: header, so we have to literally try twice in case this is set for the other on the server
|
||||
match crate::websocket_connect(to_addr, orig_domain, url, origin, config.clone()).await {
|
||||
match crate::websocket::outgoing::websocket_connect(to_addr, orig_domain, url, origin, config.clone()).await {
|
||||
Ok((wr, rd)) => return Ok((wr, rd, to_addr, "websocket-out")),
|
||||
Err(e2) => error!("websocket connection failed to IP {} from TXT {}, error try 1: {}, error try 2: {}", to_addr, url, e, e2),
|
||||
}
|
||||
@ -428,7 +439,7 @@ pub async fn srv_connect(
|
||||
domain: &str,
|
||||
is_c2s: bool,
|
||||
stream_open: &[u8],
|
||||
in_filter: &mut crate::StanzaFilter,
|
||||
in_filter: &mut StanzaFilter,
|
||||
client_addr: &mut Context<'_>,
|
||||
config: OutgoingConfig,
|
||||
) -> Result<(StanzaWrite, StanzaRead, Vec<u8>)> {
|
||||
|
@ -1,9 +1,9 @@
|
||||
#![allow(clippy::upper_case_acronyms)]
|
||||
|
||||
use crate::common::to_str;
|
||||
use anyhow::{bail, Result};
|
||||
|
||||
use crate::stanzafilter::StanzaState::*;
|
||||
use crate::to_str;
|
||||
use StanzaState::*;
|
||||
|
||||
#[derive(Debug)]
|
||||
enum StanzaState {
|
||||
|
@ -1,71 +1,30 @@
|
||||
use crate::*;
|
||||
use rustls::ServerConnection;
|
||||
use std::convert::TryFrom;
|
||||
use tokio::io::{AsyncBufReadExt, BufStream};
|
||||
use crate::common::incoming::{shuffle_rd_wr_filter, CloneableConfig, ServerCerts};
|
||||
|
||||
use tokio_rustls::{rustls::ServerName, TlsAcceptor};
|
||||
use crate::{
|
||||
common::{first_bytes_match, to_str, IN_BUFFER_SIZE},
|
||||
context::Context,
|
||||
in_out::{StanzaRead, StanzaWrite},
|
||||
slicesubsequence::SliceSubsequence,
|
||||
stanzafilter::{StanzaFilter, StanzaReader},
|
||||
*,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use die::Die;
|
||||
use log::{error, trace};
|
||||
use rustls::{ServerConfig, ServerConnection};
|
||||
|
||||
#[cfg(feature = "outgoing")]
|
||||
pub async fn tls_connect(target: SocketAddr, server_name: &str, config: OutgoingVerifierConfig) -> Result<(StanzaWrite, StanzaRead)> {
|
||||
let dnsname = ServerName::try_from(server_name)?;
|
||||
let stream = tokio::net::TcpStream::connect(target).await?;
|
||||
let stream = config.connector_alpn.connect(dnsname, stream).await?;
|
||||
let (rd, wrt) = tokio::io::split(stream);
|
||||
Ok((StanzaWrite::new(wrt), StanzaRead::new(rd)))
|
||||
use std::sync::Arc;
|
||||
use tokio::{
|
||||
io::{AsyncBufReadExt, AsyncWriteExt, BufStream},
|
||||
net::TcpListener,
|
||||
task::JoinHandle,
|
||||
};
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
|
||||
pub fn tls_acceptor(server_config: ServerConfig) -> TlsAcceptor {
|
||||
TlsAcceptor::from(Arc::new(server_config))
|
||||
}
|
||||
|
||||
#[cfg(feature = "outgoing")]
|
||||
pub async fn starttls_connect(target: SocketAddr, server_name: &str, stream_open: &[u8], in_filter: &mut StanzaFilter, config: OutgoingVerifierConfig) -> Result<(StanzaWrite, StanzaRead)> {
|
||||
let dnsname = ServerName::try_from(server_name)?;
|
||||
let mut stream = tokio::net::TcpStream::connect(target).await?;
|
||||
let (in_rd, mut in_wr) = stream.split();
|
||||
|
||||
// send the stream_open
|
||||
trace!("starttls sending: {} '{}'", server_name, to_str(stream_open));
|
||||
in_wr.write_all(stream_open).await?;
|
||||
in_wr.flush().await?;
|
||||
|
||||
// we naively read 1 byte at a time, which buffering significantly speeds up
|
||||
let in_rd = tokio::io::BufReader::with_capacity(IN_BUFFER_SIZE, in_rd);
|
||||
let mut in_rd = StanzaReader(in_rd);
|
||||
let mut proceed_received = false;
|
||||
|
||||
trace!("starttls reading stream open {}", server_name);
|
||||
while let Ok(Some(buf)) = in_rd.next(in_filter).await {
|
||||
trace!("received pre-tls stanza: {} '{}'", server_name, to_str(buf));
|
||||
if buf.starts_with(b"<?xml ") || buf.starts_with(b"<stream:stream ") {
|
||||
// ignore this
|
||||
} else if buf.starts_with(b"<stream:features") {
|
||||
// we send starttls regardless, it could have been stripped out, we don't do plaintext
|
||||
let buf = br###"<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>"###;
|
||||
trace!("> {} '{}'", server_name, to_str(buf));
|
||||
in_wr.write_all(buf).await?;
|
||||
in_wr.flush().await?;
|
||||
} else if buf.starts_with(b"<proceed ") {
|
||||
proceed_received = true;
|
||||
break;
|
||||
} else {
|
||||
bail!("bad pre-tls stanza: {}", to_str(buf));
|
||||
}
|
||||
}
|
||||
if !proceed_received {
|
||||
bail!("stream ended before proceed");
|
||||
}
|
||||
|
||||
debug!("starttls starting TLS {}", server_name);
|
||||
let stream = config.connector.connect(dnsname, stream).await?;
|
||||
let (rd, wrt) = tokio::io::split(stream);
|
||||
Ok((StanzaWrite::new(wrt), StanzaRead::new(rd)))
|
||||
}
|
||||
|
||||
#[cfg(feature = "incoming")]
|
||||
impl Config {
|
||||
pub fn tls_acceptor(&self, cert_key: Arc<CertsKey>) -> Result<TlsAcceptor> {
|
||||
Ok(TlsAcceptor::from(Arc::new(self.server_config(cert_key)?)))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "incoming")]
|
||||
pub fn spawn_tls_listener(local_addr: SocketAddr, config: CloneableConfig, acceptor: TlsAcceptor) -> JoinHandle<Result<()>> {
|
||||
tokio::spawn(async move {
|
||||
let listener = TcpListener::bind(&local_addr).await.die("cannot listen on port/interface");
|
||||
@ -83,7 +42,6 @@ pub fn spawn_tls_listener(local_addr: SocketAddr, config: CloneableConfig, accep
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(feature = "incoming")]
|
||||
async fn handle_tls_connection(mut stream: tokio::net::TcpStream, client_addr: &mut Context<'_>, local_addr: SocketAddr, config: CloneableConfig, acceptor: TlsAcceptor) -> Result<()> {
|
||||
info!("{} connected", client_addr.log_from());
|
||||
|
||||
@ -183,7 +141,7 @@ async fn handle_tls_connection(mut stream: tokio::net::TcpStream, client_addr: &
|
||||
{
|
||||
let stream: tokio_rustls::TlsStream<tokio::net::TcpStream> = stream.into();
|
||||
|
||||
let mut stream = BufStream::with_capacity(crate::IN_BUFFER_SIZE, 0, stream);
|
||||
let mut stream = BufStream::with_capacity(IN_BUFFER_SIZE, 0, stream);
|
||||
let websocket = {
|
||||
// wait up to 10 seconds until 3 bytes have been read
|
||||
use std::time::{Duration, Instant};
|
||||
@ -207,7 +165,7 @@ async fn handle_tls_connection(mut stream: tokio::net::TcpStream, client_addr: &
|
||||
};
|
||||
|
||||
if websocket {
|
||||
handle_websocket_connection(Box::new(stream), config, server_certs, local_addr, client_addr, in_filter).await
|
||||
crate::websocket::incoming::handle_websocket_connection(Box::new(stream), config, server_certs, local_addr, client_addr, in_filter).await
|
||||
} else {
|
||||
let (in_rd, in_wr) = tokio::io::split(stream);
|
||||
shuffle_rd_wr_filter(StanzaRead::already_buffered(in_rd), StanzaWrite::new(in_wr), config, server_certs, local_addr, client_addr, in_filter).await
|
5
src/tls/mod.rs
Normal file
5
src/tls/mod.rs
Normal file
@ -0,0 +1,5 @@
|
||||
#[cfg(feature = "incoming")]
|
||||
pub mod incoming;
|
||||
|
||||
#[cfg(feature = "outgoing")]
|
||||
pub mod outgoing;
|
61
src/tls/outgoing.rs
Normal file
61
src/tls/outgoing.rs
Normal file
@ -0,0 +1,61 @@
|
||||
use crate::{
|
||||
common::{outgoing::OutgoingVerifierConfig, to_str, IN_BUFFER_SIZE},
|
||||
in_out::{StanzaRead, StanzaWrite},
|
||||
stanzafilter::{StanzaFilter, StanzaReader},
|
||||
};
|
||||
use anyhow::{bail, Result};
|
||||
use log::{debug, trace};
|
||||
use rustls::ServerName;
|
||||
use std::{convert::TryFrom, net::SocketAddr};
|
||||
use tokio::io::AsyncWriteExt;
|
||||
|
||||
pub async fn tls_connect(target: SocketAddr, server_name: &str, config: OutgoingVerifierConfig) -> Result<(StanzaWrite, StanzaRead)> {
|
||||
let dnsname = ServerName::try_from(server_name)?;
|
||||
let stream = tokio::net::TcpStream::connect(target).await?;
|
||||
let stream = config.connector_alpn.connect(dnsname, stream).await?;
|
||||
let (rd, wrt) = tokio::io::split(stream);
|
||||
Ok((StanzaWrite::new(wrt), StanzaRead::new(rd)))
|
||||
}
|
||||
|
||||
pub async fn starttls_connect(target: SocketAddr, server_name: &str, stream_open: &[u8], in_filter: &mut StanzaFilter, config: OutgoingVerifierConfig) -> Result<(StanzaWrite, StanzaRead)> {
|
||||
let dnsname = ServerName::try_from(server_name)?;
|
||||
let mut stream = tokio::net::TcpStream::connect(target).await?;
|
||||
let (in_rd, mut in_wr) = stream.split();
|
||||
|
||||
// send the stream_open
|
||||
trace!("starttls sending: {} '{}'", server_name, to_str(stream_open));
|
||||
in_wr.write_all(stream_open).await?;
|
||||
in_wr.flush().await?;
|
||||
|
||||
// we naively read 1 byte at a time, which buffering significantly speeds up
|
||||
let in_rd = tokio::io::BufReader::with_capacity(IN_BUFFER_SIZE, in_rd);
|
||||
let mut in_rd = StanzaReader(in_rd);
|
||||
let mut proceed_received = false;
|
||||
|
||||
trace!("starttls reading stream open {}", server_name);
|
||||
while let Ok(Some(buf)) = in_rd.next(in_filter).await {
|
||||
trace!("received pre-tls stanza: {} '{}'", server_name, to_str(buf));
|
||||
if buf.starts_with(b"<?xml ") || buf.starts_with(b"<stream:stream ") {
|
||||
// ignore this
|
||||
} else if buf.starts_with(b"<stream:features") {
|
||||
// we send starttls regardless, it could have been stripped out, we don't do plaintext
|
||||
let buf = br###"<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>"###;
|
||||
trace!("> {} '{}'", server_name, to_str(buf));
|
||||
in_wr.write_all(buf).await?;
|
||||
in_wr.flush().await?;
|
||||
} else if buf.starts_with(b"<proceed ") {
|
||||
proceed_received = true;
|
||||
break;
|
||||
} else {
|
||||
bail!("bad pre-tls stanza: {}", to_str(buf));
|
||||
}
|
||||
}
|
||||
if !proceed_received {
|
||||
bail!("stream ended before proceed");
|
||||
}
|
||||
|
||||
debug!("starttls starting TLS {}", server_name);
|
||||
let stream = config.connector.connect(dnsname, stream).await?;
|
||||
let (rd, wrt) = tokio::io::split(stream);
|
||||
Ok((StanzaWrite::new(wrt), StanzaRead::new(rd)))
|
||||
}
|
@ -1,13 +1,16 @@
|
||||
use crate::{digest, Posh};
|
||||
use crate::{
|
||||
common::ca_roots::TLS_SERVER_ROOTS,
|
||||
srv::{digest, Posh},
|
||||
};
|
||||
use log::debug;
|
||||
use ring::digest::SHA256;
|
||||
use rustls::client::{ServerCertVerified, ServerCertVerifier};
|
||||
use rustls::server::{ClientCertVerified, ClientCertVerifier};
|
||||
use rustls::{Certificate, DistinguishedNames, Error, ServerName};
|
||||
use std::convert::TryFrom;
|
||||
use std::time::SystemTime;
|
||||
use tokio_rustls::webpki;
|
||||
use tokio_rustls::webpki::DnsName;
|
||||
use rustls::{
|
||||
client::{ServerCertVerified, ServerCertVerifier},
|
||||
server::{ClientCertVerified, ClientCertVerifier},
|
||||
Certificate, DistinguishedNames, Error, ServerName,
|
||||
};
|
||||
use std::{convert::TryFrom, time::SystemTime};
|
||||
use tokio_rustls::{webpki, webpki::DnsName};
|
||||
|
||||
type SignatureAlgorithms = &'static [&'static webpki::SignatureAlgorithm];
|
||||
|
||||
@ -112,8 +115,7 @@ impl XmppServerCertVerifier {
|
||||
let (cert, chain) = prepare(end_entity, intermediates)?;
|
||||
let webpki_now = webpki::Time::try_from(now).map_err(|_| Error::FailedToGetCurrentTime)?;
|
||||
|
||||
cert.verify_is_valid_tls_server_cert(SUPPORTED_SIG_ALGS, &crate::TLS_SERVER_ROOTS, &chain, webpki_now)
|
||||
.map_err(pki_error)?;
|
||||
cert.verify_is_valid_tls_server_cert(SUPPORTED_SIG_ALGS, &TLS_SERVER_ROOTS, &chain, webpki_now).map_err(pki_error)?;
|
||||
|
||||
for name in &self.names {
|
||||
if cert.verify_is_valid_for_dns_name(name.as_ref()).is_ok() {
|
||||
|
25
src/websocket/incoming.rs
Normal file
25
src/websocket/incoming.rs
Normal file
@ -0,0 +1,25 @@
|
||||
use crate::{
|
||||
common::incoming::{shuffle_rd_wr_filter, CloneableConfig, ServerCerts},
|
||||
context::Context,
|
||||
stanzafilter::StanzaFilter,
|
||||
websocket::{incoming_websocket_connection, AsyncReadAndWrite},
|
||||
};
|
||||
use anyhow::Result;
|
||||
use log::info;
|
||||
use std::net::SocketAddr;
|
||||
|
||||
pub async fn handle_websocket_connection(
|
||||
stream: Box<dyn AsyncReadAndWrite + Unpin + Send>,
|
||||
config: CloneableConfig,
|
||||
server_certs: ServerCerts,
|
||||
local_addr: SocketAddr,
|
||||
client_addr: &mut Context<'_>,
|
||||
in_filter: StanzaFilter,
|
||||
) -> Result<()> {
|
||||
client_addr.set_proto("websocket-in");
|
||||
info!("{} connected", client_addr.log_from());
|
||||
|
||||
let (in_rd, in_wr) = incoming_websocket_connection(stream, config.max_stanza_size_bytes).await?;
|
||||
|
||||
shuffle_rd_wr_filter(in_rd, in_wr, config, server_certs, local_addr, client_addr, in_filter).await
|
||||
}
|
@ -1,9 +1,14 @@
|
||||
use crate::*;
|
||||
use anyhow::Result;
|
||||
use futures::StreamExt;
|
||||
|
||||
use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
|
||||
|
||||
#[cfg(feature = "incoming")]
|
||||
pub mod incoming;
|
||||
|
||||
#[cfg(feature = "outgoing")]
|
||||
pub mod outgoing;
|
||||
|
||||
// https://datatracker.ietf.org/doc/html/rfc7395
|
||||
|
||||
fn ws_cfg(max_stanza_size_bytes: usize) -> Option<WebSocketConfig> {
|
||||
@ -29,23 +34,6 @@ pub async fn incoming_websocket_connection(stream: Box<dyn AsyncReadAndWrite + U
|
||||
Ok((StanzaRead::WebSocketRead(in_rd), StanzaWrite::WebSocketClientWrite(in_wr)))
|
||||
}
|
||||
|
||||
#[cfg(feature = "incoming")]
|
||||
pub async fn handle_websocket_connection(
|
||||
stream: Box<dyn AsyncReadAndWrite + Unpin + Send>,
|
||||
config: CloneableConfig,
|
||||
server_certs: ServerCerts,
|
||||
local_addr: SocketAddr,
|
||||
client_addr: &mut Context<'_>,
|
||||
in_filter: StanzaFilter,
|
||||
) -> Result<()> {
|
||||
client_addr.set_proto("websocket-in");
|
||||
info!("{} connected", client_addr.log_from());
|
||||
|
||||
let (in_rd, in_wr) = incoming_websocket_connection(stream, config.max_stanza_size_bytes).await?;
|
||||
|
||||
shuffle_rd_wr_filter(in_rd, in_wr, config, server_certs, local_addr, client_addr, in_filter).await
|
||||
}
|
||||
|
||||
pub fn from_ws(stanza: String) -> String {
|
||||
if stanza.starts_with("<open ") {
|
||||
let stanza = stanza
|
||||
@ -97,34 +85,10 @@ pub fn to_ws_new(buf: &[u8], mut end_of_first_tag: usize, is_c2s: bool) -> Resul
|
||||
Ok(ret)
|
||||
}
|
||||
|
||||
use rustls::ServerName;
|
||||
use std::convert::TryFrom;
|
||||
|
||||
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
|
||||
use tokio_tungstenite::tungstenite::http::header::{ORIGIN, SEC_WEBSOCKET_PROTOCOL};
|
||||
use tokio_tungstenite::tungstenite::http::Uri;
|
||||
|
||||
#[cfg(feature = "outgoing")]
|
||||
pub async fn websocket_connect(target: SocketAddr, server_name: &str, url: &Uri, origin: &str, config: OutgoingVerifierConfig) -> Result<(StanzaWrite, StanzaRead)> {
|
||||
let mut request = url.into_client_request()?;
|
||||
request.headers_mut().append(SEC_WEBSOCKET_PROTOCOL, "xmpp".parse()?);
|
||||
request.headers_mut().append(ORIGIN, origin.parse()?);
|
||||
|
||||
let dnsname = ServerName::try_from(server_name)?;
|
||||
let stream = tokio::net::TcpStream::connect(target).await?;
|
||||
let stream = config.connector.connect(dnsname, stream).await?;
|
||||
|
||||
//let stream: tokio_rustls::TlsStream<tokio::net::TcpStream> = stream.into();
|
||||
// todo: tokio_tungstenite seems to have a bug, if the write buffer is non-zero, it'll hang forever, even though we always flush, investigate
|
||||
//let stream = BufStream::with_capacity(crate::IN_BUFFER_SIZE, 0, stream);
|
||||
let stream: Box<dyn AsyncReadAndWrite + Unpin + Send> = Box::new(stream);
|
||||
|
||||
let (stream, _) = tokio_tungstenite::client_async_with_config(request, stream, ws_cfg(config.max_stanza_size_bytes)).await?;
|
||||
|
||||
let (wrt, rd) = stream.split();
|
||||
|
||||
Ok((StanzaWrite::WebSocketClientWrite(wrt), StanzaRead::WebSocketRead(rd)))
|
||||
}
|
||||
use crate::{
|
||||
in_out::{StanzaRead, StanzaWrite},
|
||||
slicesubsequence::SliceSubsequence,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
37
src/websocket/outgoing.rs
Normal file
37
src/websocket/outgoing.rs
Normal file
@ -0,0 +1,37 @@
|
||||
use crate::{
|
||||
common::outgoing::OutgoingVerifierConfig,
|
||||
in_out::{StanzaRead, StanzaWrite},
|
||||
websocket::{ws_cfg, AsyncReadAndWrite},
|
||||
};
|
||||
use anyhow::Result;
|
||||
use futures_util::StreamExt;
|
||||
use rustls::ServerName;
|
||||
use std::{convert::TryFrom, net::SocketAddr};
|
||||
use tokio_tungstenite::tungstenite::{
|
||||
client::IntoClientRequest,
|
||||
http::{
|
||||
header::{ORIGIN, SEC_WEBSOCKET_PROTOCOL},
|
||||
Uri,
|
||||
},
|
||||
};
|
||||
|
||||
pub async fn websocket_connect(target: SocketAddr, server_name: &str, url: &Uri, origin: &str, config: OutgoingVerifierConfig) -> Result<(StanzaWrite, StanzaRead)> {
|
||||
let mut request = url.into_client_request()?;
|
||||
request.headers_mut().append(SEC_WEBSOCKET_PROTOCOL, "xmpp".parse()?);
|
||||
request.headers_mut().append(ORIGIN, origin.parse()?);
|
||||
|
||||
let dnsname = ServerName::try_from(server_name)?;
|
||||
let stream = tokio::net::TcpStream::connect(target).await?;
|
||||
let stream = config.connector.connect(dnsname, stream).await?;
|
||||
|
||||
//let stream: tokio_rustls::TlsStream<tokio::net::TcpStream> = stream.into();
|
||||
// todo: tokio_tungstenite seems to have a bug, if the write buffer is non-zero, it'll hang forever, even though we always flush, investigate
|
||||
//let stream = BufStream::with_capacity(crate::IN_BUFFER_SIZE, 0, stream);
|
||||
let stream: Box<dyn AsyncReadAndWrite + Unpin + Send> = Box::new(stream);
|
||||
|
||||
let (stream, _) = tokio_tungstenite::client_async_with_config(request, stream, ws_cfg(config.max_stanza_size_bytes)).await?;
|
||||
|
||||
let (wrt, rd) = stream.split();
|
||||
|
||||
Ok((StanzaWrite::WebSocketClientWrite(wrt), StanzaRead::WebSocketRead(rd)))
|
||||
}
|
Loading…
Reference in New Issue
Block a user