From e0d2f89c1ec65b10d871530d1b353af97e6197cb Mon Sep 17 00:00:00 2001 From: moparisthebest Date: Thu, 15 Apr 2021 00:49:52 -0400 Subject: [PATCH] Re-factor stanza filter and add tests --- src/main.rs | 221 ++++++++++++++-------------------------- src/slicesubsequence.rs | 2 +- src/stanzafilter.rs | 131 ++++++++++++++++++++++++ 3 files changed, 207 insertions(+), 147 deletions(-) create mode 100644 src/stanzafilter.rs diff --git a/src/main.rs b/src/main.rs index 0c5a493..51ed9c1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -24,10 +24,13 @@ use anyhow::{bail, Result}; mod slicesubsequence; use slicesubsequence::*; +mod stanzafilter; +use stanzafilter::*; + const IN_BUFFER_SIZE: usize = 8192; const OUT_BUFFER_SIZE: usize = 8192; -const WHITESPACE: &[u8] = b" \t\n\r"; +pub const WHITESPACE: &[u8] = b" \t\n\r"; #[cfg(debug_assertions)] fn c2s(is_c2s: bool) -> &'static str { @@ -148,91 +151,83 @@ async fn handle_connection(mut stream: tokio::net::TcpStream, client_addr: Socke // starttls if !direct_tls { - let mut stream_open = Vec::new(); + let mut proceed_sent = false; let (in_rd, mut in_wr) = stream.split(); // we naively read 1 byte at a time, which buffering significantly speeds up - let mut in_rd = tokio::io::BufReader::with_capacity(IN_BUFFER_SIZE, in_rd); + let in_rd = tokio::io::BufReader::with_capacity(IN_BUFFER_SIZE, in_rd); + let mut in_rd = StanzaReader(in_rd); - while let Ok(n) = in_rd.read(in_filter.current_buf()).await { - if n == 0 { - bail!("stream ended before open"); - } - if let Some(buf) = in_filter.process_next_byte()? { - debug!("received pre-tls stanza: {} '{}'", client_addr, to_str(&buf)); - let buf = buf.trim_start(WHITESPACE); - if buf.starts_with(b" {} '{}'", client_addr, to_str(&stream_open)); - in_wr.write_all(&stream_open).await?; - stream_open.clear(); - - // gajim seems to REQUIRE an id here... - let buf = if buf.contains_seq(b"id=") { - buf.replace_first(b" id='", b" id='xmpp-proxy") - .replace_first(br#" id=""#, br#" id="xmpp-proxy"#) - .replace_first(b" to=", br#" bla toblala="#) - .replace_first(b" from=", b" to=") - .replace_first(br#" bla toblala="#, br#" from="#) - } else { - buf.replace_first(b" to=", br#" bla toblala="#) - .replace_first(b" from=", b" to=") - .replace_first(br#" bla toblala="#, br#" id='xmpp-proxy' from="#) - }; - - debug!("> {} '{}'", client_addr, to_str(&buf)); - in_wr.write_all(&buf).await?; - - // ejabberd never sends with the first, only the second? - //let buf = br###""###; - let buf = br###""###; - debug!("> {} '{}'", client_addr, to_str(buf)); - in_wr.write_all(buf).await?; - in_wr.flush().await?; - } else if buf.starts_with(b""###; - debug!("> {} '{}'", client_addr, to_str(buf)); - in_wr.write_all(buf).await?; - in_wr.flush().await?; - break; + while let Ok(Some(buf)) = in_rd.next(&mut in_filter).await { + debug!("received pre-tls stanza: {} '{}'", client_addr, to_str(&buf)); + let buf = buf.trim_start(WHITESPACE); + if buf.starts_with(b" {} '{}'", client_addr, to_str(&buf)); + in_wr.write_all(&buf).await?; + in_wr.flush().await?; + } else if buf.starts_with(b" {} '{}'", client_addr, to_str(&buf)); + in_wr.write_all(&buf).await?; + + // ejabberd never sends with the first, only the second? + //let buf = br###""###; + let buf = br###""###; + debug!("> {} '{}'", client_addr, to_str(buf)); + in_wr.write_all(buf).await?; + in_wr.flush().await?; + } else if buf.starts_with(b""###; + debug!("> {} '{}'", client_addr, to_str(buf)); + in_wr.write_all(buf).await?; + in_wr.flush().await?; + proceed_sent = true; + break; + } else { + bail!("bad pre-tls stanza: {}", to_str(&buf)); } } + if !proceed_sent { + bail!("stream ended before open"); + } } let stream = config.acceptor.accept(stream).await?; let (in_rd, mut in_wr) = tokio::io::split(stream); // we naively read 1 byte at a time, which buffering significantly speeds up - let mut in_rd = tokio::io::BufReader::with_capacity(IN_BUFFER_SIZE, in_rd); + let in_rd = tokio::io::BufReader::with_capacity(IN_BUFFER_SIZE, in_rd); + let mut in_rd = StanzaReader(in_rd); // now read to figure out client vs server let (stream_open, is_c2s) = { let mut stream_open = Vec::new(); let mut ret = None; - while let Ok(n) = in_rd.read(in_filter.current_buf()).await { - if n == 0 { - bail!("stream ended before open"); - } - if let Some(buf) = in_filter.process_next_byte()? { - debug!("received pre- stanza: {} '{}'", client_addr, to_str(&buf)); - let buf = buf.trim_start(WHITESPACE); - if buf.starts_with(b" stanza: {}", to_str(&buf)); - } + while let Ok(Some(buf)) = in_rd.next(&mut in_filter).await { + debug!("received pre- stanza: {} '{}'", client_addr, to_str(&buf)); + let buf = buf.trim_start(WHITESPACE); + if buf.starts_with(b" stanza: {}", to_str(&buf)); } } if ret.is_some() { @@ -281,14 +276,14 @@ async fn handle_connection(mut stream: tokio::net::TcpStream, client_addr: Socke loop { tokio::select! { - Ok(n) = in_rd.read(in_filter.current_buf()) => { - if n == 0 { - break; - } - if let Some(buf) = in_filter.process_next_byte()? { - debug!("< {} {} '{}'", client_addr, c2s(is_c2s), to_str(buf)); - out_wr.write_all(buf).await?; - out_wr.flush().await?; + Ok(buf) = in_rd.next(&mut in_filter) => { + match buf { + None => break, + Some(buf) => { + debug!("< {} {} '{}'", client_addr, c2s(is_c2s), to_str(buf)); + out_wr.write_all(buf).await?; + out_wr.flush().await?; + } } }, // we could filter outgoing from-server stanzas by size here too by doing same as above @@ -308,6 +303,11 @@ async fn handle_connection(mut stream: tokio::net::TcpStream, client_addr: Socke Ok(()) } +/* +async fn handle_connection(mut stream: tokio::net::TcpStream, client_addr: SocketAddr, local_addr: SocketAddr, config: CloneableConfig) -> Result<()> { + Ok(()) +} +*/ fn spawn_listener(listener: TcpListener, config: CloneableConfig) -> JoinHandle> { let local_addr = listener.local_addr().die("could not get local_addr?"); tokio::spawn(async move { @@ -339,74 +339,3 @@ async fn main() { } futures::future::join_all(handles).await; } - -struct StanzaFilter { - buf_size: usize, - buf: Vec, - cnt: usize, - tag_cnt: usize, - last_char_was_lt: bool, - last_char_was_backslash: bool, -} - -impl StanzaFilter { - pub fn new(buf_size: usize) -> StanzaFilter { - StanzaFilter { - buf_size, - buf: vec![0u8; buf_size], - cnt: 0, - tag_cnt: 0, - last_char_was_lt: false, - last_char_was_backslash: false, - } - } - - #[inline(always)] - pub fn current_buf(&mut self) -> &mut [u8] { - &mut self.buf[self.cnt..(self.cnt + 1)] - } - - pub fn process_next_byte(&mut self) -> Result> { - //println!("n: {}", n); - let b = self.buf[self.cnt]; - if b == b'<' { - self.tag_cnt += 1; - self.last_char_was_lt = true; - } else { - if b == b'/' { - // if last_char_was_lt but tag_cnt < 2, should only be - if self.last_char_was_lt && self.tag_cnt >= 2 { - // non-self-closing tag - self.tag_cnt -= 2; - } - self.last_char_was_backslash = true; - } else { - if b == b'>' { - if self.last_char_was_backslash { - // self-closing tag - self.tag_cnt -= 1; - } - // now special case some tags we want to send stand-alone: - if self.tag_cnt == 1 && self.cnt >= 15 && (b" SliceSubsequence for Vec { #[cfg(test)] mod tests { use crate::slicesubsequence::*; - const WHITESPACE: &[u8] = b" \t\n\r"; + use crate::WHITESPACE; #[test] fn trim_start() { diff --git a/src/stanzafilter.rs b/src/stanzafilter.rs new file mode 100644 index 0000000..277c895 --- /dev/null +++ b/src/stanzafilter.rs @@ -0,0 +1,131 @@ +use anyhow::{bail, Result}; + +use crate::to_str; + +pub struct StanzaFilter { + buf_size: usize, + pub buf: Vec, + cnt: usize, + tag_cnt: usize, + last_char_was_lt: bool, + last_char_was_backslash: bool, +} + +impl StanzaFilter { + pub fn new(buf_size: usize) -> StanzaFilter { + StanzaFilter { + buf_size, + buf: vec![0u8; buf_size], + cnt: 0, + tag_cnt: 0, + last_char_was_lt: false, + last_char_was_backslash: false, + } + } + + #[inline(always)] + pub fn current_buf(&mut self) -> &mut [u8] { + &mut self.buf[self.cnt..(self.cnt + 1)] + } + + #[allow(dead_code)] + pub fn process_next_byte(&mut self) -> Result> { + if let Some(idx) = self.process_next_byte_idx()? { + return Ok(Some(&self.buf[0..idx])); + } + Ok(None) + } + + pub fn process_next_byte_idx(&mut self) -> Result> { + //println!("n: {}", n); + let b = self.buf[self.cnt]; + if b == b'<' { + self.tag_cnt += 1; + self.last_char_was_lt = true; + } else { + if b == b'/' { + // if last_char_was_lt but tag_cnt < 2, should only be + if self.last_char_was_lt && self.tag_cnt >= 2 { + // non-self-closing tag + self.tag_cnt -= 2; + } + self.last_char_was_backslash = true; + } else { + if b == b'>' { + if self.last_char_was_backslash { + // self-closing tag + self.tag_cnt -= 1; + } + // now special case some tags we want to send stand-alone: + if self.tag_cnt == 1 && self.cnt >= 15 && (b"(pub T); + +impl StanzaReader { + pub async fn next<'a>(&'a mut self, filter: &'a mut StanzaFilter) -> Result> { + use tokio::io::AsyncReadExt; + + loop { + let n = self.0.read(filter.current_buf()).await?; + if n == 0 { + return Ok(None); + } + if let Some(idx) = filter.process_next_byte_idx()? { + return Ok(Some(&filter.buf[0..idx])); + } + } + } +} + +#[cfg(test)] +mod tests { + use crate::stanzafilter::*; + use std::borrow::Cow; + use std::io::Cursor; + + impl StanzaReader { + async fn next_str<'a>(&'a mut self, filter: &'a mut StanzaFilter) -> Cow<'_, str> { + to_str(self.next(filter).await.expect("was Err").expect("was None")) + } + } + + #[tokio::test] + async fn process_next_byte() -> std::result::Result<(), anyhow::Error> { + let mut filter = StanzaFilter::new(262_144); + + let xml_stream = Cursor::new(br###"inside b before cinside c"###); + + let mut stanza_reader = StanzaReader(xml_stream); + + assert_eq!(stanza_reader.next_str(&mut filter).await, ""); + assert_eq!(stanza_reader.next_str(&mut filter).await, "inside b before cinside c"); + assert_eq!(stanza_reader.next_str(&mut filter).await, ""); + assert_eq!(stanza_reader.next(&mut filter).await?, None); + + Ok(()) + } +}