Re-factor stanza filter and add tests

This commit is contained in:
Travis Burtrum 2021-04-15 00:49:52 -04:00
parent 2ef391c224
commit e0d2f89c1e
3 changed files with 207 additions and 147 deletions

View File

@ -24,10 +24,13 @@ use anyhow::{bail, Result};
mod slicesubsequence;
use slicesubsequence::*;
mod stanzafilter;
use stanzafilter::*;
const IN_BUFFER_SIZE: usize = 8192;
const OUT_BUFFER_SIZE: usize = 8192;
const WHITESPACE: &[u8] = b" \t\n\r";
pub const WHITESPACE: &[u8] = b" \t\n\r";
#[cfg(debug_assertions)]
fn c2s(is_c2s: bool) -> &'static str {
@ -148,91 +151,83 @@ async fn handle_connection(mut stream: tokio::net::TcpStream, client_addr: Socke
// starttls
if !direct_tls {
let mut stream_open = Vec::new();
let mut proceed_sent = false;
let (in_rd, mut in_wr) = stream.split();
// we naively read 1 byte at a time, which buffering significantly speeds up
let mut in_rd = tokio::io::BufReader::with_capacity(IN_BUFFER_SIZE, in_rd);
let in_rd = tokio::io::BufReader::with_capacity(IN_BUFFER_SIZE, in_rd);
let mut in_rd = StanzaReader(in_rd);
while let Ok(n) = in_rd.read(in_filter.current_buf()).await {
if n == 0 {
bail!("stream ended before open");
}
if let Some(buf) = in_filter.process_next_byte()? {
debug!("received pre-tls stanza: {} '{}'", client_addr, to_str(&buf));
let buf = buf.trim_start(WHITESPACE);
if buf.starts_with(b"<?xml ") {
stream_open.extend_from_slice(buf);
continue;
} else if buf.starts_with(b"<stream:stream ") {
debug!("> {} '{}'", client_addr, to_str(&stream_open));
in_wr.write_all(&stream_open).await?;
stream_open.clear();
// gajim seems to REQUIRE an id here...
let buf = if buf.contains_seq(b"id=") {
buf.replace_first(b" id='", b" id='xmpp-proxy")
.replace_first(br#" id=""#, br#" id="xmpp-proxy"#)
.replace_first(b" to=", br#" bla toblala="#)
.replace_first(b" from=", b" to=")
.replace_first(br#" bla toblala="#, br#" from="#)
} else {
buf.replace_first(b" to=", br#" bla toblala="#)
.replace_first(b" from=", b" to=")
.replace_first(br#" bla toblala="#, br#" id='xmpp-proxy' from="#)
};
debug!("> {} '{}'", client_addr, to_str(&buf));
in_wr.write_all(&buf).await?;
// ejabberd never sends <starttls/> with the first, only the second?
//let buf = br###"<features xmlns="http://etherx.jabber.org/streams"><starttls xmlns="urn:ietf:params:xml:ns:xmpp-tls"><required/></starttls></features>"###;
let buf = br###"<stream:features><starttls xmlns="urn:ietf:params:xml:ns:xmpp-tls"><required/></starttls></stream:features>"###;
debug!("> {} '{}'", client_addr, to_str(buf));
in_wr.write_all(buf).await?;
in_wr.flush().await?;
} else if buf.starts_with(b"<starttls ") {
let buf = br###"<proceed xmlns="urn:ietf:params:xml:ns:xmpp-tls" />"###;
debug!("> {} '{}'", client_addr, to_str(buf));
in_wr.write_all(buf).await?;
in_wr.flush().await?;
break;
while let Ok(Some(buf)) = in_rd.next(&mut in_filter).await {
debug!("received pre-tls stanza: {} '{}'", client_addr, to_str(&buf));
let buf = buf.trim_start(WHITESPACE);
if buf.starts_with(b"<?xml ") {
debug!("> {} '{}'", client_addr, to_str(&buf));
in_wr.write_all(&buf).await?;
in_wr.flush().await?;
} else if buf.starts_with(b"<stream:stream ") {
// gajim seems to REQUIRE an id here...
let buf = if buf.contains_seq(b"id=") {
buf.replace_first(b" id='", b" id='xmpp-proxy")
.replace_first(br#" id=""#, br#" id="xmpp-proxy"#)
.replace_first(b" to=", br#" bla toblala="#)
.replace_first(b" from=", b" to=")
.replace_first(br#" bla toblala="#, br#" from="#)
} else {
bail!("bad pre-tls stanza: {}", to_str(&buf));
}
buf.replace_first(b" to=", br#" bla toblala="#)
.replace_first(b" from=", b" to=")
.replace_first(br#" bla toblala="#, br#" id='xmpp-proxy' from="#)
};
debug!("> {} '{}'", client_addr, to_str(&buf));
in_wr.write_all(&buf).await?;
// ejabberd never sends <starttls/> with the first, only the second?
//let buf = br###"<features xmlns="http://etherx.jabber.org/streams"><starttls xmlns="urn:ietf:params:xml:ns:xmpp-tls"><required/></starttls></features>"###;
let buf = br###"<stream:features><starttls xmlns="urn:ietf:params:xml:ns:xmpp-tls"><required/></starttls></stream:features>"###;
debug!("> {} '{}'", client_addr, to_str(buf));
in_wr.write_all(buf).await?;
in_wr.flush().await?;
} else if buf.starts_with(b"<starttls ") {
let buf = br###"<proceed xmlns="urn:ietf:params:xml:ns:xmpp-tls" />"###;
debug!("> {} '{}'", client_addr, to_str(buf));
in_wr.write_all(buf).await?;
in_wr.flush().await?;
proceed_sent = true;
break;
} else {
bail!("bad pre-tls stanza: {}", to_str(&buf));
}
}
if !proceed_sent {
bail!("stream ended before open");
}
}
let stream = config.acceptor.accept(stream).await?;
let (in_rd, mut in_wr) = tokio::io::split(stream);
// we naively read 1 byte at a time, which buffering significantly speeds up
let mut in_rd = tokio::io::BufReader::with_capacity(IN_BUFFER_SIZE, in_rd);
let in_rd = tokio::io::BufReader::with_capacity(IN_BUFFER_SIZE, in_rd);
let mut in_rd = StanzaReader(in_rd);
// now read to figure out client vs server
let (stream_open, is_c2s) = {
let mut stream_open = Vec::new();
let mut ret = None;
while let Ok(n) = in_rd.read(in_filter.current_buf()).await {
if n == 0 {
bail!("stream ended before open");
}
if let Some(buf) = in_filter.process_next_byte()? {
debug!("received pre-<stream:stream> stanza: {} '{}'", client_addr, to_str(&buf));
let buf = buf.trim_start(WHITESPACE);
if buf.starts_with(b"<?xml ") {
stream_open.extend_from_slice(buf);
continue;
} else if buf.starts_with(b"<stream:stream ") {
stream_open.extend_from_slice(buf);
//return (stream_open, stanza.contains(r#" xmlns="jabber:client""#) || stanza.contains(r#" xmlns='jabber:client'"#));
ret = Some((stream_open, buf.contains_seq(br#" xmlns="jabber:client""#) || buf.contains_seq(br#" xmlns='jabber:client'"#)));
break;
} else {
bail!("bad pre-<stream:stream> stanza: {}", to_str(&buf));
}
while let Ok(Some(buf)) = in_rd.next(&mut in_filter).await {
debug!("received pre-<stream:stream> stanza: {} '{}'", client_addr, to_str(&buf));
let buf = buf.trim_start(WHITESPACE);
if buf.starts_with(b"<?xml ") {
stream_open.extend_from_slice(buf);
} else if buf.starts_with(b"<stream:stream ") {
stream_open.extend_from_slice(buf);
//return (stream_open, stanza.contains(r#" xmlns="jabber:client""#) || stanza.contains(r#" xmlns='jabber:client'"#));
ret = Some((stream_open, buf.contains_seq(br#" xmlns="jabber:client""#) || buf.contains_seq(br#" xmlns='jabber:client'"#)));
break;
} else {
bail!("bad pre-<stream:stream> stanza: {}", to_str(&buf));
}
}
if ret.is_some() {
@ -281,14 +276,14 @@ async fn handle_connection(mut stream: tokio::net::TcpStream, client_addr: Socke
loop {
tokio::select! {
Ok(n) = in_rd.read(in_filter.current_buf()) => {
if n == 0 {
break;
}
if let Some(buf) = in_filter.process_next_byte()? {
debug!("< {} {} '{}'", client_addr, c2s(is_c2s), to_str(buf));
out_wr.write_all(buf).await?;
out_wr.flush().await?;
Ok(buf) = in_rd.next(&mut in_filter) => {
match buf {
None => break,
Some(buf) => {
debug!("< {} {} '{}'", client_addr, c2s(is_c2s), to_str(buf));
out_wr.write_all(buf).await?;
out_wr.flush().await?;
}
}
},
// we could filter outgoing from-server stanzas by size here too by doing same as above
@ -308,6 +303,11 @@ async fn handle_connection(mut stream: tokio::net::TcpStream, client_addr: Socke
Ok(())
}
/*
async fn handle_connection(mut stream: tokio::net::TcpStream, client_addr: SocketAddr, local_addr: SocketAddr, config: CloneableConfig) -> Result<()> {
Ok(())
}
*/
fn spawn_listener(listener: TcpListener, config: CloneableConfig) -> JoinHandle<Result<()>> {
let local_addr = listener.local_addr().die("could not get local_addr?");
tokio::spawn(async move {
@ -339,74 +339,3 @@ async fn main() {
}
futures::future::join_all(handles).await;
}
struct StanzaFilter {
buf_size: usize,
buf: Vec<u8>,
cnt: usize,
tag_cnt: usize,
last_char_was_lt: bool,
last_char_was_backslash: bool,
}
impl StanzaFilter {
pub fn new(buf_size: usize) -> StanzaFilter {
StanzaFilter {
buf_size,
buf: vec![0u8; buf_size],
cnt: 0,
tag_cnt: 0,
last_char_was_lt: false,
last_char_was_backslash: false,
}
}
#[inline(always)]
pub fn current_buf(&mut self) -> &mut [u8] {
&mut self.buf[self.cnt..(self.cnt + 1)]
}
pub fn process_next_byte(&mut self) -> Result<Option<&[u8]>> {
//println!("n: {}", n);
let b = self.buf[self.cnt];
if b == b'<' {
self.tag_cnt += 1;
self.last_char_was_lt = true;
} else {
if b == b'/' {
// if last_char_was_lt but tag_cnt < 2, should only be </stream:stream>
if self.last_char_was_lt && self.tag_cnt >= 2 {
// non-self-closing tag
self.tag_cnt -= 2;
}
self.last_char_was_backslash = true;
} else {
if b == b'>' {
if self.last_char_was_backslash {
// self-closing tag
self.tag_cnt -= 1;
}
// now special case some tags we want to send stand-alone:
if self.tag_cnt == 1 && self.cnt >= 15 && (b"<?xml" == &self.buf[0..5] || b"<stream:stream" == &self.buf[0..14] || b"</stream:stream" == &self.buf[0..15]) {
self.tag_cnt = 0; // to fall through to next logic
}
if self.tag_cnt == 0 {
let ret = Ok(Some(&self.buf[0..(self.cnt + 1)]));
self.cnt = 0;
self.last_char_was_backslash = false;
self.last_char_was_lt = false;
return ret;
}
}
self.last_char_was_backslash = false;
}
self.last_char_was_lt = false;
}
//println!("b: '{}', cnt: {}, tag_cnt: {}, self.buf.len(): {}", b as char, self.cnt, self.tag_cnt, self.buf.len());
self.cnt += 1;
if self.cnt == self.buf_size {
bail!("stanza too big: {}", to_str(&self.buf));
}
Ok(None)
}
}

View File

@ -67,7 +67,7 @@ impl<T: PartialEq + Clone> SliceSubsequence<T> for Vec<T> {
#[cfg(test)]
mod tests {
use crate::slicesubsequence::*;
const WHITESPACE: &[u8] = b" \t\n\r";
use crate::WHITESPACE;
#[test]
fn trim_start() {

131
src/stanzafilter.rs Normal file
View File

@ -0,0 +1,131 @@
use anyhow::{bail, Result};
use crate::to_str;
pub struct StanzaFilter {
buf_size: usize,
pub buf: Vec<u8>,
cnt: usize,
tag_cnt: usize,
last_char_was_lt: bool,
last_char_was_backslash: bool,
}
impl StanzaFilter {
pub fn new(buf_size: usize) -> StanzaFilter {
StanzaFilter {
buf_size,
buf: vec![0u8; buf_size],
cnt: 0,
tag_cnt: 0,
last_char_was_lt: false,
last_char_was_backslash: false,
}
}
#[inline(always)]
pub fn current_buf(&mut self) -> &mut [u8] {
&mut self.buf[self.cnt..(self.cnt + 1)]
}
#[allow(dead_code)]
pub fn process_next_byte(&mut self) -> Result<Option<&[u8]>> {
if let Some(idx) = self.process_next_byte_idx()? {
return Ok(Some(&self.buf[0..idx]));
}
Ok(None)
}
pub fn process_next_byte_idx(&mut self) -> Result<Option<usize>> {
//println!("n: {}", n);
let b = self.buf[self.cnt];
if b == b'<' {
self.tag_cnt += 1;
self.last_char_was_lt = true;
} else {
if b == b'/' {
// if last_char_was_lt but tag_cnt < 2, should only be </stream:stream>
if self.last_char_was_lt && self.tag_cnt >= 2 {
// non-self-closing tag
self.tag_cnt -= 2;
}
self.last_char_was_backslash = true;
} else {
if b == b'>' {
if self.last_char_was_backslash {
// self-closing tag
self.tag_cnt -= 1;
}
// now special case some tags we want to send stand-alone:
if self.tag_cnt == 1 && self.cnt >= 15 && (b"<?xml" == &self.buf[0..5] || b"<stream:stream" == &self.buf[0..14] || b"</stream:stream" == &self.buf[0..15]) {
self.tag_cnt = 0; // to fall through to next logic
}
if self.tag_cnt == 0 {
//let ret = Ok(Some(&self.buf[0..(self.cnt + 1)]));
let ret = Ok(Some(self.cnt + 1));
self.cnt = 0;
self.last_char_was_backslash = false;
self.last_char_was_lt = false;
return ret;
}
}
self.last_char_was_backslash = false;
}
self.last_char_was_lt = false;
}
//println!("b: '{}', cnt: {}, tag_cnt: {}, self.buf.len(): {}", b as char, self.cnt, self.tag_cnt, self.buf.len());
self.cnt += 1;
if self.cnt == self.buf_size {
bail!("stanza too big: {}", to_str(&self.buf));
}
Ok(None)
}
}
// this would be better as an async trait, but that doesn't work yet...
pub struct StanzaReader<T>(pub T);
impl<T: tokio::io::AsyncRead + Unpin> StanzaReader<T> {
pub async fn next<'a>(&'a mut self, filter: &'a mut StanzaFilter) -> Result<Option<&'a [u8]>> {
use tokio::io::AsyncReadExt;
loop {
let n = self.0.read(filter.current_buf()).await?;
if n == 0 {
return Ok(None);
}
if let Some(idx) = filter.process_next_byte_idx()? {
return Ok(Some(&filter.buf[0..idx]));
}
}
}
}
#[cfg(test)]
mod tests {
use crate::stanzafilter::*;
use std::borrow::Cow;
use std::io::Cursor;
impl<T: tokio::io::AsyncRead + Unpin> StanzaReader<T> {
async fn next_str<'a>(&'a mut self, filter: &'a mut StanzaFilter) -> Cow<'_, str> {
to_str(self.next(filter).await.expect("was Err").expect("was None"))
}
}
#[tokio::test]
async fn process_next_byte() -> std::result::Result<(), anyhow::Error> {
let mut filter = StanzaFilter::new(262_144);
let xml_stream = Cursor::new(br###"<a/><b>inside b before c<c>inside c</c></b><d></d>"###);
let mut stanza_reader = StanzaReader(xml_stream);
assert_eq!(stanza_reader.next_str(&mut filter).await, "<a/>");
assert_eq!(stanza_reader.next_str(&mut filter).await, "<b>inside b before c<c>inside c</c></b>");
assert_eq!(stanza_reader.next_str(&mut filter).await, "<d></d>");
assert_eq!(stanza_reader.next(&mut filter).await?, None);
Ok(())
}
}