From 30ca8609d5a83990d3be48caf40c6fe8fa9fa07e Mon Sep 17 00:00:00 2001 From: moparisthebest Date: Fri, 30 Sep 2022 21:22:01 -0400 Subject: [PATCH] Fix websocket conversion code --- src/srv.rs | 2 +- src/websocket/mod.rs | 35 +++++++++++++++++++++++++++-------- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/src/srv.rs b/src/srv.rs index 0ac80b8..3b32e0f 100644 --- a/src/srv.rs +++ b/src/srv.rs @@ -831,7 +831,7 @@ pub fn digest(algorithm: &'static Algorithm, buf: &[u8]) -> String { #[cfg(test)] mod tests { use crate::srv::*; - use std::path::PathBuf; + use std::{fs::File, io::Read, path::PathBuf}; fn valid_posh(posh: &[u8], cert: &[u8]) -> bool { let posh: PoshJson = serde_json::from_slice(posh).unwrap(); diff --git a/src/websocket/mod.rs b/src/websocket/mod.rs index d5e2c9e..a407bd3 100644 --- a/src/websocket/mod.rs +++ b/src/websocket/mod.rs @@ -58,9 +58,6 @@ pub fn from_ws(stanza: String) -> String { } pub fn to_ws_new(buf: &[u8], mut end_of_first_tag: usize, is_c2s: bool) -> Result { - if end_of_first_tag == 0 { - return Ok(String::from_utf8(buf.to_vec())?); - } if buf.starts_with(b" Resul .replace("jabber:client", "urn:ietf:params:xml:ns:xmpp-framing") .replace('>', "/>")); } + if end_of_first_tag == 0 { + return Ok(String::from_utf8(buf.to_vec())?); + } if buf.starts_with(b""#.to_string()); } @@ -76,14 +76,28 @@ pub fn to_ws_new(buf: &[u8], mut end_of_first_tag: usize, is_c2s: bool) -> Resul end_of_first_tag -= 1; } let first_tag_bytes = &buf[0..end_of_first_tag]; - if first_tag_bytes.first_index_of(b" xmlns='").is_ok() || first_tag_bytes.first_index_of(br#" xmlns=""#).is_ok() { + let has_xmlns = first_tag_bytes.first_index_of(b" xmlns='").is_ok() || first_tag_bytes.first_index_of(br#" xmlns=""#).is_ok(); + let has_xmlns_stream = !first_tag_bytes.contains_seq(b"stream:") || (first_tag_bytes.first_index_of(b" xmlns:stream='").is_ok() || first_tag_bytes.first_index_of(br#" xmlns:stream=""#).is_ok()); + if has_xmlns && has_xmlns_stream { // already set, do nothing return Ok(String::from_utf8(buf.to_vec())?); } // otherwise add proper xmlns before end of tag - let mut ret = String::with_capacity(buf.len() + 22); + let mut capacity = 0; + if !has_xmlns { + capacity += 22; + } + if !has_xmlns_stream { + capacity += 48; + } + let mut ret = String::with_capacity(buf.len() + capacity); ret.push_str(std::str::from_utf8(first_tag_bytes)?); - ret.push_str(if is_c2s { " xmlns='jabber:client'" } else { " xmlns='jabber:server'" }); + if !has_xmlns { + ret.push_str(if is_c2s { " xmlns='jabber:client'" } else { " xmlns='jabber:server'" }); + } + if !has_xmlns_stream { + ret.push_str(" xmlns:stream='http://etherx.jabber.org/streams'"); + } ret.push_str(std::str::from_utf8(&buf[end_of_first_tag..])?); Ok(ret) } @@ -95,7 +109,10 @@ use crate::{ #[cfg(test)] mod tests { - use crate::websocket::*; + use crate::{ + stanzafilter::{StanzaFilter, StanzaReader}, + websocket::*, + }; use std::io::Cursor; #[test] @@ -128,6 +145,7 @@ mod tests { to_vec_eoft( StanzaReader(Cursor::new( br###" + @@ -140,10 +158,11 @@ mod tests { ) .await?, vec![ + r#""#, r#""#, r#""#, r#""#, - r#"PLAINSCRAM-SHA-1"#, + r#"PLAINSCRAM-SHA-1"#, r#"test1@test.moparisthe.best/gajim.12S9XM42"#, r#"test1@test.moparisthe.best/gajim.12S9XM42"#, ]