Make QUIC S2S bidirectional too

This commit is contained in:
Travis Burtrum 2021-05-15 00:32:36 -04:00
parent 6b851312d8
commit f127c02d45
4 changed files with 30 additions and 74 deletions

View File

@ -131,23 +131,9 @@ impl Config {
} }
} }
#[derive(PartialEq)] async fn shuffle_rd_wr<R: AsyncRead + Unpin, W: AsyncWrite + Unpin>(in_rd: R, in_wr: W, config: CloneableConfig, local_addr: SocketAddr, client_addr: SocketAddr) -> Result<()> {
pub enum AllowedType {
ClientOnly,
ServerOnly,
Any,
}
async fn shuffle_rd_wr<R: AsyncRead + Unpin, W: AsyncWrite + Unpin>(
in_rd: R,
in_wr: W,
config: CloneableConfig,
local_addr: SocketAddr,
client_addr: SocketAddr,
allowed_type: AllowedType,
) -> Result<()> {
let filter = StanzaFilter::new(config.max_stanza_size_bytes); let filter = StanzaFilter::new(config.max_stanza_size_bytes);
shuffle_rd_wr_filter(in_rd, in_wr, config, local_addr, client_addr, allowed_type, filter).await shuffle_rd_wr_filter(in_rd, in_wr, config, local_addr, client_addr, filter).await
} }
async fn shuffle_rd_wr_filter<R: AsyncRead + Unpin, W: AsyncWrite + Unpin>( async fn shuffle_rd_wr_filter<R: AsyncRead + Unpin, W: AsyncWrite + Unpin>(
@ -156,7 +142,6 @@ async fn shuffle_rd_wr_filter<R: AsyncRead + Unpin, W: AsyncWrite + Unpin>(
config: CloneableConfig, config: CloneableConfig,
local_addr: SocketAddr, local_addr: SocketAddr,
client_addr: SocketAddr, client_addr: SocketAddr,
allowed_type: AllowedType,
in_filter: StanzaFilter, in_filter: StanzaFilter,
) -> Result<()> { ) -> Result<()> {
// we naively read 1 byte at a time, which buffering significantly speeds up // we naively read 1 byte at a time, which buffering significantly speeds up
@ -165,17 +150,7 @@ async fn shuffle_rd_wr_filter<R: AsyncRead + Unpin, W: AsyncWrite + Unpin>(
// now read to figure out client vs server // now read to figure out client vs server
let (stream_open, is_c2s, mut in_rd, mut in_filter) = stream_preamble(StanzaReader(in_rd), client_addr, in_filter).await?; let (stream_open, is_c2s, mut in_rd, mut in_filter) = stream_preamble(StanzaReader(in_rd), client_addr, in_filter).await?;
let target = if is_c2s { let target = if is_c2s { config.c2s_target } else { config.s2s_target };
if allowed_type == AllowedType::ServerOnly {
bail!("c2s requested when only s2s allowed");
}
config.c2s_target
} else {
if allowed_type == AllowedType::ClientOnly {
bail!("s2s requested when only c2s allowed");
}
config.s2s_target
};
println!("INFO: {} is_c2s: {}, target: {}", client_addr, is_c2s, target); println!("INFO: {} is_c2s: {}, target: {}", client_addr, is_c2s, target);

View File

@ -18,13 +18,8 @@ pub async fn quic_connect(target: SocketAddr, server_name: &str, is_c2s: bool) -
let quinn::NewConnection { connection, .. } = endpoint.connect(&target, server_name).unwrap().await?; let quinn::NewConnection { connection, .. } = endpoint.connect(&target, server_name).unwrap().await?;
debug!("[client] connected: addr={}", connection.remote_address()); debug!("[client] connected: addr={}", connection.remote_address());
if is_c2s {
let (wrt, rd) = connection.open_bi().await?; let (wrt, rd) = connection.open_bi().await?;
Ok((Box::new(wrt), Box::new(rd))) Ok((Box::new(wrt), Box::new(rd)))
} else {
let wrt = connection.open_uni().await?;
Ok((Box::new(wrt), Box::new(NoopIo)))
}
} }
impl Config { impl Config {
@ -86,25 +81,13 @@ pub fn spawn_quic_listener(local_addr: SocketAddr, config: CloneableConfig, serv
tokio::spawn(async move { tokio::spawn(async move {
println!("INFO: {} quic connected", client_addr); println!("INFO: {} quic connected", client_addr);
loop { while let Some(Ok((wrt, rd))) = new_conn.bi_streams.next().await {
tokio::select! {
Some(Ok((wrt, rd))) = new_conn.bi_streams.next() => {
let config = config.clone(); let config = config.clone();
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = shuffle_rd_wr(rd, wrt, config, local_addr, client_addr, AllowedType::ClientOnly).await { if let Err(e) = shuffle_rd_wr(rd, wrt, config, local_addr, client_addr).await {
eprintln!("ERROR: {} {}", client_addr, e); eprintln!("ERROR: {} {}", client_addr, e);
} }
}); });
},
Some(Ok(rd)) = new_conn.uni_streams.next() => {
let config = config.clone();
tokio::spawn(async move {
if let Err(e) = shuffle_rd_wr(rd, NoopIo, config, local_addr, client_addr, AllowedType::ServerOnly).await {
eprintln!("ERROR: {} {}", client_addr, e);
}
});
},
}
} }
}); });
#[allow(unreachable_code)] #[allow(unreachable_code)]

View File

@ -170,7 +170,6 @@ pub async fn srv_connect(
out_wr.flush().await?; out_wr.flush().await?;
let mut server_response = Vec::new(); let mut server_response = Vec::new();
if is_c2s {
// let's read to first <stream:stream to make sure we are successfully connected to a real XMPP server // let's read to first <stream:stream to make sure we are successfully connected to a real XMPP server
let mut stream_received = false; let mut stream_received = false;
while let Ok(Some(buf)) = out_rd.next(&mut in_filter).await { while let Ok(Some(buf)) = out_rd.next(&mut in_filter).await {
@ -190,7 +189,6 @@ pub async fn srv_connect(
debug!("bad server response, going to next record"); debug!("bad server response, going to next record");
continue; continue;
} }
}
return Ok((Box::new(out_wr), out_rd, server_response)); return Ok((Box::new(out_wr), out_rd, server_response));
} }

View File

@ -205,5 +205,5 @@ async fn handle_tls_connection(mut stream: tokio::net::TcpStream, client_addr: S
let (in_rd, in_wr) = tokio::io::split(stream); let (in_rd, in_wr) = tokio::io::split(stream);
shuffle_rd_wr_filter(in_rd, in_wr, config, local_addr, client_addr, AllowedType::Any, in_filter).await shuffle_rd_wr_filter(in_rd, in_wr, config, local_addr, client_addr, in_filter).await
} }