diff --git a/build.rs b/build.rs index e4ff2e7..06ee2d5 100644 --- a/build.rs +++ b/build.rs @@ -39,7 +39,7 @@ fn main() { for (mut key, value) in env::vars() { //writeln!(&mut w, "{key}: {value}", ).unwrap(); if value == "1" && key.starts_with("CARGO_FEATURE_") { - let mut key = key.split_off(14).replace("_", "-"); + let mut key = key.split_off(14).replace('_', "-"); key.make_ascii_lowercase(); if allowed_features.contains(&key.as_str()) { features.push(key); diff --git a/src/common/ca_roots.rs b/src/common/ca_roots.rs new file mode 100644 index 0000000..b16ae8a --- /dev/null +++ b/src/common/ca_roots.rs @@ -0,0 +1,33 @@ +#[cfg(feature = "tokio-rustls")] +use tokio_rustls::webpki::{TlsServerTrustAnchors, TrustAnchor}; + +#[cfg(all(feature = "webpki-roots", not(feature = "rustls-native-certs")))] +pub use webpki_roots::TLS_SERVER_ROOTS; + +#[cfg(all(feature = "rustls-native-certs", not(feature = "webpki-roots")))] +lazy_static::lazy_static! { + pub static ref TLS_SERVER_ROOTS: TlsServerTrustAnchors<'static> = { + // we need these to stick around for 'static, this is only called once so no problem + let certs = Box::leak(Box::new(rustls_native_certs::load_native_certs().expect("could not load platform certs"))); + let root_cert_store = Box::leak(Box::new(Vec::new())); + for cert in certs { + // some system CAs are invalid, ignore those + if let Ok(ta) = TrustAnchor::try_from_cert_der(&cert.0) { + root_cert_store.push(ta); + } + } + TlsServerTrustAnchors(root_cert_store) + }; +} + +pub fn root_cert_store() -> rustls::RootCertStore { + use rustls::{OwnedTrustAnchor, RootCertStore}; + let mut root_cert_store = RootCertStore::empty(); + root_cert_store.add_server_trust_anchors( + TLS_SERVER_ROOTS + .0 + .iter() + .map(|ta| OwnedTrustAnchor::from_subject_spki_name_constraints(ta.subject, ta.spki, ta.name_constraints)), + ); + root_cert_store +} diff --git a/src/common/certs_key.rs b/src/common/certs_key.rs new file mode 100644 index 0000000..32070a0 --- /dev/null +++ b/src/common/certs_key.rs @@ -0,0 +1,47 @@ +use std::sync::{Arc, RwLock}; + +use anyhow::Result; +use rustls::{sign::CertifiedKey, SignatureScheme}; + +pub struct CertsKey { + #[cfg(feature = "rustls-pemfile")] + pub inner: Result>>, +} + +impl CertsKey { + pub fn new(certified_key: Result) -> Self { + CertsKey { + #[cfg(feature = "rustls-pemfile")] + inner: certified_key.map(|c| RwLock::new(Arc::new(c))), + } + } +} + +#[cfg(feature = "rustls-pemfile")] +impl rustls::server::ResolvesServerCert for CertsKey { + fn resolve(&self, _: rustls::server::ClientHello) -> Option> { + self.inner.as_ref().map(|rwl| rwl.read().expect("CertKey poisoned?").clone()).ok() + } +} + +#[cfg(feature = "rustls-pemfile")] +impl rustls::client::ResolvesClientCert for CertsKey { + fn resolve(&self, _: &[&[u8]], _: &[SignatureScheme]) -> Option> { + self.inner.as_ref().map(|rwl| rwl.read().expect("CertKey poisoned?").clone()).ok() + } + + fn has_certs(&self) -> bool { + self.inner.is_ok() + } +} + +#[cfg(not(feature = "rustls-pemfile"))] +impl rustls::client::ResolvesClientCert for CertsKey { + fn resolve(&self, _: &[&[u8]], _: &[SignatureScheme]) -> Option> { + None + } + + fn has_certs(&self) -> bool { + false + } +} diff --git a/src/common/incoming.rs b/src/common/incoming.rs new file mode 100644 index 0000000..aedf17b --- /dev/null +++ b/src/common/incoming.rs @@ -0,0 +1,198 @@ +use crate::{ + common::{c2s, certs_key::CertsKey, shuffle_rd_wr_filter_only, stream_preamble, to_str, ALPN_XMPP_CLIENT, ALPN_XMPP_SERVER}, + context::Context, + in_out::{StanzaRead, StanzaWrite}, + slicesubsequence::SliceSubsequence, + stanzafilter::StanzaFilter, +}; +use anyhow::{anyhow, bail, Result}; +use log::trace; +use rustls::{Certificate, ServerConfig, ServerConnection}; +use std::{io::Write, net::SocketAddr, sync::Arc}; +use tokio::io::{AsyncWriteExt, ReadHalf, WriteHalf}; + +#[derive(Clone)] +pub struct CloneableConfig { + pub max_stanza_size_bytes: usize, + #[cfg(feature = "s2s-incoming")] + pub s2s_target: Option, + #[cfg(feature = "c2s-incoming")] + pub c2s_target: Option, + pub proxy: bool, +} + +pub fn server_config(certs_key: Arc) -> Result { + if let Err(e) = &certs_key.inner { + bail!("invalid cert/key: {}", e); + } + + let config = ServerConfig::builder().with_safe_defaults(); + #[cfg(feature = "s2s")] + let config = config.with_client_cert_verifier(Arc::new(crate::verify::AllowAnonymousOrAnyCert)); + #[cfg(not(feature = "s2s"))] + let config = config.with_no_client_auth(); + let mut config = config.with_cert_resolver(certs_key); + // todo: will connecting without alpn work then? + config.alpn_protocols.push(ALPN_XMPP_CLIENT.to_vec()); + config.alpn_protocols.push(ALPN_XMPP_SERVER.to_vec()); + + Ok(config) +} + +#[cfg(not(feature = "s2s-incoming"))] +pub type ServerCerts = (); + +#[cfg(feature = "s2s-incoming")] +#[derive(Clone)] +pub enum ServerCerts { + Tls(&'static ServerConnection), + #[cfg(feature = "quic")] + Quic(quinn::Connection), +} + +#[cfg(feature = "s2s-incoming")] +impl ServerCerts { + pub fn peer_certificates(&self) -> Option> { + match self { + ServerCerts::Tls(c) => c.peer_certificates().map(|c| c.to_vec()), + #[cfg(feature = "quic")] + ServerCerts::Quic(c) => c.peer_identity().and_then(|v| v.downcast::>().ok()).map(|v| v.to_vec()), + } + } + + pub fn sni(&self) -> Option { + match self { + ServerCerts::Tls(c) => c.sni_hostname().map(|s| s.to_string()), + #[cfg(feature = "quic")] + ServerCerts::Quic(c) => c.handshake_data().and_then(|v| v.downcast::().ok()).and_then(|h| h.server_name), + } + } + + pub fn alpn(&self) -> Option> { + match self { + ServerCerts::Tls(c) => c.alpn_protocol().map(|s| s.to_vec()), + #[cfg(feature = "quic")] + ServerCerts::Quic(c) => c.handshake_data().and_then(|v| v.downcast::().ok()).and_then(|h| h.protocol), + } + } + + pub fn is_tls(&self) -> bool { + match self { + ServerCerts::Tls(_) => true, + #[cfg(feature = "quic")] + ServerCerts::Quic(_) => false, + } + } +} + +pub async fn shuffle_rd_wr(in_rd: StanzaRead, in_wr: StanzaWrite, config: CloneableConfig, server_certs: ServerCerts, local_addr: SocketAddr, client_addr: &mut Context<'_>) -> Result<()> { + let filter = StanzaFilter::new(config.max_stanza_size_bytes); + shuffle_rd_wr_filter(in_rd, in_wr, config, server_certs, local_addr, client_addr, filter).await +} + +pub async fn shuffle_rd_wr_filter( + mut in_rd: StanzaRead, + mut in_wr: StanzaWrite, + config: CloneableConfig, + server_certs: ServerCerts, + local_addr: SocketAddr, + client_addr: &mut Context<'_>, + mut in_filter: StanzaFilter, +) -> Result<()> { + // now read to figure out client vs server + let (stream_open, is_c2s) = stream_preamble(&mut in_rd, &mut in_wr, client_addr.log_from(), &mut in_filter).await?; + client_addr.set_c2s_stream_open(is_c2s, &stream_open); + + #[cfg(feature = "s2s-incoming")] + { + trace!( + "{} connected: sni: {:?}, alpn: {:?}, tls-not-quic: {}", + client_addr.log_from(), + server_certs.sni(), + server_certs.alpn().map(|a| String::from_utf8_lossy(&a).to_string()), + server_certs.is_tls(), + ); + + if !is_c2s { + // for s2s we need this + use std::time::SystemTime; + let domain = stream_open + .extract_between(b" from='", b"'") + .or_else(|_| stream_open.extract_between(b" from=\"", b"\"")) + .and_then(|b| Ok(std::str::from_utf8(b)?))?; + let (_, cert_verifier) = crate::srv::get_xmpp_connections(domain, is_c2s).await?; + let certs = server_certs.peer_certificates().ok_or_else(|| anyhow!("no client cert auth for s2s incoming from {}", domain))?; + // todo: send stream error saying cert is invalid + cert_verifier.verify_cert(&certs[0], &certs[1..], SystemTime::now())?; + } + drop(server_certs); + } + + let (out_rd, out_wr) = open_incoming(&config, local_addr, client_addr, &stream_open, is_c2s, &mut in_filter).await?; + drop(stream_open); + + shuffle_rd_wr_filter_only( + in_rd, + in_wr, + StanzaRead::new(out_rd), + StanzaWrite::new(out_wr), + is_c2s, + config.max_stanza_size_bytes, + client_addr, + in_filter, + ) + .await +} + +async fn open_incoming( + config: &CloneableConfig, + local_addr: SocketAddr, + client_addr: &mut Context<'_>, + stream_open: &[u8], + is_c2s: bool, + in_filter: &mut StanzaFilter, +) -> Result<(ReadHalf, WriteHalf)> { + let target = if is_c2s { + #[cfg(not(feature = "c2s-incoming"))] + bail!("incoming c2s connection but lacking compile-time support"); + #[cfg(feature = "c2s-incoming")] + config.c2s_target + } else { + #[cfg(not(feature = "s2s-incoming"))] + bail!("incoming s2s connection but lacking compile-time support"); + #[cfg(feature = "s2s-incoming")] + config.s2s_target + } + .ok_or_else(|| anyhow!("incoming connection but `{}_target` not defined", c2s(is_c2s)))?; + client_addr.set_to_addr(target); + + let out_stream = tokio::net::TcpStream::connect(target).await?; + let (out_rd, mut out_wr) = tokio::io::split(out_stream); + + if config.proxy { + /* + https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt + PROXY TCP4 255.255.255.255 255.255.255.255 65535 65535\r\n + PROXY TCP6 ffff:f...f:ffff ffff:f...f:ffff 65535 65535\r\n + PROXY TCP6 SOURCE_IP DEST_IP SOURCE_PORT DEST_PORT\r\n + */ + // tokio AsyncWrite doesn't have write_fmt so have to go through this buffer for some crazy reason + //write!(out_wr, "PROXY TCP{} {} {} {} {}\r\n", if client_addr.is_ipv4() { '4' } else {'6' }, client_addr.ip(), local_addr.ip(), client_addr.port(), local_addr.port())?; + write!( + &mut in_filter.buf[0..], + "PROXY TCP{} {} {} {} {}\r\n", + if client_addr.client_addr().is_ipv4() { '4' } else { '6' }, + client_addr.client_addr().ip(), + local_addr.ip(), + client_addr.client_addr().port(), + local_addr.port() + )?; + let end_idx = &(&in_filter.buf[0..]).first_index_of(b"\n")? + 1; + trace!("{} '{}'", client_addr.log_from(), to_str(&in_filter.buf[0..end_idx])); + out_wr.write_all(&in_filter.buf[0..end_idx]).await?; + } + trace!("{} '{}'", client_addr.log_from(), to_str(stream_open)); + out_wr.write_all(stream_open).await?; + out_wr.flush().await?; + Ok((out_rd, out_wr)) +} diff --git a/src/common/mod.rs b/src/common/mod.rs new file mode 100644 index 0000000..9428b8d --- /dev/null +++ b/src/common/mod.rs @@ -0,0 +1,144 @@ +use crate::{ + context::Context, + in_out::{StanzaRead, StanzaWrite}, + slicesubsequence::SliceSubsequence, + stanzafilter::StanzaFilter, +}; +use anyhow::{bail, Result}; +use log::{info, trace}; +use rustls::{ + sign::{RsaSigningKey, SigningKey}, + Certificate, PrivateKey, +}; +use std::{fs::File, io, io::BufReader, sync::Arc}; + +#[cfg(feature = "incoming")] +pub mod incoming; + +#[cfg(feature = "outgoing")] +pub mod outgoing; + +#[cfg(any(feature = "rustls-native-certs", feature = "webpki-roots"))] +pub mod ca_roots; + +pub mod certs_key; + +pub const IN_BUFFER_SIZE: usize = 8192; +pub const ALPN_XMPP_CLIENT: &[u8] = b"xmpp-client"; +pub const ALPN_XMPP_SERVER: &[u8] = b"xmpp-server"; + +pub fn to_str(buf: &[u8]) -> std::borrow::Cow<'_, str> { + String::from_utf8_lossy(buf) +} + +pub fn c2s(is_c2s: bool) -> &'static str { + if is_c2s { + "c2s" + } else { + "s2s" + } +} + +pub async fn first_bytes_match(stream: &tokio::net::TcpStream, p: &mut [u8], matcher: fn(&[u8]) -> bool) -> anyhow::Result { + // sooo... I don't think peek here can be used for > 1 byte without this timer craziness... can it? + let len = p.len(); + // wait up to 10 seconds until len bytes have been read + use std::time::{Duration, Instant}; + let duration = Duration::from_secs(10); + let now = Instant::now(); + loop { + let n = stream.peek(p).await?; + if n == len { + break; // success + } + if n == 0 { + bail!("not enough bytes"); + } + if Instant::now() - now > duration { + bail!("less than {} bytes in 10 seconds, closed connection?", len); + } + } + + Ok(matcher(p)) +} + +pub async fn stream_preamble(in_rd: &mut StanzaRead, in_wr: &mut StanzaWrite, client_addr: &'_ str, in_filter: &mut StanzaFilter) -> Result<(Vec, bool)> { + let mut stream_open = Vec::new(); + while let Ok(Some((buf, _))) = in_rd.next(in_filter, client_addr, in_wr).await { + trace!("{} received pre- stanza: '{}'", client_addr, to_str(buf)); + if buf.starts_with(b" stanza: {}", to_str(buf)); + } + } + bail!("stream ended before open") +} + +#[allow(clippy::too_many_arguments)] +pub async fn shuffle_rd_wr_filter_only( + mut in_rd: StanzaRead, + mut in_wr: StanzaWrite, + mut out_rd: StanzaRead, + mut out_wr: StanzaWrite, + is_c2s: bool, + max_stanza_size_bytes: usize, + client_addr: &mut Context<'_>, + mut in_filter: StanzaFilter, +) -> Result<()> { + let mut out_filter = StanzaFilter::new(max_stanza_size_bytes); + + loop { + tokio::select! { + Ok(ret) = in_rd.next(&mut in_filter, client_addr.log_to(), &mut in_wr) => { + match ret { + None => break, + Some((buf, eoft)) => { + trace!("{} '{}'", client_addr.log_from(), to_str(buf)); + out_wr.write_all(is_c2s, buf, eoft, client_addr.log_from()).await?; + out_wr.flush().await?; + } + } + }, + Ok(ret) = out_rd.next(&mut out_filter, client_addr.log_from(), &mut out_wr) => { + match ret { + None => break, + Some((buf, eoft)) => { + trace!("{} '{}'", client_addr.log_to(), to_str(buf)); + in_wr.write_all(is_c2s, buf, eoft, client_addr.log_to()).await?; + in_wr.flush().await?; + } + } + }, + } + } + + info!("{} disconnected", client_addr.log_from()); + Ok(()) +} + +#[cfg(feature = "rustls-pemfile")] +pub fn read_certified_key(tls_key: &str, tls_cert: &str) -> Result { + use rustls_pemfile::{certs, read_all, Item}; + + let tls_key = read_all(&mut BufReader::new(File::open(tls_key)?)) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid key"))? + .into_iter() + .flat_map(|item| match item { + Item::RSAKey(der) => RsaSigningKey::new(&PrivateKey(der)).ok().map(Arc::new).map(|r| r as Arc), + Item::PKCS8Key(der) => rustls::sign::any_supported_type(&PrivateKey(der)).ok(), + Item::ECKey(der) => rustls::sign::any_supported_type(&PrivateKey(der)).ok(), + _ => None, + }) + .next() + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "invalid key"))?; + + let tls_certs = certs(&mut BufReader::new(File::open(tls_cert)?)) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert")) + .map(|mut certs| certs.drain(..).map(Certificate).collect())?; + + Ok(rustls::sign::CertifiedKey::new(tls_certs, tls_key)) +} diff --git a/src/common/outgoing.rs b/src/common/outgoing.rs new file mode 100644 index 0000000..8a6ea76 --- /dev/null +++ b/src/common/outgoing.rs @@ -0,0 +1,54 @@ +use crate::{ + common::{certs_key::CertsKey, ALPN_XMPP_CLIENT, ALPN_XMPP_SERVER}, + verify::XmppServerCertVerifier, +}; +use rustls::ClientConfig; +use std::sync::Arc; +use tokio_rustls::TlsConnector; + +#[derive(Clone)] +pub struct OutgoingConfig { + pub max_stanza_size_bytes: usize, + pub certs_key: Arc, +} + +impl OutgoingConfig { + pub fn with_custom_certificate_verifier(&self, is_c2s: bool, cert_verifier: XmppServerCertVerifier) -> OutgoingVerifierConfig { + let config = match is_c2s { + false => ClientConfig::builder() + .with_safe_defaults() + .with_custom_certificate_verifier(Arc::new(cert_verifier)) + .with_client_cert_resolver(self.certs_key.clone()), + _ => ClientConfig::builder() + .with_safe_defaults() + .with_custom_certificate_verifier(Arc::new(cert_verifier)) + .with_no_client_auth(), + }; + + let mut config_alpn = config.clone(); + config_alpn.alpn_protocols.push(if is_c2s { ALPN_XMPP_CLIENT } else { ALPN_XMPP_SERVER }.to_vec()); + + let config_alpn = Arc::new(config_alpn); + + let connector_alpn: TlsConnector = config_alpn.clone().into(); + + let connector: TlsConnector = Arc::new(config).into(); + + OutgoingVerifierConfig { + max_stanza_size_bytes: self.max_stanza_size_bytes, + config_alpn, + connector_alpn, + connector, + } + } +} + +#[derive(Clone)] +pub struct OutgoingVerifierConfig { + pub max_stanza_size_bytes: usize, + + pub config_alpn: Arc, + pub connector_alpn: TlsConnector, + + pub connector: TlsConnector, +} diff --git a/src/context.rs b/src/context.rs new file mode 100644 index 0000000..380e186 --- /dev/null +++ b/src/context.rs @@ -0,0 +1,112 @@ +use crate::{ + common::{c2s, to_str}, + slicesubsequence::SliceSubsequence, +}; +use log::{info, log_enabled}; +use std::net::SocketAddr; + +#[derive(Clone)] +pub struct Context<'a> { + conn_id: String, + log_from: String, + log_to: String, + proto: &'a str, + is_c2s: Option, + to: Option, + to_addr: Option, + from: Option, + client_addr: SocketAddr, +} + +impl<'a> Context<'a> { + pub fn new(proto: &'static str, client_addr: SocketAddr) -> Context { + let (log_to, log_from, conn_id) = if log_enabled!(log::Level::Info) { + #[cfg(feature = "logging")] + let conn_id = { + use rand::{distributions::Alphanumeric, thread_rng, Rng}; + thread_rng().sample_iter(&Alphanumeric).take(10).map(char::from).collect() + }; + #[cfg(not(feature = "logging"))] + let conn_id = "".to_string(); + ( + format!("{}: ({} <- ({}-unk)):", conn_id, client_addr, proto), + format!("{}: ({} -> ({}-unk)):", conn_id, client_addr, proto), + conn_id, + ) + } else { + ("".to_string(), "".to_string(), "".to_string()) + }; + + Context { + conn_id, + log_from, + log_to, + proto, + client_addr, + is_c2s: None, + to: None, + to_addr: None, + from: None, + } + } + + fn re_calc(&mut self) { + // todo: make this good + self.log_from = format!( + "{}: ({} ({}) -> ({}-{}) -> {} ({})):", + self.conn_id, + self.client_addr, + if self.from.is_some() { self.from.as_ref().unwrap() } else { "unk" }, + self.proto, + if self.is_c2s.is_some() { c2s(self.is_c2s.unwrap()) } else { "unk" }, + if self.to_addr.is_some() { self.to_addr.as_ref().unwrap().to_string() } else { "unk".to_string() }, + if self.to.is_some() { self.to.as_ref().unwrap() } else { "unk" }, + ); + self.log_to = self.log_from.replace(" -> ", " <- "); + } + + pub fn log_from(&self) -> &str { + &self.log_from + } + + pub fn log_to(&self) -> &str { + &self.log_to + } + + pub fn client_addr(&self) -> &SocketAddr { + &self.client_addr + } + + pub fn set_proto(&mut self, proto: &'static str) { + if log_enabled!(log::Level::Info) { + self.proto = proto; + self.to_addr = None; + self.re_calc(); + } + } + + pub fn set_c2s_stream_open(&mut self, is_c2s: bool, stream_open: &[u8]) { + if log_enabled!(log::Level::Info) { + self.is_c2s = Some(is_c2s); + self.from = stream_open + .extract_between(b" from='", b"'") + .or_else(|_| stream_open.extract_between(b" from=\"", b"\"")) + .map(|b| to_str(b).to_string()) + .ok(); + self.to = stream_open + .extract_between(b" to='", b"'") + .or_else(|_| stream_open.extract_between(b" to=\"", b"\"")) + .map(|b| to_str(b).to_string()) + .ok(); + self.re_calc(); + info!("{} stream data set", &self.log_from()); + } + } + + pub fn set_to_addr(&mut self, to_addr: SocketAddr) { + if log_enabled!(log::Level::Info) { + self.to_addr = Some(to_addr); + self.re_calc(); + } + } +} diff --git a/src/in_out.rs b/src/in_out.rs index 10ffb76..0d464c4 100644 --- a/src/in_out.rs +++ b/src/in_out.rs @@ -1,14 +1,20 @@ // Box, Box #[cfg(feature = "websocket")] -use crate::{from_ws, to_ws_new, AsyncReadAndWrite}; -use crate::{slicesubsequence::SliceSubsequence, trace, StanzaFilter, StanzaRead::*, StanzaReader, StanzaWrite::*}; +use crate::websocket::{from_ws, to_ws_new, AsyncReadAndWrite}; +use crate::{ + common::IN_BUFFER_SIZE, + in_out::{StanzaRead::*, StanzaWrite::*}, + slicesubsequence::SliceSubsequence, + stanzafilter::{StanzaFilter, StanzaReader}, +}; use anyhow::{bail, Result}; #[cfg(feature = "websocket")] use futures_util::{ stream::{SplitSink, SplitStream}, SinkExt, TryStreamExt, }; +use log::trace; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}; #[cfg(feature = "websocket")] use tokio_tungstenite::{tungstenite::Message::*, WebSocketStream}; @@ -75,7 +81,7 @@ impl StanzaRead { #[inline(always)] pub fn new(rd: T) -> Self { // we naively read 1 byte at a time, which buffering significantly speeds up - AsyncRead(StanzaReader(Box::new(BufReader::with_capacity(crate::IN_BUFFER_SIZE, rd)))) + AsyncRead(StanzaReader(Box::new(BufReader::with_capacity(IN_BUFFER_SIZE, rd)))) } #[inline(always)] diff --git a/src/lib.rs b/src/lib.rs index d4cd79b..b488780 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,201 +1,28 @@ -mod stanzafilter; -pub use stanzafilter::*; - -mod slicesubsequence; -use slicesubsequence::*; - use anyhow::bail; +use log::info; use std::net::SocketAddr; -pub use log::{debug, error, info, log_enabled, trace}; +pub mod common; +pub mod slicesubsequence; +pub mod stanzafilter; -#[cfg(feature = "s2s-incoming")] -use rustls::{Certificate, ServerConnection}; +#[cfg(feature = "quic")] +pub mod quic; -pub fn to_str(buf: &[u8]) -> std::borrow::Cow<'_, str> { - String::from_utf8_lossy(buf) -} +#[cfg(feature = "tls")] +pub mod tls; -pub fn c2s(is_c2s: bool) -> &'static str { - if is_c2s { - "c2s" - } else { - "s2s" - } -} +#[cfg(feature = "outgoing")] +pub mod outgoing; -pub async fn first_bytes_match(stream: &tokio::net::TcpStream, p: &mut [u8], matcher: fn(&[u8]) -> bool) -> anyhow::Result { - // sooo... I don't think peek here can be used for > 1 byte without this timer craziness... can it? - let len = p.len(); - // wait up to 10 seconds until len bytes have been read - use std::time::{Duration, Instant}; - let duration = Duration::from_secs(10); - let now = Instant::now(); - loop { - let n = stream.peek(p).await?; - if n == len { - break; // success - } - if n == 0 { - bail!("not enough bytes"); - } - if Instant::now() - now > duration { - bail!("less than {} bytes in 10 seconds, closed connection?", len); - } - } +#[cfg(any(feature = "s2s-incoming", feature = "outgoing"))] +pub mod srv; - Ok(matcher(p)) -} +#[cfg(feature = "websocket")] +pub mod websocket; -#[derive(Clone)] -pub struct Context<'a> { - conn_id: String, - log_from: String, - log_to: String, - proto: &'a str, - is_c2s: Option, - to: Option, - to_addr: Option, - from: Option, - client_addr: SocketAddr, -} +#[cfg(any(feature = "s2s-incoming", feature = "outgoing"))] +pub mod verify; -impl<'a> Context<'a> { - pub fn new(proto: &'static str, client_addr: SocketAddr) -> Context { - let (log_to, log_from, conn_id) = if log_enabled!(log::Level::Info) { - #[cfg(feature = "logging")] - let conn_id = { - use rand::distributions::Alphanumeric; - use rand::{thread_rng, Rng}; - thread_rng().sample_iter(&Alphanumeric).take(10).map(char::from).collect() - }; - #[cfg(not(feature = "logging"))] - let conn_id = "".to_string(); - ( - format!("{}: ({} <- ({}-unk)):", conn_id, client_addr, proto), - format!("{}: ({} -> ({}-unk)):", conn_id, client_addr, proto), - conn_id, - ) - } else { - ("".to_string(), "".to_string(), "".to_string()) - }; - - Context { - conn_id, - log_from, - log_to, - proto, - client_addr, - is_c2s: None, - to: None, - to_addr: None, - from: None, - } - } - - fn re_calc(&mut self) { - // todo: make this good - self.log_from = format!( - "{}: ({} ({}) -> ({}-{}) -> {} ({})):", - self.conn_id, - self.client_addr, - if self.from.is_some() { self.from.as_ref().unwrap() } else { "unk" }, - self.proto, - if self.is_c2s.is_some() { c2s(self.is_c2s.unwrap()) } else { "unk" }, - if self.to_addr.is_some() { self.to_addr.as_ref().unwrap().to_string() } else { "unk".to_string() }, - if self.to.is_some() { self.to.as_ref().unwrap() } else { "unk" }, - ); - self.log_to = self.log_from.replace(" -> ", " <- "); - } - - pub fn log_from(&self) -> &str { - &self.log_from - } - - pub fn log_to(&self) -> &str { - &self.log_to - } - - pub fn client_addr(&self) -> &SocketAddr { - &self.client_addr - } - - pub fn set_proto(&mut self, proto: &'static str) { - if log_enabled!(log::Level::Info) { - self.proto = proto; - self.to_addr = None; - self.re_calc(); - } - } - - pub fn set_c2s_stream_open(&mut self, is_c2s: bool, stream_open: &[u8]) { - if log_enabled!(log::Level::Info) { - self.is_c2s = Some(is_c2s); - self.from = stream_open - .extract_between(b" from='", b"'") - .or_else(|_| stream_open.extract_between(b" from=\"", b"\"")) - .map(|b| to_str(b).to_string()) - .ok(); - self.to = stream_open - .extract_between(b" to='", b"'") - .or_else(|_| stream_open.extract_between(b" to=\"", b"\"")) - .map(|b| to_str(b).to_string()) - .ok(); - self.re_calc(); - info!("{} stream data set", &self.log_from()); - } - } - - pub fn set_to_addr(&mut self, to_addr: SocketAddr) { - if log_enabled!(log::Level::Info) { - self.to_addr = Some(to_addr); - self.re_calc(); - } - } -} - -#[cfg(not(feature = "s2s-incoming"))] -pub type ServerCerts = (); - -#[cfg(feature = "s2s-incoming")] -#[derive(Clone)] -pub enum ServerCerts { - Tls(&'static ServerConnection), - #[cfg(feature = "quic")] - Quic(quinn::Connection), -} - -#[cfg(feature = "s2s-incoming")] -impl ServerCerts { - pub fn peer_certificates(&self) -> Option> { - match self { - ServerCerts::Tls(c) => c.peer_certificates().map(|c| c.to_vec()), - #[cfg(feature = "quic")] - ServerCerts::Quic(c) => c.peer_identity().and_then(|v| v.downcast::>().ok()).map(|v| v.to_vec()), - } - } - - pub fn sni(&self) -> Option { - match self { - ServerCerts::Tls(c) => c.sni_hostname().map(|s| s.to_string()), - #[cfg(feature = "quic")] - ServerCerts::Quic(c) => c.handshake_data().and_then(|v| v.downcast::().ok()).and_then(|h| h.server_name), - } - } - - pub fn alpn(&self) -> Option> { - match self { - ServerCerts::Tls(c) => c.alpn_protocol().map(|s| s.to_vec()), - #[cfg(feature = "quic")] - ServerCerts::Quic(c) => c.handshake_data().and_then(|v| v.downcast::().ok()).and_then(|h| h.protocol), - } - } - - pub fn is_tls(&self) -> bool { - match self { - ServerCerts::Tls(_) => true, - #[cfg(feature = "quic")] - ServerCerts::Quic(_) => false, - } - } -} +mod context; +pub mod in_out; diff --git a/src/main.rs b/src/main.rs index 0fa1768..333c426 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,112 +1,14 @@ #![deny(clippy::all)] - -use std::ffi::OsString; -use std::fs::File; -use std::io; -use std::io::{BufReader, Read, Write}; -use std::iter::Iterator; -use std::net::SocketAddr; -use std::path::Path; -use std::sync::{Arc, RwLock}; - +use anyhow::Result; use die::{die, Die}; - +use log::{debug, error, info}; use serde_derive::Deserialize; - -use tokio::io::{AsyncWriteExt, ReadHalf, WriteHalf}; -use tokio::net::TcpListener; +use std::{ffi::OsString, fs::File, io::Read, iter::Iterator, net::SocketAddr, path::Path, sync::Arc}; use tokio::task::JoinHandle; - -#[cfg(feature = "rustls")] -use rustls::{ - sign::{CertifiedKey, RsaSigningKey, SigningKey}, - Certificate, ClientConfig, PrivateKey, ServerConfig, SignatureScheme, -}; - -#[cfg(feature = "tokio-rustls")] -use tokio_rustls::{ - webpki::{DnsNameRef, TlsServerTrustAnchors, TrustAnchor}, - TlsConnector, -}; - -use anyhow::{anyhow, bail, Result}; - -mod slicesubsequence; -use slicesubsequence::*; - -pub use xmpp_proxy::*; - -#[cfg(feature = "quic")] -mod quic; -#[cfg(feature = "quic")] -use crate::quic::*; - -#[cfg(feature = "tls")] -mod tls; -#[cfg(feature = "tls")] -use crate::tls::*; +use xmpp_proxy::common::certs_key::CertsKey; #[cfg(feature = "outgoing")] -mod outgoing; -#[cfg(feature = "outgoing")] -use crate::outgoing::*; - -#[cfg(any(feature = "s2s-incoming", feature = "outgoing"))] -mod srv; -#[cfg(any(feature = "s2s-incoming", feature = "outgoing"))] -use crate::srv::*; - -#[cfg(feature = "websocket")] -mod websocket; -#[cfg(feature = "websocket")] -use crate::websocket::*; - -#[cfg(any(feature = "s2s-incoming", feature = "outgoing"))] -mod verify; -#[cfg(any(feature = "s2s-incoming", feature = "outgoing"))] -use crate::verify::*; - -mod in_out; -pub use crate::in_out::*; - -const IN_BUFFER_SIZE: usize = 8192; - -// todo: split these out to outgoing module - -const ALPN_XMPP_CLIENT: &[u8] = b"xmpp-client"; -const ALPN_XMPP_SERVER: &[u8] = b"xmpp-server"; - -#[cfg(all(feature = "webpki-roots", not(feature = "rustls-native-certs")))] -pub use webpki_roots::TLS_SERVER_ROOTS; - -#[cfg(all(feature = "rustls-native-certs", not(feature = "webpki-roots")))] -lazy_static::lazy_static! { - static ref TLS_SERVER_ROOTS: TlsServerTrustAnchors<'static> = { - // we need these to stick around for 'static, this is only called once so no problem - let certs = Box::leak(Box::new(rustls_native_certs::load_native_certs().expect("could not load platform certs"))); - let root_cert_store = Box::leak(Box::new(Vec::new())); - for cert in certs { - // some system CAs are invalid, ignore those - if let Ok(ta) = TrustAnchor::try_from_cert_der(&cert.0) { - root_cert_store.push(ta); - } - } - TlsServerTrustAnchors(root_cert_store) - }; -} - -#[cfg(any(feature = "rustls-native-certs", feature = "webpki-roots"))] -pub fn root_cert_store() -> rustls::RootCertStore { - use rustls::{OwnedTrustAnchor, RootCertStore}; - let mut root_cert_store = RootCertStore::empty(); - root_cert_store.add_server_trust_anchors( - TLS_SERVER_ROOTS - .0 - .iter() - .map(|ta| OwnedTrustAnchor::from_subject_spki_name_constraints(ta.subject, ta.spki, ta.name_constraints)), - ); - root_cert_store -} +use xmpp_proxy::{common::outgoing::OutgoingConfig, outgoing::spawn_outgoing_listener}; #[derive(Deserialize, Default)] struct Config { @@ -123,87 +25,6 @@ struct Config { log_style: Option, } -#[derive(Clone)] -pub struct CloneableConfig { - max_stanza_size_bytes: usize, - #[cfg(feature = "s2s-incoming")] - s2s_target: Option, - #[cfg(feature = "c2s-incoming")] - c2s_target: Option, - proxy: bool, -} - -struct CertsKey { - #[cfg(feature = "rustls-pemfile")] - inner: Result>>, -} - -impl CertsKey { - fn new(main_config: &Config) -> Self { - CertsKey { - #[cfg(feature = "rustls-pemfile")] - inner: main_config.certs_key().map(|c| RwLock::new(Arc::new(c))), - } - } - - #[cfg(all(unix, any(feature = "incoming", feature = "s2s-outgoing")))] - fn spawn_refresh_task(&'static self, cfg_path: OsString) -> Option>> { - if self.inner.is_err() { - None - } else { - Some(tokio::spawn(async move { - use tokio::signal::unix::{signal, SignalKind}; - let mut stream = signal(SignalKind::hangup())?; - loop { - stream.recv().await; - info!("got SIGHUP"); - match Config::parse(&cfg_path).and_then(|c| c.certs_key()) { - Ok(cert_key) => { - if let Ok(rwl) = self.inner.as_ref() { - let cert_key = Arc::new(cert_key); - let mut certs_key = rwl.write().expect("CertKey poisoned?"); - *certs_key = cert_key; - drop(certs_key); - info!("reloaded cert/key successfully!"); - } - } - Err(e) => error!("invalid config/cert/key on SIGHUP: {}", e), - }; - } - })) - } - } -} - -#[cfg(feature = "rustls-pemfile")] -impl rustls::server::ResolvesServerCert for CertsKey { - fn resolve(&self, _: rustls::server::ClientHello) -> Option> { - self.inner.as_ref().map(|rwl| rwl.read().expect("CertKey poisoned?").clone()).ok() - } -} - -#[cfg(feature = "rustls-pemfile")] -impl rustls::client::ResolvesClientCert for CertsKey { - fn resolve(&self, _: &[&[u8]], _: &[SignatureScheme]) -> Option> { - self.inner.as_ref().map(|rwl| rwl.read().expect("CertKey poisoned?").clone()).ok() - } - - fn has_certs(&self) -> bool { - self.inner.is_ok() - } -} - -#[cfg(not(feature = "rustls-pemfile"))] -impl rustls::client::ResolvesClientCert for CertsKey { - fn resolve(&self, _: &[&[u8]], _: &[SignatureScheme]) -> Option> { - None - } - - fn has_certs(&self) -> bool { - false - } -} - impl Config { fn parse>(path: P) -> Result { let mut f = File::open(path)?; @@ -212,8 +33,9 @@ impl Config { Ok(toml::from_str(&input)?) } - fn get_cloneable_cfg(&self) -> CloneableConfig { - CloneableConfig { + #[cfg(feature = "incoming")] + fn get_cloneable_cfg(&self) -> xmpp_proxy::common::incoming::CloneableConfig { + xmpp_proxy::common::incoming::CloneableConfig { max_stanza_size_bytes: self.max_stanza_size_bytes, #[cfg(feature = "s2s-incoming")] s2s_target: self.s2s_target, @@ -238,268 +60,41 @@ impl Config { #[cfg(feature = "rustls-pemfile")] fn certs_key(&self) -> Result { - use rustls_pemfile::{certs, read_all, Item}; - - let tls_key = read_all(&mut BufReader::new(File::open(&self.tls_key)?)) - .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid key"))? - .into_iter() - .flat_map(|item| match item { - Item::RSAKey(der) => RsaSigningKey::new(&PrivateKey(der)).ok().map(Arc::new).map(|r| r as Arc), - Item::PKCS8Key(der) => rustls::sign::any_supported_type(&PrivateKey(der)).ok(), - Item::ECKey(der) => rustls::sign::any_supported_type(&PrivateKey(der)).ok(), - _ => None, - }) - .next() - .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "invalid key"))?; - - let tls_certs = certs(&mut BufReader::new(File::open(&self.tls_cert)?)) - .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert")) - .map(|mut certs| certs.drain(..).map(Certificate).collect())?; - - Ok(rustls::sign::CertifiedKey::new(tls_certs, tls_key)) + xmpp_proxy::common::read_certified_key(&self.tls_key, &self.tls_cert) } - #[cfg(feature = "incoming")] - fn server_config(&self, certs_key: Arc) -> Result { - if let Err(e) = &certs_key.inner { - bail!("invalid cert/key: {}", e); - } - - let config = ServerConfig::builder().with_safe_defaults(); - #[cfg(feature = "s2s")] - let config = config.with_client_cert_verifier(Arc::new(AllowAnonymousOrAnyCert)); - #[cfg(not(feature = "s2s"))] - let config = config.with_no_client_auth(); - let mut config = config.with_cert_resolver(certs_key); - // todo: will connecting without alpn work then? - config.alpn_protocols.push(ALPN_XMPP_CLIENT.to_vec()); - config.alpn_protocols.push(ALPN_XMPP_SERVER.to_vec()); - - Ok(config) + #[cfg(not(feature = "rustls-pemfile"))] + fn certs_key(&self) -> Result { + anyhow::bail!("rustls-pemfile disabled at compile time") } } -#[derive(Clone)] -#[cfg(feature = "outgoing")] -pub struct OutgoingConfig { - max_stanza_size_bytes: usize, - certs_key: Arc, -} - -#[cfg(feature = "outgoing")] -impl OutgoingConfig { - pub fn with_custom_certificate_verifier(&self, is_c2s: bool, cert_verifier: XmppServerCertVerifier) -> OutgoingVerifierConfig { - let config = match is_c2s { - false => ClientConfig::builder() - .with_safe_defaults() - .with_custom_certificate_verifier(Arc::new(cert_verifier)) - .with_client_cert_resolver(self.certs_key.clone()), - _ => ClientConfig::builder() - .with_safe_defaults() - .with_custom_certificate_verifier(Arc::new(cert_verifier)) - .with_no_client_auth(), - }; - - let mut config_alpn = config.clone(); - config_alpn.alpn_protocols.push(if is_c2s { ALPN_XMPP_CLIENT } else { ALPN_XMPP_SERVER }.to_vec()); - - let config_alpn = Arc::new(config_alpn); - - let connector_alpn: TlsConnector = config_alpn.clone().into(); - - let connector: TlsConnector = Arc::new(config).into(); - - OutgoingVerifierConfig { - max_stanza_size_bytes: self.max_stanza_size_bytes, - config_alpn, - connector_alpn, - connector, - } - } -} - -#[derive(Clone)] -#[cfg(feature = "outgoing")] -pub struct OutgoingVerifierConfig { - pub max_stanza_size_bytes: usize, - - pub config_alpn: Arc, - pub connector_alpn: TlsConnector, - - pub connector: TlsConnector, -} - -#[cfg(feature = "incoming")] -async fn shuffle_rd_wr(in_rd: StanzaRead, in_wr: StanzaWrite, config: CloneableConfig, server_certs: ServerCerts, local_addr: SocketAddr, client_addr: &mut Context<'_>) -> Result<()> { - let filter = StanzaFilter::new(config.max_stanza_size_bytes); - shuffle_rd_wr_filter(in_rd, in_wr, config, server_certs, local_addr, client_addr, filter).await -} - -#[cfg(feature = "incoming")] -async fn shuffle_rd_wr_filter( - mut in_rd: StanzaRead, - mut in_wr: StanzaWrite, - config: CloneableConfig, - server_certs: ServerCerts, - local_addr: SocketAddr, - client_addr: &mut Context<'_>, - mut in_filter: StanzaFilter, -) -> Result<()> { - // now read to figure out client vs server - let (stream_open, is_c2s) = stream_preamble(&mut in_rd, &mut in_wr, client_addr.log_from(), &mut in_filter).await?; - client_addr.set_c2s_stream_open(is_c2s, &stream_open); - - #[cfg(feature = "s2s-incoming")] - { - trace!( - "{} connected: sni: {:?}, alpn: {:?}, tls-not-quic: {}", - client_addr.log_from(), - server_certs.sni(), - server_certs.alpn().map(|a| String::from_utf8_lossy(&a).to_string()), - server_certs.is_tls(), - ); - - if !is_c2s { - // for s2s we need this - use std::time::SystemTime; - let domain = stream_open - .extract_between(b" from='", b"'") - .or_else(|_| stream_open.extract_between(b" from=\"", b"\"")) - .and_then(|b| Ok(std::str::from_utf8(b)?))?; - let (_, cert_verifier) = get_xmpp_connections(domain, is_c2s).await?; - let certs = server_certs.peer_certificates().ok_or_else(|| anyhow!("no client cert auth for s2s incoming from {}", domain))?; - // todo: send stream error saying cert is invalid - cert_verifier.verify_cert(&certs[0], &certs[1..], SystemTime::now())?; - } - drop(server_certs); - } - - let (out_rd, out_wr) = open_incoming(&config, local_addr, client_addr, &stream_open, is_c2s, &mut in_filter).await?; - drop(stream_open); - - shuffle_rd_wr_filter_only( - in_rd, - in_wr, - StanzaRead::new(out_rd), - StanzaWrite::new(out_wr), - is_c2s, - config.max_stanza_size_bytes, - client_addr, - in_filter, - ) - .await -} - -#[allow(clippy::too_many_arguments)] -async fn shuffle_rd_wr_filter_only( - mut in_rd: StanzaRead, - mut in_wr: StanzaWrite, - mut out_rd: StanzaRead, - mut out_wr: StanzaWrite, - is_c2s: bool, - max_stanza_size_bytes: usize, - client_addr: &mut Context<'_>, - mut in_filter: StanzaFilter, -) -> Result<()> { - let mut out_filter = StanzaFilter::new(max_stanza_size_bytes); - - loop { - tokio::select! { - Ok(ret) = in_rd.next(&mut in_filter, client_addr.log_to(), &mut in_wr) => { - match ret { - None => break, - Some((buf, eoft)) => { - trace!("{} '{}'", client_addr.log_from(), to_str(buf)); - out_wr.write_all(is_c2s, buf, eoft, client_addr.log_from()).await?; - out_wr.flush().await?; - } - } - }, - Ok(ret) = out_rd.next(&mut out_filter, client_addr.log_from(), &mut out_wr) => { - match ret { - None => break, - Some((buf, eoft)) => { - trace!("{} '{}'", client_addr.log_to(), to_str(buf)); - in_wr.write_all(is_c2s, buf, eoft, client_addr.log_to()).await?; - in_wr.flush().await?; - } - } - }, - } - } - - info!("{} disconnected", client_addr.log_from()); - Ok(()) -} - -#[cfg(feature = "incoming")] -async fn open_incoming( - config: &CloneableConfig, - local_addr: SocketAddr, - client_addr: &mut Context<'_>, - stream_open: &[u8], - is_c2s: bool, - in_filter: &mut StanzaFilter, -) -> Result<(ReadHalf, WriteHalf)> { - let target = if is_c2s { - #[cfg(not(feature = "c2s-incoming"))] - bail!("incoming c2s connection but lacking compile-time support"); - #[cfg(feature = "c2s-incoming")] - config.c2s_target +#[cfg(all(unix, any(feature = "incoming", feature = "s2s-outgoing")))] +fn spawn_refresh_task(certs_key: &'static CertsKey, cfg_path: OsString) -> Option>> { + if certs_key.inner.is_err() { + None } else { - #[cfg(not(feature = "s2s-incoming"))] - bail!("incoming s2s connection but lacking compile-time support"); - #[cfg(feature = "s2s-incoming")] - config.s2s_target + Some(tokio::spawn(async move { + use tokio::signal::unix::{signal, SignalKind}; + let mut stream = signal(SignalKind::hangup())?; + loop { + stream.recv().await; + info!("got SIGHUP"); + match Config::parse(&cfg_path).and_then(|c| c.certs_key()) { + Ok(cert_key) => { + if let Ok(rwl) = certs_key.inner.as_ref() { + let cert_key = Arc::new(cert_key); + let mut certs_key = rwl.write().expect("CertKey poisoned?"); + *certs_key = cert_key; + drop(certs_key); + info!("reloaded cert/key successfully!"); + } + } + Err(e) => error!("invalid config/cert/key on SIGHUP: {}", e), + }; + } + })) } - .ok_or_else(|| anyhow!("incoming connection but `{}_target` not defined", c2s(is_c2s)))?; - client_addr.set_to_addr(target); - - let out_stream = tokio::net::TcpStream::connect(target).await?; - let (out_rd, mut out_wr) = tokio::io::split(out_stream); - - if config.proxy { - /* - https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt - PROXY TCP4 255.255.255.255 255.255.255.255 65535 65535\r\n - PROXY TCP6 ffff:f...f:ffff ffff:f...f:ffff 65535 65535\r\n - PROXY TCP6 SOURCE_IP DEST_IP SOURCE_PORT DEST_PORT\r\n - */ - // tokio AsyncWrite doesn't have write_fmt so have to go through this buffer for some crazy reason - //write!(out_wr, "PROXY TCP{} {} {} {} {}\r\n", if client_addr.is_ipv4() { '4' } else {'6' }, client_addr.ip(), local_addr.ip(), client_addr.port(), local_addr.port())?; - write!( - &mut in_filter.buf[0..], - "PROXY TCP{} {} {} {} {}\r\n", - if client_addr.client_addr().is_ipv4() { '4' } else { '6' }, - client_addr.client_addr().ip(), - local_addr.ip(), - client_addr.client_addr().port(), - local_addr.port() - )?; - let end_idx = &(&in_filter.buf[0..]).first_index_of(b"\n")? + 1; - trace!("{} '{}'", client_addr.log_from(), to_str(&in_filter.buf[0..end_idx])); - out_wr.write_all(&in_filter.buf[0..end_idx]).await?; - } - trace!("{} '{}'", client_addr.log_from(), to_str(stream_open)); - out_wr.write_all(stream_open).await?; - out_wr.flush().await?; - Ok((out_rd, out_wr)) -} - -pub async fn stream_preamble(in_rd: &mut StanzaRead, in_wr: &mut StanzaWrite, client_addr: &'_ str, in_filter: &mut StanzaFilter) -> Result<(Vec, bool)> { - let mut stream_open = Vec::new(); - while let Ok(Some((buf, _))) = in_rd.next(in_filter, client_addr, in_wr).await { - trace!("{} received pre- stanza: '{}'", client_addr, to_str(buf)); - if buf.starts_with(b" stanza: {}", to_str(buf)); - } - } - bail!("stream ended before open") } #[tokio::main] @@ -533,18 +128,23 @@ async fn main() { die!("log_level or log_style defined in config but logging disabled at compile-time"); } + #[cfg(feature = "incoming")] let config = main_config.get_cloneable_cfg(); - let certs_key = Arc::new(CertsKey::new(&main_config)); + let certs_key = Arc::new(CertsKey::new(main_config.certs_key())); let mut handles: Vec>> = Vec::new(); if !main_config.incoming_listen.is_empty() { #[cfg(all(any(feature = "tls", feature = "websocket"), feature = "incoming"))] { + use xmpp_proxy::{ + common::incoming::server_config, + tls::incoming::{spawn_tls_listener, tls_acceptor}, + }; if main_config.c2s_target.is_none() && main_config.s2s_target.is_none() { die!("one of c2s_target/s2s_target must be defined if incoming_listen is non-empty"); } - let acceptor = main_config.tls_acceptor(certs_key.clone()).die("invalid cert/key ?"); + let acceptor = tls_acceptor(server_config(certs_key.clone()).die("invalid cert/key ?")); for listener in main_config.incoming_listen.iter() { handles.push(spawn_tls_listener(listener.parse().die("invalid listener address"), config.clone(), acceptor.clone())); } @@ -555,10 +155,14 @@ async fn main() { if !main_config.quic_listen.is_empty() { #[cfg(all(feature = "quic", feature = "incoming"))] { + use xmpp_proxy::{ + common::incoming::server_config, + quic::incoming::{quic_server_config, spawn_quic_listener}, + }; if main_config.c2s_target.is_none() && main_config.s2s_target.is_none() { die!("one of c2s_target/s2s_target must be defined if quic_listen is non-empty"); } - let quic_config = main_config.quic_server_config(certs_key.clone()).die("invalid cert/key ?"); + let quic_config = quic_server_config(server_config(certs_key.clone()).die("invalid cert/key ?")); for listener in main_config.quic_listen.iter() { handles.push(spawn_quic_listener(listener.parse().die("invalid listener address"), config.clone(), quic_config.clone())); } @@ -581,8 +185,11 @@ async fn main() { die!("all of incoming_listen, quic_listen, outgoing_listen empty, nothing to do, exiting..."); } #[cfg(all(unix, any(feature = "incoming", feature = "s2s-outgoing")))] - if let Some(refresh_task) = Box::leak(Box::new(certs_key.clone())).spawn_refresh_task(cfg_path) { - handles.push(refresh_task); + { + let certs_key = Box::leak(Box::new(certs_key.clone())); + if let Some(refresh_task) = spawn_refresh_task(certs_key, cfg_path) { + handles.push(refresh_task); + } } info!("xmpp-proxy started"); diff --git a/src/outgoing.rs b/src/outgoing.rs index 4c91216..ab71d46 100644 --- a/src/outgoing.rs +++ b/src/outgoing.rs @@ -1,4 +1,16 @@ -use crate::*; +use crate::{ + common::{first_bytes_match, outgoing::OutgoingConfig, shuffle_rd_wr_filter_only, stream_preamble}, + context::Context, + in_out::{StanzaRead, StanzaWrite}, + slicesubsequence::SliceSubsequence, + srv::srv_connect, + stanzafilter::StanzaFilter, +}; +use anyhow::Result; +use die::Die; +use log::{error, info}; +use std::net::SocketAddr; +use tokio::{net::TcpListener, task::JoinHandle}; async fn handle_outgoing_connection(stream: tokio::net::TcpStream, client_addr: &mut Context<'_>, config: OutgoingConfig) -> Result<()> { info!("{} connected", client_addr.log_from()); @@ -7,7 +19,7 @@ async fn handle_outgoing_connection(stream: tokio::net::TcpStream, client_addr: #[cfg(feature = "websocket")] let (mut in_rd, mut in_wr) = if first_bytes_match(&stream, &mut in_filter.buf[0..3], |p| p == b"GET").await? { - incoming_websocket_connection(Box::new(stream), config.max_stanza_size_bytes).await? + crate::websocket::incoming_websocket_connection(Box::new(stream), config.max_stanza_size_bytes).await? } else { let (in_rd, in_wr) = tokio::io::split(stream); (StanzaRead::new(in_rd), StanzaWrite::new(in_wr)) diff --git a/src/quic.rs b/src/quic/incoming.rs similarity index 56% rename from src/quic.rs rename to src/quic/incoming.rs index 43b6852..2529a7e 100644 --- a/src/quic.rs +++ b/src/quic/incoming.rs @@ -1,40 +1,16 @@ -use crate::*; -use futures::StreamExt; -use quinn::{ServerConfig, TransportConfig}; -use std::{net::SocketAddr, sync::Arc}; - +use crate::{ + common::incoming::{shuffle_rd_wr, CloneableConfig, ServerCerts}, + context::Context, + in_out::{StanzaRead, StanzaWrite}, +}; use anyhow::Result; +use die::Die; +use futures::StreamExt; +use log::{error, info}; +use quinn::ServerConfig; +use std::{net::SocketAddr, sync::Arc}; +use tokio::task::JoinHandle; -#[cfg(feature = "outgoing")] -pub async fn quic_connect(target: SocketAddr, server_name: &str, config: OutgoingVerifierConfig) -> Result<(StanzaWrite, StanzaRead)> { - let bind_addr = "0.0.0.0:0".parse().unwrap(); - let client_cfg = config.config_alpn; - - let mut endpoint = quinn::Endpoint::client(bind_addr)?; - endpoint.set_default_client_config(quinn::ClientConfig::new(client_cfg)); - - // connect to server - let quinn::NewConnection { connection, .. } = endpoint.connect(target, server_name)?.await?; - trace!("quic connected: addr={}", connection.remote_address()); - - let (wrt, rd) = connection.open_bi().await?; - Ok((StanzaWrite::new(wrt), StanzaRead::new(rd))) -} - -#[cfg(feature = "incoming")] -impl Config { - pub fn quic_server_config(&self, cert_key: Arc) -> Result { - let transport_config = TransportConfig::default(); - // todo: configure transport_config here if needed - let server_config = self.server_config(cert_key)?; - let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(server_config)); - server_config.transport = Arc::new(transport_config); - - Ok(server_config) - } -} - -#[cfg(feature = "incoming")] pub fn spawn_quic_listener(local_addr: SocketAddr, config: CloneableConfig, server_config: ServerConfig) -> JoinHandle> { let (_endpoint, mut incoming) = quinn::Endpoint::server(server_config, local_addr).die("cannot listen on port/interface"); tokio::spawn(async move { @@ -43,7 +19,7 @@ pub fn spawn_quic_listener(local_addr: SocketAddr, config: CloneableConfig, serv let config = config.clone(); tokio::spawn(async move { if let Ok(mut new_conn) = incoming_conn.await { - let client_addr = crate::Context::new("quic-in", new_conn.connection.remote_address()); + let client_addr = Context::new("quic-in", new_conn.connection.remote_address()); #[cfg(feature = "s2s-incoming")] let server_certs = ServerCerts::Quic(new_conn.connection); @@ -70,3 +46,12 @@ pub fn spawn_quic_listener(local_addr: SocketAddr, config: CloneableConfig, serv Ok(()) }) } + +pub fn quic_server_config(server_config: rustls::ServerConfig) -> ServerConfig { + let transport_config = quinn::TransportConfig::default(); + // todo: configure transport_config here if needed + let mut server_config = ServerConfig::with_crypto(Arc::new(server_config)); + server_config.transport = Arc::new(transport_config); + + server_config +} diff --git a/src/quic/mod.rs b/src/quic/mod.rs new file mode 100644 index 0000000..82d4c6a --- /dev/null +++ b/src/quic/mod.rs @@ -0,0 +1,5 @@ +#[cfg(feature = "incoming")] +pub mod incoming; + +#[cfg(feature = "outgoing")] +pub mod outgoing; diff --git a/src/quic/outgoing.rs b/src/quic/outgoing.rs new file mode 100644 index 0000000..3429e9a --- /dev/null +++ b/src/quic/outgoing.rs @@ -0,0 +1,23 @@ +use std::net::SocketAddr; + +use crate::{ + common::outgoing::OutgoingVerifierConfig, + in_out::{StanzaRead, StanzaWrite}, +}; +use anyhow::Result; +use log::trace; + +pub async fn quic_connect(target: SocketAddr, server_name: &str, config: OutgoingVerifierConfig) -> Result<(StanzaWrite, StanzaRead)> { + let bind_addr = "0.0.0.0:0".parse().unwrap(); + let client_cfg = config.config_alpn; + + let mut endpoint = quinn::Endpoint::client(bind_addr)?; + endpoint.set_default_client_config(quinn::ClientConfig::new(client_cfg)); + + // connect to server + let quinn::NewConnection { connection, .. } = endpoint.connect(target, server_name)?.await?; + trace!("quic connected: addr={}", connection.remote_address()); + + let (wrt, rd) = connection.open_bi().await?; + Ok((StanzaWrite::new(wrt), StanzaRead::new(rd))) +} diff --git a/src/srv.rs b/src/srv.rs index 35301e8..d19a5b2 100644 --- a/src/srv.rs +++ b/src/srv.rs @@ -1,22 +1,33 @@ #![allow(clippy::upper_case_acronyms)] -use std::cmp::Ordering; -use std::convert::TryFrom; -use std::net::{IpAddr, SocketAddr}; - -use data_encoding::BASE64; -use ring::digest::{Algorithm, Context as DigestContext, SHA256, SHA512}; - -use trust_dns_resolver::error::ResolveError; -use trust_dns_resolver::lookup::{SrvLookup, TxtLookup}; -use trust_dns_resolver::{IntoName, TokioAsyncResolver}; - +#[cfg(feature = "outgoing")] +use crate::common::outgoing::{OutgoingConfig, OutgoingVerifierConfig}; +use crate::{ + common::{stream_preamble, to_str}, + context::Context, + in_out::{StanzaRead, StanzaWrite}, + slicesubsequence::SliceSubsequence, + stanzafilter::{StanzaFilter, StanzaReader}, + verify::XmppServerCertVerifier, +}; use anyhow::{bail, Result}; -use tokio_rustls::webpki::DnsName; +use data_encoding::BASE64; +use log::{debug, error, trace}; +use ring::digest::{Algorithm, Context as DigestContext, SHA256, SHA512}; +use serde::Deserialize; +use std::{ + cmp::Ordering, + convert::TryFrom, + net::{IpAddr, SocketAddr}, +}; +use tokio_rustls::webpki::{DnsName, DnsNameRef}; #[cfg(feature = "websocket")] use tokio_tungstenite::tungstenite::http::Uri; - -use crate::*; +use trust_dns_resolver::{ + error::ResolveError, + lookup::{SrvLookup, TxtLookup}, + IntoName, TokioAsyncResolver, +}; lazy_static::lazy_static! { static ref RESOLVER: TokioAsyncResolver = make_resolver(); @@ -165,7 +176,7 @@ impl XmppConnection { &self, domain: &str, stream_open: &[u8], - in_filter: &mut crate::StanzaFilter, + in_filter: &mut StanzaFilter, client_addr: &mut Context<'_>, config: OutgoingVerifierConfig, ) -> Result<(StanzaWrite, StanzaRead, SocketAddr, &'static str)> { @@ -184,28 +195,28 @@ impl XmppConnection { debug!("{} trying ip {}", client_addr.log_from(), to_addr); match self.conn_type { #[cfg(feature = "tls")] - XmppConnectionType::StartTLS => match crate::starttls_connect(to_addr, domain, stream_open, in_filter, config.clone()).await { + XmppConnectionType::StartTLS => match crate::tls::outgoing::starttls_connect(to_addr, domain, stream_open, in_filter, config.clone()).await { Ok((wr, rd)) => return Ok((wr, rd, to_addr, "starttls-out")), Err(e) => error!("starttls connection failed to IP {} from SRV {}, error: {}", to_addr, self.target, e), }, #[cfg(feature = "tls")] - XmppConnectionType::DirectTLS => match crate::tls_connect(to_addr, domain, config.clone()).await { + XmppConnectionType::DirectTLS => match crate::tls::outgoing::tls_connect(to_addr, domain, config.clone()).await { Ok((wr, rd)) => return Ok((wr, rd, to_addr, "directtls-out")), Err(e) => error!("direct tls connection failed to IP {} from SRV {}, error: {}", to_addr, self.target, e), }, #[cfg(feature = "quic")] - XmppConnectionType::QUIC => match crate::quic_connect(to_addr, domain, config.clone()).await { + XmppConnectionType::QUIC => match crate::quic::outgoing::quic_connect(to_addr, domain, config.clone()).await { Ok((wr, rd)) => return Ok((wr, rd, to_addr, "quic-out")), Err(e) => error!("quic connection failed to IP {} from SRV {}, error: {}", to_addr, self.target, e), }, #[cfg(feature = "websocket")] // todo: when websocket is found via DNS, we need to validate cert against domain, *not* target, this is a security problem with XEP-0156, we are doing it the secure but likely unexpected way here for now - XmppConnectionType::WebSocket(ref url, ref origin) => match crate::websocket_connect(to_addr, domain, url, origin, config.clone()).await { + XmppConnectionType::WebSocket(ref url, ref origin) => match crate::websocket::outgoing::websocket_connect(to_addr, domain, url, origin, config.clone()).await { Ok((wr, rd)) => return Ok((wr, rd, to_addr, "websocket-out")), Err(e) => { if self.secure && self.target != orig_domain { // https is a special case, as target is sent in the Host: header, so we have to literally try twice in case this is set for the other on the server - match crate::websocket_connect(to_addr, orig_domain, url, origin, config.clone()).await { + match crate::websocket::outgoing::websocket_connect(to_addr, orig_domain, url, origin, config.clone()).await { Ok((wr, rd)) => return Ok((wr, rd, to_addr, "websocket-out")), Err(e2) => error!("websocket connection failed to IP {} from TXT {}, error try 1: {}, error try 2: {}", to_addr, url, e, e2), } @@ -428,7 +439,7 @@ pub async fn srv_connect( domain: &str, is_c2s: bool, stream_open: &[u8], - in_filter: &mut crate::StanzaFilter, + in_filter: &mut StanzaFilter, client_addr: &mut Context<'_>, config: OutgoingConfig, ) -> Result<(StanzaWrite, StanzaRead, Vec)> { diff --git a/src/stanzafilter.rs b/src/stanzafilter.rs index 6d35a25..4a297f5 100644 --- a/src/stanzafilter.rs +++ b/src/stanzafilter.rs @@ -1,9 +1,9 @@ #![allow(clippy::upper_case_acronyms)] +use crate::common::to_str; use anyhow::{bail, Result}; -use crate::stanzafilter::StanzaState::*; -use crate::to_str; +use StanzaState::*; #[derive(Debug)] enum StanzaState { diff --git a/src/tls.rs b/src/tls/incoming.rs similarity index 69% rename from src/tls.rs rename to src/tls/incoming.rs index 6e40ffe..86db9f4 100644 --- a/src/tls.rs +++ b/src/tls/incoming.rs @@ -1,71 +1,30 @@ -use crate::*; -use rustls::ServerConnection; -use std::convert::TryFrom; -use tokio::io::{AsyncBufReadExt, BufStream}; +use crate::common::incoming::{shuffle_rd_wr_filter, CloneableConfig, ServerCerts}; -use tokio_rustls::{rustls::ServerName, TlsAcceptor}; +use crate::{ + common::{first_bytes_match, to_str, IN_BUFFER_SIZE}, + context::Context, + in_out::{StanzaRead, StanzaWrite}, + slicesubsequence::SliceSubsequence, + stanzafilter::{StanzaFilter, StanzaReader}, + *, +}; +use anyhow::Result; +use die::Die; +use log::{error, trace}; +use rustls::{ServerConfig, ServerConnection}; -#[cfg(feature = "outgoing")] -pub async fn tls_connect(target: SocketAddr, server_name: &str, config: OutgoingVerifierConfig) -> Result<(StanzaWrite, StanzaRead)> { - let dnsname = ServerName::try_from(server_name)?; - let stream = tokio::net::TcpStream::connect(target).await?; - let stream = config.connector_alpn.connect(dnsname, stream).await?; - let (rd, wrt) = tokio::io::split(stream); - Ok((StanzaWrite::new(wrt), StanzaRead::new(rd))) +use std::sync::Arc; +use tokio::{ + io::{AsyncBufReadExt, AsyncWriteExt, BufStream}, + net::TcpListener, + task::JoinHandle, +}; +use tokio_rustls::TlsAcceptor; + +pub fn tls_acceptor(server_config: ServerConfig) -> TlsAcceptor { + TlsAcceptor::from(Arc::new(server_config)) } -#[cfg(feature = "outgoing")] -pub async fn starttls_connect(target: SocketAddr, server_name: &str, stream_open: &[u8], in_filter: &mut StanzaFilter, config: OutgoingVerifierConfig) -> Result<(StanzaWrite, StanzaRead)> { - let dnsname = ServerName::try_from(server_name)?; - let mut stream = tokio::net::TcpStream::connect(target).await?; - let (in_rd, mut in_wr) = stream.split(); - - // send the stream_open - trace!("starttls sending: {} '{}'", server_name, to_str(stream_open)); - in_wr.write_all(stream_open).await?; - in_wr.flush().await?; - - // we naively read 1 byte at a time, which buffering significantly speeds up - let in_rd = tokio::io::BufReader::with_capacity(IN_BUFFER_SIZE, in_rd); - let mut in_rd = StanzaReader(in_rd); - let mut proceed_received = false; - - trace!("starttls reading stream open {}", server_name); - while let Ok(Some(buf)) = in_rd.next(in_filter).await { - trace!("received pre-tls stanza: {} '{}'", server_name, to_str(buf)); - if buf.starts_with(b""###; - trace!("> {} '{}'", server_name, to_str(buf)); - in_wr.write_all(buf).await?; - in_wr.flush().await?; - } else if buf.starts_with(b") -> Result { - Ok(TlsAcceptor::from(Arc::new(self.server_config(cert_key)?))) - } -} - -#[cfg(feature = "incoming")] pub fn spawn_tls_listener(local_addr: SocketAddr, config: CloneableConfig, acceptor: TlsAcceptor) -> JoinHandle> { tokio::spawn(async move { let listener = TcpListener::bind(&local_addr).await.die("cannot listen on port/interface"); @@ -83,7 +42,6 @@ pub fn spawn_tls_listener(local_addr: SocketAddr, config: CloneableConfig, accep }) } -#[cfg(feature = "incoming")] async fn handle_tls_connection(mut stream: tokio::net::TcpStream, client_addr: &mut Context<'_>, local_addr: SocketAddr, config: CloneableConfig, acceptor: TlsAcceptor) -> Result<()> { info!("{} connected", client_addr.log_from()); @@ -183,7 +141,7 @@ async fn handle_tls_connection(mut stream: tokio::net::TcpStream, client_addr: & { let stream: tokio_rustls::TlsStream = stream.into(); - let mut stream = BufStream::with_capacity(crate::IN_BUFFER_SIZE, 0, stream); + let mut stream = BufStream::with_capacity(IN_BUFFER_SIZE, 0, stream); let websocket = { // wait up to 10 seconds until 3 bytes have been read use std::time::{Duration, Instant}; @@ -207,7 +165,7 @@ async fn handle_tls_connection(mut stream: tokio::net::TcpStream, client_addr: & }; if websocket { - handle_websocket_connection(Box::new(stream), config, server_certs, local_addr, client_addr, in_filter).await + crate::websocket::incoming::handle_websocket_connection(Box::new(stream), config, server_certs, local_addr, client_addr, in_filter).await } else { let (in_rd, in_wr) = tokio::io::split(stream); shuffle_rd_wr_filter(StanzaRead::already_buffered(in_rd), StanzaWrite::new(in_wr), config, server_certs, local_addr, client_addr, in_filter).await diff --git a/src/tls/mod.rs b/src/tls/mod.rs new file mode 100644 index 0000000..82d4c6a --- /dev/null +++ b/src/tls/mod.rs @@ -0,0 +1,5 @@ +#[cfg(feature = "incoming")] +pub mod incoming; + +#[cfg(feature = "outgoing")] +pub mod outgoing; diff --git a/src/tls/outgoing.rs b/src/tls/outgoing.rs new file mode 100644 index 0000000..6b62d78 --- /dev/null +++ b/src/tls/outgoing.rs @@ -0,0 +1,61 @@ +use crate::{ + common::{outgoing::OutgoingVerifierConfig, to_str, IN_BUFFER_SIZE}, + in_out::{StanzaRead, StanzaWrite}, + stanzafilter::{StanzaFilter, StanzaReader}, +}; +use anyhow::{bail, Result}; +use log::{debug, trace}; +use rustls::ServerName; +use std::{convert::TryFrom, net::SocketAddr}; +use tokio::io::AsyncWriteExt; + +pub async fn tls_connect(target: SocketAddr, server_name: &str, config: OutgoingVerifierConfig) -> Result<(StanzaWrite, StanzaRead)> { + let dnsname = ServerName::try_from(server_name)?; + let stream = tokio::net::TcpStream::connect(target).await?; + let stream = config.connector_alpn.connect(dnsname, stream).await?; + let (rd, wrt) = tokio::io::split(stream); + Ok((StanzaWrite::new(wrt), StanzaRead::new(rd))) +} + +pub async fn starttls_connect(target: SocketAddr, server_name: &str, stream_open: &[u8], in_filter: &mut StanzaFilter, config: OutgoingVerifierConfig) -> Result<(StanzaWrite, StanzaRead)> { + let dnsname = ServerName::try_from(server_name)?; + let mut stream = tokio::net::TcpStream::connect(target).await?; + let (in_rd, mut in_wr) = stream.split(); + + // send the stream_open + trace!("starttls sending: {} '{}'", server_name, to_str(stream_open)); + in_wr.write_all(stream_open).await?; + in_wr.flush().await?; + + // we naively read 1 byte at a time, which buffering significantly speeds up + let in_rd = tokio::io::BufReader::with_capacity(IN_BUFFER_SIZE, in_rd); + let mut in_rd = StanzaReader(in_rd); + let mut proceed_received = false; + + trace!("starttls reading stream open {}", server_name); + while let Ok(Some(buf)) = in_rd.next(in_filter).await { + trace!("received pre-tls stanza: {} '{}'", server_name, to_str(buf)); + if buf.starts_with(b""###; + trace!("> {} '{}'", server_name, to_str(buf)); + in_wr.write_all(buf).await?; + in_wr.flush().await?; + } else if buf.starts_with(b", + config: CloneableConfig, + server_certs: ServerCerts, + local_addr: SocketAddr, + client_addr: &mut Context<'_>, + in_filter: StanzaFilter, +) -> Result<()> { + client_addr.set_proto("websocket-in"); + info!("{} connected", client_addr.log_from()); + + let (in_rd, in_wr) = incoming_websocket_connection(stream, config.max_stanza_size_bytes).await?; + + shuffle_rd_wr_filter(in_rd, in_wr, config, server_certs, local_addr, client_addr, in_filter).await +} diff --git a/src/websocket.rs b/src/websocket/mod.rs similarity index 77% rename from src/websocket.rs rename to src/websocket/mod.rs index 09ca3f5..89e372a 100644 --- a/src/websocket.rs +++ b/src/websocket/mod.rs @@ -1,9 +1,14 @@ -use crate::*; use anyhow::Result; use futures::StreamExt; use tokio_tungstenite::tungstenite::protocol::WebSocketConfig; +#[cfg(feature = "incoming")] +pub mod incoming; + +#[cfg(feature = "outgoing")] +pub mod outgoing; + // https://datatracker.ietf.org/doc/html/rfc7395 fn ws_cfg(max_stanza_size_bytes: usize) -> Option { @@ -29,23 +34,6 @@ pub async fn incoming_websocket_connection(stream: Box, - config: CloneableConfig, - server_certs: ServerCerts, - local_addr: SocketAddr, - client_addr: &mut Context<'_>, - in_filter: StanzaFilter, -) -> Result<()> { - client_addr.set_proto("websocket-in"); - info!("{} connected", client_addr.log_from()); - - let (in_rd, in_wr) = incoming_websocket_connection(stream, config.max_stanza_size_bytes).await?; - - shuffle_rd_wr_filter(in_rd, in_wr, config, server_certs, local_addr, client_addr, in_filter).await -} - pub fn from_ws(stanza: String) -> String { if stanza.starts_with(" Resul Ok(ret) } -use rustls::ServerName; -use std::convert::TryFrom; - -use tokio_tungstenite::tungstenite::client::IntoClientRequest; -use tokio_tungstenite::tungstenite::http::header::{ORIGIN, SEC_WEBSOCKET_PROTOCOL}; -use tokio_tungstenite::tungstenite::http::Uri; - -#[cfg(feature = "outgoing")] -pub async fn websocket_connect(target: SocketAddr, server_name: &str, url: &Uri, origin: &str, config: OutgoingVerifierConfig) -> Result<(StanzaWrite, StanzaRead)> { - let mut request = url.into_client_request()?; - request.headers_mut().append(SEC_WEBSOCKET_PROTOCOL, "xmpp".parse()?); - request.headers_mut().append(ORIGIN, origin.parse()?); - - let dnsname = ServerName::try_from(server_name)?; - let stream = tokio::net::TcpStream::connect(target).await?; - let stream = config.connector.connect(dnsname, stream).await?; - - //let stream: tokio_rustls::TlsStream = stream.into(); - // todo: tokio_tungstenite seems to have a bug, if the write buffer is non-zero, it'll hang forever, even though we always flush, investigate - //let stream = BufStream::with_capacity(crate::IN_BUFFER_SIZE, 0, stream); - let stream: Box = Box::new(stream); - - let (stream, _) = tokio_tungstenite::client_async_with_config(request, stream, ws_cfg(config.max_stanza_size_bytes)).await?; - - let (wrt, rd) = stream.split(); - - Ok((StanzaWrite::WebSocketClientWrite(wrt), StanzaRead::WebSocketRead(rd))) -} +use crate::{ + in_out::{StanzaRead, StanzaWrite}, + slicesubsequence::SliceSubsequence, +}; #[cfg(test)] mod tests { diff --git a/src/websocket/outgoing.rs b/src/websocket/outgoing.rs new file mode 100644 index 0000000..bfa8670 --- /dev/null +++ b/src/websocket/outgoing.rs @@ -0,0 +1,37 @@ +use crate::{ + common::outgoing::OutgoingVerifierConfig, + in_out::{StanzaRead, StanzaWrite}, + websocket::{ws_cfg, AsyncReadAndWrite}, +}; +use anyhow::Result; +use futures_util::StreamExt; +use rustls::ServerName; +use std::{convert::TryFrom, net::SocketAddr}; +use tokio_tungstenite::tungstenite::{ + client::IntoClientRequest, + http::{ + header::{ORIGIN, SEC_WEBSOCKET_PROTOCOL}, + Uri, + }, +}; + +pub async fn websocket_connect(target: SocketAddr, server_name: &str, url: &Uri, origin: &str, config: OutgoingVerifierConfig) -> Result<(StanzaWrite, StanzaRead)> { + let mut request = url.into_client_request()?; + request.headers_mut().append(SEC_WEBSOCKET_PROTOCOL, "xmpp".parse()?); + request.headers_mut().append(ORIGIN, origin.parse()?); + + let dnsname = ServerName::try_from(server_name)?; + let stream = tokio::net::TcpStream::connect(target).await?; + let stream = config.connector.connect(dnsname, stream).await?; + + //let stream: tokio_rustls::TlsStream = stream.into(); + // todo: tokio_tungstenite seems to have a bug, if the write buffer is non-zero, it'll hang forever, even though we always flush, investigate + //let stream = BufStream::with_capacity(crate::IN_BUFFER_SIZE, 0, stream); + let stream: Box = Box::new(stream); + + let (stream, _) = tokio_tungstenite::client_async_with_config(request, stream, ws_cfg(config.max_stanza_size_bytes)).await?; + + let (wrt, rd) = stream.split(); + + Ok((StanzaWrite::WebSocketClientWrite(wrt), StanzaRead::WebSocketRead(rd))) +}