Re-factor stanza filter and add tests
This commit is contained in:
parent
2ef391c224
commit
e0d2f89c1e
221
src/main.rs
221
src/main.rs
@ -24,10 +24,13 @@ use anyhow::{bail, Result};
|
||||
mod slicesubsequence;
|
||||
use slicesubsequence::*;
|
||||
|
||||
mod stanzafilter;
|
||||
use stanzafilter::*;
|
||||
|
||||
const IN_BUFFER_SIZE: usize = 8192;
|
||||
const OUT_BUFFER_SIZE: usize = 8192;
|
||||
|
||||
const WHITESPACE: &[u8] = b" \t\n\r";
|
||||
pub const WHITESPACE: &[u8] = b" \t\n\r";
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
fn c2s(is_c2s: bool) -> &'static str {
|
||||
@ -148,91 +151,83 @@ async fn handle_connection(mut stream: tokio::net::TcpStream, client_addr: Socke
|
||||
|
||||
// starttls
|
||||
if !direct_tls {
|
||||
let mut stream_open = Vec::new();
|
||||
let mut proceed_sent = false;
|
||||
|
||||
let (in_rd, mut in_wr) = stream.split();
|
||||
// we naively read 1 byte at a time, which buffering significantly speeds up
|
||||
let mut in_rd = tokio::io::BufReader::with_capacity(IN_BUFFER_SIZE, in_rd);
|
||||
let in_rd = tokio::io::BufReader::with_capacity(IN_BUFFER_SIZE, in_rd);
|
||||
let mut in_rd = StanzaReader(in_rd);
|
||||
|
||||
while let Ok(n) = in_rd.read(in_filter.current_buf()).await {
|
||||
if n == 0 {
|
||||
bail!("stream ended before open");
|
||||
}
|
||||
if let Some(buf) = in_filter.process_next_byte()? {
|
||||
debug!("received pre-tls stanza: {} '{}'", client_addr, to_str(&buf));
|
||||
let buf = buf.trim_start(WHITESPACE);
|
||||
if buf.starts_with(b"<?xml ") {
|
||||
stream_open.extend_from_slice(buf);
|
||||
continue;
|
||||
} else if buf.starts_with(b"<stream:stream ") {
|
||||
debug!("> {} '{}'", client_addr, to_str(&stream_open));
|
||||
in_wr.write_all(&stream_open).await?;
|
||||
stream_open.clear();
|
||||
|
||||
// gajim seems to REQUIRE an id here...
|
||||
let buf = if buf.contains_seq(b"id=") {
|
||||
buf.replace_first(b" id='", b" id='xmpp-proxy")
|
||||
.replace_first(br#" id=""#, br#" id="xmpp-proxy"#)
|
||||
.replace_first(b" to=", br#" bla toblala="#)
|
||||
.replace_first(b" from=", b" to=")
|
||||
.replace_first(br#" bla toblala="#, br#" from="#)
|
||||
} else {
|
||||
buf.replace_first(b" to=", br#" bla toblala="#)
|
||||
.replace_first(b" from=", b" to=")
|
||||
.replace_first(br#" bla toblala="#, br#" id='xmpp-proxy' from="#)
|
||||
};
|
||||
|
||||
debug!("> {} '{}'", client_addr, to_str(&buf));
|
||||
in_wr.write_all(&buf).await?;
|
||||
|
||||
// ejabberd never sends <starttls/> with the first, only the second?
|
||||
//let buf = br###"<features xmlns="http://etherx.jabber.org/streams"><starttls xmlns="urn:ietf:params:xml:ns:xmpp-tls"><required/></starttls></features>"###;
|
||||
let buf = br###"<stream:features><starttls xmlns="urn:ietf:params:xml:ns:xmpp-tls"><required/></starttls></stream:features>"###;
|
||||
debug!("> {} '{}'", client_addr, to_str(buf));
|
||||
in_wr.write_all(buf).await?;
|
||||
in_wr.flush().await?;
|
||||
} else if buf.starts_with(b"<starttls ") {
|
||||
let buf = br###"<proceed xmlns="urn:ietf:params:xml:ns:xmpp-tls" />"###;
|
||||
debug!("> {} '{}'", client_addr, to_str(buf));
|
||||
in_wr.write_all(buf).await?;
|
||||
in_wr.flush().await?;
|
||||
break;
|
||||
while let Ok(Some(buf)) = in_rd.next(&mut in_filter).await {
|
||||
debug!("received pre-tls stanza: {} '{}'", client_addr, to_str(&buf));
|
||||
let buf = buf.trim_start(WHITESPACE);
|
||||
if buf.starts_with(b"<?xml ") {
|
||||
debug!("> {} '{}'", client_addr, to_str(&buf));
|
||||
in_wr.write_all(&buf).await?;
|
||||
in_wr.flush().await?;
|
||||
} else if buf.starts_with(b"<stream:stream ") {
|
||||
// gajim seems to REQUIRE an id here...
|
||||
let buf = if buf.contains_seq(b"id=") {
|
||||
buf.replace_first(b" id='", b" id='xmpp-proxy")
|
||||
.replace_first(br#" id=""#, br#" id="xmpp-proxy"#)
|
||||
.replace_first(b" to=", br#" bla toblala="#)
|
||||
.replace_first(b" from=", b" to=")
|
||||
.replace_first(br#" bla toblala="#, br#" from="#)
|
||||
} else {
|
||||
bail!("bad pre-tls stanza: {}", to_str(&buf));
|
||||
}
|
||||
buf.replace_first(b" to=", br#" bla toblala="#)
|
||||
.replace_first(b" from=", b" to=")
|
||||
.replace_first(br#" bla toblala="#, br#" id='xmpp-proxy' from="#)
|
||||
};
|
||||
|
||||
debug!("> {} '{}'", client_addr, to_str(&buf));
|
||||
in_wr.write_all(&buf).await?;
|
||||
|
||||
// ejabberd never sends <starttls/> with the first, only the second?
|
||||
//let buf = br###"<features xmlns="http://etherx.jabber.org/streams"><starttls xmlns="urn:ietf:params:xml:ns:xmpp-tls"><required/></starttls></features>"###;
|
||||
let buf = br###"<stream:features><starttls xmlns="urn:ietf:params:xml:ns:xmpp-tls"><required/></starttls></stream:features>"###;
|
||||
debug!("> {} '{}'", client_addr, to_str(buf));
|
||||
in_wr.write_all(buf).await?;
|
||||
in_wr.flush().await?;
|
||||
} else if buf.starts_with(b"<starttls ") {
|
||||
let buf = br###"<proceed xmlns="urn:ietf:params:xml:ns:xmpp-tls" />"###;
|
||||
debug!("> {} '{}'", client_addr, to_str(buf));
|
||||
in_wr.write_all(buf).await?;
|
||||
in_wr.flush().await?;
|
||||
proceed_sent = true;
|
||||
break;
|
||||
} else {
|
||||
bail!("bad pre-tls stanza: {}", to_str(&buf));
|
||||
}
|
||||
}
|
||||
if !proceed_sent {
|
||||
bail!("stream ended before open");
|
||||
}
|
||||
}
|
||||
|
||||
let stream = config.acceptor.accept(stream).await?;
|
||||
|
||||
let (in_rd, mut in_wr) = tokio::io::split(stream);
|
||||
// we naively read 1 byte at a time, which buffering significantly speeds up
|
||||
let mut in_rd = tokio::io::BufReader::with_capacity(IN_BUFFER_SIZE, in_rd);
|
||||
let in_rd = tokio::io::BufReader::with_capacity(IN_BUFFER_SIZE, in_rd);
|
||||
let mut in_rd = StanzaReader(in_rd);
|
||||
|
||||
// now read to figure out client vs server
|
||||
let (stream_open, is_c2s) = {
|
||||
let mut stream_open = Vec::new();
|
||||
let mut ret = None;
|
||||
|
||||
while let Ok(n) = in_rd.read(in_filter.current_buf()).await {
|
||||
if n == 0 {
|
||||
bail!("stream ended before open");
|
||||
}
|
||||
if let Some(buf) = in_filter.process_next_byte()? {
|
||||
debug!("received pre-<stream:stream> stanza: {} '{}'", client_addr, to_str(&buf));
|
||||
let buf = buf.trim_start(WHITESPACE);
|
||||
if buf.starts_with(b"<?xml ") {
|
||||
stream_open.extend_from_slice(buf);
|
||||
continue;
|
||||
} else if buf.starts_with(b"<stream:stream ") {
|
||||
stream_open.extend_from_slice(buf);
|
||||
//return (stream_open, stanza.contains(r#" xmlns="jabber:client""#) || stanza.contains(r#" xmlns='jabber:client'"#));
|
||||
ret = Some((stream_open, buf.contains_seq(br#" xmlns="jabber:client""#) || buf.contains_seq(br#" xmlns='jabber:client'"#)));
|
||||
break;
|
||||
} else {
|
||||
bail!("bad pre-<stream:stream> stanza: {}", to_str(&buf));
|
||||
}
|
||||
while let Ok(Some(buf)) = in_rd.next(&mut in_filter).await {
|
||||
debug!("received pre-<stream:stream> stanza: {} '{}'", client_addr, to_str(&buf));
|
||||
let buf = buf.trim_start(WHITESPACE);
|
||||
if buf.starts_with(b"<?xml ") {
|
||||
stream_open.extend_from_slice(buf);
|
||||
} else if buf.starts_with(b"<stream:stream ") {
|
||||
stream_open.extend_from_slice(buf);
|
||||
//return (stream_open, stanza.contains(r#" xmlns="jabber:client""#) || stanza.contains(r#" xmlns='jabber:client'"#));
|
||||
ret = Some((stream_open, buf.contains_seq(br#" xmlns="jabber:client""#) || buf.contains_seq(br#" xmlns='jabber:client'"#)));
|
||||
break;
|
||||
} else {
|
||||
bail!("bad pre-<stream:stream> stanza: {}", to_str(&buf));
|
||||
}
|
||||
}
|
||||
if ret.is_some() {
|
||||
@ -281,14 +276,14 @@ async fn handle_connection(mut stream: tokio::net::TcpStream, client_addr: Socke
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
Ok(n) = in_rd.read(in_filter.current_buf()) => {
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
if let Some(buf) = in_filter.process_next_byte()? {
|
||||
debug!("< {} {} '{}'", client_addr, c2s(is_c2s), to_str(buf));
|
||||
out_wr.write_all(buf).await?;
|
||||
out_wr.flush().await?;
|
||||
Ok(buf) = in_rd.next(&mut in_filter) => {
|
||||
match buf {
|
||||
None => break,
|
||||
Some(buf) => {
|
||||
debug!("< {} {} '{}'", client_addr, c2s(is_c2s), to_str(buf));
|
||||
out_wr.write_all(buf).await?;
|
||||
out_wr.flush().await?;
|
||||
}
|
||||
}
|
||||
},
|
||||
// we could filter outgoing from-server stanzas by size here too by doing same as above
|
||||
@ -308,6 +303,11 @@ async fn handle_connection(mut stream: tokio::net::TcpStream, client_addr: Socke
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/*
|
||||
async fn handle_connection(mut stream: tokio::net::TcpStream, client_addr: SocketAddr, local_addr: SocketAddr, config: CloneableConfig) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
*/
|
||||
fn spawn_listener(listener: TcpListener, config: CloneableConfig) -> JoinHandle<Result<()>> {
|
||||
let local_addr = listener.local_addr().die("could not get local_addr?");
|
||||
tokio::spawn(async move {
|
||||
@ -339,74 +339,3 @@ async fn main() {
|
||||
}
|
||||
futures::future::join_all(handles).await;
|
||||
}
|
||||
|
||||
struct StanzaFilter {
|
||||
buf_size: usize,
|
||||
buf: Vec<u8>,
|
||||
cnt: usize,
|
||||
tag_cnt: usize,
|
||||
last_char_was_lt: bool,
|
||||
last_char_was_backslash: bool,
|
||||
}
|
||||
|
||||
impl StanzaFilter {
|
||||
pub fn new(buf_size: usize) -> StanzaFilter {
|
||||
StanzaFilter {
|
||||
buf_size,
|
||||
buf: vec![0u8; buf_size],
|
||||
cnt: 0,
|
||||
tag_cnt: 0,
|
||||
last_char_was_lt: false,
|
||||
last_char_was_backslash: false,
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn current_buf(&mut self) -> &mut [u8] {
|
||||
&mut self.buf[self.cnt..(self.cnt + 1)]
|
||||
}
|
||||
|
||||
pub fn process_next_byte(&mut self) -> Result<Option<&[u8]>> {
|
||||
//println!("n: {}", n);
|
||||
let b = self.buf[self.cnt];
|
||||
if b == b'<' {
|
||||
self.tag_cnt += 1;
|
||||
self.last_char_was_lt = true;
|
||||
} else {
|
||||
if b == b'/' {
|
||||
// if last_char_was_lt but tag_cnt < 2, should only be </stream:stream>
|
||||
if self.last_char_was_lt && self.tag_cnt >= 2 {
|
||||
// non-self-closing tag
|
||||
self.tag_cnt -= 2;
|
||||
}
|
||||
self.last_char_was_backslash = true;
|
||||
} else {
|
||||
if b == b'>' {
|
||||
if self.last_char_was_backslash {
|
||||
// self-closing tag
|
||||
self.tag_cnt -= 1;
|
||||
}
|
||||
// now special case some tags we want to send stand-alone:
|
||||
if self.tag_cnt == 1 && self.cnt >= 15 && (b"<?xml" == &self.buf[0..5] || b"<stream:stream" == &self.buf[0..14] || b"</stream:stream" == &self.buf[0..15]) {
|
||||
self.tag_cnt = 0; // to fall through to next logic
|
||||
}
|
||||
if self.tag_cnt == 0 {
|
||||
let ret = Ok(Some(&self.buf[0..(self.cnt + 1)]));
|
||||
self.cnt = 0;
|
||||
self.last_char_was_backslash = false;
|
||||
self.last_char_was_lt = false;
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
self.last_char_was_backslash = false;
|
||||
}
|
||||
self.last_char_was_lt = false;
|
||||
}
|
||||
//println!("b: '{}', cnt: {}, tag_cnt: {}, self.buf.len(): {}", b as char, self.cnt, self.tag_cnt, self.buf.len());
|
||||
self.cnt += 1;
|
||||
if self.cnt == self.buf_size {
|
||||
bail!("stanza too big: {}", to_str(&self.buf));
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
@ -67,7 +67,7 @@ impl<T: PartialEq + Clone> SliceSubsequence<T> for Vec<T> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::slicesubsequence::*;
|
||||
const WHITESPACE: &[u8] = b" \t\n\r";
|
||||
use crate::WHITESPACE;
|
||||
|
||||
#[test]
|
||||
fn trim_start() {
|
||||
|
131
src/stanzafilter.rs
Normal file
131
src/stanzafilter.rs
Normal file
@ -0,0 +1,131 @@
|
||||
use anyhow::{bail, Result};
|
||||
|
||||
use crate::to_str;
|
||||
|
||||
pub struct StanzaFilter {
|
||||
buf_size: usize,
|
||||
pub buf: Vec<u8>,
|
||||
cnt: usize,
|
||||
tag_cnt: usize,
|
||||
last_char_was_lt: bool,
|
||||
last_char_was_backslash: bool,
|
||||
}
|
||||
|
||||
impl StanzaFilter {
|
||||
pub fn new(buf_size: usize) -> StanzaFilter {
|
||||
StanzaFilter {
|
||||
buf_size,
|
||||
buf: vec![0u8; buf_size],
|
||||
cnt: 0,
|
||||
tag_cnt: 0,
|
||||
last_char_was_lt: false,
|
||||
last_char_was_backslash: false,
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub fn current_buf(&mut self) -> &mut [u8] {
|
||||
&mut self.buf[self.cnt..(self.cnt + 1)]
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn process_next_byte(&mut self) -> Result<Option<&[u8]>> {
|
||||
if let Some(idx) = self.process_next_byte_idx()? {
|
||||
return Ok(Some(&self.buf[0..idx]));
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
pub fn process_next_byte_idx(&mut self) -> Result<Option<usize>> {
|
||||
//println!("n: {}", n);
|
||||
let b = self.buf[self.cnt];
|
||||
if b == b'<' {
|
||||
self.tag_cnt += 1;
|
||||
self.last_char_was_lt = true;
|
||||
} else {
|
||||
if b == b'/' {
|
||||
// if last_char_was_lt but tag_cnt < 2, should only be </stream:stream>
|
||||
if self.last_char_was_lt && self.tag_cnt >= 2 {
|
||||
// non-self-closing tag
|
||||
self.tag_cnt -= 2;
|
||||
}
|
||||
self.last_char_was_backslash = true;
|
||||
} else {
|
||||
if b == b'>' {
|
||||
if self.last_char_was_backslash {
|
||||
// self-closing tag
|
||||
self.tag_cnt -= 1;
|
||||
}
|
||||
// now special case some tags we want to send stand-alone:
|
||||
if self.tag_cnt == 1 && self.cnt >= 15 && (b"<?xml" == &self.buf[0..5] || b"<stream:stream" == &self.buf[0..14] || b"</stream:stream" == &self.buf[0..15]) {
|
||||
self.tag_cnt = 0; // to fall through to next logic
|
||||
}
|
||||
if self.tag_cnt == 0 {
|
||||
//let ret = Ok(Some(&self.buf[0..(self.cnt + 1)]));
|
||||
let ret = Ok(Some(self.cnt + 1));
|
||||
self.cnt = 0;
|
||||
self.last_char_was_backslash = false;
|
||||
self.last_char_was_lt = false;
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
self.last_char_was_backslash = false;
|
||||
}
|
||||
self.last_char_was_lt = false;
|
||||
}
|
||||
//println!("b: '{}', cnt: {}, tag_cnt: {}, self.buf.len(): {}", b as char, self.cnt, self.tag_cnt, self.buf.len());
|
||||
self.cnt += 1;
|
||||
if self.cnt == self.buf_size {
|
||||
bail!("stanza too big: {}", to_str(&self.buf));
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
// this would be better as an async trait, but that doesn't work yet...
|
||||
pub struct StanzaReader<T>(pub T);
|
||||
|
||||
impl<T: tokio::io::AsyncRead + Unpin> StanzaReader<T> {
|
||||
pub async fn next<'a>(&'a mut self, filter: &'a mut StanzaFilter) -> Result<Option<&'a [u8]>> {
|
||||
use tokio::io::AsyncReadExt;
|
||||
|
||||
loop {
|
||||
let n = self.0.read(filter.current_buf()).await?;
|
||||
if n == 0 {
|
||||
return Ok(None);
|
||||
}
|
||||
if let Some(idx) = filter.process_next_byte_idx()? {
|
||||
return Ok(Some(&filter.buf[0..idx]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::stanzafilter::*;
|
||||
use std::borrow::Cow;
|
||||
use std::io::Cursor;
|
||||
|
||||
impl<T: tokio::io::AsyncRead + Unpin> StanzaReader<T> {
|
||||
async fn next_str<'a>(&'a mut self, filter: &'a mut StanzaFilter) -> Cow<'_, str> {
|
||||
to_str(self.next(filter).await.expect("was Err").expect("was None"))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_next_byte() -> std::result::Result<(), anyhow::Error> {
|
||||
let mut filter = StanzaFilter::new(262_144);
|
||||
|
||||
let xml_stream = Cursor::new(br###"<a/><b>inside b before c<c>inside c</c></b><d></d>"###);
|
||||
|
||||
let mut stanza_reader = StanzaReader(xml_stream);
|
||||
|
||||
assert_eq!(stanza_reader.next_str(&mut filter).await, "<a/>");
|
||||
assert_eq!(stanza_reader.next_str(&mut filter).await, "<b>inside b before c<c>inside c</c></b>");
|
||||
assert_eq!(stanza_reader.next_str(&mut filter).await, "<d></d>");
|
||||
assert_eq!(stanza_reader.next(&mut filter).await?, None);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user