Implement unix socket support in all directions
All checks were successful
moparisthebest/xmpp-proxy/pipeline/head This commit looks good
All checks were successful
moparisthebest/xmpp-proxy/pipeline/head This commit looks good
This commit is contained in:
parent
2459576f52
commit
3bdb461142
5
.ci/Jenkinsfile
vendored
5
.ci/Jenkinsfile
vendored
@ -28,9 +28,14 @@ node('linux && docker') {
|
|||||||
|
|
||||||
stage('Build + Deploy') {
|
stage('Build + Deploy') {
|
||||||
sh '''
|
sh '''
|
||||||
|
./check-all-features.sh || exit 1
|
||||||
|
cargo clean
|
||||||
mkdir -p release
|
mkdir -p release
|
||||||
cp xmpp-proxy.toml release
|
cp xmpp-proxy.toml release
|
||||||
curl --compressed -sL https://code.moparisthebest.com/moparisthebest/self-ci/raw/branch/master/build-ci.sh | bash
|
curl --compressed -sL https://code.moparisthebest.com/moparisthebest/self-ci/raw/branch/master/build-ci.sh | bash
|
||||||
|
ret=$?
|
||||||
|
docker system prune -af
|
||||||
|
exit $ret
|
||||||
'''
|
'''
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -11,6 +11,10 @@ echo "$TARGET" | grep -E '^x86_64-pc-windows-gnu$' >/dev/null && SUFFIX=".exe"
|
|||||||
|
|
||||||
# ring fails to compile here
|
# ring fails to compile here
|
||||||
echo "$TARGET" | grep -E '^(s390x|powerpc|mips|riscv64gc|.*solaris$)' >/dev/null && echo "$TARGET not supported in rustls" && exit 0
|
echo "$TARGET" | grep -E '^(s390x|powerpc|mips|riscv64gc|.*solaris$)' >/dev/null && echo "$TARGET not supported in rustls" && exit 0
|
||||||
|
|
||||||
|
# running `docker system prune -af` after these because they are roughly every 25% through and my hard drive space is limited
|
||||||
|
echo "$TARGET" | grep -E '^(armv7-unknown-linux-gnueabihf|x86_64-linux-android|mips-unknown-linux-gnu)$' >/dev/null && docker system prune -af
|
||||||
|
|
||||||
# mio fails to link here
|
# mio fails to link here
|
||||||
echo "$TARGET" | grep -E '^x86_64-unknown-netbsd$' >/dev/null && echo "$TARGET not supported in mio" && exit 0
|
echo "$TARGET" | grep -E '^x86_64-unknown-netbsd$' >/dev/null && echo "$TARGET not supported in mio" && exit 0
|
||||||
|
|
||||||
|
2
.gitignore
vendored
2
.gitignore
vendored
@ -6,5 +6,5 @@
|
|||||||
**/out/
|
**/out/
|
||||||
**/core.*
|
**/core.*
|
||||||
fuzz/target/
|
fuzz/target/
|
||||||
todo.txt
|
*.txt
|
||||||
conflict/
|
conflict/
|
||||||
|
@ -32,7 +32,7 @@ serde = { version = "1.0", features = ["derive"] }
|
|||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
die = "0.2"
|
die = "0.2"
|
||||||
anyhow = "1.0"
|
anyhow = "1.0"
|
||||||
tokio = { version = "1.9", features = ["net", "rt", "rt-multi-thread", "macros", "io-util", "signal"] }
|
tokio = { version = "1.9", features = ["net", "rt", "rt-multi-thread", "macros", "io-util", "signal", "time"] }
|
||||||
ring = "0.16"
|
ring = "0.16"
|
||||||
data-encoding = "2.3"
|
data-encoding = "2.3"
|
||||||
async-trait = "0.1"
|
async-trait = "0.1"
|
||||||
@ -101,5 +101,8 @@ websocket = ["tokio-tungstenite", "futures-util", "tls"] # websocket+incoming al
|
|||||||
logging = ["rand", "env_logger"]
|
logging = ["rand", "env_logger"]
|
||||||
systemd = ["nix"]
|
systemd = ["nix"]
|
||||||
|
|
||||||
|
# enables unit tests that need network and therefore may be flaky
|
||||||
|
net-test = []
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
|
@ -1,12 +1,20 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
threads="$1"
|
threads="$1"
|
||||||
|
shift
|
||||||
|
clean_after_num_builds="$1"
|
||||||
|
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
# if we have access to nproc, divide that by 2, otherwise use 1 thread by default
|
# if we have access to nproc, divide that by 2, otherwise use 1 thread by default
|
||||||
[ "$threads" == "" ] && threads=$(($(nproc || echo 2) / 2))
|
[ "$threads" == "" ] && threads=$(($(nproc || echo 2) / 2))
|
||||||
|
|
||||||
|
# 50 is about 1.5gb, ymmv
|
||||||
|
[ "$clean_after_num_builds" == "" ] && clean_after_num_builds=50
|
||||||
|
|
||||||
|
export clean_after_num_builds
|
||||||
|
|
||||||
echo "threads: $threads"
|
echo "threads: $threads"
|
||||||
|
echo "clean_after_num_builds: $clean_after_num_builds"
|
||||||
|
|
||||||
export RUSTFLAGS=-Awarnings
|
export RUSTFLAGS=-Awarnings
|
||||||
|
|
||||||
@ -65,12 +73,26 @@ echo_cargo() {
|
|||||||
#echo cargo run "$@" -- -v
|
#echo cargo run "$@" -- -v
|
||||||
#cargo run "$@" -- -v
|
#cargo run "$@" -- -v
|
||||||
echo cargo check "$@"
|
echo cargo check "$@"
|
||||||
cargo check "$@"
|
flock -s /tmp/xmpp-proxy-check-all-features.lock cargo check "$@"
|
||||||
ret=$?
|
ret=$?
|
||||||
if [ $ret -ne 0 ]
|
if [ $ret -ne 0 ]
|
||||||
then
|
then
|
||||||
echo "features failed: $@"
|
echo "command failed: cargo check $@"
|
||||||
fi
|
fi
|
||||||
|
(
|
||||||
|
flock -x 200
|
||||||
|
# now we are under an exclusive lock
|
||||||
|
count=$(cat /tmp/xmpp-proxy-check-all-features.count)
|
||||||
|
count=$(( count + 1 ))
|
||||||
|
if [ $count -ge $clean_after_num_builds ]
|
||||||
|
then
|
||||||
|
echo cargo clean
|
||||||
|
cargo clean
|
||||||
|
count=0
|
||||||
|
fi
|
||||||
|
echo $count > /tmp/xmpp-proxy-check-all-features.count
|
||||||
|
|
||||||
|
) 200>/tmp/xmpp-proxy-check-all-features.lock
|
||||||
return $ret
|
return $ret
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -78,6 +100,8 @@ echo_cargo() {
|
|||||||
|
|
||||||
export -f echo_cargo
|
export -f echo_cargo
|
||||||
|
|
||||||
|
echo 0 > /tmp/xmpp-proxy-check-all-features.count
|
||||||
|
|
||||||
echo_cargo
|
echo_cargo
|
||||||
|
|
||||||
all_features | sort | xargs -n1 --max-procs=$threads bash -c 'echo_cargo --no-default-features --features "$@" || exit 255' _
|
all_features | sort | xargs -n1 --max-procs=$threads bash -c 'echo_cargo --no-default-features --features "$@" || exit 255' _
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
common::{c2s, certs_key::CertsKey, shuffle_rd_wr_filter_only, stream_preamble, to_str, ALPN_XMPP_CLIENT, ALPN_XMPP_SERVER},
|
common::{c2s, certs_key::CertsKey, shuffle_rd_wr_filter_only, stream_preamble, to_str, SocketAddrPath, ALPN_XMPP_CLIENT, ALPN_XMPP_SERVER},
|
||||||
context::Context,
|
context::Context,
|
||||||
in_out::{StanzaRead, StanzaWrite},
|
in_out::{StanzaRead, StanzaWrite},
|
||||||
slicesubsequence::SliceSubsequence,
|
slicesubsequence::SliceSubsequence,
|
||||||
@ -8,16 +8,16 @@ use crate::{
|
|||||||
use anyhow::{anyhow, bail, Result};
|
use anyhow::{anyhow, bail, Result};
|
||||||
use log::trace;
|
use log::trace;
|
||||||
use rustls::{Certificate, ServerConfig, ServerConnection};
|
use rustls::{Certificate, ServerConfig, ServerConnection};
|
||||||
use std::{io::Write, net::SocketAddr, sync::Arc};
|
|
||||||
use tokio::io::{AsyncWriteExt, ReadHalf, WriteHalf};
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
use std::{io::Write, net::SocketAddr, sync::Arc};
|
||||||
pub struct CloneableConfig {
|
use tokio::io::AsyncWriteExt;
|
||||||
|
|
||||||
|
pub struct IncomingConfig {
|
||||||
pub max_stanza_size_bytes: usize,
|
pub max_stanza_size_bytes: usize,
|
||||||
#[cfg(feature = "s2s-incoming")]
|
#[cfg(feature = "s2s-incoming")]
|
||||||
pub s2s_target: Option<SocketAddr>,
|
pub s2s_target: Option<SocketAddrPath>,
|
||||||
#[cfg(feature = "c2s-incoming")]
|
#[cfg(feature = "c2s-incoming")]
|
||||||
pub c2s_target: Option<SocketAddr>,
|
pub c2s_target: Option<SocketAddrPath>,
|
||||||
pub proxy: bool,
|
pub proxy: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -85,7 +85,7 @@ impl ServerCerts {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn shuffle_rd_wr(in_rd: StanzaRead, in_wr: StanzaWrite, config: CloneableConfig, server_certs: ServerCerts, local_addr: SocketAddr, client_addr: &mut Context<'_>) -> Result<()> {
|
pub async fn shuffle_rd_wr(in_rd: StanzaRead, in_wr: StanzaWrite, config: Arc<IncomingConfig>, server_certs: ServerCerts, local_addr: SocketAddr, client_addr: &mut Context<'_>) -> Result<()> {
|
||||||
let filter = StanzaFilter::new(config.max_stanza_size_bytes);
|
let filter = StanzaFilter::new(config.max_stanza_size_bytes);
|
||||||
shuffle_rd_wr_filter(in_rd, in_wr, config, server_certs, local_addr, client_addr, filter).await
|
shuffle_rd_wr_filter(in_rd, in_wr, config, server_certs, local_addr, client_addr, filter).await
|
||||||
}
|
}
|
||||||
@ -93,7 +93,7 @@ pub async fn shuffle_rd_wr(in_rd: StanzaRead, in_wr: StanzaWrite, config: Clonea
|
|||||||
pub async fn shuffle_rd_wr_filter(
|
pub async fn shuffle_rd_wr_filter(
|
||||||
mut in_rd: StanzaRead,
|
mut in_rd: StanzaRead,
|
||||||
mut in_wr: StanzaWrite,
|
mut in_wr: StanzaWrite,
|
||||||
config: CloneableConfig,
|
config: Arc<IncomingConfig>,
|
||||||
server_certs: ServerCerts,
|
server_certs: ServerCerts,
|
||||||
local_addr: SocketAddr,
|
local_addr: SocketAddr,
|
||||||
client_addr: &mut Context<'_>,
|
client_addr: &mut Context<'_>,
|
||||||
@ -131,43 +131,33 @@ pub async fn shuffle_rd_wr_filter(
|
|||||||
let (out_rd, out_wr) = open_incoming(&config, local_addr, client_addr, &stream_open, is_c2s, &mut in_filter).await?;
|
let (out_rd, out_wr) = open_incoming(&config, local_addr, client_addr, &stream_open, is_c2s, &mut in_filter).await?;
|
||||||
drop(stream_open);
|
drop(stream_open);
|
||||||
|
|
||||||
shuffle_rd_wr_filter_only(
|
shuffle_rd_wr_filter_only(in_rd, in_wr, out_rd, out_wr, is_c2s, config.max_stanza_size_bytes, client_addr, in_filter).await
|
||||||
in_rd,
|
|
||||||
in_wr,
|
|
||||||
StanzaRead::new(out_rd),
|
|
||||||
StanzaWrite::new(out_wr),
|
|
||||||
is_c2s,
|
|
||||||
config.max_stanza_size_bytes,
|
|
||||||
client_addr,
|
|
||||||
in_filter,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn open_incoming(
|
async fn open_incoming(
|
||||||
config: &CloneableConfig,
|
config: &IncomingConfig,
|
||||||
local_addr: SocketAddr,
|
local_addr: SocketAddr,
|
||||||
client_addr: &mut Context<'_>,
|
client_addr: &mut Context<'_>,
|
||||||
stream_open: &[u8],
|
stream_open: &[u8],
|
||||||
is_c2s: bool,
|
is_c2s: bool,
|
||||||
in_filter: &mut StanzaFilter,
|
in_filter: &mut StanzaFilter,
|
||||||
) -> Result<(ReadHalf<tokio::net::TcpStream>, WriteHalf<tokio::net::TcpStream>)> {
|
) -> Result<(StanzaRead, StanzaWrite)> {
|
||||||
let target: Option<SocketAddr> = if is_c2s {
|
let target: &Option<SocketAddrPath> = if is_c2s {
|
||||||
#[cfg(not(feature = "c2s-incoming"))]
|
#[cfg(not(feature = "c2s-incoming"))]
|
||||||
bail!("incoming c2s connection but lacking compile-time support");
|
bail!("incoming c2s connection but lacking compile-time support");
|
||||||
#[cfg(feature = "c2s-incoming")]
|
#[cfg(feature = "c2s-incoming")]
|
||||||
config.c2s_target
|
&config.c2s_target
|
||||||
} else {
|
} else {
|
||||||
#[cfg(not(feature = "s2s-incoming"))]
|
#[cfg(not(feature = "s2s-incoming"))]
|
||||||
bail!("incoming s2s connection but lacking compile-time support");
|
bail!("incoming s2s connection but lacking compile-time support");
|
||||||
#[cfg(feature = "s2s-incoming")]
|
#[cfg(feature = "s2s-incoming")]
|
||||||
config.s2s_target
|
&config.s2s_target
|
||||||
};
|
};
|
||||||
let target = target.ok_or_else(|| anyhow!("incoming connection but `{}_target` not defined", c2s(is_c2s)))?;
|
let target = target.as_ref().ok_or_else(|| anyhow!("incoming connection but `{}_target` not defined", c2s(is_c2s)))?;
|
||||||
client_addr.set_to_addr(target);
|
client_addr.set_to_addr(target.to_string());
|
||||||
|
|
||||||
let out_stream = tokio::net::TcpStream::connect(target).await?;
|
let (out_rd, mut out_wr) = target.connect().await?;
|
||||||
let (out_rd, mut out_wr) = tokio::io::split(out_stream);
|
let out_rd = StanzaRead::new(out_rd);
|
||||||
|
|
||||||
if config.proxy {
|
if config.proxy {
|
||||||
/*
|
/*
|
||||||
@ -194,5 +184,5 @@ async fn open_incoming(
|
|||||||
trace!("{} '{}'", client_addr.log_from(), to_str(stream_open));
|
trace!("{} '{}'", client_addr.log_from(), to_str(stream_open));
|
||||||
out_wr.write_all(stream_open).await?;
|
out_wr.write_all(stream_open).await?;
|
||||||
out_wr.flush().await?;
|
out_wr.flush().await?;
|
||||||
Ok((out_rd, out_wr))
|
Ok((out_rd, StanzaWrite::AsyncWrite(out_wr)))
|
||||||
}
|
}
|
||||||
|
@ -12,10 +12,20 @@ use rustls::{
|
|||||||
sign::{RsaSigningKey, SigningKey},
|
sign::{RsaSigningKey, SigningKey},
|
||||||
Certificate, PrivateKey,
|
Certificate, PrivateKey,
|
||||||
};
|
};
|
||||||
use std::{fs::File, io, sync::Arc};
|
use serde::{Deserialize, Deserializer};
|
||||||
|
use std::{
|
||||||
|
fmt::{Display, Formatter},
|
||||||
|
fs::File,
|
||||||
|
io,
|
||||||
|
net::{SocketAddr, UdpSocket},
|
||||||
|
path::PathBuf,
|
||||||
|
sync::Arc,
|
||||||
|
};
|
||||||
|
#[cfg(not(target_os = "windows"))]
|
||||||
|
use tokio::net::UnixStream;
|
||||||
use tokio::{
|
use tokio::{
|
||||||
io::{AsyncRead, AsyncWrite, BufReader, BufStream},
|
io::{AsyncBufRead, AsyncRead, AsyncWrite, BufReader, BufStream},
|
||||||
net::TcpStream,
|
net::{TcpListener, TcpStream},
|
||||||
};
|
};
|
||||||
|
|
||||||
#[cfg(feature = "incoming")]
|
#[cfg(feature = "incoming")]
|
||||||
@ -29,11 +39,26 @@ pub mod ca_roots;
|
|||||||
|
|
||||||
#[cfg(feature = "rustls")]
|
#[cfg(feature = "rustls")]
|
||||||
pub mod certs_key;
|
pub mod certs_key;
|
||||||
|
pub mod stream_listener;
|
||||||
|
|
||||||
pub const IN_BUFFER_SIZE: usize = 8192;
|
pub const IN_BUFFER_SIZE: usize = 8192;
|
||||||
pub const ALPN_XMPP_CLIENT: &[u8] = b"xmpp-client";
|
pub const ALPN_XMPP_CLIENT: &[u8] = b"xmpp-client";
|
||||||
pub const ALPN_XMPP_SERVER: &[u8] = b"xmpp-server";
|
pub const ALPN_XMPP_SERVER: &[u8] = b"xmpp-server";
|
||||||
|
|
||||||
|
pub trait AsyncReadAndWrite: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send {}
|
||||||
|
impl<T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send> AsyncReadAndWrite for T {}
|
||||||
|
|
||||||
|
pub trait AsyncReadWritePeekSplit: tokio::io::AsyncRead + tokio::io::AsyncWrite + Peek + Send + 'static + Unpin + Split {}
|
||||||
|
impl<T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Peek + Send + 'static + Unpin + Split> AsyncReadWritePeekSplit for T {}
|
||||||
|
|
||||||
|
pub type BoxAsyncReadWrite = Box<dyn AsyncReadAndWrite>;
|
||||||
|
pub type BufAsyncReadWrite = BufStream<BoxAsyncReadWrite>;
|
||||||
|
|
||||||
|
pub fn buf_stream(stream: BoxAsyncReadWrite) -> BufAsyncReadWrite {
|
||||||
|
// todo: do we *want* a non-zero writer_capacity ?
|
||||||
|
BufStream::with_capacity(IN_BUFFER_SIZE, 0, stream)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn to_str(buf: &[u8]) -> std::borrow::Cow<'_, str> {
|
pub fn to_str(buf: &[u8]) -> std::borrow::Cow<'_, str> {
|
||||||
String::from_utf8_lossy(buf)
|
String::from_utf8_lossy(buf)
|
||||||
}
|
}
|
||||||
@ -46,13 +71,97 @@ pub fn c2s(is_c2s: bool) -> &'static str {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub enum SocketAddrPath {
|
||||||
|
SocketAddr(SocketAddr),
|
||||||
|
#[cfg(not(target_os = "windows"))]
|
||||||
|
Path(PathBuf),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SocketAddrPath {
|
||||||
|
pub async fn connect(&self) -> Result<(Box<dyn AsyncRead + Unpin + Send>, Box<dyn AsyncWrite + Unpin + Send>)> {
|
||||||
|
Ok(match self {
|
||||||
|
SocketAddrPath::SocketAddr(sa) => TcpStream::connect(sa).await?.split_boxed(),
|
||||||
|
#[cfg(not(target_os = "windows"))]
|
||||||
|
SocketAddrPath::Path(path) => tokio::net::UnixStream::connect(path).await?.split_boxed(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn bind(&self) -> Result<Listener> {
|
||||||
|
Ok(match self {
|
||||||
|
SocketAddrPath::SocketAddr(sa) => Listener::Tcp(TcpListener::bind(sa).await?),
|
||||||
|
#[cfg(not(target_os = "windows"))]
|
||||||
|
SocketAddrPath::Path(path) => Listener::Unix(tokio::net::UnixListener::bind(path)?),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn bind_udp(&self) -> Result<UdpListener> {
|
||||||
|
Ok(match self {
|
||||||
|
SocketAddrPath::SocketAddr(sa) => UdpListener::Udp(UdpSocket::bind(sa)?),
|
||||||
|
#[cfg(not(target_os = "windows"))]
|
||||||
|
SocketAddrPath::Path(path) => UdpListener::Unix(std::os::unix::net::UnixDatagram::bind(path)?),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for SocketAddrPath {
|
||||||
|
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
SocketAddrPath::SocketAddr(x) => x.fmt(f),
|
||||||
|
#[cfg(not(target_os = "windows"))]
|
||||||
|
SocketAddrPath::Path(x) => x.display().fmt(f),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'de> Deserialize<'de> for SocketAddrPath {
|
||||||
|
fn deserialize<D>(deserializer: D) -> Result<SocketAddrPath, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
#[cfg(not(target_os = "windows"))]
|
||||||
|
{
|
||||||
|
let str = String::deserialize(deserializer)?;
|
||||||
|
// seems good enough, possibly could improve
|
||||||
|
Ok(if str.contains('/') {
|
||||||
|
SocketAddrPath::Path(PathBuf::from(str))
|
||||||
|
} else {
|
||||||
|
SocketAddrPath::SocketAddr(str.parse().map_err(serde::de::Error::custom)?)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
#[cfg(target_os = "windows")]
|
||||||
|
{
|
||||||
|
Ok(SocketAddrPath::SocketAddr(SocketAddr::deserialize(deserializer)?))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum Listener {
|
||||||
|
Tcp(TcpListener),
|
||||||
|
#[cfg(not(target_os = "windows"))]
|
||||||
|
Unix(tokio::net::UnixListener),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum UdpListener {
|
||||||
|
Udp(UdpSocket),
|
||||||
|
#[cfg(not(target_os = "windows"))]
|
||||||
|
Unix(std::os::unix::net::UnixDatagram),
|
||||||
|
}
|
||||||
|
|
||||||
pub trait Split: Sized {
|
pub trait Split: Sized {
|
||||||
type ReadHalf: AsyncRead + Unpin;
|
type ReadHalf: AsyncRead + Unpin + Send + 'static;
|
||||||
type WriteHalf: AsyncWrite + Unpin;
|
type WriteHalf: AsyncWrite + Unpin + Send + 'static;
|
||||||
|
|
||||||
fn combine(read_half: Self::ReadHalf, write_half: Self::WriteHalf) -> Result<Self>;
|
fn combine(read_half: Self::ReadHalf, write_half: Self::WriteHalf) -> Result<Self>;
|
||||||
|
|
||||||
fn split(self) -> (Self::ReadHalf, Self::WriteHalf);
|
fn split(self) -> (Self::ReadHalf, Self::WriteHalf);
|
||||||
|
|
||||||
|
fn stanza_rw(self) -> (StanzaRead, StanzaWrite);
|
||||||
|
|
||||||
|
fn split_boxed(self) -> (Box<dyn AsyncRead + Unpin + Send>, Box<dyn AsyncWrite + Unpin + Send>) {
|
||||||
|
let (rd, wr) = self.split();
|
||||||
|
(Box::new(rd), Box::new(wr))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Split for TcpStream {
|
impl Split for TcpStream {
|
||||||
@ -66,9 +175,56 @@ impl Split for TcpStream {
|
|||||||
fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
|
fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
|
||||||
self.into_split()
|
self.into_split()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn stanza_rw(self) -> (StanzaRead, StanzaWrite) {
|
||||||
|
let (in_rd, in_wr) = self.into_split();
|
||||||
|
(StanzaRead::new(in_rd), StanzaWrite::new(in_wr))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: AsyncRead + AsyncWrite + Unpin + Send> Split for BufStream<T> {
|
#[cfg(feature = "tokio-rustls")]
|
||||||
|
impl<T: AsyncRead + AsyncWrite + Unpin + Send + 'static> Split for tokio_rustls::server::TlsStream<T> {
|
||||||
|
type ReadHalf = tokio::io::ReadHalf<tokio_rustls::server::TlsStream<T>>;
|
||||||
|
type WriteHalf = tokio::io::WriteHalf<tokio_rustls::server::TlsStream<T>>;
|
||||||
|
|
||||||
|
fn combine(read_half: Self::ReadHalf, write_half: Self::WriteHalf) -> Result<Self> {
|
||||||
|
if read_half.is_pair_of(&write_half) {
|
||||||
|
Ok(read_half.unsplit(write_half))
|
||||||
|
} else {
|
||||||
|
bail!("non-matching read/write half")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
|
||||||
|
tokio::io::split(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn stanza_rw(self) -> (StanzaRead, StanzaWrite) {
|
||||||
|
let (in_rd, in_wr) = tokio::io::split(self);
|
||||||
|
(StanzaRead::new(in_rd), StanzaWrite::new(in_wr))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(target_os = "windows"))]
|
||||||
|
impl Split for UnixStream {
|
||||||
|
type ReadHalf = tokio::net::unix::OwnedReadHalf;
|
||||||
|
type WriteHalf = tokio::net::unix::OwnedWriteHalf;
|
||||||
|
|
||||||
|
fn combine(read_half: Self::ReadHalf, write_half: Self::WriteHalf) -> Result<Self> {
|
||||||
|
Ok(read_half.reunite(write_half)?)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
|
||||||
|
self.into_split()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn stanza_rw(self) -> (StanzaRead, StanzaWrite) {
|
||||||
|
let (in_rd, in_wr) = self.into_split();
|
||||||
|
(StanzaRead::new(in_rd), StanzaWrite::new(in_wr))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: AsyncRead + AsyncWrite + Unpin + Send + 'static> Split for BufStream<T> {
|
||||||
type ReadHalf = tokio::io::ReadHalf<BufStream<T>>;
|
type ReadHalf = tokio::io::ReadHalf<BufStream<T>>;
|
||||||
type WriteHalf = tokio::io::WriteHalf<BufStream<T>>;
|
type WriteHalf = tokio::io::WriteHalf<BufStream<T>>;
|
||||||
|
|
||||||
@ -83,13 +239,18 @@ impl<T: AsyncRead + AsyncWrite + Unpin + Send> Split for BufStream<T> {
|
|||||||
fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
|
fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
|
||||||
tokio::io::split(self)
|
tokio::io::split(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn stanza_rw(self) -> (StanzaRead, StanzaWrite) {
|
||||||
|
let (in_rd, in_wr) = tokio::io::split(self);
|
||||||
|
(StanzaRead::already_buffered(in_rd), StanzaWrite::new(in_wr))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait Peek {
|
pub trait Peek {
|
||||||
async fn peek_bytes<'a>(&mut self, p: &'a mut [u8]) -> anyhow::Result<&'a [u8]>;
|
async fn peek_bytes<'a>(&mut self, p: &'a mut [u8]) -> anyhow::Result<&'a [u8]>;
|
||||||
|
|
||||||
async fn first_bytes_match<'a>(&mut self, p: &'a mut [u8], matcher: fn(&'a [u8]) -> bool) -> anyhow::Result<bool> {
|
async fn first_bytes_match<'a>(&mut self, p: &'a mut [u8], matcher: fn(&'a [u8]) -> bool) -> Result<bool> {
|
||||||
Ok(matcher(self.peek_bytes(p).await?))
|
Ok(matcher(self.peek_bytes(p).await?))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -170,6 +331,23 @@ impl<T: AsyncRead + Unpin + Send> Peek for BufReader<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Caution: this will loop forever, call timeout variant `first_bytes_match_buf_timeout`
|
||||||
|
async fn first_bytes_match_buf(stream: &mut (dyn AsyncBufRead + Send + Unpin), len: usize, matcher: fn(&[u8]) -> bool) -> Result<bool> {
|
||||||
|
use tokio::io::AsyncBufReadExt;
|
||||||
|
loop {
|
||||||
|
let buf = stream.fill_buf().await?;
|
||||||
|
if buf.len() >= len {
|
||||||
|
return Ok(matcher(&buf[0..len]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn first_bytes_match_buf_timeout(stream: &mut (dyn AsyncBufRead + Send + Unpin), len: usize, matcher: fn(&[u8]) -> bool) -> Result<bool> {
|
||||||
|
// wait up to 10 seconds until 3 bytes have been read
|
||||||
|
use std::time::Duration;
|
||||||
|
tokio::time::timeout(Duration::from_secs(10), first_bytes_match_buf(stream, len, matcher)).await?
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn stream_preamble(in_rd: &mut StanzaRead, in_wr: &mut StanzaWrite, client_addr: &'_ str, in_filter: &mut StanzaFilter) -> Result<(Vec<u8>, bool)> {
|
pub async fn stream_preamble(in_rd: &mut StanzaRead, in_wr: &mut StanzaWrite, client_addr: &'_ str, in_filter: &mut StanzaFilter) -> Result<(Vec<u8>, bool)> {
|
||||||
let mut stream_open = Vec::new();
|
let mut stream_open = Vec::new();
|
||||||
while let Ok(Some((buf, _))) = in_rd.next(in_filter, client_addr, in_wr).await {
|
while let Ok(Some((buf, _))) = in_rd.next(in_filter, client_addr, in_wr).await {
|
||||||
|
111
src/common/stream_listener.rs
Normal file
111
src/common/stream_listener.rs
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
use crate::common::AsyncReadWritePeekSplit;
|
||||||
|
use anyhow::Result;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use std::io::{Error, IoSlice};
|
||||||
|
|
||||||
|
use std::{
|
||||||
|
net::SocketAddr,
|
||||||
|
pin::Pin,
|
||||||
|
task::{Context, Poll},
|
||||||
|
};
|
||||||
|
|
||||||
|
#[cfg(not(target_os = "windows"))]
|
||||||
|
use tokio::net::{UnixListener, UnixStream};
|
||||||
|
use tokio::{
|
||||||
|
io::{AsyncRead, AsyncWrite, BufStream, ReadBuf},
|
||||||
|
net::{TcpListener, TcpStream},
|
||||||
|
};
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait StreamListener: Send + Sync + 'static {
|
||||||
|
type Stream: AsyncReadWritePeekSplit;
|
||||||
|
async fn accept(&self) -> Result<(Self::Stream, SocketAddr)>;
|
||||||
|
fn local_addr(&self) -> Result<SocketAddr>;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl StreamListener for TcpListener {
|
||||||
|
type Stream = TcpStream;
|
||||||
|
async fn accept(&self) -> Result<(Self::Stream, SocketAddr)> {
|
||||||
|
Ok(self.accept().await?)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn local_addr(&self) -> Result<SocketAddr> {
|
||||||
|
Ok(self.local_addr()?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(target_os = "windows"))]
|
||||||
|
#[async_trait]
|
||||||
|
impl StreamListener for UnixListener {
|
||||||
|
type Stream = BufStream<UnixStream>;
|
||||||
|
async fn accept(&self) -> Result<(Self::Stream, SocketAddr)> {
|
||||||
|
let (stream, _client_addr) = self.accept().await?;
|
||||||
|
// todo: real SocketAddr
|
||||||
|
let client_addr: SocketAddr = "127.0.0.1:0".parse()?;
|
||||||
|
Ok((BufStream::new(stream), client_addr))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn local_addr(&self) -> Result<SocketAddr> {
|
||||||
|
// todo: real SocketAddr
|
||||||
|
Ok("127.0.0.1:0".parse()?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum AllStream {
|
||||||
|
Tcp(TcpStream),
|
||||||
|
#[cfg(not(target_os = "windows"))]
|
||||||
|
Unix(UnixStream),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncRead for AllStream {
|
||||||
|
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
|
||||||
|
match self.get_mut() {
|
||||||
|
AllStream::Tcp(s) => Pin::new(s).poll_read(cx, buf),
|
||||||
|
#[cfg(not(target_os = "windows"))]
|
||||||
|
AllStream::Unix(s) => Pin::new(s).poll_read(cx, buf),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncWrite for AllStream {
|
||||||
|
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::result::Result<usize, Error>> {
|
||||||
|
match self.get_mut() {
|
||||||
|
AllStream::Tcp(s) => Pin::new(s).poll_write(cx, buf),
|
||||||
|
#[cfg(not(target_os = "windows"))]
|
||||||
|
AllStream::Unix(s) => Pin::new(s).poll_write(cx, buf),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Error>> {
|
||||||
|
match self.get_mut() {
|
||||||
|
AllStream::Tcp(s) => Pin::new(s).poll_flush(cx),
|
||||||
|
#[cfg(not(target_os = "windows"))]
|
||||||
|
AllStream::Unix(s) => Pin::new(s).poll_flush(cx),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::result::Result<(), Error>> {
|
||||||
|
match self.get_mut() {
|
||||||
|
AllStream::Tcp(s) => Pin::new(s).poll_shutdown(cx),
|
||||||
|
#[cfg(not(target_os = "windows"))]
|
||||||
|
AllStream::Unix(s) => Pin::new(s).poll_shutdown(cx),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>]) -> Poll<std::result::Result<usize, Error>> {
|
||||||
|
match self.get_mut() {
|
||||||
|
AllStream::Tcp(s) => Pin::new(s).poll_write_vectored(cx, bufs),
|
||||||
|
#[cfg(not(target_os = "windows"))]
|
||||||
|
AllStream::Unix(s) => Pin::new(s).poll_write_vectored(cx, bufs),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_write_vectored(&self) -> bool {
|
||||||
|
match self {
|
||||||
|
AllStream::Tcp(s) => s.is_write_vectored(),
|
||||||
|
#[cfg(not(target_os = "windows"))]
|
||||||
|
AllStream::Unix(s) => s.is_write_vectored(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -13,7 +13,7 @@ pub struct Context<'a> {
|
|||||||
proto: &'a str,
|
proto: &'a str,
|
||||||
is_c2s: Option<bool>,
|
is_c2s: Option<bool>,
|
||||||
to: Option<String>,
|
to: Option<String>,
|
||||||
to_addr: Option<SocketAddr>,
|
to_addr: Option<String>,
|
||||||
from: Option<String>,
|
from: Option<String>,
|
||||||
client_addr: SocketAddr,
|
client_addr: SocketAddr,
|
||||||
}
|
}
|
||||||
@ -59,7 +59,7 @@ impl<'a> Context<'a> {
|
|||||||
if self.from.is_some() { self.from.as_ref().unwrap() } else { "unk" },
|
if self.from.is_some() { self.from.as_ref().unwrap() } else { "unk" },
|
||||||
self.proto,
|
self.proto,
|
||||||
if self.is_c2s.is_some() { c2s(self.is_c2s.unwrap()) } else { "unk" },
|
if self.is_c2s.is_some() { c2s(self.is_c2s.unwrap()) } else { "unk" },
|
||||||
if self.to_addr.is_some() { self.to_addr.as_ref().unwrap().to_string() } else { "unk".to_string() },
|
if self.to_addr.is_some() { self.to_addr.as_ref().unwrap() } else { "unk" },
|
||||||
if self.to.is_some() { self.to.as_ref().unwrap() } else { "unk" },
|
if self.to.is_some() { self.to.as_ref().unwrap() } else { "unk" },
|
||||||
);
|
);
|
||||||
self.log_to = self.log_from.replace(" -> ", " <- ");
|
self.log_to = self.log_from.replace(" -> ", " <- ");
|
||||||
@ -103,7 +103,7 @@ impl<'a> Context<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_to_addr(&mut self, to_addr: SocketAddr) {
|
pub fn set_to_addr(&mut self, to_addr: String) {
|
||||||
if log_enabled!(log::Level::Info) {
|
if log_enabled!(log::Level::Info) {
|
||||||
self.to_addr = Some(to_addr);
|
self.to_addr = Some(to_addr);
|
||||||
self.re_calc();
|
self.re_calc();
|
||||||
|
109
src/main.rs
109
src/main.rs
@ -3,35 +3,42 @@ use anyhow::Result;
|
|||||||
use die::{die, Die};
|
use die::{die, Die};
|
||||||
use log::{debug, info};
|
use log::{debug, info};
|
||||||
use serde_derive::Deserialize;
|
use serde_derive::Deserialize;
|
||||||
use std::{
|
use std::{ffi::OsString, fs::File, io::Read, iter::Iterator, path::Path, sync::Arc};
|
||||||
ffi::OsString,
|
|
||||||
fs::File,
|
|
||||||
io::Read,
|
|
||||||
iter::Iterator,
|
|
||||||
net::{SocketAddr, UdpSocket},
|
|
||||||
path::Path,
|
|
||||||
sync::Arc,
|
|
||||||
};
|
|
||||||
use tokio::{net::TcpListener, task::JoinHandle};
|
use tokio::{net::TcpListener, task::JoinHandle};
|
||||||
use xmpp_proxy::common::certs_key::CertsKey;
|
use xmpp_proxy::common::{certs_key::CertsKey, Listener, SocketAddrPath, UdpListener};
|
||||||
|
|
||||||
|
#[cfg(not(target_os = "windows"))]
|
||||||
|
use tokio::net::UnixListener;
|
||||||
|
|
||||||
#[cfg(feature = "outgoing")]
|
#[cfg(feature = "outgoing")]
|
||||||
use xmpp_proxy::{common::outgoing::OutgoingConfig, outgoing::spawn_outgoing_listener};
|
use xmpp_proxy::{common::outgoing::OutgoingConfig, outgoing::spawn_outgoing_listener};
|
||||||
|
|
||||||
#[derive(Deserialize, Default)]
|
#[derive(Deserialize, Default)]
|
||||||
struct Config {
|
struct Config {
|
||||||
|
#[serde(default)]
|
||||||
tls_key: String,
|
tls_key: String,
|
||||||
|
#[serde(default)]
|
||||||
tls_cert: String,
|
tls_cert: String,
|
||||||
incoming_listen: Vec<String>,
|
#[serde(default)]
|
||||||
quic_listen: Vec<String>,
|
incoming_listen: Vec<SocketAddrPath>,
|
||||||
outgoing_listen: Vec<String>,
|
#[serde(default)]
|
||||||
|
quic_listen: Vec<SocketAddrPath>,
|
||||||
|
#[serde(default)]
|
||||||
|
outgoing_listen: Vec<SocketAddrPath>,
|
||||||
|
#[serde(default = "default_max_stanza_size_bytes")]
|
||||||
max_stanza_size_bytes: usize,
|
max_stanza_size_bytes: usize,
|
||||||
s2s_target: Option<SocketAddr>,
|
s2s_target: Option<SocketAddrPath>,
|
||||||
c2s_target: Option<SocketAddr>,
|
c2s_target: Option<SocketAddrPath>,
|
||||||
|
#[serde(default)]
|
||||||
proxy: bool,
|
proxy: bool,
|
||||||
log_level: Option<String>,
|
log_level: Option<String>,
|
||||||
log_style: Option<String>,
|
log_style: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn default_max_stanza_size_bytes() -> usize {
|
||||||
|
262_144
|
||||||
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
fn parse<P: AsRef<Path>>(path: P) -> Result<Config> {
|
fn parse<P: AsRef<Path>>(path: P) -> Result<Config> {
|
||||||
let mut f = File::open(path)?;
|
let mut f = File::open(path)?;
|
||||||
@ -41,13 +48,13 @@ impl Config {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "incoming")]
|
#[cfg(feature = "incoming")]
|
||||||
fn get_cloneable_cfg(&self) -> xmpp_proxy::common::incoming::CloneableConfig {
|
fn get_incoming_cfg(&self) -> xmpp_proxy::common::incoming::IncomingConfig {
|
||||||
xmpp_proxy::common::incoming::CloneableConfig {
|
xmpp_proxy::common::incoming::IncomingConfig {
|
||||||
max_stanza_size_bytes: self.max_stanza_size_bytes,
|
max_stanza_size_bytes: self.max_stanza_size_bytes,
|
||||||
#[cfg(feature = "s2s-incoming")]
|
#[cfg(feature = "s2s-incoming")]
|
||||||
s2s_target: self.s2s_target,
|
s2s_target: self.s2s_target.clone(),
|
||||||
#[cfg(feature = "c2s-incoming")]
|
#[cfg(feature = "c2s-incoming")]
|
||||||
c2s_target: self.c2s_target,
|
c2s_target: self.c2s_target.clone(),
|
||||||
proxy: self.proxy,
|
proxy: self.proxy,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -137,31 +144,31 @@ async fn main() {
|
|||||||
|
|
||||||
let mut incoming_listen = Vec::new();
|
let mut incoming_listen = Vec::new();
|
||||||
for a in main_config.incoming_listen.iter() {
|
for a in main_config.incoming_listen.iter() {
|
||||||
incoming_listen.push(TcpListener::bind(a).await.die("cannot listen on port/interface"));
|
incoming_listen.push(a.bind().await.die("cannot listen on port/interface/socket"));
|
||||||
}
|
}
|
||||||
let mut quic_listen = Vec::new();
|
let mut quic_listen = Vec::new();
|
||||||
for a in main_config.quic_listen.iter() {
|
for a in main_config.quic_listen.iter() {
|
||||||
quic_listen.push(UdpSocket::bind(a).die("cannot listen on port/interface"));
|
quic_listen.push(a.bind_udp().await.die("cannot listen on port/interface/socket"));
|
||||||
}
|
}
|
||||||
let mut outgoing_listen = Vec::new();
|
let mut outgoing_listen = Vec::new();
|
||||||
for a in main_config.outgoing_listen.iter() {
|
for a in main_config.outgoing_listen.iter() {
|
||||||
outgoing_listen.push(TcpListener::bind(a).await.die("cannot listen on port/interface"));
|
outgoing_listen.push(a.bind().await.die("cannot listen on port/interface/socket"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(all(feature = "nix", not(target_os = "windows")))]
|
#[cfg(all(feature = "nix", not(target_os = "windows")))]
|
||||||
if let Ok(fds) = xmpp_proxy::systemd::receive_descriptors_with_names(true) {
|
if let Ok(fds) = xmpp_proxy::systemd::receive_descriptors_with_names(true) {
|
||||||
use xmpp_proxy::systemd::Listener;
|
use xmpp_proxy::systemd::SystemdListener;
|
||||||
for fd in fds {
|
for fd in fds {
|
||||||
match fd.listener() {
|
match fd.listener() {
|
||||||
Listener::Tcp(tcp_listener) => {
|
SystemdListener::Tcp(tcp_listener) => {
|
||||||
let tcp_listener = TcpListener::from_std(tcp_listener()).die("cannot open systemd TcpListener");
|
let listener = Listener::Tcp(TcpListener::from_std(tcp_listener()).die("cannot open systemd TcpListener"));
|
||||||
if let Some(name) = fd.name().map(|n| n.to_ascii_lowercase()) {
|
if let Some(name) = fd.name().map(|n| n.to_ascii_lowercase()) {
|
||||||
if name.starts_with("in") {
|
if name.starts_with("in") {
|
||||||
incoming_listen.push(tcp_listener);
|
incoming_listen.push(listener);
|
||||||
} else if name.starts_with("out") {
|
} else if name.starts_with("out") {
|
||||||
outgoing_listen.push(tcp_listener);
|
outgoing_listen.push(listener);
|
||||||
} else {
|
} else {
|
||||||
die!("systemd socket name must start with 'in' or 'out' but is '{}'", name);
|
die!("systemd TCP socket name must start with 'in' or 'out' but is '{}'", name);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// what to do here... for now we will require names
|
// what to do here... for now we will require names
|
||||||
@ -169,14 +176,29 @@ async fn main() {
|
|||||||
die!("systemd TCP socket activation requires name that starts with 'in' or 'out'");
|
die!("systemd TCP socket activation requires name that starts with 'in' or 'out'");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Listener::Udp(udp_socket) => quic_listen.push(udp_socket()),
|
SystemdListener::UnixListener(unix_listener) => {
|
||||||
_ => continue,
|
let listener = Listener::Unix(UnixListener::from_std(unix_listener()).die("cannot open systemd UnixListener"));
|
||||||
|
if let Some(name) = fd.name().map(|n| n.to_ascii_lowercase()) {
|
||||||
|
if name.starts_with("in") {
|
||||||
|
incoming_listen.push(listener);
|
||||||
|
} else if name.starts_with("out") {
|
||||||
|
outgoing_listen.push(listener);
|
||||||
|
} else {
|
||||||
|
die!("systemd Unix socket name must start with 'in' or 'out' but is '{}'", name);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// what to do here... for now we will require names
|
||||||
|
die!("systemd Unix socket activation requires name that starts with 'in' or 'out'");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
SystemdListener::Udp(udp_socket) => quic_listen.push(UdpListener::Udp(udp_socket())),
|
||||||
|
SystemdListener::UnixDatagram(unix_datagram) => quic_listen.push(UdpListener::Unix(unix_datagram())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "incoming")]
|
#[cfg(feature = "incoming")]
|
||||||
let config = main_config.get_cloneable_cfg();
|
let config = Arc::new(main_config.get_incoming_cfg());
|
||||||
|
|
||||||
let certs_key = Arc::new(CertsKey::new(main_config.certs_key()));
|
let certs_key = Arc::new(CertsKey::new(main_config.certs_key()));
|
||||||
|
|
||||||
@ -193,7 +215,13 @@ async fn main() {
|
|||||||
}
|
}
|
||||||
let acceptor = tls_acceptor(server_config(certs_key.clone()).die("invalid cert/key ?"));
|
let acceptor = tls_acceptor(server_config(certs_key.clone()).die("invalid cert/key ?"));
|
||||||
for listener in incoming_listen {
|
for listener in incoming_listen {
|
||||||
handles.push(spawn_tls_listener(listener, config.clone(), acceptor.clone()));
|
// todo: first is slower at runtime but smaller executable size, second opposite
|
||||||
|
//handles.push(spawn_tls_listener(listener, config.clone(), acceptor.clone()));
|
||||||
|
match listener {
|
||||||
|
Listener::Tcp(listener) => handles.push(spawn_tls_listener(listener, config.clone(), acceptor.clone())),
|
||||||
|
#[cfg(not(target_os = "windows"))]
|
||||||
|
Listener::Unix(listener) => handles.push(spawn_tls_listener(listener, config.clone(), acceptor.clone())),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#[cfg(not(all(any(feature = "tls", feature = "websocket"), feature = "incoming")))]
|
#[cfg(not(all(any(feature = "tls", feature = "websocket"), feature = "incoming")))]
|
||||||
@ -211,7 +239,13 @@ async fn main() {
|
|||||||
}
|
}
|
||||||
let quic_config = quic_server_config(server_config(certs_key.clone()).die("invalid cert/key ?"));
|
let quic_config = quic_server_config(server_config(certs_key.clone()).die("invalid cert/key ?"));
|
||||||
for listener in quic_listen {
|
for listener in quic_listen {
|
||||||
handles.push(spawn_quic_listener(listener, config.clone(), quic_config.clone()));
|
// todo: maybe write a way to Box<dyn> this thing for smaller executable sizes
|
||||||
|
//handles.push(spawn_quic_listener(listener, config.clone(), quic_config.clone()));
|
||||||
|
match listener {
|
||||||
|
UdpListener::Udp(listener) => handles.push(spawn_quic_listener(listener, config.clone(), quic_config.clone())),
|
||||||
|
#[cfg(not(target_os = "windows"))]
|
||||||
|
UdpListener::Unix(listener) => handles.push(xmpp_proxy::quic::incoming::spawn_quic_listener_unix(listener, config.clone(), quic_config.clone())),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#[cfg(not(all(feature = "quic", feature = "incoming")))]
|
#[cfg(not(all(feature = "quic", feature = "incoming")))]
|
||||||
@ -222,8 +256,15 @@ async fn main() {
|
|||||||
{
|
{
|
||||||
let outgoing_cfg = main_config.get_outgoing_cfg(certs_key.clone());
|
let outgoing_cfg = main_config.get_outgoing_cfg(certs_key.clone());
|
||||||
for listener in outgoing_listen {
|
for listener in outgoing_listen {
|
||||||
handles.push(spawn_outgoing_listener(listener, outgoing_cfg.clone()));
|
// todo: first is slower at runtime but smaller executable size, second opposite
|
||||||
|
//handles.push(spawn_outgoing_listener(listener, outgoing_cfg.clone()));
|
||||||
|
match listener {
|
||||||
|
Listener::Tcp(listener) => handles.push(spawn_outgoing_listener(listener, outgoing_cfg.clone())),
|
||||||
|
#[cfg(not(target_os = "windows"))]
|
||||||
|
Listener::Unix(listener) => handles.push(spawn_outgoing_listener(listener, outgoing_cfg.clone())),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
//#[cfg(not(target_os = "windows"))]
|
||||||
}
|
}
|
||||||
#[cfg(not(feature = "outgoing"))]
|
#[cfg(not(feature = "outgoing"))]
|
||||||
die!("outgoing_listen non-empty but c2s-outgoing and s2s-outgoing disabled at compile-time");
|
die!("outgoing_listen non-empty but c2s-outgoing and s2s-outgoing disabled at compile-time");
|
||||||
|
@ -1,16 +1,16 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
common::{outgoing::OutgoingConfig, shuffle_rd_wr_filter_only, stream_preamble, Peek},
|
common::{outgoing::OutgoingConfig, shuffle_rd_wr_filter_only, stream_listener::StreamListener, stream_preamble, AsyncReadWritePeekSplit},
|
||||||
context::Context,
|
context::Context,
|
||||||
in_out::{StanzaRead, StanzaWrite},
|
|
||||||
slicesubsequence::SliceSubsequence,
|
slicesubsequence::SliceSubsequence,
|
||||||
srv::srv_connect,
|
srv::srv_connect,
|
||||||
stanzafilter::StanzaFilter,
|
stanzafilter::StanzaFilter,
|
||||||
};
|
};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use log::{error, info};
|
use log::{error, info};
|
||||||
use tokio::{net::TcpListener, task::JoinHandle};
|
|
||||||
|
|
||||||
async fn handle_outgoing_connection(mut stream: tokio::net::TcpStream, client_addr: &mut Context<'_>, config: OutgoingConfig) -> Result<()> {
|
use tokio::task::JoinHandle;
|
||||||
|
|
||||||
|
async fn handle_outgoing_connection<S: AsyncReadWritePeekSplit>(mut stream: S, client_addr: &mut Context<'_>, config: OutgoingConfig) -> Result<()> {
|
||||||
info!("{} connected", client_addr.log_from());
|
info!("{} connected", client_addr.log_from());
|
||||||
|
|
||||||
let mut in_filter = StanzaFilter::new(config.max_stanza_size_bytes);
|
let mut in_filter = StanzaFilter::new(config.max_stanza_size_bytes);
|
||||||
@ -19,15 +19,11 @@ async fn handle_outgoing_connection(mut stream: tokio::net::TcpStream, client_ad
|
|||||||
let (mut in_rd, mut in_wr) = if stream.first_bytes_match(&mut in_filter.buf[0..3], |p| p == b"GET").await? {
|
let (mut in_rd, mut in_wr) = if stream.first_bytes_match(&mut in_filter.buf[0..3], |p| p == b"GET").await? {
|
||||||
crate::websocket::incoming_websocket_connection(Box::new(stream), config.max_stanza_size_bytes).await?
|
crate::websocket::incoming_websocket_connection(Box::new(stream), config.max_stanza_size_bytes).await?
|
||||||
} else {
|
} else {
|
||||||
let (in_rd, in_wr) = tokio::io::split(stream);
|
stream.stanza_rw()
|
||||||
(StanzaRead::new(in_rd), StanzaWrite::new(in_wr))
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#[cfg(not(feature = "websocket"))]
|
#[cfg(not(feature = "websocket"))]
|
||||||
let (mut in_rd, mut in_wr) = {
|
let (mut in_rd, mut in_wr) = stream.stanza_rw();
|
||||||
let (in_rd, in_wr) = tokio::io::split(stream);
|
|
||||||
(StanzaRead::new(in_rd), StanzaWrite::new(in_wr))
|
|
||||||
};
|
|
||||||
|
|
||||||
// now read to figure out client vs server
|
// now read to figure out client vs server
|
||||||
let (stream_open, is_c2s) = stream_preamble(&mut in_rd, &mut in_wr, client_addr.log_to(), &mut in_filter).await?;
|
let (stream_open, is_c2s) = stream_preamble(&mut in_rd, &mut in_wr, client_addr.log_to(), &mut in_filter).await?;
|
||||||
@ -46,13 +42,13 @@ async fn handle_outgoing_connection(mut stream: tokio::net::TcpStream, client_ad
|
|||||||
shuffle_rd_wr_filter_only(in_rd, in_wr, out_rd, out_wr, is_c2s, max_stanza_size_bytes, client_addr, in_filter).await
|
shuffle_rd_wr_filter_only(in_rd, in_wr, out_rd, out_wr, is_c2s, max_stanza_size_bytes, client_addr, in_filter).await
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn spawn_outgoing_listener(listener: TcpListener, config: OutgoingConfig) -> JoinHandle<Result<()>> {
|
pub fn spawn_outgoing_listener(listener: impl StreamListener, config: OutgoingConfig) -> JoinHandle<Result<()>> {
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
loop {
|
loop {
|
||||||
let (stream, client_addr) = listener.accept().await?;
|
let (stream, client_addr) = listener.accept().await?;
|
||||||
|
let mut client_addr = Context::new("unk-out", client_addr);
|
||||||
let config = config.clone();
|
let config = config.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let mut client_addr = Context::new("unk-out", client_addr);
|
|
||||||
if let Err(e) = handle_outgoing_connection(stream, &mut client_addr, config).await {
|
if let Err(e) = handle_outgoing_connection(stream, &mut client_addr, config).await {
|
||||||
error!("{} {}", client_addr.log_from(), e);
|
error!("{} {}", client_addr.log_from(), e);
|
||||||
}
|
}
|
||||||
|
@ -1,18 +1,34 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
common::incoming::{shuffle_rd_wr, CloneableConfig, ServerCerts},
|
common::incoming::{shuffle_rd_wr, IncomingConfig, ServerCerts},
|
||||||
context::Context,
|
context::Context,
|
||||||
in_out::{StanzaRead, StanzaWrite},
|
in_out::{StanzaRead, StanzaWrite},
|
||||||
};
|
};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use die::Die;
|
use die::Die;
|
||||||
use log::{error, info};
|
use log::{error, info};
|
||||||
use quinn::{Endpoint, EndpointConfig, ServerConfig, TokioRuntime};
|
use quinn::{AsyncUdpSocket, Endpoint, EndpointConfig, ServerConfig, TokioRuntime};
|
||||||
use std::{net::UdpSocket, sync::Arc};
|
use std::{
|
||||||
|
net::{SocketAddr, UdpSocket},
|
||||||
|
sync::Arc,
|
||||||
|
};
|
||||||
use tokio::task::JoinHandle;
|
use tokio::task::JoinHandle;
|
||||||
|
|
||||||
pub fn spawn_quic_listener(udp_socket: UdpSocket, config: CloneableConfig, server_config: ServerConfig) -> JoinHandle<Result<()>> {
|
#[cfg(not(target_os = "windows"))]
|
||||||
|
pub fn spawn_quic_listener_unix(udp_socket: std::os::unix::net::UnixDatagram, config: Arc<IncomingConfig>, server_config: ServerConfig) -> JoinHandle<Result<()>> {
|
||||||
|
let udp_socket = crate::quic::unix_datagram::wrap_unix_udp_socket(udp_socket).die("cannot wrap unix udp socket");
|
||||||
|
// todo: fake local_addr
|
||||||
|
let local_addr = udp_socket.local_addr().die("cannot get local_addr for quic socket");
|
||||||
|
let incoming = Endpoint::new_with_abstract_socket(EndpointConfig::default(), Some(server_config), udp_socket, Arc::new(TokioRuntime)).die("cannot listen on port/interface");
|
||||||
|
internal_spawn_quic_listener(incoming, local_addr, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn spawn_quic_listener(udp_socket: UdpSocket, config: Arc<IncomingConfig>, server_config: ServerConfig) -> JoinHandle<Result<()>> {
|
||||||
let local_addr = udp_socket.local_addr().die("cannot get local_addr for quic socket");
|
let local_addr = udp_socket.local_addr().die("cannot get local_addr for quic socket");
|
||||||
let incoming = Endpoint::new(EndpointConfig::default(), Some(server_config), udp_socket, Arc::new(TokioRuntime)).die("cannot listen on port/interface");
|
let incoming = Endpoint::new(EndpointConfig::default(), Some(server_config), udp_socket, Arc::new(TokioRuntime)).die("cannot listen on port/interface");
|
||||||
|
internal_spawn_quic_listener(incoming, local_addr, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn internal_spawn_quic_listener(incoming: Endpoint, local_addr: SocketAddr, config: Arc<IncomingConfig>) -> JoinHandle<Result<()>> {
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
// when could this return None, do we quit?
|
// when could this return None, do we quit?
|
||||||
while let Some(incoming_conn) = incoming.accept().await {
|
while let Some(incoming_conn) = incoming.accept().await {
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
use crate::common::Split;
|
use crate::{
|
||||||
|
common::Split,
|
||||||
|
in_out::{StanzaRead, StanzaWrite},
|
||||||
|
};
|
||||||
use anyhow::bail;
|
use anyhow::bail;
|
||||||
use quinn::{RecvStream, SendStream};
|
use quinn::{RecvStream, SendStream};
|
||||||
use std::{
|
use std::{
|
||||||
@ -14,6 +17,9 @@ pub mod incoming;
|
|||||||
#[cfg(feature = "outgoing")]
|
#[cfg(feature = "outgoing")]
|
||||||
pub mod outgoing;
|
pub mod outgoing;
|
||||||
|
|
||||||
|
#[cfg(all(feature = "incoming", not(target_os = "windows")))]
|
||||||
|
pub mod unix_datagram;
|
||||||
|
|
||||||
pub struct QuicStream {
|
pub struct QuicStream {
|
||||||
pub send: SendStream,
|
pub send: SendStream,
|
||||||
pub recv: RecvStream,
|
pub recv: RecvStream,
|
||||||
@ -54,4 +60,8 @@ impl Split for QuicStream {
|
|||||||
fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
|
fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
|
||||||
(self.recv, self.send)
|
(self.recv, self.send)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn stanza_rw(self) -> (StanzaRead, StanzaWrite) {
|
||||||
|
(StanzaRead::new(self.recv), StanzaWrite::new(self.send))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
63
src/quic/unix_datagram.rs
Normal file
63
src/quic/unix_datagram.rs
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
use quinn::{udp, AsyncUdpSocket};
|
||||||
|
|
||||||
|
use std::{
|
||||||
|
io,
|
||||||
|
task::{Context, Poll},
|
||||||
|
};
|
||||||
|
use tokio::net::UnixDatagram;
|
||||||
|
|
||||||
|
use tokio::io::Interest;
|
||||||
|
|
||||||
|
macro_rules! ready {
|
||||||
|
($e:expr $(,)?) => {
|
||||||
|
match $e {
|
||||||
|
std::task::Poll::Ready(t) => t,
|
||||||
|
std::task::Poll::Pending => return std::task::Poll::Pending,
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn wrap_unix_udp_socket(sock: std::os::unix::net::UnixDatagram) -> io::Result<UnixUdpSocket> {
|
||||||
|
udp::UdpSocketState::configure((&sock).into())?;
|
||||||
|
Ok(UnixUdpSocket {
|
||||||
|
io: UnixDatagram::from_std(sock)?,
|
||||||
|
inner: udp::UdpSocketState::new(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct UnixUdpSocket {
|
||||||
|
io: UnixDatagram,
|
||||||
|
inner: udp::UdpSocketState,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncUdpSocket for UnixUdpSocket {
|
||||||
|
fn poll_send(&self, state: &udp::UdpState, cx: &mut Context, transmits: &[udp::Transmit]) -> Poll<io::Result<usize>> {
|
||||||
|
let inner = &self.inner;
|
||||||
|
let io = &self.io;
|
||||||
|
loop {
|
||||||
|
ready!(io.poll_send_ready(cx))?;
|
||||||
|
if let Ok(res) = io.try_io(Interest::WRITABLE, || inner.send(io.into(), state, transmits)) {
|
||||||
|
return Poll::Ready(Ok(res));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_recv(&self, cx: &mut Context, bufs: &mut [std::io::IoSliceMut<'_>], meta: &mut [udp::RecvMeta]) -> Poll<io::Result<usize>> {
|
||||||
|
loop {
|
||||||
|
ready!(self.io.poll_recv_ready(cx))?;
|
||||||
|
if let Ok(res) = self.io.try_io(Interest::READABLE, || self.inner.recv((&self.io).into(), bufs, meta)) {
|
||||||
|
return Poll::Ready(Ok(res));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn local_addr(&self) -> io::Result<std::net::SocketAddr> {
|
||||||
|
// todo: real SocketAddr
|
||||||
|
Ok("127.0.0.1:0".parse().expect("this one is hardcoded and fine"))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn may_fragment(&self) -> bool {
|
||||||
|
udp::may_fragment()
|
||||||
|
}
|
||||||
|
}
|
154
src/srv.rs
154
src/srv.rs
@ -473,7 +473,7 @@ pub async fn srv_connect(
|
|||||||
let (mut out_wr, mut out_rd, to_addr, proto) = connect.unwrap();
|
let (mut out_wr, mut out_rd, to_addr, proto) = connect.unwrap();
|
||||||
// if any of these ? returns early with an Err, these will stay set, I think that's ok though, the connection will be closed
|
// if any of these ? returns early with an Err, these will stay set, I think that's ok though, the connection will be closed
|
||||||
client_addr.set_proto(proto);
|
client_addr.set_proto(proto);
|
||||||
client_addr.set_to_addr(to_addr);
|
client_addr.set_to_addr(to_addr.to_string());
|
||||||
debug!("{} connected", client_addr.log_from());
|
debug!("{} connected", client_addr.log_from());
|
||||||
|
|
||||||
trace!("{} '{}'", client_addr.log_from(), to_str(stream_open));
|
trace!("{} '{}'", client_addr.log_from(), to_str(stream_open));
|
||||||
@ -906,7 +906,8 @@ mod tests {
|
|||||||
println!("posh: {:?}", posh);
|
println!("posh: {:?}", posh);
|
||||||
}
|
}
|
||||||
|
|
||||||
//#[tokio::test]
|
#[cfg(feature = "net-test")]
|
||||||
|
#[tokio::test]
|
||||||
async fn posh() -> Result<()> {
|
async fn posh() -> Result<()> {
|
||||||
let domain = "posh.badxmpp.eu";
|
let domain = "posh.badxmpp.eu";
|
||||||
let posh = collect_posh(domain).await.unwrap();
|
let posh = collect_posh(domain).await.unwrap();
|
||||||
@ -914,7 +915,8 @@ mod tests {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
//#[tokio::test]
|
#[cfg(feature = "net-test")]
|
||||||
|
#[tokio::test]
|
||||||
async fn srv() -> Result<()> {
|
async fn srv() -> Result<()> {
|
||||||
let domain = "burtrum.org";
|
let domain = "burtrum.org";
|
||||||
let is_c2s = true;
|
let is_c2s = true;
|
||||||
@ -930,7 +932,8 @@ mod tests {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
//#[tokio::test]
|
#[cfg(feature = "net-test")]
|
||||||
|
#[tokio::test]
|
||||||
async fn http() -> Result<()> {
|
async fn http() -> Result<()> {
|
||||||
let mut hosts = Vec::new();
|
let mut hosts = Vec::new();
|
||||||
let mut sha256_pinnedpubkeys = Vec::new();
|
let mut sha256_pinnedpubkeys = Vec::new();
|
||||||
@ -990,77 +993,78 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_dedup() {
|
fn test_dedup() {
|
||||||
let domain = "example.org";
|
let domain = "example.org";
|
||||||
let mut ret = Vec::new();
|
let mut ret = vec![
|
||||||
ret.push(XmppConnection {
|
XmppConnection {
|
||||||
priority: 10,
|
priority: 10,
|
||||||
weight: 0,
|
weight: 0,
|
||||||
target: domain.to_string(),
|
target: domain.to_string(),
|
||||||
conn_type: XmppConnectionType::DirectTLS,
|
conn_type: XmppConnectionType::DirectTLS,
|
||||||
port: 443,
|
port: 443,
|
||||||
secure: false,
|
secure: false,
|
||||||
ips: Vec::new(),
|
ips: Vec::new(),
|
||||||
ech: None,
|
ech: None,
|
||||||
});
|
},
|
||||||
ret.push(XmppConnection {
|
XmppConnection {
|
||||||
priority: 0,
|
priority: 0,
|
||||||
weight: 0,
|
weight: 0,
|
||||||
target: domain.to_string(),
|
target: domain.to_string(),
|
||||||
conn_type: XmppConnectionType::StartTLS,
|
conn_type: XmppConnectionType::StartTLS,
|
||||||
port: 5222,
|
port: 5222,
|
||||||
secure: false,
|
secure: false,
|
||||||
ips: Vec::new(),
|
ips: Vec::new(),
|
||||||
ech: None,
|
ech: None,
|
||||||
});
|
},
|
||||||
ret.push(XmppConnection {
|
XmppConnection {
|
||||||
priority: 15,
|
priority: 15,
|
||||||
weight: 0,
|
weight: 0,
|
||||||
target: domain.to_string(),
|
target: domain.to_string(),
|
||||||
conn_type: XmppConnectionType::DirectTLS,
|
conn_type: XmppConnectionType::DirectTLS,
|
||||||
port: 443,
|
port: 443,
|
||||||
secure: true,
|
secure: true,
|
||||||
ips: Vec::new(),
|
ips: Vec::new(),
|
||||||
ech: None,
|
ech: None,
|
||||||
});
|
},
|
||||||
ret.push(XmppConnection {
|
XmppConnection {
|
||||||
priority: 10,
|
priority: 10,
|
||||||
weight: 0,
|
weight: 0,
|
||||||
target: domain.to_string(),
|
target: domain.to_string(),
|
||||||
conn_type: XmppConnectionType::DirectTLS,
|
conn_type: XmppConnectionType::DirectTLS,
|
||||||
port: 443,
|
port: 443,
|
||||||
secure: true,
|
secure: true,
|
||||||
ips: Vec::new(),
|
ips: Vec::new(),
|
||||||
ech: None,
|
ech: None,
|
||||||
});
|
},
|
||||||
ret.push(XmppConnection {
|
XmppConnection {
|
||||||
priority: 10,
|
priority: 10,
|
||||||
weight: 50,
|
weight: 50,
|
||||||
target: domain.to_string(),
|
target: domain.to_string(),
|
||||||
conn_type: XmppConnectionType::DirectTLS,
|
conn_type: XmppConnectionType::DirectTLS,
|
||||||
port: 443,
|
port: 443,
|
||||||
secure: true,
|
secure: true,
|
||||||
ips: Vec::new(),
|
ips: Vec::new(),
|
||||||
ech: None,
|
ech: None,
|
||||||
});
|
},
|
||||||
ret.push(XmppConnection {
|
XmppConnection {
|
||||||
priority: 10,
|
priority: 10,
|
||||||
weight: 100,
|
weight: 100,
|
||||||
target: "example.com".to_string(),
|
target: "example.com".to_string(),
|
||||||
conn_type: XmppConnectionType::DirectTLS,
|
conn_type: XmppConnectionType::DirectTLS,
|
||||||
port: 443,
|
port: 443,
|
||||||
secure: true,
|
secure: true,
|
||||||
ips: Vec::new(),
|
ips: Vec::new(),
|
||||||
ech: None,
|
ech: None,
|
||||||
});
|
},
|
||||||
ret.push(XmppConnection {
|
XmppConnection {
|
||||||
priority: 0,
|
priority: 0,
|
||||||
weight: 100,
|
weight: 100,
|
||||||
target: "example.com".to_string(),
|
target: "example.com".to_string(),
|
||||||
conn_type: XmppConnectionType::DirectTLS,
|
conn_type: XmppConnectionType::DirectTLS,
|
||||||
port: 443,
|
port: 443,
|
||||||
secure: true,
|
secure: true,
|
||||||
ips: Vec::new(),
|
ips: Vec::new(),
|
||||||
ech: None,
|
ech: None,
|
||||||
});
|
},
|
||||||
|
];
|
||||||
sort_dedup(&mut ret);
|
sort_dedup(&mut ret);
|
||||||
println!("ret dedup: {:?}", ret);
|
println!("ret dedup: {:?}", ret);
|
||||||
}
|
}
|
||||||
|
@ -24,7 +24,7 @@ pub struct FileDescriptor {
|
|||||||
pub name: Option<String>,
|
pub name: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub enum Listener {
|
pub enum SystemdListener {
|
||||||
Tcp(Box<dyn FnOnce() -> TcpListener>),
|
Tcp(Box<dyn FnOnce() -> TcpListener>),
|
||||||
Udp(Box<dyn FnOnce() -> UdpSocket>),
|
Udp(Box<dyn FnOnce() -> UdpSocket>),
|
||||||
UnixListener(Box<dyn FnOnce() -> UnixListener>),
|
UnixListener(Box<dyn FnOnce() -> UnixListener>),
|
||||||
@ -36,13 +36,13 @@ impl FileDescriptor {
|
|||||||
self.name
|
self.name
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn listener(&self) -> Listener {
|
pub fn listener(&self) -> SystemdListener {
|
||||||
let raw_fd = self.raw_fd;
|
let raw_fd = self.raw_fd;
|
||||||
match (self.tcp_not_udp, self.inet_not_unix) {
|
match (self.tcp_not_udp, self.inet_not_unix) {
|
||||||
(true, true) => Listener::Tcp(Box::new(move || unsafe { TcpListener::from_raw_fd(raw_fd) })),
|
(true, true) => SystemdListener::Tcp(Box::new(move || unsafe { TcpListener::from_raw_fd(raw_fd) })),
|
||||||
(false, true) => Listener::Udp(Box::new(move || unsafe { UdpSocket::from_raw_fd(raw_fd) })),
|
(false, true) => SystemdListener::Udp(Box::new(move || unsafe { UdpSocket::from_raw_fd(raw_fd) })),
|
||||||
(true, false) => Listener::UnixListener(Box::new(move || unsafe { UnixListener::from_raw_fd(raw_fd) })),
|
(true, false) => SystemdListener::UnixListener(Box::new(move || unsafe { UnixListener::from_raw_fd(raw_fd) })),
|
||||||
(false, false) => Listener::UnixDatagram(Box::new(move || unsafe { UnixDatagram::from_raw_fd(raw_fd) })),
|
(false, false) => SystemdListener::UnixDatagram(Box::new(move || unsafe { UnixDatagram::from_raw_fd(raw_fd) })),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
common::{
|
common::{
|
||||||
incoming::{shuffle_rd_wr_filter, CloneableConfig, ServerCerts},
|
first_bytes_match_buf_timeout,
|
||||||
to_str, Peek, Split, IN_BUFFER_SIZE,
|
incoming::{shuffle_rd_wr_filter, IncomingConfig, ServerCerts},
|
||||||
|
stream_listener::StreamListener,
|
||||||
|
to_str, AsyncReadWritePeekSplit, Split, IN_BUFFER_SIZE,
|
||||||
},
|
},
|
||||||
context::Context,
|
context::Context,
|
||||||
in_out::{StanzaRead, StanzaWrite},
|
in_out::{StanzaRead, StanzaWrite},
|
||||||
@ -13,8 +15,7 @@ use log::{error, info, trace};
|
|||||||
use rustls::{ServerConfig, ServerConnection};
|
use rustls::{ServerConfig, ServerConnection};
|
||||||
use std::{net::SocketAddr, sync::Arc};
|
use std::{net::SocketAddr, sync::Arc};
|
||||||
use tokio::{
|
use tokio::{
|
||||||
io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufStream},
|
io::{AsyncWriteExt, BufStream},
|
||||||
net::TcpListener,
|
|
||||||
task::JoinHandle,
|
task::JoinHandle,
|
||||||
};
|
};
|
||||||
use tokio_rustls::TlsAcceptor;
|
use tokio_rustls::TlsAcceptor;
|
||||||
@ -23,7 +24,7 @@ pub fn tls_acceptor(server_config: ServerConfig) -> TlsAcceptor {
|
|||||||
TlsAcceptor::from(Arc::new(server_config))
|
TlsAcceptor::from(Arc::new(server_config))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn spawn_tls_listener(listener: TcpListener, config: CloneableConfig, acceptor: TlsAcceptor) -> JoinHandle<Result<()>> {
|
pub fn spawn_tls_listener(listener: impl StreamListener, config: Arc<IncomingConfig>, acceptor: TlsAcceptor) -> JoinHandle<Result<()>> {
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let local_addr = listener.local_addr()?;
|
let local_addr = listener.local_addr()?;
|
||||||
loop {
|
loop {
|
||||||
@ -40,13 +41,7 @@ pub fn spawn_tls_listener(listener: TcpListener, config: CloneableConfig, accept
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn handle_tls_connection<S: AsyncRead + AsyncWrite + Unpin + Send + Sync + Peek + Split + 'static>(
|
pub async fn handle_tls_connection<S: AsyncReadWritePeekSplit>(mut stream: S, client_addr: &mut Context<'_>, local_addr: SocketAddr, config: Arc<IncomingConfig>, acceptor: TlsAcceptor) -> Result<()> {
|
||||||
mut stream: S,
|
|
||||||
client_addr: &mut Context<'_>,
|
|
||||||
local_addr: SocketAddr,
|
|
||||||
config: CloneableConfig,
|
|
||||||
acceptor: TlsAcceptor,
|
|
||||||
) -> Result<()> {
|
|
||||||
info!("{} connected", client_addr.log_from());
|
info!("{} connected", client_addr.log_from());
|
||||||
|
|
||||||
let mut in_filter = StanzaFilter::new(config.max_stanza_size_bytes);
|
let mut in_filter = StanzaFilter::new(config.max_stanza_size_bytes);
|
||||||
@ -67,11 +62,10 @@ pub async fn handle_tls_connection<S: AsyncRead + AsyncWrite + Unpin + Send + Sy
|
|||||||
let mut proceed_sent = false;
|
let mut proceed_sent = false;
|
||||||
|
|
||||||
let (in_rd, mut in_wr) = stream.split();
|
let (in_rd, mut in_wr) = stream.split();
|
||||||
// todo: more efficient version for TCP:
|
|
||||||
//let (in_rd, mut in_wr) = stream.split();
|
|
||||||
|
|
||||||
// we naively read 1 byte at a time, which buffering significantly speeds up
|
// we naively read 1 byte at a time, which buffering significantly speeds up
|
||||||
let in_rd = tokio::io::BufReader::with_capacity(IN_BUFFER_SIZE, in_rd);
|
// todo: I don't think we can buffer here, because then we throw away the data left in the buffer? yet it's been working... am I losing my mind?
|
||||||
|
//let in_rd = tokio::io::BufReader::with_capacity(IN_BUFFER_SIZE, in_rd);
|
||||||
let mut in_rd = StanzaReader(in_rd);
|
let mut in_rd = StanzaReader(in_rd);
|
||||||
|
|
||||||
while let Ok(Some(buf)) = in_rd.next(&mut in_filter).await {
|
while let Ok(Some(buf)) = in_rd.next(&mut in_filter).await {
|
||||||
@ -117,7 +111,7 @@ pub async fn handle_tls_connection<S: AsyncRead + AsyncWrite + Unpin + Send + Sy
|
|||||||
if !proceed_sent {
|
if !proceed_sent {
|
||||||
bail!("stream ended before open");
|
bail!("stream ended before open");
|
||||||
}
|
}
|
||||||
<S as Split>::combine(in_rd.0.into_inner(), in_wr)?
|
<S as Split>::combine(in_rd.0, in_wr)?
|
||||||
} else {
|
} else {
|
||||||
stream
|
stream
|
||||||
};
|
};
|
||||||
@ -143,20 +137,19 @@ pub async fn handle_tls_connection<S: AsyncRead + AsyncWrite + Unpin + Send + Sy
|
|||||||
|
|
||||||
#[cfg(not(feature = "websocket"))]
|
#[cfg(not(feature = "websocket"))]
|
||||||
{
|
{
|
||||||
let (in_rd, in_wr) = tokio::io::split(stream);
|
let (in_rd, in_wr) = stream.split();
|
||||||
shuffle_rd_wr_filter(StanzaRead::new(in_rd), StanzaWrite::new(in_wr), config, server_certs, local_addr, client_addr, in_filter).await
|
shuffle_rd_wr_filter(StanzaRead::new(in_rd), StanzaWrite::new(in_wr), config, server_certs, local_addr, client_addr, in_filter).await
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "websocket")]
|
#[cfg(feature = "websocket")]
|
||||||
{
|
{
|
||||||
let mut stream = BufStream::with_capacity(IN_BUFFER_SIZE, 0, stream);
|
let mut stream = BufStream::with_capacity(IN_BUFFER_SIZE, 0, stream);
|
||||||
|
let websocket = first_bytes_match_buf_timeout(&mut stream, 3, |p| p == b"GET").await?;
|
||||||
let websocket = stream.first_bytes_match(&mut in_filter.buf[0..3], |b| b == b"GET").await?;
|
|
||||||
|
|
||||||
if websocket {
|
if websocket {
|
||||||
crate::websocket::incoming::handle_websocket_connection(Box::new(stream), config, server_certs, local_addr, client_addr, in_filter).await
|
crate::websocket::incoming::handle_websocket_connection(Box::new(stream), config, server_certs, local_addr, client_addr, in_filter).await
|
||||||
} else {
|
} else {
|
||||||
let (in_rd, in_wr) = tokio::io::split(stream);
|
let (in_rd, in_wr) = stream.split();
|
||||||
shuffle_rd_wr_filter(StanzaRead::already_buffered(in_rd), StanzaWrite::new(in_wr), config, server_certs, local_addr, client_addr, in_filter).await
|
shuffle_rd_wr_filter(StanzaRead::already_buffered(in_rd), StanzaWrite::new(in_wr), config, server_certs, local_addr, client_addr, in_filter).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,16 +1,19 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
common::incoming::{shuffle_rd_wr_filter, CloneableConfig, ServerCerts},
|
common::{
|
||||||
|
incoming::{shuffle_rd_wr_filter, IncomingConfig, ServerCerts},
|
||||||
|
BoxAsyncReadWrite,
|
||||||
|
},
|
||||||
context::Context,
|
context::Context,
|
||||||
stanzafilter::StanzaFilter,
|
stanzafilter::StanzaFilter,
|
||||||
websocket::{incoming_websocket_connection, AsyncReadAndWrite},
|
websocket::incoming_websocket_connection,
|
||||||
};
|
};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use log::info;
|
use log::info;
|
||||||
use std::net::SocketAddr;
|
use std::{net::SocketAddr, sync::Arc};
|
||||||
|
|
||||||
pub async fn handle_websocket_connection(
|
pub async fn handle_websocket_connection(
|
||||||
stream: Box<dyn AsyncReadAndWrite + Unpin + Send>,
|
stream: BoxAsyncReadWrite,
|
||||||
config: CloneableConfig,
|
config: Arc<IncomingConfig>,
|
||||||
server_certs: ServerCerts,
|
server_certs: ServerCerts,
|
||||||
local_addr: SocketAddr,
|
local_addr: SocketAddr,
|
||||||
client_addr: &mut Context<'_>,
|
client_addr: &mut Context<'_>,
|
||||||
|
@ -16,8 +16,8 @@ pub mod incoming;
|
|||||||
#[cfg(feature = "outgoing")]
|
#[cfg(feature = "outgoing")]
|
||||||
pub mod outgoing;
|
pub mod outgoing;
|
||||||
|
|
||||||
pub type WsWr = SplitSink<WebSocketStream<Box<dyn AsyncReadAndWrite + Unpin + Send>>, tokio_tungstenite::tungstenite::Message>;
|
pub type WsWr = SplitSink<WebSocketStream<BoxAsyncReadWrite>, tokio_tungstenite::tungstenite::Message>;
|
||||||
pub type WsRd = SplitStream<WebSocketStream<Box<dyn AsyncReadAndWrite + Unpin + Send>>>;
|
pub type WsRd = SplitStream<WebSocketStream<BoxAsyncReadWrite>>;
|
||||||
|
|
||||||
// https://datatracker.ietf.org/doc/html/rfc7395
|
// https://datatracker.ietf.org/doc/html/rfc7395
|
||||||
|
|
||||||
@ -30,11 +30,7 @@ fn ws_cfg(max_stanza_size_bytes: usize) -> Option<WebSocketConfig> {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait AsyncReadAndWrite: tokio::io::AsyncRead + tokio::io::AsyncWrite {}
|
pub async fn incoming_websocket_connection(stream: BoxAsyncReadWrite, max_stanza_size_bytes: usize) -> Result<(StanzaRead, StanzaWrite)> {
|
||||||
|
|
||||||
impl<T: tokio::io::AsyncRead + tokio::io::AsyncWrite> AsyncReadAndWrite for T {}
|
|
||||||
|
|
||||||
pub async fn incoming_websocket_connection(stream: Box<dyn AsyncReadAndWrite + Unpin + Send>, max_stanza_size_bytes: usize) -> Result<(StanzaRead, StanzaWrite)> {
|
|
||||||
// accept the websocket
|
// accept the websocket
|
||||||
let stream = tokio_tungstenite::accept_hdr_async_with_config(
|
let stream = tokio_tungstenite::accept_hdr_async_with_config(
|
||||||
stream,
|
stream,
|
||||||
@ -118,6 +114,7 @@ pub fn to_ws_new(buf: &[u8], mut end_of_first_tag: usize, is_c2s: bool) -> Resul
|
|||||||
}
|
}
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
common::BoxAsyncReadWrite,
|
||||||
in_out::{StanzaRead, StanzaWrite},
|
in_out::{StanzaRead, StanzaWrite},
|
||||||
slicesubsequence::SliceSubsequence,
|
slicesubsequence::SliceSubsequence,
|
||||||
};
|
};
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
common::outgoing::OutgoingVerifierConfig,
|
common::{outgoing::OutgoingVerifierConfig, BoxAsyncReadWrite},
|
||||||
in_out::{StanzaRead, StanzaWrite},
|
in_out::{StanzaRead, StanzaWrite},
|
||||||
websocket::{ws_cfg, AsyncReadAndWrite},
|
websocket::ws_cfg,
|
||||||
};
|
};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use futures_util::StreamExt;
|
use futures_util::StreamExt;
|
||||||
@ -27,7 +27,7 @@ pub async fn websocket_connect(target: SocketAddr, server_name: &str, url: &Uri,
|
|||||||
//let stream: tokio_rustls::TlsStream<tokio::net::TcpStream> = stream.into();
|
//let stream: tokio_rustls::TlsStream<tokio::net::TcpStream> = stream.into();
|
||||||
// todo: tokio_tungstenite seems to have a bug, if the write buffer is non-zero, it'll hang forever, even though we always flush, investigate
|
// todo: tokio_tungstenite seems to have a bug, if the write buffer is non-zero, it'll hang forever, even though we always flush, investigate
|
||||||
//let stream = BufStream::with_capacity(crate::IN_BUFFER_SIZE, 0, stream);
|
//let stream = BufStream::with_capacity(crate::IN_BUFFER_SIZE, 0, stream);
|
||||||
let stream: Box<dyn AsyncReadAndWrite + Unpin + Send> = Box::new(stream);
|
let stream: BoxAsyncReadWrite = Box::new(stream);
|
||||||
|
|
||||||
let (stream, _) = tokio_tungstenite::client_async_with_config(request, stream, ws_cfg(config.max_stanza_size_bytes)).await?;
|
let (stream, _) = tokio_tungstenite::client_async_with_config(request, stream, ws_cfg(config.max_stanza_size_bytes)).await?;
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
incoming_listen = [ "0.0.0.0:5222", "0.0.0.0:5269", "0.0.0.0:443" ]
|
incoming_listen = [ "0.0.0.0:5222", "0.0.0.0:5269", "0.0.0.0:443" ]
|
||||||
# interfaces to listen for reverse proxy QUIC XMPP connections on, should be open to the internet
|
# interfaces to listen for reverse proxy QUIC XMPP connections on, should be open to the internet
|
||||||
quic_listen = [ "0.0.0.0:443" ]
|
quic_listen = [ "0.0.0.0:443" ]
|
||||||
# interfaces to listen for outgoing proxy TCP or WebSocket XMPP connections on, should be localhost
|
# interfaces to listen for outgoing proxy TCP or WebSocket XMPP connections on, should be localhost or a path for a unix socket
|
||||||
outgoing_listen = [ "127.0.0.1:15270" ]
|
outgoing_listen = [ "127.0.0.1:15270" ]
|
||||||
|
|
||||||
# these ports shouldn't do any TLS, but should assume any connection from xmpp-proxy is secure
|
# these ports shouldn't do any TLS, but should assume any connection from xmpp-proxy is secure
|
||||||
|
Loading…
Reference in New Issue
Block a user