diff --git a/.ci/Jenkinsfile b/.ci/Jenkinsfile index 803a002..98eec80 100644 --- a/.ci/Jenkinsfile +++ b/.ci/Jenkinsfile @@ -28,9 +28,14 @@ node('linux && docker') { stage('Build + Deploy') { sh ''' + ./check-all-features.sh || exit 1 + cargo clean mkdir -p release cp xmpp-proxy.toml release curl --compressed -sL https://code.moparisthebest.com/moparisthebest/self-ci/raw/branch/master/build-ci.sh | bash + ret=$? + docker system prune -af + exit $ret ''' } diff --git a/.ci/build.sh b/.ci/build.sh index 1a36f47..4a2fead 100755 --- a/.ci/build.sh +++ b/.ci/build.sh @@ -11,6 +11,10 @@ echo "$TARGET" | grep -E '^x86_64-pc-windows-gnu$' >/dev/null && SUFFIX=".exe" # ring fails to compile here echo "$TARGET" | grep -E '^(s390x|powerpc|mips|riscv64gc|.*solaris$)' >/dev/null && echo "$TARGET not supported in rustls" && exit 0 + +# running `docker system prune -af` after these because they are roughly every 25% through and my hard drive space is limited +echo "$TARGET" | grep -E '^(armv7-unknown-linux-gnueabihf|x86_64-linux-android|mips-unknown-linux-gnu)$' >/dev/null && docker system prune -af + # mio fails to link here echo "$TARGET" | grep -E '^x86_64-unknown-netbsd$' >/dev/null && echo "$TARGET not supported in mio" && exit 0 diff --git a/.gitignore b/.gitignore index 9621941..ae6e578 100644 --- a/.gitignore +++ b/.gitignore @@ -6,5 +6,5 @@ **/out/ **/core.* fuzz/target/ -todo.txt +*.txt conflict/ diff --git a/Cargo.toml b/Cargo.toml index cb0c1c2..75782c1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,7 +32,7 @@ serde = { version = "1.0", features = ["derive"] } futures = "0.3" die = "0.2" anyhow = "1.0" -tokio = { version = "1.9", features = ["net", "rt", "rt-multi-thread", "macros", "io-util", "signal"] } +tokio = { version = "1.9", features = ["net", "rt", "rt-multi-thread", "macros", "io-util", "signal", "time"] } ring = "0.16" data-encoding = "2.3" async-trait = "0.1" @@ -101,5 +101,8 @@ websocket = ["tokio-tungstenite", "futures-util", "tls"] # websocket+incoming al logging = ["rand", "env_logger"] systemd = ["nix"] +# enables unit tests that need network and therefore may be flaky +net-test = [] + [dev-dependencies] serde_json = "1.0" diff --git a/check-all-features.sh b/check-all-features.sh index 47e0570..45d35e4 100755 --- a/check-all-features.sh +++ b/check-all-features.sh @@ -1,12 +1,20 @@ #!/bin/bash threads="$1" +shift +clean_after_num_builds="$1" set -euo pipefail # if we have access to nproc, divide that by 2, otherwise use 1 thread by default [ "$threads" == "" ] && threads=$(($(nproc || echo 2) / 2)) +# 50 is about 1.5gb, ymmv +[ "$clean_after_num_builds" == "" ] && clean_after_num_builds=50 + +export clean_after_num_builds + echo "threads: $threads" +echo "clean_after_num_builds: $clean_after_num_builds" export RUSTFLAGS=-Awarnings @@ -65,12 +73,26 @@ echo_cargo() { #echo cargo run "$@" -- -v #cargo run "$@" -- -v echo cargo check "$@" - cargo check "$@" + flock -s /tmp/xmpp-proxy-check-all-features.lock cargo check "$@" ret=$? if [ $ret -ne 0 ] then - echo "features failed: $@" + echo "command failed: cargo check $@" fi + ( + flock -x 200 + # now we are under an exclusive lock + count=$(cat /tmp/xmpp-proxy-check-all-features.count) + count=$(( count + 1 )) + if [ $count -ge $clean_after_num_builds ] + then + echo cargo clean + cargo clean + count=0 + fi + echo $count > /tmp/xmpp-proxy-check-all-features.count + + ) 200>/tmp/xmpp-proxy-check-all-features.lock return $ret } @@ -78,6 +100,8 @@ echo_cargo() { export -f echo_cargo +echo 0 > /tmp/xmpp-proxy-check-all-features.count + echo_cargo all_features | sort | xargs -n1 --max-procs=$threads bash -c 'echo_cargo --no-default-features --features "$@" || exit 255' _ diff --git a/src/common/incoming.rs b/src/common/incoming.rs index 6d6ff6d..2a410b1 100644 --- a/src/common/incoming.rs +++ b/src/common/incoming.rs @@ -1,5 +1,5 @@ use crate::{ - common::{c2s, certs_key::CertsKey, shuffle_rd_wr_filter_only, stream_preamble, to_str, ALPN_XMPP_CLIENT, ALPN_XMPP_SERVER}, + common::{c2s, certs_key::CertsKey, shuffle_rd_wr_filter_only, stream_preamble, to_str, SocketAddrPath, ALPN_XMPP_CLIENT, ALPN_XMPP_SERVER}, context::Context, in_out::{StanzaRead, StanzaWrite}, slicesubsequence::SliceSubsequence, @@ -8,16 +8,16 @@ use crate::{ use anyhow::{anyhow, bail, Result}; use log::trace; use rustls::{Certificate, ServerConfig, ServerConnection}; -use std::{io::Write, net::SocketAddr, sync::Arc}; -use tokio::io::{AsyncWriteExt, ReadHalf, WriteHalf}; -#[derive(Clone)] -pub struct CloneableConfig { +use std::{io::Write, net::SocketAddr, sync::Arc}; +use tokio::io::AsyncWriteExt; + +pub struct IncomingConfig { pub max_stanza_size_bytes: usize, #[cfg(feature = "s2s-incoming")] - pub s2s_target: Option, + pub s2s_target: Option, #[cfg(feature = "c2s-incoming")] - pub c2s_target: Option, + pub c2s_target: Option, pub proxy: bool, } @@ -85,7 +85,7 @@ impl ServerCerts { } } -pub async fn shuffle_rd_wr(in_rd: StanzaRead, in_wr: StanzaWrite, config: CloneableConfig, server_certs: ServerCerts, local_addr: SocketAddr, client_addr: &mut Context<'_>) -> Result<()> { +pub async fn shuffle_rd_wr(in_rd: StanzaRead, in_wr: StanzaWrite, config: Arc, server_certs: ServerCerts, local_addr: SocketAddr, client_addr: &mut Context<'_>) -> Result<()> { let filter = StanzaFilter::new(config.max_stanza_size_bytes); shuffle_rd_wr_filter(in_rd, in_wr, config, server_certs, local_addr, client_addr, filter).await } @@ -93,7 +93,7 @@ pub async fn shuffle_rd_wr(in_rd: StanzaRead, in_wr: StanzaWrite, config: Clonea pub async fn shuffle_rd_wr_filter( mut in_rd: StanzaRead, mut in_wr: StanzaWrite, - config: CloneableConfig, + config: Arc, server_certs: ServerCerts, local_addr: SocketAddr, client_addr: &mut Context<'_>, @@ -131,43 +131,33 @@ pub async fn shuffle_rd_wr_filter( let (out_rd, out_wr) = open_incoming(&config, local_addr, client_addr, &stream_open, is_c2s, &mut in_filter).await?; drop(stream_open); - shuffle_rd_wr_filter_only( - in_rd, - in_wr, - StanzaRead::new(out_rd), - StanzaWrite::new(out_wr), - is_c2s, - config.max_stanza_size_bytes, - client_addr, - in_filter, - ) - .await + shuffle_rd_wr_filter_only(in_rd, in_wr, out_rd, out_wr, is_c2s, config.max_stanza_size_bytes, client_addr, in_filter).await } async fn open_incoming( - config: &CloneableConfig, + config: &IncomingConfig, local_addr: SocketAddr, client_addr: &mut Context<'_>, stream_open: &[u8], is_c2s: bool, in_filter: &mut StanzaFilter, -) -> Result<(ReadHalf, WriteHalf)> { - let target: Option = if is_c2s { +) -> Result<(StanzaRead, StanzaWrite)> { + let target: &Option = if is_c2s { #[cfg(not(feature = "c2s-incoming"))] bail!("incoming c2s connection but lacking compile-time support"); #[cfg(feature = "c2s-incoming")] - config.c2s_target + &config.c2s_target } else { #[cfg(not(feature = "s2s-incoming"))] bail!("incoming s2s connection but lacking compile-time support"); #[cfg(feature = "s2s-incoming")] - config.s2s_target + &config.s2s_target }; - let target = target.ok_or_else(|| anyhow!("incoming connection but `{}_target` not defined", c2s(is_c2s)))?; - client_addr.set_to_addr(target); + let target = target.as_ref().ok_or_else(|| anyhow!("incoming connection but `{}_target` not defined", c2s(is_c2s)))?; + client_addr.set_to_addr(target.to_string()); - let out_stream = tokio::net::TcpStream::connect(target).await?; - let (out_rd, mut out_wr) = tokio::io::split(out_stream); + let (out_rd, mut out_wr) = target.connect().await?; + let out_rd = StanzaRead::new(out_rd); if config.proxy { /* @@ -194,5 +184,5 @@ async fn open_incoming( trace!("{} '{}'", client_addr.log_from(), to_str(stream_open)); out_wr.write_all(stream_open).await?; out_wr.flush().await?; - Ok((out_rd, out_wr)) + Ok((out_rd, StanzaWrite::AsyncWrite(out_wr))) } diff --git a/src/common/mod.rs b/src/common/mod.rs index f14ec8e..3f1484e 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -12,10 +12,20 @@ use rustls::{ sign::{RsaSigningKey, SigningKey}, Certificate, PrivateKey, }; -use std::{fs::File, io, sync::Arc}; +use serde::{Deserialize, Deserializer}; +use std::{ + fmt::{Display, Formatter}, + fs::File, + io, + net::{SocketAddr, UdpSocket}, + path::PathBuf, + sync::Arc, +}; +#[cfg(not(target_os = "windows"))] +use tokio::net::UnixStream; use tokio::{ - io::{AsyncRead, AsyncWrite, BufReader, BufStream}, - net::TcpStream, + io::{AsyncBufRead, AsyncRead, AsyncWrite, BufReader, BufStream}, + net::{TcpListener, TcpStream}, }; #[cfg(feature = "incoming")] @@ -29,11 +39,26 @@ pub mod ca_roots; #[cfg(feature = "rustls")] pub mod certs_key; +pub mod stream_listener; pub const IN_BUFFER_SIZE: usize = 8192; pub const ALPN_XMPP_CLIENT: &[u8] = b"xmpp-client"; pub const ALPN_XMPP_SERVER: &[u8] = b"xmpp-server"; +pub trait AsyncReadAndWrite: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send {} +impl AsyncReadAndWrite for T {} + +pub trait AsyncReadWritePeekSplit: tokio::io::AsyncRead + tokio::io::AsyncWrite + Peek + Send + 'static + Unpin + Split {} +impl AsyncReadWritePeekSplit for T {} + +pub type BoxAsyncReadWrite = Box; +pub type BufAsyncReadWrite = BufStream; + +pub fn buf_stream(stream: BoxAsyncReadWrite) -> BufAsyncReadWrite { + // todo: do we *want* a non-zero writer_capacity ? + BufStream::with_capacity(IN_BUFFER_SIZE, 0, stream) +} + pub fn to_str(buf: &[u8]) -> std::borrow::Cow<'_, str> { String::from_utf8_lossy(buf) } @@ -46,13 +71,97 @@ pub fn c2s(is_c2s: bool) -> &'static str { } } +#[derive(Clone)] +pub enum SocketAddrPath { + SocketAddr(SocketAddr), + #[cfg(not(target_os = "windows"))] + Path(PathBuf), +} + +impl SocketAddrPath { + pub async fn connect(&self) -> Result<(Box, Box)> { + Ok(match self { + SocketAddrPath::SocketAddr(sa) => TcpStream::connect(sa).await?.split_boxed(), + #[cfg(not(target_os = "windows"))] + SocketAddrPath::Path(path) => tokio::net::UnixStream::connect(path).await?.split_boxed(), + }) + } + + pub async fn bind(&self) -> Result { + Ok(match self { + SocketAddrPath::SocketAddr(sa) => Listener::Tcp(TcpListener::bind(sa).await?), + #[cfg(not(target_os = "windows"))] + SocketAddrPath::Path(path) => Listener::Unix(tokio::net::UnixListener::bind(path)?), + }) + } + + pub async fn bind_udp(&self) -> Result { + Ok(match self { + SocketAddrPath::SocketAddr(sa) => UdpListener::Udp(UdpSocket::bind(sa)?), + #[cfg(not(target_os = "windows"))] + SocketAddrPath::Path(path) => UdpListener::Unix(std::os::unix::net::UnixDatagram::bind(path)?), + }) + } +} + +impl Display for SocketAddrPath { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + SocketAddrPath::SocketAddr(x) => x.fmt(f), + #[cfg(not(target_os = "windows"))] + SocketAddrPath::Path(x) => x.display().fmt(f), + } + } +} + +impl<'de> Deserialize<'de> for SocketAddrPath { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + #[cfg(not(target_os = "windows"))] + { + let str = String::deserialize(deserializer)?; + // seems good enough, possibly could improve + Ok(if str.contains('/') { + SocketAddrPath::Path(PathBuf::from(str)) + } else { + SocketAddrPath::SocketAddr(str.parse().map_err(serde::de::Error::custom)?) + }) + } + #[cfg(target_os = "windows")] + { + Ok(SocketAddrPath::SocketAddr(SocketAddr::deserialize(deserializer)?)) + } + } +} + +pub enum Listener { + Tcp(TcpListener), + #[cfg(not(target_os = "windows"))] + Unix(tokio::net::UnixListener), +} + +pub enum UdpListener { + Udp(UdpSocket), + #[cfg(not(target_os = "windows"))] + Unix(std::os::unix::net::UnixDatagram), +} + pub trait Split: Sized { - type ReadHalf: AsyncRead + Unpin; - type WriteHalf: AsyncWrite + Unpin; + type ReadHalf: AsyncRead + Unpin + Send + 'static; + type WriteHalf: AsyncWrite + Unpin + Send + 'static; fn combine(read_half: Self::ReadHalf, write_half: Self::WriteHalf) -> Result; fn split(self) -> (Self::ReadHalf, Self::WriteHalf); + + fn stanza_rw(self) -> (StanzaRead, StanzaWrite); + + fn split_boxed(self) -> (Box, Box) { + let (rd, wr) = self.split(); + (Box::new(rd), Box::new(wr)) + } } impl Split for TcpStream { @@ -66,9 +175,56 @@ impl Split for TcpStream { fn split(self) -> (Self::ReadHalf, Self::WriteHalf) { self.into_split() } + + fn stanza_rw(self) -> (StanzaRead, StanzaWrite) { + let (in_rd, in_wr) = self.into_split(); + (StanzaRead::new(in_rd), StanzaWrite::new(in_wr)) + } } -impl Split for BufStream { +#[cfg(feature = "tokio-rustls")] +impl Split for tokio_rustls::server::TlsStream { + type ReadHalf = tokio::io::ReadHalf>; + type WriteHalf = tokio::io::WriteHalf>; + + fn combine(read_half: Self::ReadHalf, write_half: Self::WriteHalf) -> Result { + if read_half.is_pair_of(&write_half) { + Ok(read_half.unsplit(write_half)) + } else { + bail!("non-matching read/write half") + } + } + + fn split(self) -> (Self::ReadHalf, Self::WriteHalf) { + tokio::io::split(self) + } + + fn stanza_rw(self) -> (StanzaRead, StanzaWrite) { + let (in_rd, in_wr) = tokio::io::split(self); + (StanzaRead::new(in_rd), StanzaWrite::new(in_wr)) + } +} + +#[cfg(not(target_os = "windows"))] +impl Split for UnixStream { + type ReadHalf = tokio::net::unix::OwnedReadHalf; + type WriteHalf = tokio::net::unix::OwnedWriteHalf; + + fn combine(read_half: Self::ReadHalf, write_half: Self::WriteHalf) -> Result { + Ok(read_half.reunite(write_half)?) + } + + fn split(self) -> (Self::ReadHalf, Self::WriteHalf) { + self.into_split() + } + + fn stanza_rw(self) -> (StanzaRead, StanzaWrite) { + let (in_rd, in_wr) = self.into_split(); + (StanzaRead::new(in_rd), StanzaWrite::new(in_wr)) + } +} + +impl Split for BufStream { type ReadHalf = tokio::io::ReadHalf>; type WriteHalf = tokio::io::WriteHalf>; @@ -83,13 +239,18 @@ impl Split for BufStream { fn split(self) -> (Self::ReadHalf, Self::WriteHalf) { tokio::io::split(self) } + + fn stanza_rw(self) -> (StanzaRead, StanzaWrite) { + let (in_rd, in_wr) = tokio::io::split(self); + (StanzaRead::already_buffered(in_rd), StanzaWrite::new(in_wr)) + } } #[async_trait] pub trait Peek { async fn peek_bytes<'a>(&mut self, p: &'a mut [u8]) -> anyhow::Result<&'a [u8]>; - async fn first_bytes_match<'a>(&mut self, p: &'a mut [u8], matcher: fn(&'a [u8]) -> bool) -> anyhow::Result { + async fn first_bytes_match<'a>(&mut self, p: &'a mut [u8], matcher: fn(&'a [u8]) -> bool) -> Result { Ok(matcher(self.peek_bytes(p).await?)) } } @@ -170,6 +331,23 @@ impl Peek for BufReader { } } +/// Caution: this will loop forever, call timeout variant `first_bytes_match_buf_timeout` +async fn first_bytes_match_buf(stream: &mut (dyn AsyncBufRead + Send + Unpin), len: usize, matcher: fn(&[u8]) -> bool) -> Result { + use tokio::io::AsyncBufReadExt; + loop { + let buf = stream.fill_buf().await?; + if buf.len() >= len { + return Ok(matcher(&buf[0..len])); + } + } +} + +pub async fn first_bytes_match_buf_timeout(stream: &mut (dyn AsyncBufRead + Send + Unpin), len: usize, matcher: fn(&[u8]) -> bool) -> Result { + // wait up to 10 seconds until 3 bytes have been read + use std::time::Duration; + tokio::time::timeout(Duration::from_secs(10), first_bytes_match_buf(stream, len, matcher)).await? +} + pub async fn stream_preamble(in_rd: &mut StanzaRead, in_wr: &mut StanzaWrite, client_addr: &'_ str, in_filter: &mut StanzaFilter) -> Result<(Vec, bool)> { let mut stream_open = Vec::new(); while let Ok(Some((buf, _))) = in_rd.next(in_filter, client_addr, in_wr).await { diff --git a/src/common/stream_listener.rs b/src/common/stream_listener.rs new file mode 100644 index 0000000..77e70b9 --- /dev/null +++ b/src/common/stream_listener.rs @@ -0,0 +1,111 @@ +use crate::common::AsyncReadWritePeekSplit; +use anyhow::Result; +use async_trait::async_trait; +use std::io::{Error, IoSlice}; + +use std::{ + net::SocketAddr, + pin::Pin, + task::{Context, Poll}, +}; + +#[cfg(not(target_os = "windows"))] +use tokio::net::{UnixListener, UnixStream}; +use tokio::{ + io::{AsyncRead, AsyncWrite, BufStream, ReadBuf}, + net::{TcpListener, TcpStream}, +}; + +#[async_trait] +pub trait StreamListener: Send + Sync + 'static { + type Stream: AsyncReadWritePeekSplit; + async fn accept(&self) -> Result<(Self::Stream, SocketAddr)>; + fn local_addr(&self) -> Result; +} + +#[async_trait] +impl StreamListener for TcpListener { + type Stream = TcpStream; + async fn accept(&self) -> Result<(Self::Stream, SocketAddr)> { + Ok(self.accept().await?) + } + + fn local_addr(&self) -> Result { + Ok(self.local_addr()?) + } +} + +#[cfg(not(target_os = "windows"))] +#[async_trait] +impl StreamListener for UnixListener { + type Stream = BufStream; + async fn accept(&self) -> Result<(Self::Stream, SocketAddr)> { + let (stream, _client_addr) = self.accept().await?; + // todo: real SocketAddr + let client_addr: SocketAddr = "127.0.0.1:0".parse()?; + Ok((BufStream::new(stream), client_addr)) + } + + fn local_addr(&self) -> Result { + // todo: real SocketAddr + Ok("127.0.0.1:0".parse()?) + } +} + +pub enum AllStream { + Tcp(TcpStream), + #[cfg(not(target_os = "windows"))] + Unix(UnixStream), +} + +impl AsyncRead for AllStream { + fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { + match self.get_mut() { + AllStream::Tcp(s) => Pin::new(s).poll_read(cx, buf), + #[cfg(not(target_os = "windows"))] + AllStream::Unix(s) => Pin::new(s).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for AllStream { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + match self.get_mut() { + AllStream::Tcp(s) => Pin::new(s).poll_write(cx, buf), + #[cfg(not(target_os = "windows"))] + AllStream::Unix(s) => Pin::new(s).poll_write(cx, buf), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + AllStream::Tcp(s) => Pin::new(s).poll_flush(cx), + #[cfg(not(target_os = "windows"))] + AllStream::Unix(s) => Pin::new(s).poll_flush(cx), + } + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + AllStream::Tcp(s) => Pin::new(s).poll_shutdown(cx), + #[cfg(not(target_os = "windows"))] + AllStream::Unix(s) => Pin::new(s).poll_shutdown(cx), + } + } + + fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>]) -> Poll> { + match self.get_mut() { + AllStream::Tcp(s) => Pin::new(s).poll_write_vectored(cx, bufs), + #[cfg(not(target_os = "windows"))] + AllStream::Unix(s) => Pin::new(s).poll_write_vectored(cx, bufs), + } + } + + fn is_write_vectored(&self) -> bool { + match self { + AllStream::Tcp(s) => s.is_write_vectored(), + #[cfg(not(target_os = "windows"))] + AllStream::Unix(s) => s.is_write_vectored(), + } + } +} diff --git a/src/context.rs b/src/context.rs index 380e186..14b361d 100644 --- a/src/context.rs +++ b/src/context.rs @@ -13,7 +13,7 @@ pub struct Context<'a> { proto: &'a str, is_c2s: Option, to: Option, - to_addr: Option, + to_addr: Option, from: Option, client_addr: SocketAddr, } @@ -59,7 +59,7 @@ impl<'a> Context<'a> { if self.from.is_some() { self.from.as_ref().unwrap() } else { "unk" }, self.proto, if self.is_c2s.is_some() { c2s(self.is_c2s.unwrap()) } else { "unk" }, - if self.to_addr.is_some() { self.to_addr.as_ref().unwrap().to_string() } else { "unk".to_string() }, + if self.to_addr.is_some() { self.to_addr.as_ref().unwrap() } else { "unk" }, if self.to.is_some() { self.to.as_ref().unwrap() } else { "unk" }, ); self.log_to = self.log_from.replace(" -> ", " <- "); @@ -103,7 +103,7 @@ impl<'a> Context<'a> { } } - pub fn set_to_addr(&mut self, to_addr: SocketAddr) { + pub fn set_to_addr(&mut self, to_addr: String) { if log_enabled!(log::Level::Info) { self.to_addr = Some(to_addr); self.re_calc(); diff --git a/src/main.rs b/src/main.rs index 1b10038..c30f66d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,35 +3,42 @@ use anyhow::Result; use die::{die, Die}; use log::{debug, info}; use serde_derive::Deserialize; -use std::{ - ffi::OsString, - fs::File, - io::Read, - iter::Iterator, - net::{SocketAddr, UdpSocket}, - path::Path, - sync::Arc, -}; +use std::{ffi::OsString, fs::File, io::Read, iter::Iterator, path::Path, sync::Arc}; use tokio::{net::TcpListener, task::JoinHandle}; -use xmpp_proxy::common::certs_key::CertsKey; +use xmpp_proxy::common::{certs_key::CertsKey, Listener, SocketAddrPath, UdpListener}; + +#[cfg(not(target_os = "windows"))] +use tokio::net::UnixListener; + #[cfg(feature = "outgoing")] use xmpp_proxy::{common::outgoing::OutgoingConfig, outgoing::spawn_outgoing_listener}; #[derive(Deserialize, Default)] struct Config { + #[serde(default)] tls_key: String, + #[serde(default)] tls_cert: String, - incoming_listen: Vec, - quic_listen: Vec, - outgoing_listen: Vec, + #[serde(default)] + incoming_listen: Vec, + #[serde(default)] + quic_listen: Vec, + #[serde(default)] + outgoing_listen: Vec, + #[serde(default = "default_max_stanza_size_bytes")] max_stanza_size_bytes: usize, - s2s_target: Option, - c2s_target: Option, + s2s_target: Option, + c2s_target: Option, + #[serde(default)] proxy: bool, log_level: Option, log_style: Option, } +fn default_max_stanza_size_bytes() -> usize { + 262_144 +} + impl Config { fn parse>(path: P) -> Result { let mut f = File::open(path)?; @@ -41,13 +48,13 @@ impl Config { } #[cfg(feature = "incoming")] - fn get_cloneable_cfg(&self) -> xmpp_proxy::common::incoming::CloneableConfig { - xmpp_proxy::common::incoming::CloneableConfig { + fn get_incoming_cfg(&self) -> xmpp_proxy::common::incoming::IncomingConfig { + xmpp_proxy::common::incoming::IncomingConfig { max_stanza_size_bytes: self.max_stanza_size_bytes, #[cfg(feature = "s2s-incoming")] - s2s_target: self.s2s_target, + s2s_target: self.s2s_target.clone(), #[cfg(feature = "c2s-incoming")] - c2s_target: self.c2s_target, + c2s_target: self.c2s_target.clone(), proxy: self.proxy, } } @@ -137,31 +144,31 @@ async fn main() { let mut incoming_listen = Vec::new(); for a in main_config.incoming_listen.iter() { - incoming_listen.push(TcpListener::bind(a).await.die("cannot listen on port/interface")); + incoming_listen.push(a.bind().await.die("cannot listen on port/interface/socket")); } let mut quic_listen = Vec::new(); for a in main_config.quic_listen.iter() { - quic_listen.push(UdpSocket::bind(a).die("cannot listen on port/interface")); + quic_listen.push(a.bind_udp().await.die("cannot listen on port/interface/socket")); } let mut outgoing_listen = Vec::new(); for a in main_config.outgoing_listen.iter() { - outgoing_listen.push(TcpListener::bind(a).await.die("cannot listen on port/interface")); + outgoing_listen.push(a.bind().await.die("cannot listen on port/interface/socket")); } #[cfg(all(feature = "nix", not(target_os = "windows")))] if let Ok(fds) = xmpp_proxy::systemd::receive_descriptors_with_names(true) { - use xmpp_proxy::systemd::Listener; + use xmpp_proxy::systemd::SystemdListener; for fd in fds { match fd.listener() { - Listener::Tcp(tcp_listener) => { - let tcp_listener = TcpListener::from_std(tcp_listener()).die("cannot open systemd TcpListener"); + SystemdListener::Tcp(tcp_listener) => { + let listener = Listener::Tcp(TcpListener::from_std(tcp_listener()).die("cannot open systemd TcpListener")); if let Some(name) = fd.name().map(|n| n.to_ascii_lowercase()) { if name.starts_with("in") { - incoming_listen.push(tcp_listener); + incoming_listen.push(listener); } else if name.starts_with("out") { - outgoing_listen.push(tcp_listener); + outgoing_listen.push(listener); } else { - die!("systemd socket name must start with 'in' or 'out' but is '{}'", name); + die!("systemd TCP socket name must start with 'in' or 'out' but is '{}'", name); } } else { // what to do here... for now we will require names @@ -169,14 +176,29 @@ async fn main() { die!("systemd TCP socket activation requires name that starts with 'in' or 'out'"); } } - Listener::Udp(udp_socket) => quic_listen.push(udp_socket()), - _ => continue, + SystemdListener::UnixListener(unix_listener) => { + let listener = Listener::Unix(UnixListener::from_std(unix_listener()).die("cannot open systemd UnixListener")); + if let Some(name) = fd.name().map(|n| n.to_ascii_lowercase()) { + if name.starts_with("in") { + incoming_listen.push(listener); + } else if name.starts_with("out") { + outgoing_listen.push(listener); + } else { + die!("systemd Unix socket name must start with 'in' or 'out' but is '{}'", name); + } + } else { + // what to do here... for now we will require names + die!("systemd Unix socket activation requires name that starts with 'in' or 'out'"); + } + } + SystemdListener::Udp(udp_socket) => quic_listen.push(UdpListener::Udp(udp_socket())), + SystemdListener::UnixDatagram(unix_datagram) => quic_listen.push(UdpListener::Unix(unix_datagram())), } } } #[cfg(feature = "incoming")] - let config = main_config.get_cloneable_cfg(); + let config = Arc::new(main_config.get_incoming_cfg()); let certs_key = Arc::new(CertsKey::new(main_config.certs_key())); @@ -193,7 +215,13 @@ async fn main() { } let acceptor = tls_acceptor(server_config(certs_key.clone()).die("invalid cert/key ?")); for listener in incoming_listen { - handles.push(spawn_tls_listener(listener, config.clone(), acceptor.clone())); + // todo: first is slower at runtime but smaller executable size, second opposite + //handles.push(spawn_tls_listener(listener, config.clone(), acceptor.clone())); + match listener { + Listener::Tcp(listener) => handles.push(spawn_tls_listener(listener, config.clone(), acceptor.clone())), + #[cfg(not(target_os = "windows"))] + Listener::Unix(listener) => handles.push(spawn_tls_listener(listener, config.clone(), acceptor.clone())), + } } } #[cfg(not(all(any(feature = "tls", feature = "websocket"), feature = "incoming")))] @@ -211,7 +239,13 @@ async fn main() { } let quic_config = quic_server_config(server_config(certs_key.clone()).die("invalid cert/key ?")); for listener in quic_listen { - handles.push(spawn_quic_listener(listener, config.clone(), quic_config.clone())); + // todo: maybe write a way to Box this thing for smaller executable sizes + //handles.push(spawn_quic_listener(listener, config.clone(), quic_config.clone())); + match listener { + UdpListener::Udp(listener) => handles.push(spawn_quic_listener(listener, config.clone(), quic_config.clone())), + #[cfg(not(target_os = "windows"))] + UdpListener::Unix(listener) => handles.push(xmpp_proxy::quic::incoming::spawn_quic_listener_unix(listener, config.clone(), quic_config.clone())), + } } } #[cfg(not(all(feature = "quic", feature = "incoming")))] @@ -222,8 +256,15 @@ async fn main() { { let outgoing_cfg = main_config.get_outgoing_cfg(certs_key.clone()); for listener in outgoing_listen { - handles.push(spawn_outgoing_listener(listener, outgoing_cfg.clone())); + // todo: first is slower at runtime but smaller executable size, second opposite + //handles.push(spawn_outgoing_listener(listener, outgoing_cfg.clone())); + match listener { + Listener::Tcp(listener) => handles.push(spawn_outgoing_listener(listener, outgoing_cfg.clone())), + #[cfg(not(target_os = "windows"))] + Listener::Unix(listener) => handles.push(spawn_outgoing_listener(listener, outgoing_cfg.clone())), + } } + //#[cfg(not(target_os = "windows"))] } #[cfg(not(feature = "outgoing"))] die!("outgoing_listen non-empty but c2s-outgoing and s2s-outgoing disabled at compile-time"); diff --git a/src/outgoing.rs b/src/outgoing.rs index 1b8e7f8..8012d44 100644 --- a/src/outgoing.rs +++ b/src/outgoing.rs @@ -1,16 +1,16 @@ use crate::{ - common::{outgoing::OutgoingConfig, shuffle_rd_wr_filter_only, stream_preamble, Peek}, + common::{outgoing::OutgoingConfig, shuffle_rd_wr_filter_only, stream_listener::StreamListener, stream_preamble, AsyncReadWritePeekSplit}, context::Context, - in_out::{StanzaRead, StanzaWrite}, slicesubsequence::SliceSubsequence, srv::srv_connect, stanzafilter::StanzaFilter, }; use anyhow::Result; use log::{error, info}; -use tokio::{net::TcpListener, task::JoinHandle}; -async fn handle_outgoing_connection(mut stream: tokio::net::TcpStream, client_addr: &mut Context<'_>, config: OutgoingConfig) -> Result<()> { +use tokio::task::JoinHandle; + +async fn handle_outgoing_connection(mut stream: S, client_addr: &mut Context<'_>, config: OutgoingConfig) -> Result<()> { info!("{} connected", client_addr.log_from()); let mut in_filter = StanzaFilter::new(config.max_stanza_size_bytes); @@ -19,15 +19,11 @@ async fn handle_outgoing_connection(mut stream: tokio::net::TcpStream, client_ad let (mut in_rd, mut in_wr) = if stream.first_bytes_match(&mut in_filter.buf[0..3], |p| p == b"GET").await? { crate::websocket::incoming_websocket_connection(Box::new(stream), config.max_stanza_size_bytes).await? } else { - let (in_rd, in_wr) = tokio::io::split(stream); - (StanzaRead::new(in_rd), StanzaWrite::new(in_wr)) + stream.stanza_rw() }; #[cfg(not(feature = "websocket"))] - let (mut in_rd, mut in_wr) = { - let (in_rd, in_wr) = tokio::io::split(stream); - (StanzaRead::new(in_rd), StanzaWrite::new(in_wr)) - }; + let (mut in_rd, mut in_wr) = stream.stanza_rw(); // now read to figure out client vs server let (stream_open, is_c2s) = stream_preamble(&mut in_rd, &mut in_wr, client_addr.log_to(), &mut in_filter).await?; @@ -46,13 +42,13 @@ async fn handle_outgoing_connection(mut stream: tokio::net::TcpStream, client_ad shuffle_rd_wr_filter_only(in_rd, in_wr, out_rd, out_wr, is_c2s, max_stanza_size_bytes, client_addr, in_filter).await } -pub fn spawn_outgoing_listener(listener: TcpListener, config: OutgoingConfig) -> JoinHandle> { +pub fn spawn_outgoing_listener(listener: impl StreamListener, config: OutgoingConfig) -> JoinHandle> { tokio::spawn(async move { loop { let (stream, client_addr) = listener.accept().await?; + let mut client_addr = Context::new("unk-out", client_addr); let config = config.clone(); tokio::spawn(async move { - let mut client_addr = Context::new("unk-out", client_addr); if let Err(e) = handle_outgoing_connection(stream, &mut client_addr, config).await { error!("{} {}", client_addr.log_from(), e); } diff --git a/src/quic/incoming.rs b/src/quic/incoming.rs index fa36a9d..114d938 100644 --- a/src/quic/incoming.rs +++ b/src/quic/incoming.rs @@ -1,18 +1,34 @@ use crate::{ - common::incoming::{shuffle_rd_wr, CloneableConfig, ServerCerts}, + common::incoming::{shuffle_rd_wr, IncomingConfig, ServerCerts}, context::Context, in_out::{StanzaRead, StanzaWrite}, }; use anyhow::Result; use die::Die; use log::{error, info}; -use quinn::{Endpoint, EndpointConfig, ServerConfig, TokioRuntime}; -use std::{net::UdpSocket, sync::Arc}; +use quinn::{AsyncUdpSocket, Endpoint, EndpointConfig, ServerConfig, TokioRuntime}; +use std::{ + net::{SocketAddr, UdpSocket}, + sync::Arc, +}; use tokio::task::JoinHandle; -pub fn spawn_quic_listener(udp_socket: UdpSocket, config: CloneableConfig, server_config: ServerConfig) -> JoinHandle> { +#[cfg(not(target_os = "windows"))] +pub fn spawn_quic_listener_unix(udp_socket: std::os::unix::net::UnixDatagram, config: Arc, server_config: ServerConfig) -> JoinHandle> { + let udp_socket = crate::quic::unix_datagram::wrap_unix_udp_socket(udp_socket).die("cannot wrap unix udp socket"); + // todo: fake local_addr + let local_addr = udp_socket.local_addr().die("cannot get local_addr for quic socket"); + let incoming = Endpoint::new_with_abstract_socket(EndpointConfig::default(), Some(server_config), udp_socket, Arc::new(TokioRuntime)).die("cannot listen on port/interface"); + internal_spawn_quic_listener(incoming, local_addr, config) +} + +pub fn spawn_quic_listener(udp_socket: UdpSocket, config: Arc, server_config: ServerConfig) -> JoinHandle> { let local_addr = udp_socket.local_addr().die("cannot get local_addr for quic socket"); let incoming = Endpoint::new(EndpointConfig::default(), Some(server_config), udp_socket, Arc::new(TokioRuntime)).die("cannot listen on port/interface"); + internal_spawn_quic_listener(incoming, local_addr, config) +} + +fn internal_spawn_quic_listener(incoming: Endpoint, local_addr: SocketAddr, config: Arc) -> JoinHandle> { tokio::spawn(async move { // when could this return None, do we quit? while let Some(incoming_conn) = incoming.accept().await { diff --git a/src/quic/mod.rs b/src/quic/mod.rs index 3a7bba1..4af4608 100644 --- a/src/quic/mod.rs +++ b/src/quic/mod.rs @@ -1,4 +1,7 @@ -use crate::common::Split; +use crate::{ + common::Split, + in_out::{StanzaRead, StanzaWrite}, +}; use anyhow::bail; use quinn::{RecvStream, SendStream}; use std::{ @@ -14,6 +17,9 @@ pub mod incoming; #[cfg(feature = "outgoing")] pub mod outgoing; +#[cfg(all(feature = "incoming", not(target_os = "windows")))] +pub mod unix_datagram; + pub struct QuicStream { pub send: SendStream, pub recv: RecvStream, @@ -54,4 +60,8 @@ impl Split for QuicStream { fn split(self) -> (Self::ReadHalf, Self::WriteHalf) { (self.recv, self.send) } + + fn stanza_rw(self) -> (StanzaRead, StanzaWrite) { + (StanzaRead::new(self.recv), StanzaWrite::new(self.send)) + } } diff --git a/src/quic/unix_datagram.rs b/src/quic/unix_datagram.rs new file mode 100644 index 0000000..ba8b664 --- /dev/null +++ b/src/quic/unix_datagram.rs @@ -0,0 +1,63 @@ +use quinn::{udp, AsyncUdpSocket}; + +use std::{ + io, + task::{Context, Poll}, +}; +use tokio::net::UnixDatagram; + +use tokio::io::Interest; + +macro_rules! ready { + ($e:expr $(,)?) => { + match $e { + std::task::Poll::Ready(t) => t, + std::task::Poll::Pending => return std::task::Poll::Pending, + } + }; +} + +pub fn wrap_unix_udp_socket(sock: std::os::unix::net::UnixDatagram) -> io::Result { + udp::UdpSocketState::configure((&sock).into())?; + Ok(UnixUdpSocket { + io: UnixDatagram::from_std(sock)?, + inner: udp::UdpSocketState::new(), + }) +} + +#[derive(Debug)] +pub struct UnixUdpSocket { + io: UnixDatagram, + inner: udp::UdpSocketState, +} + +impl AsyncUdpSocket for UnixUdpSocket { + fn poll_send(&self, state: &udp::UdpState, cx: &mut Context, transmits: &[udp::Transmit]) -> Poll> { + let inner = &self.inner; + let io = &self.io; + loop { + ready!(io.poll_send_ready(cx))?; + if let Ok(res) = io.try_io(Interest::WRITABLE, || inner.send(io.into(), state, transmits)) { + return Poll::Ready(Ok(res)); + } + } + } + + fn poll_recv(&self, cx: &mut Context, bufs: &mut [std::io::IoSliceMut<'_>], meta: &mut [udp::RecvMeta]) -> Poll> { + loop { + ready!(self.io.poll_recv_ready(cx))?; + if let Ok(res) = self.io.try_io(Interest::READABLE, || self.inner.recv((&self.io).into(), bufs, meta)) { + return Poll::Ready(Ok(res)); + } + } + } + + fn local_addr(&self) -> io::Result { + // todo: real SocketAddr + Ok("127.0.0.1:0".parse().expect("this one is hardcoded and fine")) + } + + fn may_fragment(&self) -> bool { + udp::may_fragment() + } +} diff --git a/src/srv.rs b/src/srv.rs index 3ffbd89..ed0a68f 100644 --- a/src/srv.rs +++ b/src/srv.rs @@ -473,7 +473,7 @@ pub async fn srv_connect( let (mut out_wr, mut out_rd, to_addr, proto) = connect.unwrap(); // if any of these ? returns early with an Err, these will stay set, I think that's ok though, the connection will be closed client_addr.set_proto(proto); - client_addr.set_to_addr(to_addr); + client_addr.set_to_addr(to_addr.to_string()); debug!("{} connected", client_addr.log_from()); trace!("{} '{}'", client_addr.log_from(), to_str(stream_open)); @@ -906,7 +906,8 @@ mod tests { println!("posh: {:?}", posh); } - //#[tokio::test] + #[cfg(feature = "net-test")] + #[tokio::test] async fn posh() -> Result<()> { let domain = "posh.badxmpp.eu"; let posh = collect_posh(domain).await.unwrap(); @@ -914,7 +915,8 @@ mod tests { Ok(()) } - //#[tokio::test] + #[cfg(feature = "net-test")] + #[tokio::test] async fn srv() -> Result<()> { let domain = "burtrum.org"; let is_c2s = true; @@ -930,7 +932,8 @@ mod tests { Ok(()) } - //#[tokio::test] + #[cfg(feature = "net-test")] + #[tokio::test] async fn http() -> Result<()> { let mut hosts = Vec::new(); let mut sha256_pinnedpubkeys = Vec::new(); @@ -990,77 +993,78 @@ mod tests { #[test] fn test_dedup() { let domain = "example.org"; - let mut ret = Vec::new(); - ret.push(XmppConnection { - priority: 10, - weight: 0, - target: domain.to_string(), - conn_type: XmppConnectionType::DirectTLS, - port: 443, - secure: false, - ips: Vec::new(), - ech: None, - }); - ret.push(XmppConnection { - priority: 0, - weight: 0, - target: domain.to_string(), - conn_type: XmppConnectionType::StartTLS, - port: 5222, - secure: false, - ips: Vec::new(), - ech: None, - }); - ret.push(XmppConnection { - priority: 15, - weight: 0, - target: domain.to_string(), - conn_type: XmppConnectionType::DirectTLS, - port: 443, - secure: true, - ips: Vec::new(), - ech: None, - }); - ret.push(XmppConnection { - priority: 10, - weight: 0, - target: domain.to_string(), - conn_type: XmppConnectionType::DirectTLS, - port: 443, - secure: true, - ips: Vec::new(), - ech: None, - }); - ret.push(XmppConnection { - priority: 10, - weight: 50, - target: domain.to_string(), - conn_type: XmppConnectionType::DirectTLS, - port: 443, - secure: true, - ips: Vec::new(), - ech: None, - }); - ret.push(XmppConnection { - priority: 10, - weight: 100, - target: "example.com".to_string(), - conn_type: XmppConnectionType::DirectTLS, - port: 443, - secure: true, - ips: Vec::new(), - ech: None, - }); - ret.push(XmppConnection { - priority: 0, - weight: 100, - target: "example.com".to_string(), - conn_type: XmppConnectionType::DirectTLS, - port: 443, - secure: true, - ips: Vec::new(), - ech: None, - }); + let mut ret = vec![ + XmppConnection { + priority: 10, + weight: 0, + target: domain.to_string(), + conn_type: XmppConnectionType::DirectTLS, + port: 443, + secure: false, + ips: Vec::new(), + ech: None, + }, + XmppConnection { + priority: 0, + weight: 0, + target: domain.to_string(), + conn_type: XmppConnectionType::StartTLS, + port: 5222, + secure: false, + ips: Vec::new(), + ech: None, + }, + XmppConnection { + priority: 15, + weight: 0, + target: domain.to_string(), + conn_type: XmppConnectionType::DirectTLS, + port: 443, + secure: true, + ips: Vec::new(), + ech: None, + }, + XmppConnection { + priority: 10, + weight: 0, + target: domain.to_string(), + conn_type: XmppConnectionType::DirectTLS, + port: 443, + secure: true, + ips: Vec::new(), + ech: None, + }, + XmppConnection { + priority: 10, + weight: 50, + target: domain.to_string(), + conn_type: XmppConnectionType::DirectTLS, + port: 443, + secure: true, + ips: Vec::new(), + ech: None, + }, + XmppConnection { + priority: 10, + weight: 100, + target: "example.com".to_string(), + conn_type: XmppConnectionType::DirectTLS, + port: 443, + secure: true, + ips: Vec::new(), + ech: None, + }, + XmppConnection { + priority: 0, + weight: 100, + target: "example.com".to_string(), + conn_type: XmppConnectionType::DirectTLS, + port: 443, + secure: true, + ips: Vec::new(), + ech: None, + }, + ]; sort_dedup(&mut ret); println!("ret dedup: {:?}", ret); } diff --git a/src/systemd.rs b/src/systemd.rs index 36b05ab..e41b643 100644 --- a/src/systemd.rs +++ b/src/systemd.rs @@ -24,7 +24,7 @@ pub struct FileDescriptor { pub name: Option, } -pub enum Listener { +pub enum SystemdListener { Tcp(Box TcpListener>), Udp(Box UdpSocket>), UnixListener(Box UnixListener>), @@ -36,13 +36,13 @@ impl FileDescriptor { self.name } - pub fn listener(&self) -> Listener { + pub fn listener(&self) -> SystemdListener { let raw_fd = self.raw_fd; match (self.tcp_not_udp, self.inet_not_unix) { - (true, true) => Listener::Tcp(Box::new(move || unsafe { TcpListener::from_raw_fd(raw_fd) })), - (false, true) => Listener::Udp(Box::new(move || unsafe { UdpSocket::from_raw_fd(raw_fd) })), - (true, false) => Listener::UnixListener(Box::new(move || unsafe { UnixListener::from_raw_fd(raw_fd) })), - (false, false) => Listener::UnixDatagram(Box::new(move || unsafe { UnixDatagram::from_raw_fd(raw_fd) })), + (true, true) => SystemdListener::Tcp(Box::new(move || unsafe { TcpListener::from_raw_fd(raw_fd) })), + (false, true) => SystemdListener::Udp(Box::new(move || unsafe { UdpSocket::from_raw_fd(raw_fd) })), + (true, false) => SystemdListener::UnixListener(Box::new(move || unsafe { UnixListener::from_raw_fd(raw_fd) })), + (false, false) => SystemdListener::UnixDatagram(Box::new(move || unsafe { UnixDatagram::from_raw_fd(raw_fd) })), } } } diff --git a/src/tls/incoming.rs b/src/tls/incoming.rs index 1432424..d04c0a4 100644 --- a/src/tls/incoming.rs +++ b/src/tls/incoming.rs @@ -1,7 +1,9 @@ use crate::{ common::{ - incoming::{shuffle_rd_wr_filter, CloneableConfig, ServerCerts}, - to_str, Peek, Split, IN_BUFFER_SIZE, + first_bytes_match_buf_timeout, + incoming::{shuffle_rd_wr_filter, IncomingConfig, ServerCerts}, + stream_listener::StreamListener, + to_str, AsyncReadWritePeekSplit, Split, IN_BUFFER_SIZE, }, context::Context, in_out::{StanzaRead, StanzaWrite}, @@ -13,8 +15,7 @@ use log::{error, info, trace}; use rustls::{ServerConfig, ServerConnection}; use std::{net::SocketAddr, sync::Arc}; use tokio::{ - io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufStream}, - net::TcpListener, + io::{AsyncWriteExt, BufStream}, task::JoinHandle, }; use tokio_rustls::TlsAcceptor; @@ -23,7 +24,7 @@ pub fn tls_acceptor(server_config: ServerConfig) -> TlsAcceptor { TlsAcceptor::from(Arc::new(server_config)) } -pub fn spawn_tls_listener(listener: TcpListener, config: CloneableConfig, acceptor: TlsAcceptor) -> JoinHandle> { +pub fn spawn_tls_listener(listener: impl StreamListener, config: Arc, acceptor: TlsAcceptor) -> JoinHandle> { tokio::spawn(async move { let local_addr = listener.local_addr()?; loop { @@ -40,13 +41,7 @@ pub fn spawn_tls_listener(listener: TcpListener, config: CloneableConfig, accept }) } -pub async fn handle_tls_connection( - mut stream: S, - client_addr: &mut Context<'_>, - local_addr: SocketAddr, - config: CloneableConfig, - acceptor: TlsAcceptor, -) -> Result<()> { +pub async fn handle_tls_connection(mut stream: S, client_addr: &mut Context<'_>, local_addr: SocketAddr, config: Arc, acceptor: TlsAcceptor) -> Result<()> { info!("{} connected", client_addr.log_from()); let mut in_filter = StanzaFilter::new(config.max_stanza_size_bytes); @@ -67,11 +62,10 @@ pub async fn handle_tls_connection::combine(in_rd.0.into_inner(), in_wr)? + ::combine(in_rd.0, in_wr)? } else { stream }; @@ -143,20 +137,19 @@ pub async fn handle_tls_connection, - config: CloneableConfig, + stream: BoxAsyncReadWrite, + config: Arc, server_certs: ServerCerts, local_addr: SocketAddr, client_addr: &mut Context<'_>, diff --git a/src/websocket/mod.rs b/src/websocket/mod.rs index 3a50e3f..7d53744 100644 --- a/src/websocket/mod.rs +++ b/src/websocket/mod.rs @@ -16,8 +16,8 @@ pub mod incoming; #[cfg(feature = "outgoing")] pub mod outgoing; -pub type WsWr = SplitSink>, tokio_tungstenite::tungstenite::Message>; -pub type WsRd = SplitStream>>; +pub type WsWr = SplitSink, tokio_tungstenite::tungstenite::Message>; +pub type WsRd = SplitStream>; // https://datatracker.ietf.org/doc/html/rfc7395 @@ -30,11 +30,7 @@ fn ws_cfg(max_stanza_size_bytes: usize) -> Option { }) } -pub trait AsyncReadAndWrite: tokio::io::AsyncRead + tokio::io::AsyncWrite {} - -impl AsyncReadAndWrite for T {} - -pub async fn incoming_websocket_connection(stream: Box, max_stanza_size_bytes: usize) -> Result<(StanzaRead, StanzaWrite)> { +pub async fn incoming_websocket_connection(stream: BoxAsyncReadWrite, max_stanza_size_bytes: usize) -> Result<(StanzaRead, StanzaWrite)> { // accept the websocket let stream = tokio_tungstenite::accept_hdr_async_with_config( stream, @@ -118,6 +114,7 @@ pub fn to_ws_new(buf: &[u8], mut end_of_first_tag: usize, is_c2s: bool) -> Resul } use crate::{ + common::BoxAsyncReadWrite, in_out::{StanzaRead, StanzaWrite}, slicesubsequence::SliceSubsequence, }; diff --git a/src/websocket/outgoing.rs b/src/websocket/outgoing.rs index f06cbb5..fd3124f 100644 --- a/src/websocket/outgoing.rs +++ b/src/websocket/outgoing.rs @@ -1,7 +1,7 @@ use crate::{ - common::outgoing::OutgoingVerifierConfig, + common::{outgoing::OutgoingVerifierConfig, BoxAsyncReadWrite}, in_out::{StanzaRead, StanzaWrite}, - websocket::{ws_cfg, AsyncReadAndWrite}, + websocket::ws_cfg, }; use anyhow::Result; use futures_util::StreamExt; @@ -27,7 +27,7 @@ pub async fn websocket_connect(target: SocketAddr, server_name: &str, url: &Uri, //let stream: tokio_rustls::TlsStream = stream.into(); // todo: tokio_tungstenite seems to have a bug, if the write buffer is non-zero, it'll hang forever, even though we always flush, investigate //let stream = BufStream::with_capacity(crate::IN_BUFFER_SIZE, 0, stream); - let stream: Box = Box::new(stream); + let stream: BoxAsyncReadWrite = Box::new(stream); let (stream, _) = tokio_tungstenite::client_async_with_config(request, stream, ws_cfg(config.max_stanza_size_bytes)).await?; diff --git a/xmpp-proxy.toml b/xmpp-proxy.toml index 3e3f6e0..b1f7618 100644 --- a/xmpp-proxy.toml +++ b/xmpp-proxy.toml @@ -3,7 +3,7 @@ incoming_listen = [ "0.0.0.0:5222", "0.0.0.0:5269", "0.0.0.0:443" ] # interfaces to listen for reverse proxy QUIC XMPP connections on, should be open to the internet quic_listen = [ "0.0.0.0:443" ] -# interfaces to listen for outgoing proxy TCP or WebSocket XMPP connections on, should be localhost +# interfaces to listen for outgoing proxy TCP or WebSocket XMPP connections on, should be localhost or a path for a unix socket outgoing_listen = [ "127.0.0.1:15270" ] # these ports shouldn't do any TLS, but should assume any connection from xmpp-proxy is secure