use anyhow::Result; use futures::StreamExt; use futures_util::stream::{SplitSink, SplitStream}; use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream}; #[cfg(feature = "incoming")] pub mod incoming; #[cfg(feature = "outgoing")] pub mod outgoing; pub type WsWr = SplitSink>, tokio_tungstenite::tungstenite::Message>; pub type WsRd = SplitStream>>; // https://datatracker.ietf.org/doc/html/rfc7395 fn ws_cfg(max_stanza_size_bytes: usize) -> Option { Some(WebSocketConfig { max_send_queue: None, // unlimited max_frame_size: Some(max_stanza_size_bytes), // this is exactly the stanza size max_message_size: Some(max_stanza_size_bytes * 4), // this is the message size, default is 4x frame size, so I guess we'll do the same here accept_unmasked_frames: true, }) } pub trait AsyncReadAndWrite: tokio::io::AsyncRead + tokio::io::AsyncWrite {} impl AsyncReadAndWrite for T {} pub async fn incoming_websocket_connection(stream: Box, max_stanza_size_bytes: usize) -> Result<(StanzaRead, StanzaWrite)> { // accept the websocket // todo: check SEC_WEBSOCKET_PROTOCOL or ORIGIN ? let stream = tokio_tungstenite::accept_async_with_config(stream, ws_cfg(max_stanza_size_bytes)).await?; let (in_wr, in_rd) = stream.split(); Ok((StanzaRead::WebSocketRead(in_rd), StanzaWrite::WebSocketClientWrite(in_wr))) } pub fn from_ws(stanza: String) -> String { if stanza.starts_with("", r#" xmlns:stream="http://etherx.jabber.org/streams">"#) } else { stanza.replace("/>", ">") } } else if stanza.starts_with("".to_string() } else { stanza } } pub fn to_ws_new(buf: &[u8], mut end_of_first_tag: usize, is_c2s: bool) -> Result { if end_of_first_tag == 0 { return Ok(String::from_utf8(buf.to_vec())?); } if buf.starts_with(b"', "/>")); } if buf.starts_with(b""#.to_string()); } if buf[end_of_first_tag - 1] == b'/' { end_of_first_tag -= 1; } let first_tag_bytes = &buf[0..end_of_first_tag]; if first_tag_bytes.first_index_of(b" xmlns='").is_ok() || first_tag_bytes.first_index_of(br#" xmlns=""#).is_ok() { // already set, do nothing return Ok(String::from_utf8(buf.to_vec())?); } // otherwise add proper xmlns before end of tag let mut ret = String::with_capacity(buf.len() + 22); ret.push_str(std::str::from_utf8(first_tag_bytes)?); ret.push_str(if is_c2s { " xmlns='jabber:client'" } else { " xmlns='jabber:server'" }); ret.push_str(std::str::from_utf8(&buf[end_of_first_tag..])?); Ok(ret) } use crate::{ in_out::{StanzaRead, StanzaWrite}, slicesubsequence::SliceSubsequence, }; #[cfg(test)] mod tests { use crate::websocket::*; use std::io::Cursor; #[test] fn test_from_ws() { assert_eq!( from_ws(r#""#.to_string()), r#""#.to_string() ); assert_eq!(from_ws(r#""#.to_string()), r#""#.to_string()); assert_eq!( from_ws(r#""#.to_string()), r#""#.to_string() ); } async fn to_vec_eoft(mut stanza_reader: StanzaReader, filter: &mut StanzaFilter) -> Result> { let mut ret = Vec::new(); while let Some((buf, end_of_first_tag)) = stanza_reader.next_eoft(filter).await? { ret.push(to_ws_new(buf, end_of_first_tag, true)?); } Ok(ret) } #[tokio::test] async fn test_to_ws() -> Result<()> { let mut filter = StanzaFilter::new(262_144); assert_eq!( to_vec_eoft( StanzaReader(Cursor::new( br###" PLAINSCRAM-SHA-1 test1@test.moparisthe.best/gajim.12S9XM42 test1@test.moparisthe.best/gajim.12S9XM42 "###, )), &mut filter ) .await?, vec![ r#""#, r#""#, r#""#, r#"PLAINSCRAM-SHA-1"#, r#"test1@test.moparisthe.best/gajim.12S9XM42"#, r#"test1@test.moparisthe.best/gajim.12S9XM42"#, ] ); Ok(()) } }