use crate::*; use anyhow::Result; use futures::StreamExt; use tokio_tungstenite::tungstenite::protocol::WebSocketConfig; // https://datatracker.ietf.org/doc/html/rfc7395 pub fn spawn_websocket_listener(local_addr: SocketAddr, config: CloneableConfig, acceptor: TlsAcceptor) -> JoinHandle> { tokio::spawn(async move { let listener = TcpListener::bind(&local_addr).await.die("cannot listen on port/interface"); loop { let (stream, client_addr) = listener.accept().await?; let config = config.clone(); let acceptor = acceptor.clone(); tokio::spawn(async move { let mut client_addr = Context::new("websocket-in", client_addr); if let Err(e) = handle_websocket_connection(stream, &mut client_addr, local_addr, config, acceptor).await { error!("{} {}", client_addr.log_from(), e); } }); } #[allow(unreachable_code)] Ok(()) }) } async fn handle_websocket_connection(stream: tokio::net::TcpStream, client_addr: &mut Context<'_>, local_addr: SocketAddr, config: CloneableConfig, acceptor: TlsAcceptor) -> Result<()> { info!("{} connected", client_addr.log_from()); // start TLS let stream = acceptor.accept(stream).await?; let stream: tokio_rustls::TlsStream = stream.into(); // accept the websocket // todo: check SEC_WEBSOCKET_PROTOCOL or ORIGIN ? let stream = tokio_tungstenite::accept_async_with_config( stream, Some(WebSocketConfig { max_send_queue: None, // unlimited max_frame_size: Some(config.max_stanza_size_bytes), // this is exactly the stanza size max_message_size: Some(config.max_stanza_size_bytes * 4), // this is the message size, default is 4x frame size, so I guess we'll do the same here accept_unmasked_frames: true, }), ) .await?; let (in_wr, in_rd) = stream.split(); let in_filter = StanzaFilter::new(config.max_stanza_size_bytes); shuffle_rd_wr_filter(StanzaRead::WebSocketRead(in_rd), StanzaWrite::WebSocketClientWrite(in_wr), config, local_addr, client_addr, in_filter).await } pub fn from_ws(stanza: String) -> String { if stanza.starts_with("", r#" xmlns:stream="http://etherx.jabber.org/streams">"#) } else { stanza.replace("/>", ">") } } else if stanza.starts_with("".to_string() } else { stanza } } pub fn to_ws_new(buf: &[u8], mut end_of_first_tag: usize, is_c2s: bool) -> Result { if end_of_first_tag == 0 { return Ok(String::from_utf8(buf.to_vec())?); } if buf.starts_with(b"", "/>")); } if buf.starts_with(b""#.to_string()); } if buf[end_of_first_tag - 1] == b'/' { end_of_first_tag -= 1; } let first_tag_bytes = &buf[0..end_of_first_tag]; if first_tag_bytes.first_index_of(b" xmlns='").is_ok() || first_tag_bytes.first_index_of(br#" xmlns=""#).is_ok() { // already set, do nothing return Ok(String::from_utf8(buf.to_vec())?); } // otherwise add proper xmlns before end of tag let mut ret = String::with_capacity(buf.len() + 22); ret.push_str(std::str::from_utf8(first_tag_bytes)?); ret.push_str(if is_c2s { " xmlns='jabber:client'" } else { " xmlns='jabber:server'" }); ret.push_str(std::str::from_utf8(&buf[end_of_first_tag..])?); Ok(ret) } use rustls::ServerName; use std::convert::TryFrom; use tokio_rustls::TlsConnector; use tokio_tungstenite::tungstenite::client::IntoClientRequest; use tokio_tungstenite::tungstenite::http::header::{ORIGIN, SEC_WEBSOCKET_PROTOCOL}; use tokio_tungstenite::tungstenite::http::Uri; pub async fn websocket_connect(target: SocketAddr, server_name: &str, url: &Uri, origin: &str, _is_c2s: bool) -> Result<(StanzaWrite, StanzaRead)> { // todo: WebSocketConfig // todo: static ? alpn? client cert auth for server let connector = rustls::ClientConfig::builder().with_safe_defaults().with_root_certificates(root_cert_store()).with_no_client_auth(); let mut request = url.into_client_request()?; request.headers_mut().append(SEC_WEBSOCKET_PROTOCOL, "xmpp".parse()?); request.headers_mut().append(ORIGIN, origin.parse()?); let dnsname = ServerName::try_from(server_name)?; let stream = tokio::net::TcpStream::connect(target).await?; let connector = TlsConnector::from(Arc::new(connector)); let stream = connector.connect(dnsname, stream).await?; let stream: tokio_rustls::TlsStream = stream.into(); let (stream, _) = tokio_tungstenite::client_async_with_config(request, stream, None).await?; let (wrt, rd) = stream.split(); Ok((StanzaWrite::WebSocketClientWrite(wrt), StanzaRead::WebSocketRead(rd))) } #[cfg(test)] mod tests { use crate::websocket::*; use std::io::Cursor; #[test] fn test_from_ws() { assert_eq!( from_ws(r#""#.to_string()), r#""#.to_string() ); assert_eq!(from_ws(r#""#.to_string()), r#""#.to_string()); assert_eq!( from_ws(r#""#.to_string()), r#""#.to_string() ); } async fn to_vec_eoft(mut stanza_reader: StanzaReader, filter: &mut StanzaFilter) -> Result> { let mut ret = Vec::new(); while let Some((buf, end_of_first_tag)) = stanza_reader.next_eoft(filter).await? { ret.push(to_ws_new(buf, end_of_first_tag, true)?); } Ok(ret) } #[tokio::test] async fn test_to_ws() -> Result<()> { let mut filter = StanzaFilter::new(262_144); assert_eq!( to_vec_eoft( StanzaReader(Cursor::new( br###" PLAINSCRAM-SHA-1 test1@test.moparisthe.best/gajim.12S9XM42 test1@test.moparisthe.best/gajim.12S9XM42 "###, )), &mut filter ) .await?, vec![ r#""#, r#""#, r#""#, r#"PLAINSCRAM-SHA-1"#, r#"test1@test.moparisthe.best/gajim.12S9XM42"#, r#"test1@test.moparisthe.best/gajim.12S9XM42"#, ] ); Ok(()) } }