Try safer

This commit is contained in:
Travis Burtrum 2025-02-25 00:24:27 -05:00
parent cf0eac9e0b
commit 85e7779c23
Signed by: moparisthebest
GPG Key ID: 88C93BFE27BC8229

View File

@ -245,6 +245,57 @@ impl Stream for StanzaStream {
}
}
pub struct AsyncReadWriteWs {
state: Option<AsyncReadWriteWsState>,
}
enum AsyncReadWriteWsState {
Stream(StanzaStream),
Fut(Pin<Box<dyn Future<Output = (Result<Option<&'static [u8]>, IoError>, StanzaStream)>>>),
}
impl AsyncRead for AsyncReadWriteWs {
fn poll_read(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> std::task::Poll<std::io::Result<()>> {
// todo: instead of waiting for a whole stanza, if self is AsyncRead, we could go directly to that and skip stanzafilter, problem is this would break Stream::poll_next and XmppStream::next_stanza, so maybe we need a different struct to do that?
// todo: instead of using our StanzaFilter and copying bytes from it, we could make one out of the buf?
let mut future = match self.state.take().unwrap() {
AsyncReadWriteWsState::Stream(mut stream) => {
let fut = async move {
let res = stream.next_stanza().await;
match res.map_err(|e| IoError::other(e)) {
Ok(Some((stanza, _))) => {
// this is a reference to stream which we return with it, and it's not returned from there either, so this is safe or needs pin or ?
let stanza: &'static [u8] = unsafe { std::mem::transmute(stanza) };
(Ok(Some(stanza)), stream)
}
Ok(None) => (Ok(None), stream),
Err(e) => (Err(e), stream),
}
};
Box::pin(fut)
}
AsyncReadWriteWsState::Fut(fut) => fut,
};
match future.as_mut().poll(cx) {
std::task::Poll::Ready((res, stream)) => {
self.state = AsyncReadWriteWsState::Stream(stream).into();
if let Some(stanza) = res.map_err(|e| IoError::other(e))? {
if stanza.len() >= buf.remaining() {
return std::task::Poll::Ready(Err(IoError::other(format!("stanza of length {} read but buffer of only {} supplied", stanza.len(), buf.remaining()))));
}
buf.put_slice(stanza);
}
Poll::Ready(Ok(()))
}
std::task::Poll::Pending => {
self.state = Some(AsyncReadWriteWsState::Fut(future));
std::task::Poll::Pending
}
}
}
}
impl AsyncRead for StanzaStream {
fn poll_read(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> std::task::Poll<std::io::Result<()>> {
// todo: instead of waiting for a whole stanza, if self is AsyncRead, we could go directly to that and skip stanzafilter, problem is this would break Stream::poll_next and XmppStream::next_stanza, so maybe we need a different struct to do that?