use crate::{ context::Context, in_out::{StanzaRead, StanzaWrite}, slicesubsequence::SliceSubsequence, stanzafilter::StanzaFilter, }; use anyhow::{bail, Result}; use async_trait::async_trait; use log::{info, trace}; #[cfg(feature = "rustls")] use rustls::{ sign::{RsaSigningKey, SigningKey}, Certificate, PrivateKey, }; use serde::{Deserialize, Deserializer}; use std::{ fmt::{Display, Formatter}, fs::File, io, net::{SocketAddr, UdpSocket}, path::PathBuf, sync::Arc, }; #[cfg(not(target_os = "windows"))] use tokio::net::UnixStream; use tokio::{ io::{AsyncBufRead, AsyncRead, AsyncWrite, BufReader, BufStream}, net::{TcpListener, TcpStream}, }; #[cfg(feature = "incoming")] pub mod incoming; #[cfg(feature = "outgoing")] pub mod outgoing; #[cfg(any(feature = "rustls-native-certs", feature = "webpki-roots"))] pub mod ca_roots; #[cfg(feature = "rustls")] pub mod certs_key; pub mod stream_listener; pub const IN_BUFFER_SIZE: usize = 8192; pub const ALPN_XMPP_CLIENT: &[u8] = b"xmpp-client"; pub const ALPN_XMPP_SERVER: &[u8] = b"xmpp-server"; pub trait AsyncReadAndWrite: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send {} impl AsyncReadAndWrite for T {} pub trait AsyncReadWritePeekSplit: tokio::io::AsyncRead + tokio::io::AsyncWrite + Peek + Send + 'static + Unpin + Split {} impl AsyncReadWritePeekSplit for T {} pub type BoxAsyncReadWrite = Box; pub type BufAsyncReadWrite = BufStream; pub fn buf_stream(stream: BoxAsyncReadWrite) -> BufAsyncReadWrite { // todo: do we *want* a non-zero writer_capacity ? BufStream::with_capacity(IN_BUFFER_SIZE, 0, stream) } pub fn to_str(buf: &[u8]) -> std::borrow::Cow<'_, str> { String::from_utf8_lossy(buf) } pub fn c2s(is_c2s: bool) -> &'static str { if is_c2s { "c2s" } else { "s2s" } } #[derive(Clone)] pub enum SocketAddrPath { SocketAddr(SocketAddr), #[cfg(not(target_os = "windows"))] Path(PathBuf), } impl SocketAddrPath { pub async fn connect(&self) -> Result<(Box, Box)> { Ok(match self { SocketAddrPath::SocketAddr(sa) => TcpStream::connect(sa).await?.split_boxed(), #[cfg(not(target_os = "windows"))] SocketAddrPath::Path(path) => tokio::net::UnixStream::connect(path).await?.split_boxed(), }) } pub async fn bind(&self) -> Result { Ok(match self { SocketAddrPath::SocketAddr(sa) => Listener::Tcp(TcpListener::bind(sa).await?), #[cfg(not(target_os = "windows"))] SocketAddrPath::Path(path) => Listener::Unix(tokio::net::UnixListener::bind(path)?), }) } pub async fn bind_udp(&self) -> Result { Ok(match self { SocketAddrPath::SocketAddr(sa) => UdpListener::Udp(UdpSocket::bind(sa)?), #[cfg(not(target_os = "windows"))] SocketAddrPath::Path(path) => UdpListener::Unix(std::os::unix::net::UnixDatagram::bind(path)?), }) } } impl Display for SocketAddrPath { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { SocketAddrPath::SocketAddr(x) => x.fmt(f), #[cfg(not(target_os = "windows"))] SocketAddrPath::Path(x) => x.display().fmt(f), } } } impl<'de> Deserialize<'de> for SocketAddrPath { fn deserialize(deserializer: D) -> Result where D: Deserializer<'de>, { #[cfg(not(target_os = "windows"))] { let str = String::deserialize(deserializer)?; // seems good enough, possibly could improve Ok(if str.contains('/') { SocketAddrPath::Path(PathBuf::from(str)) } else { SocketAddrPath::SocketAddr(str.parse().map_err(serde::de::Error::custom)?) }) } #[cfg(target_os = "windows")] { Ok(SocketAddrPath::SocketAddr(SocketAddr::deserialize(deserializer)?)) } } } pub enum Listener { Tcp(TcpListener), #[cfg(not(target_os = "windows"))] Unix(tokio::net::UnixListener), } pub enum UdpListener { Udp(UdpSocket), #[cfg(not(target_os = "windows"))] Unix(std::os::unix::net::UnixDatagram), } pub trait Split: Sized { type ReadHalf: AsyncRead + Unpin + Send + 'static; type WriteHalf: AsyncWrite + Unpin + Send + 'static; fn combine(read_half: Self::ReadHalf, write_half: Self::WriteHalf) -> Result; fn split(self) -> (Self::ReadHalf, Self::WriteHalf); fn stanza_rw(self) -> (StanzaRead, StanzaWrite); fn split_boxed(self) -> (Box, Box) { let (rd, wr) = self.split(); (Box::new(rd), Box::new(wr)) } } impl Split for TcpStream { type ReadHalf = tokio::net::tcp::OwnedReadHalf; type WriteHalf = tokio::net::tcp::OwnedWriteHalf; fn combine(read_half: Self::ReadHalf, write_half: Self::WriteHalf) -> Result { Ok(read_half.reunite(write_half)?) } fn split(self) -> (Self::ReadHalf, Self::WriteHalf) { self.into_split() } fn stanza_rw(self) -> (StanzaRead, StanzaWrite) { let (in_rd, in_wr) = self.into_split(); (StanzaRead::new(in_rd), StanzaWrite::new(in_wr)) } } #[cfg(feature = "tokio-rustls")] impl Split for tokio_rustls::server::TlsStream { type ReadHalf = tokio::io::ReadHalf>; type WriteHalf = tokio::io::WriteHalf>; fn combine(read_half: Self::ReadHalf, write_half: Self::WriteHalf) -> Result { if read_half.is_pair_of(&write_half) { Ok(read_half.unsplit(write_half)) } else { bail!("non-matching read/write half") } } fn split(self) -> (Self::ReadHalf, Self::WriteHalf) { tokio::io::split(self) } fn stanza_rw(self) -> (StanzaRead, StanzaWrite) { let (in_rd, in_wr) = tokio::io::split(self); (StanzaRead::new(in_rd), StanzaWrite::new(in_wr)) } } #[cfg(not(target_os = "windows"))] impl Split for UnixStream { type ReadHalf = tokio::net::unix::OwnedReadHalf; type WriteHalf = tokio::net::unix::OwnedWriteHalf; fn combine(read_half: Self::ReadHalf, write_half: Self::WriteHalf) -> Result { Ok(read_half.reunite(write_half)?) } fn split(self) -> (Self::ReadHalf, Self::WriteHalf) { self.into_split() } fn stanza_rw(self) -> (StanzaRead, StanzaWrite) { let (in_rd, in_wr) = self.into_split(); (StanzaRead::new(in_rd), StanzaWrite::new(in_wr)) } } impl Split for BufStream { type ReadHalf = tokio::io::ReadHalf>; type WriteHalf = tokio::io::WriteHalf>; fn combine(read_half: Self::ReadHalf, write_half: Self::WriteHalf) -> Result { if read_half.is_pair_of(&write_half) { Ok(read_half.unsplit(write_half)) } else { bail!("non-matching read/write half") } } fn split(self) -> (Self::ReadHalf, Self::WriteHalf) { tokio::io::split(self) } fn stanza_rw(self) -> (StanzaRead, StanzaWrite) { let (in_rd, in_wr) = tokio::io::split(self); (StanzaRead::already_buffered(in_rd), StanzaWrite::new(in_wr)) } } #[async_trait] pub trait Peek { async fn peek_bytes<'a>(&mut self, p: &'a mut [u8]) -> anyhow::Result<&'a [u8]>; async fn first_bytes_match<'a>(&mut self, p: &'a mut [u8], matcher: fn(&'a [u8]) -> bool) -> Result { Ok(matcher(self.peek_bytes(p).await?)) } } #[async_trait] impl Peek for TcpStream { async fn peek_bytes<'a>(&mut self, p: &'a mut [u8]) -> anyhow::Result<&'a [u8]> { // sooo... I don't think peek here can be used for > 1 byte without this timer craziness... can it? let len = p.len(); // wait up to 10 seconds until len bytes have been read use std::time::{Duration, Instant}; let duration = Duration::from_secs(10); let now = Instant::now(); loop { let n = self.peek(p).await?; if n == len { return Ok(p); // success } if n == 0 { bail!("not enough bytes"); } if Instant::now() - now > duration { bail!("less than {} bytes in 10 seconds, closed connection?", len); } } } } #[async_trait] impl Peek for BufStream { async fn peek_bytes<'a>(&mut self, p: &'a mut [u8]) -> anyhow::Result<&'a [u8]> { // sooo... I don't think peek here can be used for > 1 byte without this timer craziness... can it? let len = p.len(); // wait up to 10 seconds until len bytes have been read use std::time::{Duration, Instant}; use tokio::io::AsyncBufReadExt; let duration = Duration::from_secs(10); let now = Instant::now(); loop { let buf = self.fill_buf().await?; if buf.len() >= len { p.copy_from_slice(&buf[0..len]); return Ok(p); // success } if buf.is_empty() { bail!("not enough bytes"); } if Instant::now() - now > duration { bail!("less than {} bytes in 10 seconds, closed connection?", len); } } } } #[async_trait] impl Peek for BufReader { async fn peek_bytes<'a>(&mut self, p: &'a mut [u8]) -> anyhow::Result<&'a [u8]> { // sooo... I don't think peek here can be used for > 1 byte without this timer craziness... can it? let len = p.len(); // wait up to 10 seconds until len bytes have been read use std::time::{Duration, Instant}; use tokio::io::AsyncBufReadExt; let duration = Duration::from_secs(10); let now = Instant::now(); loop { let buf = self.fill_buf().await?; if buf.len() >= len { p.copy_from_slice(&buf[0..len]); return Ok(p); // success } if buf.is_empty() { bail!("not enough bytes"); } if Instant::now() - now > duration { bail!("less than {} bytes in 10 seconds, closed connection?", len); } } } } /// Caution: this will loop forever, call timeout variant `first_bytes_match_buf_timeout` async fn first_bytes_match_buf(stream: &mut (dyn AsyncBufRead + Send + Unpin), len: usize, matcher: fn(&[u8]) -> bool) -> Result { use tokio::io::AsyncBufReadExt; loop { let buf = stream.fill_buf().await?; if buf.len() >= len { return Ok(matcher(&buf[0..len])); } } } pub async fn first_bytes_match_buf_timeout(stream: &mut (dyn AsyncBufRead + Send + Unpin), len: usize, matcher: fn(&[u8]) -> bool) -> Result { // wait up to 10 seconds until 3 bytes have been read use std::time::Duration; tokio::time::timeout(Duration::from_secs(10), first_bytes_match_buf(stream, len, matcher)).await? } pub async fn stream_preamble(in_rd: &mut StanzaRead, in_wr: &mut StanzaWrite, client_addr: &'_ str, in_filter: &mut StanzaFilter) -> Result<(Vec, bool)> { let mut stream_open = Vec::new(); while let Ok(Some((buf, _))) = in_rd.next(in_filter, client_addr, in_wr).await { trace!("{} received pre- stanza: '{}'", client_addr, to_str(buf)); if buf.starts_with(b" stanza: {}", to_str(buf)); } } bail!("stream ended before open") } #[allow(clippy::too_many_arguments)] pub async fn shuffle_rd_wr_filter_only( mut in_rd: StanzaRead, mut in_wr: StanzaWrite, mut out_rd: StanzaRead, mut out_wr: StanzaWrite, is_c2s: bool, max_stanza_size_bytes: usize, client_addr: &mut Context<'_>, mut in_filter: StanzaFilter, ) -> Result<()> { let mut out_filter = StanzaFilter::new(max_stanza_size_bytes); loop { tokio::select! { Ok(ret) = in_rd.next(&mut in_filter, client_addr.log_to(), &mut in_wr) => { match ret { None => break, Some((buf, eoft)) => { trace!("{} '{}'", client_addr.log_from(), to_str(buf)); out_wr.write_all(is_c2s, buf, eoft, client_addr.log_from()).await?; out_wr.flush().await?; } } }, Ok(ret) = out_rd.next(&mut out_filter, client_addr.log_from(), &mut out_wr) => { match ret { None => break, Some((buf, eoft)) => { trace!("{} '{}'", client_addr.log_to(), to_str(buf)); in_wr.write_all(is_c2s, buf, eoft, client_addr.log_to()).await?; in_wr.flush().await?; } } }, } } info!("{} disconnected", client_addr.log_from()); Ok(()) } #[cfg(feature = "rustls-pemfile")] pub fn read_certified_key(tls_key: &str, tls_cert: &str) -> Result { use rustls_pemfile::{certs, read_all, Item}; let tls_key = read_all(&mut io::BufReader::new(File::open(tls_key)?)) .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid key"))? .into_iter() .flat_map(|item| match item { Item::RSAKey(der) => RsaSigningKey::new(&PrivateKey(der)).ok().map(Arc::new).map(|r| r as Arc), Item::PKCS8Key(der) => rustls::sign::any_supported_type(&PrivateKey(der)).ok(), Item::ECKey(der) => rustls::sign::any_supported_type(&PrivateKey(der)).ok(), _ => None, }) .next() .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "invalid key"))?; let tls_certs = certs(&mut io::BufReader::new(File::open(tls_cert)?)) .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert")) .map(|mut certs| certs.drain(..).map(Certificate).collect())?; Ok(rustls::sign::CertifiedKey::new(tls_certs, tls_key)) }