Make QUIC S2S bidirectional too

This commit is contained in:
Travis Burtrum 2021-05-15 00:32:36 -04:00
parent 6b851312d8
commit f127c02d45
4 changed files with 30 additions and 74 deletions

View File

@ -131,23 +131,9 @@ impl Config {
}
}
#[derive(PartialEq)]
pub enum AllowedType {
ClientOnly,
ServerOnly,
Any,
}
async fn shuffle_rd_wr<R: AsyncRead + Unpin, W: AsyncWrite + Unpin>(
in_rd: R,
in_wr: W,
config: CloneableConfig,
local_addr: SocketAddr,
client_addr: SocketAddr,
allowed_type: AllowedType,
) -> Result<()> {
async fn shuffle_rd_wr<R: AsyncRead + Unpin, W: AsyncWrite + Unpin>(in_rd: R, in_wr: W, config: CloneableConfig, local_addr: SocketAddr, client_addr: SocketAddr) -> Result<()> {
let filter = StanzaFilter::new(config.max_stanza_size_bytes);
shuffle_rd_wr_filter(in_rd, in_wr, config, local_addr, client_addr, allowed_type, filter).await
shuffle_rd_wr_filter(in_rd, in_wr, config, local_addr, client_addr, filter).await
}
async fn shuffle_rd_wr_filter<R: AsyncRead + Unpin, W: AsyncWrite + Unpin>(
@ -156,7 +142,6 @@ async fn shuffle_rd_wr_filter<R: AsyncRead + Unpin, W: AsyncWrite + Unpin>(
config: CloneableConfig,
local_addr: SocketAddr,
client_addr: SocketAddr,
allowed_type: AllowedType,
in_filter: StanzaFilter,
) -> Result<()> {
// we naively read 1 byte at a time, which buffering significantly speeds up
@ -165,17 +150,7 @@ async fn shuffle_rd_wr_filter<R: AsyncRead + Unpin, W: AsyncWrite + Unpin>(
// now read to figure out client vs server
let (stream_open, is_c2s, mut in_rd, mut in_filter) = stream_preamble(StanzaReader(in_rd), client_addr, in_filter).await?;
let target = if is_c2s {
if allowed_type == AllowedType::ServerOnly {
bail!("c2s requested when only s2s allowed");
}
config.c2s_target
} else {
if allowed_type == AllowedType::ClientOnly {
bail!("s2s requested when only c2s allowed");
}
config.s2s_target
};
let target = if is_c2s { config.c2s_target } else { config.s2s_target };
println!("INFO: {} is_c2s: {}, target: {}", client_addr, is_c2s, target);

View File

@ -18,13 +18,8 @@ pub async fn quic_connect(target: SocketAddr, server_name: &str, is_c2s: bool) -
let quinn::NewConnection { connection, .. } = endpoint.connect(&target, server_name).unwrap().await?;
debug!("[client] connected: addr={}", connection.remote_address());
if is_c2s {
let (wrt, rd) = connection.open_bi().await?;
Ok((Box::new(wrt), Box::new(rd)))
} else {
let wrt = connection.open_uni().await?;
Ok((Box::new(wrt), Box::new(NoopIo)))
}
let (wrt, rd) = connection.open_bi().await?;
Ok((Box::new(wrt), Box::new(rd)))
}
impl Config {
@ -86,25 +81,13 @@ pub fn spawn_quic_listener(local_addr: SocketAddr, config: CloneableConfig, serv
tokio::spawn(async move {
println!("INFO: {} quic connected", client_addr);
loop {
tokio::select! {
Some(Ok((wrt, rd))) = new_conn.bi_streams.next() => {
let config = config.clone();
tokio::spawn(async move {
if let Err(e) = shuffle_rd_wr(rd, wrt, config, local_addr, client_addr, AllowedType::ClientOnly).await {
eprintln!("ERROR: {} {}", client_addr, e);
}
});
},
Some(Ok(rd)) = new_conn.uni_streams.next() => {
let config = config.clone();
tokio::spawn(async move {
if let Err(e) = shuffle_rd_wr(rd, NoopIo, config, local_addr, client_addr, AllowedType::ServerOnly).await {
eprintln!("ERROR: {} {}", client_addr, e);
}
});
},
}
while let Some(Ok((wrt, rd))) = new_conn.bi_streams.next().await {
let config = config.clone();
tokio::spawn(async move {
if let Err(e) = shuffle_rd_wr(rd, wrt, config, local_addr, client_addr).await {
eprintln!("ERROR: {} {}", client_addr, e);
}
});
}
});
#[allow(unreachable_code)]

View File

@ -170,27 +170,25 @@ pub async fn srv_connect(
out_wr.flush().await?;
let mut server_response = Vec::new();
if is_c2s {
// let's read to first <stream:stream to make sure we are successfully connected to a real XMPP server
let mut stream_received = false;
while let Ok(Some(buf)) = out_rd.next(&mut in_filter).await {
debug!("received pre-tls stanza: {} '{}'", domain, to_str(&buf));
if buf.starts_with(b"<?xml ") {
server_response.extend_from_slice(&buf);
} else if buf.starts_with(b"<stream:stream ") {
server_response.extend_from_slice(&buf);
stream_received = true;
break;
} else {
debug!("bad pre-tls stanza: {}", to_str(&buf));
break;
}
}
if !stream_received {
debug!("bad server response, going to next record");
continue;
// let's read to first <stream:stream to make sure we are successfully connected to a real XMPP server
let mut stream_received = false;
while let Ok(Some(buf)) = out_rd.next(&mut in_filter).await {
debug!("received pre-tls stanza: {} '{}'", domain, to_str(&buf));
if buf.starts_with(b"<?xml ") {
server_response.extend_from_slice(&buf);
} else if buf.starts_with(b"<stream:stream ") {
server_response.extend_from_slice(&buf);
stream_received = true;
break;
} else {
debug!("bad pre-tls stanza: {}", to_str(&buf));
break;
}
}
if !stream_received {
debug!("bad server response, going to next record");
continue;
}
return Ok((Box::new(out_wr), out_rd, server_response));
}

View File

@ -205,5 +205,5 @@ async fn handle_tls_connection(mut stream: tokio::net::TcpStream, client_addr: S
let (in_rd, in_wr) = tokio::io::split(stream);
shuffle_rd_wr_filter(in_rd, in_wr, config, local_addr, client_addr, AllowedType::Any, in_filter).await
shuffle_rd_wr_filter(in_rd, in_wr, config, local_addr, client_addr, in_filter).await
}