From 85e7779c238ff0f685df576f3741451f31cab604 Mon Sep 17 00:00:00 2001 From: moparisthebest Date: Tue, 25 Feb 2025 00:24:27 -0500 Subject: [PATCH] Try safer --- src/in_out.rs | 51 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/src/in_out.rs b/src/in_out.rs index 2ee5a8c..398f78e 100644 --- a/src/in_out.rs +++ b/src/in_out.rs @@ -245,6 +245,57 @@ impl Stream for StanzaStream { } } +pub struct AsyncReadWriteWs { + state: Option, +} + +enum AsyncReadWriteWsState { + Stream(StanzaStream), + Fut(Pin, IoError>, StanzaStream)>>>), +} + +impl AsyncRead for AsyncReadWriteWs { + fn poll_read(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> std::task::Poll> { + // todo: instead of waiting for a whole stanza, if self is AsyncRead, we could go directly to that and skip stanzafilter, problem is this would break Stream::poll_next and XmppStream::next_stanza, so maybe we need a different struct to do that? + // todo: instead of using our StanzaFilter and copying bytes from it, we could make one out of the buf? + let mut future = match self.state.take().unwrap() { + AsyncReadWriteWsState::Stream(mut stream) => { + let fut = async move { + let res = stream.next_stanza().await; + match res.map_err(|e| IoError::other(e)) { + Ok(Some((stanza, _))) => { + // this is a reference to stream which we return with it, and it's not returned from there either, so this is safe or needs pin or ? + let stanza: &'static [u8] = unsafe { std::mem::transmute(stanza) }; + (Ok(Some(stanza)), stream) + } + Ok(None) => (Ok(None), stream), + Err(e) => (Err(e), stream), + } + }; + Box::pin(fut) + } + AsyncReadWriteWsState::Fut(fut) => fut, + }; + + match future.as_mut().poll(cx) { + std::task::Poll::Ready((res, stream)) => { + self.state = AsyncReadWriteWsState::Stream(stream).into(); + if let Some(stanza) = res.map_err(|e| IoError::other(e))? { + if stanza.len() >= buf.remaining() { + return std::task::Poll::Ready(Err(IoError::other(format!("stanza of length {} read but buffer of only {} supplied", stanza.len(), buf.remaining())))); + } + buf.put_slice(stanza); + } + Poll::Ready(Ok(())) + } + std::task::Poll::Pending => { + self.state = Some(AsyncReadWriteWsState::Fut(future)); + std::task::Poll::Pending + } + } + } +} + impl AsyncRead for StanzaStream { fn poll_read(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> std::task::Poll> { // todo: instead of waiting for a whole stanza, if self is AsyncRead, we could go directly to that and skip stanzafilter, problem is this would break Stream::poll_next and XmppStream::next_stanza, so maybe we need a different struct to do that?