From 9644f881445bb93dff2cf3031c7f36fa5b435af6 Mon Sep 17 00:00:00 2001 From: moparisthebest Date: Mon, 2 Sep 2024 22:18:03 -0400 Subject: [PATCH] Implement tokio_xmpp::ServerConnector --- .helix/config.toml | 2 + Cargo.lock | 260 +++++++++++++++++++++++++++++++++++++++++++- Cargo.toml | 8 +- src/common/mod.rs | 1 + src/in_out.rs | 240 +++++++++++++++++++++++++++++++++++++++- src/lib.rs | 3 + src/main.rs | 4 +- src/stanzafilter.rs | 5 +- src/tokio_xmpp.rs | 77 +++++++++++++ 9 files changed, 589 insertions(+), 11 deletions(-) create mode 100644 .helix/config.toml create mode 100644 src/tokio_xmpp.rs diff --git a/.helix/config.toml b/.helix/config.toml new file mode 100644 index 0000000..0379adb --- /dev/null +++ b/.helix/config.toml @@ -0,0 +1,2 @@ +[editor] +workspace-lsp-roots = ["."] diff --git a/Cargo.lock b/Cargo.lock index 0bcd7b9..e263ccd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -95,6 +95,12 @@ version = "0.21.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "35636a1494ede3b646cc98f74f8e62c773a38a659ebc777a2cf26b9b74171df9" +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + [[package]] name = "bitflags" version = "1.3.2" @@ -107,6 +113,15 @@ version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" +[[package]] +name = "blake2" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" +dependencies = [ + "digest", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -134,6 +149,15 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" +[[package]] +name = "castaway" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0abae9be0aaf9ea96a3b1b8b1b55c602ca751eba1b1500220cea4ecbafe7c0d5" +dependencies = [ + "rustversion", +] + [[package]] name = "cc" version = "1.0.83" @@ -149,6 +173,28 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "chrono" +version = "0.4.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" +dependencies = [ + "num-traits", +] + +[[package]] +name = "compact_str" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f86b9c4c00838774a6d902ef931eff7470720c51d90c2e32cfe15dc304737b3f" +dependencies = [ + "castaway", + "cfg-if", + "itoa", + "ryu", + "static_assertions", +] + [[package]] name = "core-foundation" version = "0.9.4" @@ -213,6 +259,7 @@ checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", "crypto-common", + "subtle", ] [[package]] @@ -384,7 +431,7 @@ name = "fuzz" version = "0.1.0" dependencies = [ "afl", - "rxml", + "rxml 0.9.1", "sha256", "tokio", "xmpp-proxy", @@ -460,6 +507,15 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "home" version = "0.5.9" @@ -634,6 +690,17 @@ version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" +[[package]] +name = "jid" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2cc0defda507f1e140ce2c1c670565c7c0b9bda8ae994a3c0478a53b5279b46b" +dependencies = [ + "memchr", + "minidom", + "stringprep", +] + [[package]] name = "js-sys" version = "0.3.66" @@ -643,6 +710,15 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "keccak" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc2af9a1119c51f12a14607e783cb977bde58bc069ff0c3da1095e635d70654" +dependencies = [ + "cpufeatures", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -719,6 +795,15 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "minidom" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e394a0e3c7ccc2daea3dffabe82f09857b6b510cb25af87d54bf3e910ac1642d" +dependencies = [ + "rxml 0.11.1", +] + [[package]] name = "miniz_oxide" version = "0.7.1" @@ -751,6 +836,15 @@ dependencies = [ "memoffset", ] +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + [[package]] name = "num_cpus" version = "1.16.0" @@ -805,6 +899,15 @@ dependencies = [ "windows-targets 0.48.5", ] +[[package]] +name = "pbkdf2" +version = "0.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ed6a7761f76e3b9f92dfb0a60a6a6477c61024b775147ff0973a02653abaf2" +dependencies = [ + "digest", +] + [[package]] name = "percent-encoding" version = "2.3.1" @@ -976,7 +1079,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37b1ae8d9ac08420c66222fb9096fc5de435c3c48542bc5336c51892cffafb41" dependencies = [ "async-compression", - "base64", + "base64 0.21.5", "bytes", "encoding_rs", "futures-core", @@ -1115,7 +1218,7 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" dependencies = [ - "base64", + "base64 0.21.5", ] [[package]] @@ -1128,6 +1231,12 @@ dependencies = [ "untrusted 0.9.0", ] +[[package]] +name = "rustversion" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "955d28af4278de8121b7ebeb796b6a45735dc01436d898801014aced2773a3d6" + [[package]] name = "rxml" version = "0.9.1" @@ -1136,23 +1245,58 @@ checksum = "a98f186c7a2f3abbffb802984b7f1dfd65dac8be1aafdaabbca4137f53f0dff7" dependencies = [ "bytes", "pin-project-lite", - "rxml_validation", + "rxml_validation 0.9.1", "smartstring", "tokio", ] +[[package]] +name = "rxml" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc94b580d0f5a6b7a2d604e597513d3c673154b52ddeccd1d5c32360d945ee" +dependencies = [ + "bytes", + "pin-project-lite", + "rxml_validation 0.11.0", + "tokio", +] + [[package]] name = "rxml_validation" version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22a197350ece202f19a166d1ad6d9d6de145e1d2a8ef47db299abe164dbd7530" +[[package]] +name = "rxml_validation" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "826e80413b9a35e9d33217b3dcac04cf95f6559d15944b93887a08be5496c4a4" +dependencies = [ + "compact_str", +] + [[package]] name = "ryu" version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" +[[package]] +name = "sasl" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6777dddc8108d9f36afbb008bc15b18edab2d17a8664ef58380b9398460e4e30" +dependencies = [ + "base64 0.22.1", + "getrandom", + "hmac", + "pbkdf2", + "sha1", + "sha2", +] + [[package]] name = "schannel" version = "0.1.22" @@ -1294,6 +1438,16 @@ dependencies = [ "tokio", ] +[[package]] +name = "sha3" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75872d278a8f37ef87fa0ddbda7802605cb18344497949862c0d4dcb291eba60" +dependencies = [ + "digest", + "keccak", +] + [[package]] name = "signal-hook-registry" version = "1.4.1" @@ -1357,6 +1511,23 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +[[package]] +name = "stringprep" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b4df3d392d81bd458a8a621b8bffbd2302a12ffe288a9d931670948749463b1" +dependencies = [ + "unicode-bidi", + "unicode-normalization", + "unicode-properties", +] + +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + [[package]] name = "syn" version = "2.0.41" @@ -1472,6 +1643,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-stream" +version = "0.1.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "267ac89e0bec6e691e5813911606935d77c476ff49024f98abcea3e7b15e37af" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "tokio-tungstenite" version = "0.21.0" @@ -1498,6 +1680,25 @@ dependencies = [ "tracing", ] +[[package]] +name = "tokio-xmpp" +version = "4.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8f8ac72971d5c993971490239c6fa7306bf5e699d39c6bc2775bdb98d590c7c" +dependencies = [ + "bytes", + "futures", + "log", + "minidom", + "rand", + "rxml 0.11.1", + "sasl", + "tokio", + "tokio-stream", + "tokio-util", + "xmpp-parsers", +] + [[package]] name = "toml" version = "0.8.8" @@ -1668,6 +1869,12 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "unicode-properties" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52ea75f83c0137a9b98608359a5f1af8144876eb67bcb1ce837368e906a9f524" + [[package]] name = "untrusted" version = "0.7.1" @@ -2031,6 +2238,24 @@ version = "2.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "213b7324336b53d2414b2db8537e56544d981803139155afa84f76eeebb7a546" +[[package]] +name = "xmpp-parsers" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab9890d7a3df540a6a8a7384fa3db58be3a4685799e0271756b26213d3f67903" +dependencies = [ + "base64 0.22.1", + "blake2", + "chrono", + "digest", + "jid", + "minidom", + "sha1", + "sha2", + "sha3", + "xso", +] + [[package]] name = "xmpp-proxy" version = "1.0.0" @@ -2059,8 +2284,35 @@ dependencies = [ "tokio", "tokio-rustls", "tokio-tungstenite", + "tokio-util", + "tokio-xmpp", "toml", "trust-dns-resolver", "webpki-roots", "webtransport-quinn", ] + +[[package]] +name = "xso" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c1e554e5e6689a0ec1b62a3b5ed450a1bb45118553a9665f4ee2277d135ba83" +dependencies = [ + "base64 0.22.1", + "jid", + "minidom", + "rxml 0.11.1", + "xso_proc", +] + +[[package]] +name = "xso_proc" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef695815a751e37fc0d78b4b1c1d8b669db7cf826f39d5a8c95df396a54f09d6" +dependencies = [ + "proc-macro2", + "quote", + "rxml_validation 0.11.0", + "syn", +] diff --git a/Cargo.toml b/Cargo.toml index 372fe86..73d79be 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -74,8 +74,12 @@ webtransport-quinn = { version = "0.6", optional = true } # systemd dep nix = { version = "0.27", optional = true, default-features = false, features = ["socket"]} +# tokio-xmpp if you want a ServerConnector impl +tokio-xmpp = { version = "4.0", optional = true, default-features = false } +tokio-util = { version = "0.7", optional = true, features = ["codec"] } + [features] -default = ["c2s-incoming", "c2s-outgoing", "s2s-incoming", "s2s-outgoing", "tls", "quic", "websocket", "webtransport", "logging", "tls-ca-roots-native", "systemd"] +default = ["c2s-incoming", "c2s-outgoing", "s2s-incoming", "s2s-outgoing", "tls", "quic", "websocket", "webtransport", "logging", "tls-ca-roots-native", "systemd", "tokio-xmpp"] # you must pick one of these or the other, not both: todo: enable picking both and choosing at runtime # don't need either of these if only doing c2s-incoming @@ -105,6 +109,8 @@ webtransport = ["webtransport-quinn", "quic"] # webtransport requires quic logging = ["rand", "env_logger"] systemd = ["nix"] +tokio-xmpp = ["dep:tokio-xmpp", "dep:tokio-util", "outgoing"] + # enables unit tests that need network and therefore may be flaky net-test = [] diff --git a/src/common/mod.rs b/src/common/mod.rs index 3f1484e..2b52a4a 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -41,6 +41,7 @@ pub mod ca_roots; pub mod certs_key; pub mod stream_listener; +pub const DEFAULT_MAX_STANZA_SIZE_BYTES: usize = 262_144; pub const IN_BUFFER_SIZE: usize = 8192; pub const ALPN_XMPP_CLIENT: &[u8] = b"xmpp-client"; pub const ALPN_XMPP_SERVER: &[u8] = b"xmpp-server"; diff --git a/src/in_out.rs b/src/in_out.rs index 3988c0a..d411ef8 100644 --- a/src/in_out.rs +++ b/src/in_out.rs @@ -1,3 +1,6 @@ +use std::io::{Cursor, Error as IoError}; +use std::{pin::Pin, task::Poll}; + #[cfg(feature = "websocket")] use crate::websocket::{from_ws, to_ws_new, WsRd, WsWr}; use crate::{ @@ -7,6 +10,8 @@ use crate::{ stanzafilter::{StanzaFilter, StanzaReader}, }; use anyhow::{bail, Result}; +use futures_util::Future; +use futures_util::Stream; #[cfg(feature = "websocket")] use futures_util::{SinkExt, TryStreamExt}; use log::trace; @@ -34,7 +39,7 @@ impl StanzaWrite { //AsyncWrite(Box::new(tokio::io::BufWriter::with_capacity(8192, wr))) } - pub async fn write_all<'a>(&'a mut self, is_c2s: bool, buf: &'a [u8], end_of_first_tag: usize, client_addr: &'a str) -> Result<()> { + pub async fn write_all(&mut self, is_c2s: bool, buf: &[u8], end_of_first_tag: usize, client_addr: &str) -> Result<()> { match self { AsyncWrite(wr) => Ok(wr.write_all(buf).await?), #[cfg(feature = "websocket")] @@ -65,6 +70,14 @@ impl StanzaWrite { WebSocketClientWrite(ws) => Ok(ws.flush().await?), } } + + pub async fn shutdown(&mut self) -> Result<()> { + match self { + AsyncWrite(wr) => Ok(wr.shutdown().await?), + #[cfg(feature = "websocket")] + WebSocketClientWrite(ws) => Ok(ws.close().await?), + } + } } impl StanzaRead { @@ -114,10 +127,233 @@ impl StanzaRead { _ => bail!("invalid websocket message: {}", msg), // Binary or Pong } } else { - bail!("websocket stream ended") + return Ok(None); } } } } } } + +pub struct StanzaStream { + wr: StanzaWrite, + rd: StanzaRead, + + fut_next_stanza: Option, + + send_stream_open: bool, + stream_open: Vec, + + client_addr: String, + is_c2s: bool, + + filter: StanzaFilter, + wr_filter: Option, +} + +impl StanzaStream { + #[cfg(feature = "outgoing")] + pub async fn connect(domain: &str, is_c2s: bool) -> Result { + let ns = if is_c2s { "jabber:client" } else { "jabber:server" }; + let stream_open = format!(""); + Self::connect_open(domain, is_c2s, stream_open.as_bytes()).await + } + + #[cfg(feature = "outgoing")] + pub async fn connect_open(domain: &str, is_c2s: bool, stream_open: &[u8]) -> Result { + use crate::{ + common::{certs_key::CertsKey, outgoing::OutgoingConfig, DEFAULT_MAX_STANZA_SIZE_BYTES}, + context::Context, + srv::srv_connect, + }; + const ADDR: &str = "127.0.0.1"; + let mut context = Context::new("StanzaStream", ADDR.parse().expect("valid")); + + let mut in_filter = StanzaFilter::new(DEFAULT_MAX_STANZA_SIZE_BYTES); + let config = OutgoingConfig { + max_stanza_size_bytes: DEFAULT_MAX_STANZA_SIZE_BYTES, + certs_key: CertsKey::new(Err(anyhow::anyhow!("StanzaStream doesn't support client certs yet"))).into(), + }; + let (wr, rd, stream_open) = srv_connect(domain, is_c2s, stream_open, &mut in_filter, &mut context, config).await?; + Ok(StanzaStream::new(wr, rd, stream_open, ADDR.to_string(), is_c2s, in_filter)) + } + + pub fn new(wr: StanzaWrite, rd: StanzaRead, stream_open: Vec, client_addr: String, is_c2s: bool, filter: StanzaFilter) -> Self { + let async_write = matches!(wr, StanzaWrite::AsyncWrite(_)); + let wr_filter = if async_write { None } else { Some(filter.clone()) }; + Self { + wr, + rd, + send_stream_open: !stream_open.is_empty(), + stream_open, + client_addr, + is_c2s, + filter, + wr_filter, + fut_next_stanza: None, + } + } + + pub async fn next_stanza<'a>(&'a mut self) -> Result> { + if self.send_stream_open { + self.send_stream_open = false; + return Ok(Some((self.stream_open.as_slice(), 0))); + } + self.rd.next(&mut self.filter, self.client_addr.as_str(), &mut self.wr).await + } + + pub async fn write_stanzas(&mut self, buf: &[u8]) -> Result { + match self.wr_filter.as_mut() { + None => { + // we don't care about how many stanzas or anything + self.wr.write_all(self.is_c2s, buf, 0, self.client_addr.as_str()).await?; + Ok(buf.len()) + } + Some(wr_filter) => { + let mut rd = StanzaReader(Cursor::new(buf)); + let mut wrote = 0; + while let Some((buf, eoft)) = rd.next_eoft(wr_filter).await? { + self.wr.write_all(self.is_c2s, buf, eoft, self.client_addr.as_str()).await?; + wrote += buf.len(); + } + Ok(wrote) + } + } + } +} + +// todo: using Arc and .make_mut() and a wrapping struct can still return slices safely, and clone will only happen if someone keeps a reference, which is ideal +impl Stream for StanzaStream { + type Item = Result<(Vec, usize)>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + if self.send_stream_open { + self.send_stream_open = false; + // swap in an empty vec and send ours + let stream_open = std::mem::replace(&mut self.stream_open, Vec::new()); + return std::task::Poll::Ready(Some(Ok((stream_open, 0)))); + } + let future = self.next_stanza(); + let future = std::pin::pin!(future); + match future.poll(cx) { + std::task::Poll::Ready(res) => std::task::Poll::Ready(res.map(|r| r.map(|r| (r.0.to_vec(), r.1))).transpose()), + std::task::Poll::Pending => std::task::Poll::Pending, + } + } +} + +impl AsyncRead for StanzaStream { + fn poll_read(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> std::task::Poll> { + // todo: instead of waiting for a whole stanza, if self is AsyncRead, we could go directly to that and skip stanzafilter, problem is this would break Stream::poll_next and XmppStream::next_stanza, so maybe we need a different struct to do that? + // todo: instead of using our StanzaFilter and copying bytes from it, we could make one out of the buf? + let future = self.next_stanza(); + // self.fut_next_stanza = Some(future); + let future = std::pin::pin!(future); + match future.poll(cx) { + std::task::Poll::Ready(res) => { + if let Some((stanza, _)) = res.map_err(|e| IoError::other(e))? { + if stanza.len() >= buf.remaining() { + return std::task::Poll::Ready(Err(IoError::other(format!("stanza of length {} read but buffer of only {} supplied", stanza.len(), buf.remaining())))); + } + buf.put_slice(stanza); + } + return Poll::Ready(Ok(())); + } + std::task::Poll::Pending => { + // self.fut_next_stanza = Some(future); + std::task::Poll::Pending + } + } + } +} + +impl AsyncWrite for StanzaStream { + fn poll_write(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &[u8]) -> Poll> { + let future = self.write_stanzas(buf); + let future = std::pin::pin!(future); + match future.poll(cx) { + Poll::Ready(r) => r.map_err(|e| IoError::other(e)).into(), + Poll::Pending => Poll::Pending, + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + let future = self.wr.flush(); + let future = std::pin::pin!(future); + match future.poll(cx) { + Poll::Ready(r) => r.map_err(|e| IoError::other(e)).into(), + Poll::Pending => Poll::Pending, + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + let future = self.wr.shutdown(); + let future = std::pin::pin!(future); + match future.poll(cx) { + Poll::Ready(r) => r.map_err(|e| IoError::other(e)).into(), + Poll::Pending => Poll::Pending, + } + } +} + +#[cfg(test)] +mod test { + use anyhow::Result; + use std::{ + any::{Any, TypeId}, + io::Cursor, + }; + use tokio::io::AsyncReadExt; + + use crate::{common::to_str, stanzafilter::StanzaFilter}; + + use super::*; + + #[tokio::test] + async fn async_read() -> Result<()> { + let stream_open = br###" + "###; + let orig = br###"woo"###; + let rd = Cursor::new(orig.clone()); + // let wr = Cursor::new(&mut written[..]); + let mut wr = Cursor::new(Vec::new()); + + let mut stream = StanzaStream::new( + StanzaWrite::new(wr.clone()), + StanzaRead::new(rd), + stream_open.to_vec(), + "client-addr".to_string(), + true, + StanzaFilter::new(262_144), + ); + + let mut buf = [0u8; 262_144]; + let mut _total_size = 0; + while let Ok(n) = stream.read(&mut buf[..]).await { + if n == 0 { + break; + } + wr.write(&buf[0..n]).await?; + } + // match stream.wr { + // StanzaWrite::AsyncWrite(a) => { + // // let a = &a.as_ref() as &dyn Any; + // // let a = Box::leak(a); + // let a = &a as &dyn Any; + // println!("woo"); + // println!("typeid: '{:?}', cursor: '{:?}", a.type_id(), TypeId::of::>>()); + // let out = a.downcast_ref::>>().expect("must be Cursor>"); + // assert_eq!(out.get_ref(), orig); + // } + // WebSocketClientWrite(_) => panic!("impossible"), + // }; + drop(stream); + + let mut expected = stream_open.to_vec(); + expected.extend_from_slice(orig); + // assert_eq!(&wr.get_ref()[..], &expected[..]); + assert_eq!(to_str(&wr.get_ref()[..]), to_str(&expected[..])); + + Ok(()) + } +} diff --git a/src/lib.rs b/src/lib.rs index 4aab4aa..13054f1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -26,5 +26,8 @@ pub mod verify; #[cfg(all(feature = "nix", not(target_os = "windows")))] pub mod systemd; +#[cfg(feature = "tokio-xmpp")] +pub mod tokio_xmpp; + pub mod context; pub mod in_out; diff --git a/src/main.rs b/src/main.rs index c30f66d..a278d8d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,7 +5,7 @@ use log::{debug, info}; use serde_derive::Deserialize; use std::{ffi::OsString, fs::File, io::Read, iter::Iterator, path::Path, sync::Arc}; use tokio::{net::TcpListener, task::JoinHandle}; -use xmpp_proxy::common::{certs_key::CertsKey, Listener, SocketAddrPath, UdpListener}; +use xmpp_proxy::common::{certs_key::CertsKey, Listener, SocketAddrPath, UdpListener, DEFAULT_MAX_STANZA_SIZE_BYTES}; #[cfg(not(target_os = "windows"))] use tokio::net::UnixListener; @@ -36,7 +36,7 @@ struct Config { } fn default_max_stanza_size_bytes() -> usize { - 262_144 + DEFAULT_MAX_STANZA_SIZE_BYTES } impl Config { diff --git a/src/stanzafilter.rs b/src/stanzafilter.rs index 4a297f5..0230060 100644 --- a/src/stanzafilter.rs +++ b/src/stanzafilter.rs @@ -5,7 +5,7 @@ use anyhow::{bail, Result}; use StanzaState::*; -#[derive(Debug)] +#[derive(Debug, Clone)] enum StanzaState { OutsideStanza, StanzaFirstChar, @@ -20,6 +20,7 @@ enum StanzaState { EndStream, } +#[derive(Clone)] pub struct StanzaFilter { buf_size: usize, pub buf: Vec, @@ -214,7 +215,7 @@ impl StanzaReader { } } - pub async fn next_eoft<'a>(&'a mut self, filter: &'a mut StanzaFilter) -> Result> { + pub async fn next_eoft<'a>(&mut self, filter: &'a mut StanzaFilter) -> Result> { use tokio::io::AsyncReadExt; loop { diff --git a/src/tokio_xmpp.rs b/src/tokio_xmpp.rs new file mode 100644 index 0000000..c9335b8 --- /dev/null +++ b/src/tokio_xmpp.rs @@ -0,0 +1,77 @@ +use std::{fmt::Display, sync::Arc}; + +use futures_util::StreamExt; +use tokio_util::codec::Framed; +pub use tokio_xmpp::*; +use xmpp_stream::XMPPStream; + +use crate::{common::certs_key::CertsKey, context::Context, in_out::StanzaStream, srv::srv_connect, stanzafilter::StanzaFilter}; + +#[derive(Clone, Debug)] +pub struct XmppProxyServerConnectorError(Arc); + +impl Display for XmppProxyServerConnectorError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +impl From for XmppProxyServerConnectorError { + fn from(value: anyhow::Error) -> Self { + Self(value.into()) + } +} + +impl From for XmppProxyServerConnectorError { + fn from(value: tokio_xmpp::Error) -> Self { + Self(anyhow::Error::from(value).into()) + } +} + +impl std::error::Error for XmppProxyServerConnectorError {} + +impl connect::ServerConnectorError for XmppProxyServerConnectorError {} + +#[derive(Clone, Debug)] +pub struct XmppProxyServerConnector; + +impl connect::ServerConnector for XmppProxyServerConnector { + type Stream = StanzaStream; + type Error = XmppProxyServerConnectorError; + + async fn connect(&self, jid: &jid::Jid, ns: &str) -> Result, Self::Error> { + let domain = jid.domain(); + let is_c2s = ns == "jabber:client"; + let stanza_stream = StanzaStream::connect(domain, is_c2s).await?; + let mut stanza_stream = Framed::new(stanza_stream, XmppCodec::new()); + let stream_attrs; + loop { + match stanza_stream.next().await { + Some(Ok(Packet::StreamStart(attrs))) => { + stream_attrs = attrs; + break; + } + Some(Ok(_)) => {} + Some(Err(e)) => return Err(e.into()), + None => return Err(Error::Disconnected.into()), + } + } + + let stream_id = stream_attrs.get("id").ok_or(ProtocolError::NoStreamId).unwrap().clone(); + let stream_features; + loop { + match stanza_stream.next().await { + Some(Ok(Packet::Stanza(stanza))) if stanza.is("features", tokio_xmpp::parsers::ns::STREAM) => { + stream_features = stanza; + break; + } + Some(Ok(_)) => {} + Some(Err(e)) => return Err(e.into()), + None => return Err(Error::Disconnected.into()), + } + } + let xmpp_stream = XMPPStream::new(jid.clone(), stanza_stream, ns.to_string(), stream_id, stream_features); + + Ok(xmpp_stream) + } +}