Complete rewrite of StanzaFilter, more tests, supports CDATA and more
moparisthebest/xmpp-proxy/pipeline/head This commit looks good Details

This commit is contained in:
Travis Burtrum 2021-04-16 01:16:19 -04:00
parent 3792d2234a
commit efadaf30d2
1 changed files with 146 additions and 41 deletions

View File

@ -1,14 +1,28 @@
use anyhow::{bail, Result}; use anyhow::{bail, Result};
use crate::stanzafilter::StanzaState::*;
use crate::to_str; use crate::to_str;
#[derive(Debug)]
enum StanzaState {
OutsideStanza,
StanzaFirstChar,
InsideTagFirstChar,
InsideTag,
BetweenTags,
ExclamationTag(usize),
InsideCDATA,
QuestionTag(usize),
InsideXmlTag,
EndStream,
}
pub struct StanzaFilter { pub struct StanzaFilter {
buf_size: usize, buf_size: usize,
pub buf: Vec<u8>, pub buf: Vec<u8>,
cnt: usize, cnt: usize,
tag_cnt: usize, tag_cnt: usize,
last_char_was_lt: bool, state: StanzaState,
last_char_was_backslash: bool,
} }
impl StanzaFilter { impl StanzaFilter {
@ -18,8 +32,7 @@ impl StanzaFilter {
buf: vec![0u8; buf_size], buf: vec![0u8; buf_size],
cnt: 0, cnt: 0,
tag_cnt: 0, tag_cnt: 0,
last_char_was_lt: false, state: OutsideStanza,
last_char_was_backslash: false,
} }
} }
@ -37,49 +50,122 @@ impl StanzaFilter {
} }
pub fn process_next_byte_idx(&mut self) -> Result<Option<usize>> { pub fn process_next_byte_idx(&mut self) -> Result<Option<usize>> {
//println!("n: {}", n);
let b = self.buf[self.cnt]; let b = self.buf[self.cnt];
if b == b'<' { //print!("b: '{}', cnt: {}, tag_cnt: {}, state: {:?}; ", b as char, self.cnt, self.tag_cnt, self.state);
self.tag_cnt += 1; match self.state {
self.last_char_was_lt = true; OutsideStanza => {
} else { if b == b'<' {
if b == b'/' { self.tag_cnt += 1;
// if last_char_was_lt but tag_cnt < 2, should only be </stream:stream> self.state = StanzaFirstChar;
if self.last_char_was_lt && self.tag_cnt >= 2 { } else {
// non-self-closing tag // outside of stanzas, let's ignore all characters except <
self.tag_cnt -= 2; // prosody does this, and since things do whitespace pings, it's good
return Ok(None);
} }
self.last_char_was_backslash = true; }
} else { BetweenTags => {
if b == b'<' {
self.tag_cnt += 1;
self.state = InsideTagFirstChar;
}
}
StanzaFirstChar => match b {
b'/' => self.state = EndStream,
b'!' => bail!("illegal stanza: {}", to_str(&self.buf[..(self.cnt + 1)])),
b'?' => self.state = QuestionTag(self.cnt + 4), // 4 is length of b"xml "
_ => self.state = InsideTag,
},
InsideTagFirstChar => match b {
b'/' => self.tag_cnt -= 2,
b'!' => self.state = ExclamationTag(self.cnt + 7), // 7 is length of b"[CDATA["
b'?' => bail!("illegal stanza: {}", to_str(&self.buf[..(self.cnt + 1)])),
_ => self.state = InsideTag,
},
InsideTag => {
if b == b'>' { if b == b'>' {
if self.last_char_was_backslash { if self.buf[self.cnt - 1] == b'/' {
// state can't be InsideTag unless we are on at least the second character, so can't go out of range
// self-closing tag // self-closing tag
self.tag_cnt -= 1; self.tag_cnt -= 1;
} }
// now special case some tags we want to send stand-alone:
if self.tag_cnt == 1 && self.cnt >= 15 && (b"<?xml" == &self.buf[0..5] || b"<stream:stream" == &self.buf[0..14] || b"</stream:stream" == &self.buf[0..15]) {
self.tag_cnt = 0; // to fall through to next logic
}
if self.tag_cnt == 0 { if self.tag_cnt == 0 {
//let ret = Ok(Some(&self.buf[0..(self.cnt + 1)])); return self.stanza_end();
let ret = Ok(Some(self.cnt + 1)); }
self.cnt = 0; // now special case <stream:stream ...> which we want to send stand-alone:
self.last_char_was_backslash = false; if self.tag_cnt == 1 && self.buf.len() >= 15 && b"<stream:stream " == &self.buf[0..15] {
self.last_char_was_lt = false; return self.stanza_end();
return ret; }
self.state = BetweenTags;
}
}
QuestionTag(idx) => {
if idx == self.cnt {
if self.last_equals(b"xml ")? {
self.state = InsideXmlTag;
} else {
bail!("illegal stanza: {}", to_str(&self.buf[..(self.cnt + 1)]));
}
}
}
InsideXmlTag => {
if b == b'>' {
return self.stanza_end();
}
}
ExclamationTag(idx) => {
if idx == self.cnt {
if self.last_equals(b"[CDATA[")? {
self.state = InsideCDATA;
self.tag_cnt -= 1; // cdata not a tag
} else {
bail!("illegal stanza: {}", to_str(&self.buf[..(self.cnt + 1)]));
}
}
}
InsideCDATA => {
if b == b'>' && self.last_equals(b"]]>")? {
self.state = BetweenTags;
}
}
EndStream => {
if b == b'>' {
if self.last_equals(b"</stream:stream>")? {
return self.stanza_end();
} else {
bail!("illegal stanza: {}", to_str(&self.buf[..(self.cnt + 1)]));
} }
} }
self.last_char_was_backslash = false;
} }
self.last_char_was_lt = false;
} }
//println!("b: '{}', cnt: {}, tag_cnt: {}, self.buf.len(): {}", b as char, self.cnt, self.tag_cnt, self.buf.len()); //println!("cnt: {}, tag_cnt: {}, state: {:?}", self.cnt, self.tag_cnt, self.state);
self.cnt += 1; self.cnt += 1;
if self.cnt == self.buf_size { if self.cnt == self.buf_size {
bail!("stanza too big: {}", to_str(&self.buf)); bail!("stanza too big: {}", to_str(&self.buf));
} }
Ok(None) Ok(None)
} }
fn stanza_end(&mut self) -> Result<Option<usize>> {
let ret = Ok(Some(self.cnt + 1));
self.tag_cnt = 0;
self.cnt = 0;
self.state = OutsideStanza;
//println!("cnt: {}, tag_cnt: {}, state: {:?}", self.cnt, self.tag_cnt, self.state);
return ret;
}
fn last_equals(&self, needle: &[u8]) -> Result<bool> {
Ok(needle == self.last_num_bytes(needle.len())?)
}
fn last_num_bytes(&self, num: usize) -> Result<&[u8]> {
let num = num - 1;
if num <= self.cnt {
Ok(&self.buf[(self.cnt - num)..(self.cnt + 1)])
} else {
bail!("expected {} bytes only have {} bytes", num, (self.cnt + 1))
}
}
} }
// this would be better as an async trait, but that doesn't work yet... // this would be better as an async trait, but that doesn't work yet...
@ -104,12 +190,15 @@ impl<T: tokio::io::AsyncRead + Unpin> StanzaReader<T> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::stanzafilter::*; use crate::stanzafilter::*;
use std::borrow::Cow;
use std::io::Cursor; use std::io::Cursor;
impl<T: tokio::io::AsyncRead + Unpin> StanzaReader<T> { impl<T: tokio::io::AsyncRead + Unpin> StanzaReader<T> {
async fn next_str<'a>(&'a mut self, filter: &'a mut StanzaFilter) -> Cow<'_, str> { async fn to_vec<'a>(&'a mut self, filter: &'a mut StanzaFilter) -> Result<Vec<String>> {
to_str(self.next(filter).await.expect("was Err").expect("was None")) let mut ret = Vec::new();
while let Some(stanza) = self.next(filter).await? {
ret.push(to_str(stanza).to_string());
}
return Ok(ret);
} }
} }
@ -117,14 +206,30 @@ mod tests {
async fn process_next_byte() -> std::result::Result<(), anyhow::Error> { async fn process_next_byte() -> std::result::Result<(), anyhow::Error> {
let mut filter = StanzaFilter::new(262_144); let mut filter = StanzaFilter::new(262_144);
let xml_stream = Cursor::new(br###"<a/><b>inside b before c<c>inside c</c></b><d></d>"###); assert_eq!(
StanzaReader(Cursor::new(
let mut stanza_reader = StanzaReader(xml_stream); br###"
<?xml version='1.0'?>
assert_eq!(stanza_reader.next_str(&mut filter).await, "<a/>"); <stream:stream xmlns='jabber:server' xmlns:stream='http://etherx.jabber.org/streams' xmlns:db='jabber:server:dialback' version='1.0' to='example.org' from='example.com' xml:lang='en'>
assert_eq!(stanza_reader.next_str(&mut filter).await, "<b>inside b before c<c>inside c</c></b>"); <a/><b>inside b before c<c>inside c</c></b></stream:stream>
assert_eq!(stanza_reader.next_str(&mut filter).await, "<d></d>"); <q>bla<![CDATA[<this>is</not><xml/>]]>bloo</q>
assert_eq!(stanza_reader.next(&mut filter).await?, None); <d></d><e><![CDATA[what]>]]]]></e></stream:stream>
"###,
))
.to_vec(&mut filter)
.await?,
vec![
"<?xml version='1.0'?>",
"<stream:stream xmlns='jabber:server' xmlns:stream='http://etherx.jabber.org/streams' xmlns:db='jabber:server:dialback' version='1.0' to='example.org' from='example.com' xml:lang='en'>",
"<a/>",
"<b>inside b before c<c>inside c</c></b>",
"</stream:stream>",
"<q>bla<![CDATA[<this>is</not><xml/>]]>bloo</q>",
"<d></d>",
"<e><![CDATA[what]>]]]]></e>",
"</stream:stream>",
]
);
Ok(()) Ok(())
} }