diff --git a/.gitignore b/.gitignore index 1cf41aa..e83ea7b 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,5 @@ **/out/ **/core.* fuzz/target/ +todo.txt +conflict/ \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index d20698e..350a441 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -394,9 +394,9 @@ checksum = "1aab8fc367588b89dcee83ab0fd66b72b50b72fa1904d7095045ace2b0c81c35" [[package]] name = "js-sys" -version = "0.3.55" +version = "0.3.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7cc9ffccd38c451a86bf13657df244e9c3f37493cce8e5e21e940963777acc84" +checksum = "a38fc24e30fd564ce974c02bf1d337caddff65be6cc4735a1f7eab22a7440f04" dependencies = [ "wasm-bindgen", ] @@ -897,9 +897,9 @@ checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" [[package]] name = "syn" -version = "1.0.85" +version = "1.0.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a684ac3dcd8913827e18cd09a68384ee66c1de24157e3c556c9ab16d85695fb7" +checksum = "8a65b3f4ffa0092e9887669db0eae07941f023991ab58ea44da8fe8e2d511c6b" dependencies = [ "proc-macro2", "quote", @@ -996,8 +996,12 @@ checksum = "e80b39df6afcc12cdf752398ade96a6b9e99c903dfdc36e53ad10b9c366bca72" dependencies = [ "futures-util", "log", + "rustls", + "rustls-native-certs", "tokio", + "tokio-rustls", "tungstenite", + "webpki", ] [[package]] @@ -1099,10 +1103,12 @@ dependencies = [ "httparse", "log", "rand", + "rustls", "sha-1", "thiserror", "url", "utf-8", + "webpki", ] [[package]] @@ -1170,15 +1176,15 @@ checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" [[package]] name = "wasi" -version = "0.10.3+wasi-snapshot-preview1" +version = "0.10.2+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46a2e384a3f170b0c7543787a91411175b71afd56ba4d3a0ae5678d4e2243c0e" +checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6" [[package]] name = "wasm-bindgen" -version = "0.2.78" +version = "0.2.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "632f73e236b219150ea279196e54e610f5dbafa5d61786303d4da54f84e47fce" +checksum = "25f1af7423d8588a3d840681122e72e6a24ddbcb3f0ec385cac0d12d24256c06" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -1186,9 +1192,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.78" +version = "0.2.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a317bf8f9fba2476b4b2c85ef4c4af8ff39c3c7f0cdfeed4f82c34a880aa837b" +checksum = "8b21c0df030f5a177f3cba22e9bc4322695ec43e7257d865302900290bcdedca" dependencies = [ "bumpalo", "lazy_static", @@ -1201,9 +1207,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.78" +version = "0.2.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d56146e7c495528bf6587663bea13a8eb588d39b36b679d83972e1a2dbbdacf9" +checksum = "2f4203d69e40a52ee523b2529a773d5ffc1dc0071801c87b3d270b471b80ed01" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -1211,9 +1217,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.78" +version = "0.2.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7803e0eea25835f8abdc585cd3021b3deb11543c6fe226dcd30b228857c5c5ab" +checksum = "bfa8a30d46208db204854cadbb5d4baf5fcf8071ba5bf48190c3e59937962ebc" dependencies = [ "proc-macro2", "quote", @@ -1224,15 +1230,15 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.78" +version = "0.2.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0237232789cf037d5480773fe568aac745bfe2afbc11a863e97901780a6b47cc" +checksum = "3d958d035c4438e28c70e4321a2911302f10135ce78a9c7834c0cab4123d06a2" [[package]] name = "web-sys" -version = "0.3.55" +version = "0.3.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38eb105f1c59d9eaa6b5cdc92b859d85b926e82cb2e0945cd0c9259faa6fe9fb" +checksum = "c060b319f29dd25724f09a2ba1418f142f539b2be99fbf4d2d5a8f7330afb8eb" dependencies = [ "js-sys", "wasm-bindgen", @@ -1311,11 +1317,13 @@ dependencies = [ "die", "env_logger", "futures", + "futures-util", "lazy_static", "log", "quinn", "rand", "rustls", + "rustls-native-certs", "rustls-pemfile", "serde", "serde_derive", diff --git a/Cargo.toml b/Cargo.toml index fb32b8c..04bb59b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,9 +38,11 @@ tokio-rustls = { version = "0.23", optional = true } # outgoing deps lazy_static = { version = "1.4", optional = true } -webpki-roots = { version = "0.22", optional = true } trust-dns-resolver = { version = "0.20", optional = true } #trust-dns-resolver = { version = "0.20", features = ["dns-over-rustls"], optional = true } +# todo: feature to swap between webpki-roots and rustls-native-certs +webpki-roots = { version = "0.22", optional = true } +rustls-native-certs = { version = "0.6", optional = true } # quic deps quinn = { version = "0.8", optional = true } @@ -50,20 +52,16 @@ rustls = { version = "0.20.2", optional = true } rustls-pemfile = { version = "0.2.1", optional = true } # websocket deps -tokio-tungstenite = { version = "0.16", optional = true } +#tokio-tungstenite = { version = "0.16", optional = true, features = ["rustls-tls-webpki-roots"] } +tokio-tungstenite = { version = "0.16", optional = true, features = ["rustls-tls-native-roots"] } +futures-util = { version = "0.3", default-features = false, features = ["async-await", "sink", "std"], optional = true } [features] default = ["incoming", "outgoing", "quic", "websocket", "logging"] -#default = ["incoming", "outgoing"] -#default = ["incoming", "quic"] -#default = ["outgoing", "quic"] -#default = ["quic"] -#default = ["outgoing"] -#default = ["incoming"] incoming = ["tokio-rustls", "rustls-pemfile", "rustls"] -outgoing = ["tokio-rustls", "trust-dns-resolver", "webpki-roots", "lazy_static", "rustls"] -quic = ["quinn", "rustls-pemfile", "rustls", "webpki-roots"] -websocket = ["tokio-tungstenite", "tokio-rustls", "rustls-pemfile", "rustls"] +outgoing = ["tokio-rustls", "trust-dns-resolver", "rustls-native-certs", "lazy_static", "rustls"] +quic = ["quinn", "rustls-pemfile", "rustls", "rustls-native-certs"] +websocket = ["tokio-tungstenite", "futures-util", "tokio-rustls", "rustls-pemfile", "rustls", "rustls-native-certs"] logging = ["rand", "env_logger"] [package.metadata.cargo-all-features] diff --git a/README.md b/README.md index a292fcb..8c6b480 100644 --- a/README.md +++ b/README.md @@ -34,8 +34,7 @@ xmpp-proxy in outgoing mode will: #### Configuration * `mkdir /etc/xmpp-proxy/ && cp xmpp-proxy.toml /etc/xmpp-proxy/` * edit `/etc/xmpp-proxy/xmpp-proxy.toml` as needed, file is annotated clearly with comments - * put your TLS key/cert in `/etc/xmpp-proxy/`, if your key has "RSA PRIVATE KEY" in it, change that to "PRIVATE KEY": - `sed -i 's/RSA PRIVATE KEY/PRIVATE KEY/' /etc/xmpp-proxy/le.key` + * put your TLS key/cert in `/etc/xmpp-proxy/` * Example systemd unit is provided in xmpp-proxy.service and locks it down with bare minimum permissions. Need to set the permissions correctly: `chown -Rv 'systemd-network:' /etc/xmpp-proxy/` * start xmpp-proxy: `Usage: xmpp-proxy [/path/to/xmpp-proxy.toml (default /etc/xmpp-proxy/xmpp-proxy.toml]` @@ -53,7 +52,6 @@ use the provided `xmpp-proxy.toml` configuration as-is. Edit `/etc/prosody/prosody.cfg.lua`, Add these to modules_enabled: ``` "net_proxy"; -"secure_interfaces"; "s2s_outgoing_proxy"; ``` Until prosody-modules is updated, use my new module [mod_s2s_outgoing_proxy.lua](https://www.moparisthebest.com/mod_s2s_outgoing_proxy.lua). @@ -68,13 +66,12 @@ interfaces = { "127.0.0.1" } -- you can also remove all certificates from your config s2s_require_encryption = false s2s_secure_auth = false +c2s_require_encryption = false +allow_unencrypted_plain_auth = true -- xmpp-proxy outgoing is listening on this port, make all outgoing s2s connections directly to here s2s_outgoing_proxy = { "127.0.0.1", 15270 } --- trust connections coming from these IPs -secure_interfaces = { "127.0.0.1", "::1" } - -- handle PROXY protocol on these ports proxy_port_mappings = { [15222] = "c2s", diff --git a/src/in_out.rs b/src/in_out.rs new file mode 100644 index 0000000..370a081 --- /dev/null +++ b/src/in_out.rs @@ -0,0 +1,117 @@ +// Box, Box + +#[cfg(feature = "websocket")] +use crate::{from_ws, to_ws_new}; +use crate::{slicesubsequence::SliceSubsequence, trace, StanzaFilter, StanzaRead::*, StanzaReader, StanzaWrite::*}; +use anyhow::{bail, Result}; +#[cfg(feature = "websocket")] +use futures_util::{ + stream::{SplitSink, SplitStream}, + SinkExt, TryStreamExt, +}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}; +#[cfg(feature = "websocket")] +use tokio_tungstenite::{tungstenite::Message::*, WebSocketStream}; + +#[cfg(feature = "websocket")] +type WsWr = SplitSink>, tokio_tungstenite::tungstenite::Message>; +#[cfg(feature = "websocket")] +type WsRd = SplitStream>>; + +pub enum StanzaWrite { + AsyncWrite(Box), + #[cfg(feature = "websocket")] + WebSocketClientWrite(WsWr), +} + +pub enum StanzaRead { + AsyncRead(StanzaReader>>), + #[cfg(feature = "websocket")] + WebSocketRead(WsRd), +} + +impl StanzaWrite { + pub fn new(wr: Box) -> Self { + AsyncWrite(wr) + } + + pub async fn write_all<'a>(&'a mut self, is_c2s: bool, buf: &'a [u8], end_of_first_tag: usize, client_addr: &'a str) -> Result<()> { + match self { + AsyncWrite(wr) => Ok(wr.write_all(buf).await?), + #[cfg(feature = "websocket")] + WebSocketClientWrite(in_wr) => { + let mut buf = buf; + // ignore this + if buf.starts_with(b" Result<()> { + match self { + AsyncWrite(wr) => Ok(wr.flush().await?), + #[cfg(feature = "websocket")] + WebSocketClientWrite(ws) => Ok(ws.flush().await?), + } + } +} + +impl StanzaRead { + pub fn new(rd: Box) -> Self { + // we naively read 1 byte at a time, which buffering significantly speeds up + AsyncRead(StanzaReader(BufReader::with_capacity(crate::IN_BUFFER_SIZE, rd))) + } + + pub async fn next<'a>(&'a mut self, filter: &'a mut StanzaFilter, client_addr: &'a str, wrt: &mut StanzaWrite) -> Result> { + match self { + AsyncRead(rd) => rd.next_eoft(filter).await, + #[cfg(feature = "websocket")] + WebSocketRead(rd) => { + loop { + if let Some(msg) = rd.try_next().await? { + match msg { + // actual XMPP stanzas + Text(stanza) => { + trace!("{} (before ws conversion) '{}'", client_addr, stanza); + let stanza = from_ws(stanza); + let stanza = stanza.as_bytes(); + // todo: set up websocket connection so max size cannot be bigger than filter.buf.len() + let buf = &mut filter.buf[0..stanza.len()]; + buf.copy_from_slice(stanza); + return Ok(Some((buf, 0))); // todo: 0 or None... + } + // websocket ping/pong + Ping(msg) => { + match wrt { + AsyncWrite(_) => bail!("programming error! should always send matching write pair into read, so websocket for websocket..."), + WebSocketClientWrite(ws) => { + ws.feed(Pong(msg)).await?; + ws.flush().await?; + } + } + continue; + } + // handle Close, just break from loop, hopefully client sent before + Close(cf) => bail!("websocket close: {:?}", cf), + _ => bail!("invalid websocket message: {}", msg), // Binary or Pong + } + } else { + bail!("websocket stream ended") + } + } + } + } + } +} diff --git a/src/main.rs b/src/main.rs index 4ac9da8..bdd61e5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,7 +13,7 @@ use die::Die; use serde_derive::Deserialize; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf}; +use tokio::io::{AsyncWriteExt, ReadHalf, WriteHalf}; use tokio::net::TcpListener; use tokio::task::JoinHandle; @@ -54,8 +54,10 @@ mod websocket; #[cfg(feature = "websocket")] use crate::websocket::*; +mod in_out; +pub use crate::in_out::*; + const IN_BUFFER_SIZE: usize = 8192; -const OUT_BUFFER_SIZE: usize = 8192; // todo: split these out to outgoing module @@ -75,6 +77,16 @@ pub fn root_cert_store() -> rustls::RootCertStore { root_cert_store } +#[cfg(feature = "rustls-native-certs")] +pub fn root_cert_store() -> rustls::RootCertStore { + use rustls::RootCertStore; + let mut root_cert_store = RootCertStore::empty(); + for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs") { + root_cert_store.add(&rustls::Certificate(cert.0)).unwrap(); + } + root_cert_store +} + #[derive(Deserialize)] struct Config { tls_key: String, @@ -143,52 +155,73 @@ impl Config { } } -async fn shuffle_rd_wr(in_rd: R, in_wr: W, config: CloneableConfig, local_addr: SocketAddr, client_addr: &mut Context<'_>) -> Result<()> { +async fn shuffle_rd_wr(in_rd: StanzaRead, in_wr: StanzaWrite, config: CloneableConfig, local_addr: SocketAddr, client_addr: &mut Context<'_>) -> Result<()> { let filter = StanzaFilter::new(config.max_stanza_size_bytes); shuffle_rd_wr_filter(in_rd, in_wr, config, local_addr, client_addr, filter).await } -async fn shuffle_rd_wr_filter( - in_rd: R, - mut in_wr: W, +async fn shuffle_rd_wr_filter( + mut in_rd: StanzaRead, + mut in_wr: StanzaWrite, config: CloneableConfig, local_addr: SocketAddr, client_addr: &mut Context<'_>, - in_filter: StanzaFilter, + mut in_filter: StanzaFilter, ) -> Result<()> { - // we naively read 1 byte at a time, which buffering significantly speeds up - let in_rd = tokio::io::BufReader::with_capacity(IN_BUFFER_SIZE, in_rd); - // now read to figure out client vs server - let (stream_open, is_c2s, mut in_rd, mut in_filter) = stream_preamble(StanzaReader(in_rd), client_addr, in_filter).await?; + let (stream_open, is_c2s) = stream_preamble(&mut in_rd, &mut in_wr, client_addr.log_from(), &mut in_filter).await?; - let (mut out_rd, mut out_wr) = open_incoming(config, local_addr, client_addr, &stream_open, is_c2s, &mut in_filter).await?; + let (out_rd, out_wr) = open_incoming(&config, local_addr, client_addr, &stream_open, is_c2s, &mut in_filter).await?; drop(stream_open); - let mut out_buf = [0u8; OUT_BUFFER_SIZE]; + shuffle_rd_wr_filter_only( + in_rd, + in_wr, + StanzaRead::new(Box::new(out_rd)), + StanzaWrite::new(Box::new(out_wr)), + is_c2s, + config.max_stanza_size_bytes, + client_addr, + in_filter, + ) + .await +} + +#[allow(clippy::too_many_arguments)] +async fn shuffle_rd_wr_filter_only( + mut in_rd: StanzaRead, + mut in_wr: StanzaWrite, + mut out_rd: StanzaRead, + mut out_wr: StanzaWrite, + is_c2s: bool, + max_stanza_size_bytes: usize, + client_addr: &mut Context<'_>, + mut in_filter: StanzaFilter, +) -> Result<()> { + let mut out_filter = StanzaFilter::new(max_stanza_size_bytes); loop { tokio::select! { - Ok(buf) = in_rd.next(&mut in_filter) => { - match buf { - None => break, - Some(buf) => { - trace!("{} '{}'", client_addr.log_from(), to_str(buf)); - out_wr.write_all(buf).await?; - out_wr.flush().await?; + Ok(ret) = in_rd.next(&mut in_filter, client_addr.log_to(), &mut in_wr) => { + match ret { + None => break, + Some((buf, eoft)) => { + trace!("{} '{}'", client_addr.log_from(), to_str(buf)); + out_wr.write_all(is_c2s, buf, eoft, client_addr.log_from()).await?; + out_wr.flush().await?; + } } - } - }, - // we could filter outgoing from-server stanzas by size here too by doing same as above - // but instead, we'll just send whatever the server sends as it sends it... - Ok(n) = out_rd.read(&mut out_buf) => { - if n == 0 { - break; - } - trace!("{} '{}'", client_addr.log_to(), to_str(&out_buf[0..n])); - in_wr.write_all(&out_buf[0..n]).await?; - in_wr.flush().await?; - }, + }, + Ok(ret) = out_rd.next(&mut out_filter, client_addr.log_from(), &mut out_wr) => { + match ret { + None => break, + Some((buf, eoft)) => { + trace!("{} '{}'", client_addr.log_to(), to_str(buf)); + in_wr.write_all(is_c2s, buf, eoft, client_addr.log_to()).await?; + in_wr.flush().await?; + } + } + }, } } @@ -197,7 +230,7 @@ async fn shuffle_rd_wr_filter( } async fn open_incoming( - config: CloneableConfig, + config: &CloneableConfig, local_addr: SocketAddr, client_addr: &mut Context<'_>, stream_open: &[u8], @@ -239,25 +272,20 @@ async fn open_incoming( Ok((out_rd, out_wr)) } -async fn stream_preamble(mut in_rd: StanzaReader, client_addr: &Context<'_>, mut in_filter: StanzaFilter) -> Result<(Vec, bool, StanzaReader, StanzaFilter)> { +pub async fn stream_preamble(in_rd: &mut StanzaRead, in_wr: &mut StanzaWrite, client_addr: &'_ str, in_filter: &mut StanzaFilter) -> Result<(Vec, bool)> { let mut stream_open = Vec::new(); - while let Ok(Some(buf)) = in_rd.next(&mut in_filter).await { - trace!("{} received pre- stanza: '{}'", client_addr.log_from(), to_str(buf)); + while let Ok(Some((buf, _))) = in_rd.next(in_filter, client_addr, in_wr).await { + trace!("{} received pre- stanza: '{}'", client_addr, to_str(buf)); if buf.starts_with(b" stanza: {}", to_str(buf)); } } - bail!("stream ended before open"); + bail!("stream ended before open") } #[tokio::main] @@ -277,6 +305,7 @@ async fn main() { if let Some(ref log_style) = main_config.log_style { builder.parse_write_style(log_style); } + // todo: config for this: builder.format_timestamp(None); builder.init(); } diff --git a/src/outgoing.rs b/src/outgoing.rs index a5cac46..e30e9d1 100644 --- a/src/outgoing.rs +++ b/src/outgoing.rs @@ -3,57 +3,27 @@ use crate::*; async fn handle_outgoing_connection(stream: tokio::net::TcpStream, client_addr: &mut Context<'_>, max_stanza_size_bytes: usize) -> Result<()> { info!("{} connected", client_addr.log_from()); - let in_filter = StanzaFilter::new(max_stanza_size_bytes); + let mut in_filter = StanzaFilter::new(max_stanza_size_bytes); - let (in_rd, mut in_wr) = tokio::io::split(stream); + let (in_rd, in_wr) = tokio::io::split(stream); - // we naively read 1 byte at a time, which buffering significantly speeds up - //let in_rd = tokio::io::BufReader::with_capacity(IN_BUFFER_SIZE, in_rd); + let mut in_rd = StanzaRead::new(Box::new(in_rd)); + let mut in_wr = StanzaWrite::new(Box::new(in_wr)); // now read to figure out client vs server - let (stream_open, is_c2s, in_rd, mut in_filter) = stream_preamble(StanzaReader(in_rd), client_addr, in_filter).await?; + let (stream_open, is_c2s) = stream_preamble(&mut in_rd, &mut in_wr, client_addr.log_to(), &mut in_filter).await?; client_addr.set_c2s_stream_open(is_c2s, &stream_open); - // pull raw reader back out of StanzaReader - let mut in_rd = in_rd.0; // we require a valid to= here or we fail let to = std::str::from_utf8(stream_open.extract_between(b" to='", b"'").or_else(|_| stream_open.extract_between(b" to=\"", b"\""))?)?; - let (mut out_wr, mut out_rd, stream_open) = srv_connect(to, is_c2s, &stream_open, &mut in_filter, client_addr).await?; + let (out_wr, out_rd, stream_open) = srv_connect(to, is_c2s, &stream_open, &mut in_filter, client_addr).await?; // send server response to client - in_wr.write_all(&stream_open).await?; + in_wr.write_all(is_c2s, &stream_open, 0, client_addr.log_from()).await?; in_wr.flush().await?; drop(stream_open); - let mut out_buf = [0u8; OUT_BUFFER_SIZE]; - - loop { - tokio::select! { - Ok(buf) = out_rd.next(&mut in_filter) => { - match buf { - None => break, - Some(buf) => { - trace!("{} '{}'", client_addr.log_to(), to_str(buf)); - in_wr.write_all(buf).await?; - in_wr.flush().await?; - } - } - }, - // we could filter outgoing from-client stanzas by size here too by doing same as above - // but instead, we'll just send whatever the client sends as it sends it... - Ok(n) = in_rd.read(&mut out_buf) => { - if n == 0 { - break; - } - trace!("{} '{}'", client_addr.log_from(), to_str(&out_buf[0..n])); - out_wr.write_all(&out_buf[0..n]).await?; - out_wr.flush().await?; - }, - } - } - - info!("{} disconnected", client_addr.log_from()); - Ok(()) + shuffle_rd_wr_filter_only(in_rd, in_wr, out_rd, out_wr, is_c2s, max_stanza_size_bytes, client_addr, in_filter).await } pub fn spawn_outgoing_listener(local_addr: SocketAddr, max_stanza_size_bytes: usize) -> JoinHandle> { diff --git a/src/quic.rs b/src/quic.rs index cece106..2f3e617 100644 --- a/src/quic.rs +++ b/src/quic.rs @@ -7,7 +7,7 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use anyhow::Result; -pub async fn quic_connect(target: SocketAddr, server_name: &str, is_c2s: bool) -> Result<(Box, Box)> { +pub async fn quic_connect(target: SocketAddr, server_name: &str, is_c2s: bool) -> Result<(StanzaWrite, StanzaRead)> { let bind_addr = "0.0.0.0:0".parse().unwrap(); let mut client_cfg = ClientConfig::builder().with_safe_defaults().with_root_certificates(root_cert_store()).with_no_client_auth(); // todo: for s2s do client auth client_cfg.alpn_protocols.push(if is_c2s { ALPN_XMPP_CLIENT } else { ALPN_XMPP_SERVER }.to_vec()); @@ -20,7 +20,7 @@ pub async fn quic_connect(target: SocketAddr, server_name: &str, is_c2s: bool) - trace!("quic connected: addr={}", connection.remote_address()); let (wrt, rd) = connection.open_bi().await?; - Ok((Box::new(wrt), Box::new(rd))) + Ok((StanzaWrite::AsyncWrite(Box::new(wrt)), StanzaRead::new(Box::new(rd)))) } impl Config { @@ -80,7 +80,7 @@ pub fn spawn_quic_listener(local_addr: SocketAddr, config: CloneableConfig, serv let mut client_addr = client_addr.clone(); info!("{} connected new stream", client_addr.log_from()); tokio::spawn(async move { - if let Err(e) = shuffle_rd_wr(rd, wrt, config, local_addr, &mut client_addr).await { + if let Err(e) = shuffle_rd_wr(StanzaRead::new(Box::new(rd)), StanzaWrite::new(Box::new(wrt)), config, local_addr, &mut client_addr).await { error!("{} {}", client_addr.log_from(), e); } }); diff --git a/src/srv.rs b/src/srv.rs index ef26352..404b6cc 100644 --- a/src/srv.rs +++ b/src/srv.rs @@ -1,14 +1,15 @@ #![allow(clippy::upper_case_acronyms)] +use std::convert::TryFrom; use std::net::SocketAddr; use trust_dns_resolver::error::ResolveError; -use trust_dns_resolver::lookup::SrvLookup; +use trust_dns_resolver::lookup::{SrvLookup, TxtLookup}; use trust_dns_resolver::{IntoName, TokioAsyncResolver}; use anyhow::{bail, Result}; - -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +#[cfg(feature = "websocket")] +use tokio_tungstenite::tungstenite::http::Uri; use crate::*; @@ -22,12 +23,14 @@ fn make_resolver() -> TokioAsyncResolver { TokioAsyncResolver::tokio(config, options).unwrap() } -#[derive(Copy, Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] pub enum XmppConnectionType { StartTLS, DirectTLS, #[cfg(feature = "quic")] QUIC, + #[cfg(feature = "websocket")] + WebSocket(Uri, String), } #[derive(Debug)] @@ -48,13 +51,14 @@ impl XmppConnection { stream_open: &[u8], in_filter: &mut crate::StanzaFilter, client_addr: &mut Context<'_>, - ) -> Result<(Box, Box, SocketAddr, &'static str)> { + ) -> Result<(StanzaWrite, StanzaRead, SocketAddr, &'static str)> { debug!("{} attempting connection to SRV: {:?}", client_addr.log_from(), self); // todo: need to set options to Ipv4AndIpv6 let ips = RESOLVER.lookup_ip(self.target.clone()).await?; for ip in ips.iter() { let to_addr = SocketAddr::new(ip, self.port); debug!("{} trying ip {}", client_addr.log_from(), to_addr); + // todo: for DNSSEC we need to optionally allow target in addition to domain, but what for SNI match self.conn_type { XmppConnectionType::StartTLS => match crate::starttls_connect(to_addr, domain, is_c2s, stream_open, in_filter).await { Ok((wr, rd)) => return Ok((wr, rd, to_addr, "starttls-out")), @@ -69,6 +73,12 @@ impl XmppConnection { Ok((wr, rd)) => return Ok((wr, rd, to_addr, "quic-out")), Err(e) => error!("quic connection failed to IP {} from SRV {}, error: {}", to_addr, self.target, e), }, + #[cfg(feature = "websocket")] + // todo: when websocket is found via DNS, we need to validate cert against domain, *not* target, this is a security problem with XEP-0156, we are doing it the secure but likely unexpected way here for now + XmppConnectionType::WebSocket(ref url, ref origin) => match crate::websocket_connect(to_addr, domain, url, origin, is_c2s).await { + Ok((wr, rd)) => return Ok((wr, rd, to_addr, "websocket-out")), + Err(e) => error!("websocket connection failed to IP {} from TXT {}, error: {}", to_addr, url, e), + }, } } bail!("cannot connect to any IPs for SRV: {}", self.target) @@ -80,7 +90,7 @@ fn collect_srvs(ret: &mut Vec, srv_records: std::result::Result< for srv in srv_records.iter() { if !srv.target().is_root() { ret.push(XmppConnection { - conn_type, + conn_type: conn_type.clone(), priority: srv.priority(), weight: srv.weight(), port: srv.port(), @@ -91,25 +101,92 @@ fn collect_srvs(ret: &mut Vec, srv_records: std::result::Result< } } +#[cfg(feature = "websocket")] +fn collect_txts(ret: &mut Vec, txt_records: std::result::Result, is_c2s: bool) { + if let Ok(txt_records) = txt_records { + for txt in txt_records.iter() { + for txt in txt.iter() { + // we only support wss and not ws (insecure) on purpose + if txt.starts_with(if is_c2s { b"_xmpp-client-websocket=wss://" } else { b"_xmpp-server-websocket=wss://" }) { + // 23 is the length of "_xmpp-client-websocket=" and "_xmpp-server-websocket=" + let url = &txt[23..]; + let url = match Uri::try_from(url) { + Ok(url) => url, + Err(e) => { + debug!("invalid TXT record '{}', {}", to_str(txt), e); + continue; + } + }; + let server_name = match url.host() { + Some(server_name) => server_name.to_string(), + None => { + debug!("invalid TXT record '{}'", to_str(txt)); + continue; + } + }; + let target = server_name.clone().to_string(); + + let mut origin = "https://".to_string(); + origin.push_str(&server_name); + let port = if let Some(port) = url.port() { + origin.push(':'); + origin.push_str(port.as_str()); + port.as_u16() + } else { + 443 + }; + ret.push(XmppConnection { + conn_type: XmppConnectionType::WebSocket(url, origin), + priority: u16::MAX, + weight: 0, + port, + target, + }); + } + } + } + } +} + pub async fn get_xmpp_connections(domain: &str, is_c2s: bool) -> Result> { - let (starttls, direct_tls, quic) = if is_c2s { - ("_xmpp-client._tcp", "_xmpps-client._tcp", "_xmppq-client._udp") + let (starttls, direct_tls, quic, websocket) = if is_c2s { + ("_xmpp-client._tcp", "_xmpps-client._tcp", "_xmppq-client._udp", "_xmppconnect") } else { - ("_xmpp-server._tcp", "_xmpps-server._tcp", "_xmppq-server._udp") + ("_xmpp-server._tcp", "_xmpps-server._tcp", "_xmppq-server._udp", "_xmppconnect-server") }; let starttls = format!("{}.{}.", starttls, domain).into_name()?; let direct_tls = format!("{}.{}.", direct_tls, domain).into_name()?; + #[cfg(feature = "quic")] let quic = format!("{}.{}.", quic, domain).into_name()?; + #[cfg(feature = "websocket")] + let websocket = format!("{}.{}.", websocket, domain).into_name()?; // this lets them run concurrently but not in parallel, could spawn parallel tasks but... worth it ? - let (starttls, direct_tls, quic) = tokio::join!(RESOLVER.srv_lookup(starttls), RESOLVER.srv_lookup(direct_tls), RESOLVER.srv_lookup(quic),); + // todo: don't look up websocket or quic records when they are disabled + let ( + starttls, + direct_tls, + //#[cfg(feature = "quic")] + quic, + //#[cfg(feature = "websocket")] + websocket, + ) = tokio::join!( + RESOLVER.srv_lookup(starttls), + RESOLVER.srv_lookup(direct_tls), + //#[cfg(feature = "quic")] + RESOLVER.srv_lookup(quic), + //#[cfg(feature = "websocket")] + RESOLVER.txt_lookup(websocket), + ); let mut ret = Vec::new(); collect_srvs(&mut ret, starttls, XmppConnectionType::StartTLS); collect_srvs(&mut ret, direct_tls, XmppConnectionType::DirectTLS); #[cfg(feature = "quic")] collect_srvs(&mut ret, quic, XmppConnectionType::QUIC); + #[cfg(feature = "websocket")] + collect_txts(&mut ret, websocket, is_c2s); ret.sort_by(|a, b| a.priority.cmp(&b.priority)); // todo: do something with weight @@ -157,54 +234,30 @@ pub async fn get_xmpp_connections(domain: &str, is_c2s: bool) -> Result, -) -> Result<(Box, StanzaReader>>, Vec)> { +pub async fn srv_connect(domain: &str, is_c2s: bool, stream_open: &[u8], in_filter: &mut crate::StanzaFilter, client_addr: &mut Context<'_>) -> Result<(StanzaWrite, StanzaRead, Vec)> { for srv in get_xmpp_connections(domain, is_c2s).await? { let connect = srv.connect(domain, is_c2s, stream_open, in_filter, client_addr).await; if connect.is_err() { continue; } - let (mut out_wr, out_rd, to_addr, proto) = connect.unwrap(); + let (mut out_wr, mut out_rd, to_addr, proto) = connect.unwrap(); // if any of these ? returns early with an Err, these will stay set, I think that's ok though, the connection will be closed client_addr.set_proto(proto); client_addr.set_to_addr(to_addr); debug!("{} connected", client_addr.log_from()); - // we naively read 1 byte at a time, which buffering significantly speeds up - let mut out_rd = StanzaReader(tokio::io::BufReader::with_capacity(crate::IN_BUFFER_SIZE, out_rd)); - trace!("{} '{}'", client_addr.log_from(), to_str(stream_open)); - out_wr.write_all(stream_open).await?; + out_wr.write_all(is_c2s, stream_open, stream_open.len(), client_addr.log_from()).await?; out_wr.flush().await?; - let mut server_response = Vec::new(); - // let's read to first return Ok((out_wr, out_rd, server_response)), + Err(e) => { + debug!("{} bad server response, going to next record, error: {}", client_addr.log_to(), e); + client_addr.set_proto("unk-out"); + continue; } } - if !stream_received { - debug!("{} bad server response, going to next record", client_addr.log_to()); - client_addr.set_proto("unk-out"); - continue; - } - - return Ok((Box::new(out_wr), out_rd, server_response)); } bail!("all connection attempts failed") } @@ -212,13 +265,18 @@ pub async fn srv_connect( #[cfg(test)] mod tests { use crate::srv::*; + #[tokio::test] async fn srv() -> Result<()> { - let domain = "moparisthebest.com"; + let domain = "burtrum.org"; let is_c2s = true; for srv in get_xmpp_connections(domain, is_c2s).await? { - let ips = RESOLVER.lookup_ip(srv.target.clone()).await?; println!("trying 1 domain {}, SRV: {:?}", domain, srv); + #[cfg(feature = "websocket")] + if srv.conn_type == XmppConnectionType::WebSocket { + continue; + } + let ips = RESOLVER.lookup_ip(srv.target.clone()).await?; for ip in ips.iter() { println!("trying domain {}, ip {}, is_c2s: {}, SRV: {:?}", domain, ip, is_c2s, srv); } diff --git a/src/stanzafilter.rs b/src/stanzafilter.rs index 737bc4a..6d35a25 100644 --- a/src/stanzafilter.rs +++ b/src/stanzafilter.rs @@ -214,7 +214,6 @@ impl StanzaReader { } } - #[cfg(feature = "websocket")] pub async fn next_eoft<'a>(&'a mut self, filter: &'a mut StanzaFilter) -> Result> { use tokio::io::AsyncReadExt; diff --git a/src/tls.rs b/src/tls.rs index 1f932be..9ca2ca0 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -28,7 +28,7 @@ lazy_static::lazy_static! { } #[cfg(feature = "outgoing")] -pub async fn tls_connect(target: SocketAddr, server_name: &str, is_c2s: bool) -> Result<(Box, Box)> { +pub async fn tls_connect(target: SocketAddr, server_name: &str, is_c2s: bool) -> Result<(StanzaWrite, StanzaRead)> { let dnsname = ServerName::try_from(server_name)?; let stream = tokio::net::TcpStream::connect(target).await?; let stream = if is_c2s { @@ -37,17 +37,11 @@ pub async fn tls_connect(target: SocketAddr, server_name: &str, is_c2s: bool) -> SERVER_TLS_CONFIG.connect(dnsname, stream).await? }; let (rd, wrt) = tokio::io::split(stream); - Ok((Box::new(wrt), Box::new(rd))) + Ok((StanzaWrite::AsyncWrite(Box::new(wrt)), StanzaRead::new(Box::new(rd)))) } #[cfg(feature = "outgoing")] -pub async fn starttls_connect( - target: SocketAddr, - server_name: &str, - is_c2s: bool, - stream_open: &[u8], - in_filter: &mut StanzaFilter, -) -> Result<(Box, Box)> { +pub async fn starttls_connect(target: SocketAddr, server_name: &str, is_c2s: bool, stream_open: &[u8], in_filter: &mut StanzaFilter) -> Result<(StanzaWrite, StanzaRead)> { let dnsname = ServerName::try_from(server_name)?; let mut stream = tokio::net::TcpStream::connect(target).await?; let (in_rd, mut in_wr) = stream.split(); @@ -91,7 +85,7 @@ pub async fn starttls_connect( SERVER_TLS_CONFIG.connect(dnsname, stream).await? }; let (rd, wrt) = tokio::io::split(stream); - Ok((Box::new(wrt), Box::new(rd))) + Ok((StanzaWrite::AsyncWrite(Box::new(wrt)), StanzaRead::new(Box::new(rd)))) } #[cfg(feature = "incoming")] @@ -213,5 +207,5 @@ async fn handle_tls_connection(mut stream: tokio::net::TcpStream, client_addr: & let (in_rd, in_wr) = tokio::io::split(stream); - shuffle_rd_wr_filter(in_rd, in_wr, config, local_addr, client_addr, in_filter).await + shuffle_rd_wr_filter(StanzaRead::new(Box::new(in_rd)), StanzaWrite::new(Box::new(in_wr)), config, local_addr, client_addr, in_filter).await } diff --git a/src/websocket.rs b/src/websocket.rs index 5ee1d29..643380e 100644 --- a/src/websocket.rs +++ b/src/websocket.rs @@ -1,7 +1,7 @@ use crate::*; -use futures::{SinkExt, StreamExt, TryStreamExt}; +use anyhow::Result; +use futures::StreamExt; -use tokio_tungstenite::tungstenite::protocol::Message::*; use tokio_tungstenite::tungstenite::protocol::WebSocketConfig; // https://datatracker.ietf.org/doc/html/rfc7395 @@ -31,7 +31,10 @@ async fn handle_websocket_connection(stream: tokio::net::TcpStream, client_addr: // start TLS let stream = acceptor.accept(stream).await?; + let stream: tokio_rustls::TlsStream = stream.into(); + // accept the websocket + // todo: check SEC_WEBSOCKET_PROTOCOL or ORIGIN ? let stream = tokio_tungstenite::accept_async_with_config( stream, Some(WebSocketConfig { @@ -43,86 +46,31 @@ async fn handle_websocket_connection(stream: tokio::net::TcpStream, client_addr: ) .await?; - let (mut in_wr, mut in_rd) = stream.split(); + let (in_wr, in_rd) = stream.split(); - // https://docs.rs/tungstenite/0.14.0/tungstenite/protocol/enum.Message.html - // https://datatracker.ietf.org/doc/html/rfc7395#section-3.2 Data frame messages in the XMPP subprotocol MUST be of the text type and contain UTF-8 encoded data. - let (stanza, is_c2s) = match in_rd.try_next().await? { - // todo: c2s is xmlns="urn:ietf:params:xml:ns:xmpp-framing", let's make up s2s ? xmlns="urn:ietf:params:xml:ns:xmpp-framing-server" sounds good to me - Some(Text(stanza)) => { - let is_c2s = stanza.contains(r#" xmlns="urn:ietf:params:xml:ns:xmpp-framing""#) || stanza.contains(r#" xmlns='urn:ietf:params:xml:ns:xmpp-framing'"#); - (stanza, is_c2s) - } - _ => bail!("expected first websocket frame to be open"), - }; + let in_filter = StanzaFilter::new(config.max_stanza_size_bytes); - let stanza = from_ws(stanza); - let stream_open = stanza.as_bytes(); - - // websocket frame size filters incoming stanza size from client, this is used to split the - // stanzas from the servers up so we can send them across websocket frames - let mut in_filter = StanzaFilter::new(config.max_stanza_size_bytes); - - let (out_rd, mut out_wr) = open_incoming(config, local_addr, client_addr, stream_open, is_c2s, &mut in_filter).await?; - - let mut out_rd = StanzaReader(out_rd); - - loop { - tokio::select! { - // server to client - Ok(buf) = out_rd.next_eoft(&mut in_filter) => { - match buf { - None => break, - Some((buf, end_of_first_tag)) => { - // ignore this - if buf.starts_with(b" { - match msg { - // actual XMPP stanzas - Text(stanza) => { - let stanza = from_ws(stanza); - trace!("{} '{}'", client_addr.log_from(), stanza); - out_wr.write_all(stanza.as_bytes()).await?; - out_wr.flush().await?; - } - // websocket ping/pong - Ping(msg) => { - in_wr.feed(Pong(msg)).await?; - in_wr.flush().await?; - }, - // handle Close, just break from loop, hopefully client sent before - Close(_) => break, - _ => bail!("invalid websocket message: {}", msg) // Binary or Pong - } - }, - // todo: should we also send pings to the client ourselves on a schedule? StanzaFilter strips out whitespace pings if the server uses them... - } - } - - info!("{} disconnected", client_addr.log_from()); - Ok(()) + shuffle_rd_wr_filter(StanzaRead::WebSocketRead(in_rd), StanzaWrite::WebSocketClientWrite(in_wr), config, local_addr, client_addr, in_filter).await } pub fn from_ws(stanza: String) -> String { if stanza.starts_with("", ">"); + .replace("urn:ietf:params:xml:ns:xmpp-framing", "jabber:client"); + if !stanza.contains("xmlns:stream=") { + stanza.replace("/>", r#" xmlns:stream="http://etherx.jabber.org/streams">"#) + } else { + stanza.replace("/>", ">") + } } else if stanza.starts_with("".to_string(); + "".to_string() + } else { + stanza } - stanza } pub fn to_ws_new(buf: &[u8], mut end_of_first_tag: usize, is_c2s: bool) -> Result { @@ -156,6 +104,37 @@ pub fn to_ws_new(buf: &[u8], mut end_of_first_tag: usize, is_c2s: bool) -> Resul Ok(ret) } +use rustls::ServerName; +use std::convert::TryFrom; + +use tokio_rustls::TlsConnector; +use tokio_tungstenite::tungstenite::client::IntoClientRequest; +use tokio_tungstenite::tungstenite::http::header::{ORIGIN, SEC_WEBSOCKET_PROTOCOL}; +use tokio_tungstenite::tungstenite::http::Uri; + +pub async fn websocket_connect(target: SocketAddr, server_name: &str, url: &Uri, origin: &str, _is_c2s: bool) -> Result<(StanzaWrite, StanzaRead)> { + // todo: WebSocketConfig + // todo: static ? alpn? client cert auth for server + let connector = rustls::ClientConfig::builder().with_safe_defaults().with_root_certificates(root_cert_store()).with_no_client_auth(); + + let mut request = url.into_client_request()?; + request.headers_mut().append(SEC_WEBSOCKET_PROTOCOL, "xmpp".parse()?); + request.headers_mut().append(ORIGIN, origin.parse()?); + + let dnsname = ServerName::try_from(server_name)?; + let stream = tokio::net::TcpStream::connect(target).await?; + let connector = TlsConnector::from(Arc::new(connector)); + let stream = connector.connect(dnsname, stream).await?; + + let stream: tokio_rustls::TlsStream = stream.into(); + + let (stream, _) = tokio_tungstenite::client_async_with_config(request, stream, None).await?; + + let (wrt, rd) = stream.split(); + + Ok((StanzaWrite::WebSocketClientWrite(wrt), StanzaRead::WebSocketRead(rd))) +} + #[cfg(test)] mod tests { use crate::websocket::*; @@ -165,9 +144,14 @@ mod tests { fn test_from_ws() { assert_eq!( from_ws(r#""#.to_string()), - r#""#.to_string() + r#""#.to_string() ); assert_eq!(from_ws(r#""#.to_string()), r#""#.to_string()); + + assert_eq!( + from_ws(r#""#.to_string()), + r#""#.to_string() + ); } async fn to_vec_eoft(mut stanza_reader: StanzaReader, filter: &mut StanzaFilter) -> Result> {