diff --git a/src/main.rs b/src/main.rs
index 0c5a493..51ed9c1 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -24,10 +24,13 @@ use anyhow::{bail, Result};
mod slicesubsequence;
use slicesubsequence::*;
+mod stanzafilter;
+use stanzafilter::*;
+
const IN_BUFFER_SIZE: usize = 8192;
const OUT_BUFFER_SIZE: usize = 8192;
-const WHITESPACE: &[u8] = b" \t\n\r";
+pub const WHITESPACE: &[u8] = b" \t\n\r";
#[cfg(debug_assertions)]
fn c2s(is_c2s: bool) -> &'static str {
@@ -148,91 +151,83 @@ async fn handle_connection(mut stream: tokio::net::TcpStream, client_addr: Socke
// starttls
if !direct_tls {
- let mut stream_open = Vec::new();
+ let mut proceed_sent = false;
let (in_rd, mut in_wr) = stream.split();
// we naively read 1 byte at a time, which buffering significantly speeds up
- let mut in_rd = tokio::io::BufReader::with_capacity(IN_BUFFER_SIZE, in_rd);
+ let in_rd = tokio::io::BufReader::with_capacity(IN_BUFFER_SIZE, in_rd);
+ let mut in_rd = StanzaReader(in_rd);
- while let Ok(n) = in_rd.read(in_filter.current_buf()).await {
- if n == 0 {
- bail!("stream ended before open");
- }
- if let Some(buf) = in_filter.process_next_byte()? {
- debug!("received pre-tls stanza: {} '{}'", client_addr, to_str(&buf));
- let buf = buf.trim_start(WHITESPACE);
- if buf.starts_with(b" {} '{}'", client_addr, to_str(&stream_open));
- in_wr.write_all(&stream_open).await?;
- stream_open.clear();
-
- // gajim seems to REQUIRE an id here...
- let buf = if buf.contains_seq(b"id=") {
- buf.replace_first(b" id='", b" id='xmpp-proxy")
- .replace_first(br#" id=""#, br#" id="xmpp-proxy"#)
- .replace_first(b" to=", br#" bla toblala="#)
- .replace_first(b" from=", b" to=")
- .replace_first(br#" bla toblala="#, br#" from="#)
- } else {
- buf.replace_first(b" to=", br#" bla toblala="#)
- .replace_first(b" from=", b" to=")
- .replace_first(br#" bla toblala="#, br#" id='xmpp-proxy' from="#)
- };
-
- debug!("> {} '{}'", client_addr, to_str(&buf));
- in_wr.write_all(&buf).await?;
-
- // ejabberd never sends with the first, only the second?
- //let buf = br###""###;
- let buf = br###""###;
- debug!("> {} '{}'", client_addr, to_str(buf));
- in_wr.write_all(buf).await?;
- in_wr.flush().await?;
- } else if buf.starts_with(b""###;
- debug!("> {} '{}'", client_addr, to_str(buf));
- in_wr.write_all(buf).await?;
- in_wr.flush().await?;
- break;
+ while let Ok(Some(buf)) = in_rd.next(&mut in_filter).await {
+ debug!("received pre-tls stanza: {} '{}'", client_addr, to_str(&buf));
+ let buf = buf.trim_start(WHITESPACE);
+ if buf.starts_with(b" {} '{}'", client_addr, to_str(&buf));
+ in_wr.write_all(&buf).await?;
+ in_wr.flush().await?;
+ } else if buf.starts_with(b" {} '{}'", client_addr, to_str(&buf));
+ in_wr.write_all(&buf).await?;
+
+ // ejabberd never sends with the first, only the second?
+ //let buf = br###""###;
+ let buf = br###""###;
+ debug!("> {} '{}'", client_addr, to_str(buf));
+ in_wr.write_all(buf).await?;
+ in_wr.flush().await?;
+ } else if buf.starts_with(b""###;
+ debug!("> {} '{}'", client_addr, to_str(buf));
+ in_wr.write_all(buf).await?;
+ in_wr.flush().await?;
+ proceed_sent = true;
+ break;
+ } else {
+ bail!("bad pre-tls stanza: {}", to_str(&buf));
}
}
+ if !proceed_sent {
+ bail!("stream ended before open");
+ }
}
let stream = config.acceptor.accept(stream).await?;
let (in_rd, mut in_wr) = tokio::io::split(stream);
// we naively read 1 byte at a time, which buffering significantly speeds up
- let mut in_rd = tokio::io::BufReader::with_capacity(IN_BUFFER_SIZE, in_rd);
+ let in_rd = tokio::io::BufReader::with_capacity(IN_BUFFER_SIZE, in_rd);
+ let mut in_rd = StanzaReader(in_rd);
// now read to figure out client vs server
let (stream_open, is_c2s) = {
let mut stream_open = Vec::new();
let mut ret = None;
- while let Ok(n) = in_rd.read(in_filter.current_buf()).await {
- if n == 0 {
- bail!("stream ended before open");
- }
- if let Some(buf) = in_filter.process_next_byte()? {
- debug!("received pre- stanza: {} '{}'", client_addr, to_str(&buf));
- let buf = buf.trim_start(WHITESPACE);
- if buf.starts_with(b" stanza: {}", to_str(&buf));
- }
+ while let Ok(Some(buf)) = in_rd.next(&mut in_filter).await {
+ debug!("received pre- stanza: {} '{}'", client_addr, to_str(&buf));
+ let buf = buf.trim_start(WHITESPACE);
+ if buf.starts_with(b" stanza: {}", to_str(&buf));
}
}
if ret.is_some() {
@@ -281,14 +276,14 @@ async fn handle_connection(mut stream: tokio::net::TcpStream, client_addr: Socke
loop {
tokio::select! {
- Ok(n) = in_rd.read(in_filter.current_buf()) => {
- if n == 0 {
- break;
- }
- if let Some(buf) = in_filter.process_next_byte()? {
- debug!("< {} {} '{}'", client_addr, c2s(is_c2s), to_str(buf));
- out_wr.write_all(buf).await?;
- out_wr.flush().await?;
+ Ok(buf) = in_rd.next(&mut in_filter) => {
+ match buf {
+ None => break,
+ Some(buf) => {
+ debug!("< {} {} '{}'", client_addr, c2s(is_c2s), to_str(buf));
+ out_wr.write_all(buf).await?;
+ out_wr.flush().await?;
+ }
}
},
// we could filter outgoing from-server stanzas by size here too by doing same as above
@@ -308,6 +303,11 @@ async fn handle_connection(mut stream: tokio::net::TcpStream, client_addr: Socke
Ok(())
}
+/*
+async fn handle_connection(mut stream: tokio::net::TcpStream, client_addr: SocketAddr, local_addr: SocketAddr, config: CloneableConfig) -> Result<()> {
+ Ok(())
+}
+*/
fn spawn_listener(listener: TcpListener, config: CloneableConfig) -> JoinHandle> {
let local_addr = listener.local_addr().die("could not get local_addr?");
tokio::spawn(async move {
@@ -339,74 +339,3 @@ async fn main() {
}
futures::future::join_all(handles).await;
}
-
-struct StanzaFilter {
- buf_size: usize,
- buf: Vec,
- cnt: usize,
- tag_cnt: usize,
- last_char_was_lt: bool,
- last_char_was_backslash: bool,
-}
-
-impl StanzaFilter {
- pub fn new(buf_size: usize) -> StanzaFilter {
- StanzaFilter {
- buf_size,
- buf: vec![0u8; buf_size],
- cnt: 0,
- tag_cnt: 0,
- last_char_was_lt: false,
- last_char_was_backslash: false,
- }
- }
-
- #[inline(always)]
- pub fn current_buf(&mut self) -> &mut [u8] {
- &mut self.buf[self.cnt..(self.cnt + 1)]
- }
-
- pub fn process_next_byte(&mut self) -> Result
- if self.last_char_was_lt && self.tag_cnt >= 2 {
- // non-self-closing tag
- self.tag_cnt -= 2;
- }
- self.last_char_was_backslash = true;
- } else {
- if b == b'>' {
- if self.last_char_was_backslash {
- // self-closing tag
- self.tag_cnt -= 1;
- }
- // now special case some tags we want to send stand-alone:
- if self.tag_cnt == 1 && self.cnt >= 15 && (b" SliceSubsequence for Vec {
#[cfg(test)]
mod tests {
use crate::slicesubsequence::*;
- const WHITESPACE: &[u8] = b" \t\n\r";
+ use crate::WHITESPACE;
#[test]
fn trim_start() {
diff --git a/src/stanzafilter.rs b/src/stanzafilter.rs
new file mode 100644
index 0000000..277c895
--- /dev/null
+++ b/src/stanzafilter.rs
@@ -0,0 +1,131 @@
+use anyhow::{bail, Result};
+
+use crate::to_str;
+
+pub struct StanzaFilter {
+ buf_size: usize,
+ pub buf: Vec,
+ cnt: usize,
+ tag_cnt: usize,
+ last_char_was_lt: bool,
+ last_char_was_backslash: bool,
+}
+
+impl StanzaFilter {
+ pub fn new(buf_size: usize) -> StanzaFilter {
+ StanzaFilter {
+ buf_size,
+ buf: vec![0u8; buf_size],
+ cnt: 0,
+ tag_cnt: 0,
+ last_char_was_lt: false,
+ last_char_was_backslash: false,
+ }
+ }
+
+ #[inline(always)]
+ pub fn current_buf(&mut self) -> &mut [u8] {
+ &mut self.buf[self.cnt..(self.cnt + 1)]
+ }
+
+ #[allow(dead_code)]
+ pub fn process_next_byte(&mut self) -> Result
+ if self.last_char_was_lt && self.tag_cnt >= 2 {
+ // non-self-closing tag
+ self.tag_cnt -= 2;
+ }
+ self.last_char_was_backslash = true;
+ } else {
+ if b == b'>' {
+ if self.last_char_was_backslash {
+ // self-closing tag
+ self.tag_cnt -= 1;
+ }
+ // now special case some tags we want to send stand-alone:
+ if self.tag_cnt == 1 && self.cnt >= 15 && (b"(pub T);
+
+impl StanzaReader {
+ pub async fn next<'a>(&'a mut self, filter: &'a mut StanzaFilter) -> Result