Make QUIC S2S bidirectional too
This commit is contained in:
parent
6b851312d8
commit
f127c02d45
31
src/main.rs
31
src/main.rs
@ -131,23 +131,9 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq)]
|
||||
pub enum AllowedType {
|
||||
ClientOnly,
|
||||
ServerOnly,
|
||||
Any,
|
||||
}
|
||||
|
||||
async fn shuffle_rd_wr<R: AsyncRead + Unpin, W: AsyncWrite + Unpin>(
|
||||
in_rd: R,
|
||||
in_wr: W,
|
||||
config: CloneableConfig,
|
||||
local_addr: SocketAddr,
|
||||
client_addr: SocketAddr,
|
||||
allowed_type: AllowedType,
|
||||
) -> Result<()> {
|
||||
async fn shuffle_rd_wr<R: AsyncRead + Unpin, W: AsyncWrite + Unpin>(in_rd: R, in_wr: W, config: CloneableConfig, local_addr: SocketAddr, client_addr: SocketAddr) -> Result<()> {
|
||||
let filter = StanzaFilter::new(config.max_stanza_size_bytes);
|
||||
shuffle_rd_wr_filter(in_rd, in_wr, config, local_addr, client_addr, allowed_type, filter).await
|
||||
shuffle_rd_wr_filter(in_rd, in_wr, config, local_addr, client_addr, filter).await
|
||||
}
|
||||
|
||||
async fn shuffle_rd_wr_filter<R: AsyncRead + Unpin, W: AsyncWrite + Unpin>(
|
||||
@ -156,7 +142,6 @@ async fn shuffle_rd_wr_filter<R: AsyncRead + Unpin, W: AsyncWrite + Unpin>(
|
||||
config: CloneableConfig,
|
||||
local_addr: SocketAddr,
|
||||
client_addr: SocketAddr,
|
||||
allowed_type: AllowedType,
|
||||
in_filter: StanzaFilter,
|
||||
) -> Result<()> {
|
||||
// we naively read 1 byte at a time, which buffering significantly speeds up
|
||||
@ -165,17 +150,7 @@ async fn shuffle_rd_wr_filter<R: AsyncRead + Unpin, W: AsyncWrite + Unpin>(
|
||||
// now read to figure out client vs server
|
||||
let (stream_open, is_c2s, mut in_rd, mut in_filter) = stream_preamble(StanzaReader(in_rd), client_addr, in_filter).await?;
|
||||
|
||||
let target = if is_c2s {
|
||||
if allowed_type == AllowedType::ServerOnly {
|
||||
bail!("c2s requested when only s2s allowed");
|
||||
}
|
||||
config.c2s_target
|
||||
} else {
|
||||
if allowed_type == AllowedType::ClientOnly {
|
||||
bail!("s2s requested when only c2s allowed");
|
||||
}
|
||||
config.s2s_target
|
||||
};
|
||||
let target = if is_c2s { config.c2s_target } else { config.s2s_target };
|
||||
|
||||
println!("INFO: {} is_c2s: {}, target: {}", client_addr, is_c2s, target);
|
||||
|
||||
|
35
src/quic.rs
35
src/quic.rs
@ -18,13 +18,8 @@ pub async fn quic_connect(target: SocketAddr, server_name: &str, is_c2s: bool) -
|
||||
let quinn::NewConnection { connection, .. } = endpoint.connect(&target, server_name).unwrap().await?;
|
||||
debug!("[client] connected: addr={}", connection.remote_address());
|
||||
|
||||
if is_c2s {
|
||||
let (wrt, rd) = connection.open_bi().await?;
|
||||
Ok((Box::new(wrt), Box::new(rd)))
|
||||
} else {
|
||||
let wrt = connection.open_uni().await?;
|
||||
Ok((Box::new(wrt), Box::new(NoopIo)))
|
||||
}
|
||||
let (wrt, rd) = connection.open_bi().await?;
|
||||
Ok((Box::new(wrt), Box::new(rd)))
|
||||
}
|
||||
|
||||
impl Config {
|
||||
@ -86,25 +81,13 @@ pub fn spawn_quic_listener(local_addr: SocketAddr, config: CloneableConfig, serv
|
||||
tokio::spawn(async move {
|
||||
println!("INFO: {} quic connected", client_addr);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
Some(Ok((wrt, rd))) = new_conn.bi_streams.next() => {
|
||||
let config = config.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = shuffle_rd_wr(rd, wrt, config, local_addr, client_addr, AllowedType::ClientOnly).await {
|
||||
eprintln!("ERROR: {} {}", client_addr, e);
|
||||
}
|
||||
});
|
||||
},
|
||||
Some(Ok(rd)) = new_conn.uni_streams.next() => {
|
||||
let config = config.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = shuffle_rd_wr(rd, NoopIo, config, local_addr, client_addr, AllowedType::ServerOnly).await {
|
||||
eprintln!("ERROR: {} {}", client_addr, e);
|
||||
}
|
||||
});
|
||||
},
|
||||
}
|
||||
while let Some(Ok((wrt, rd))) = new_conn.bi_streams.next().await {
|
||||
let config = config.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = shuffle_rd_wr(rd, wrt, config, local_addr, client_addr).await {
|
||||
eprintln!("ERROR: {} {}", client_addr, e);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
#[allow(unreachable_code)]
|
||||
|
36
src/srv.rs
36
src/srv.rs
@ -170,27 +170,25 @@ pub async fn srv_connect(
|
||||
out_wr.flush().await?;
|
||||
|
||||
let mut server_response = Vec::new();
|
||||
if is_c2s {
|
||||
// let's read to first <stream:stream to make sure we are successfully connected to a real XMPP server
|
||||
let mut stream_received = false;
|
||||
while let Ok(Some(buf)) = out_rd.next(&mut in_filter).await {
|
||||
debug!("received pre-tls stanza: {} '{}'", domain, to_str(&buf));
|
||||
if buf.starts_with(b"<?xml ") {
|
||||
server_response.extend_from_slice(&buf);
|
||||
} else if buf.starts_with(b"<stream:stream ") {
|
||||
server_response.extend_from_slice(&buf);
|
||||
stream_received = true;
|
||||
break;
|
||||
} else {
|
||||
debug!("bad pre-tls stanza: {}", to_str(&buf));
|
||||
break;
|
||||
}
|
||||
}
|
||||
if !stream_received {
|
||||
debug!("bad server response, going to next record");
|
||||
continue;
|
||||
// let's read to first <stream:stream to make sure we are successfully connected to a real XMPP server
|
||||
let mut stream_received = false;
|
||||
while let Ok(Some(buf)) = out_rd.next(&mut in_filter).await {
|
||||
debug!("received pre-tls stanza: {} '{}'", domain, to_str(&buf));
|
||||
if buf.starts_with(b"<?xml ") {
|
||||
server_response.extend_from_slice(&buf);
|
||||
} else if buf.starts_with(b"<stream:stream ") {
|
||||
server_response.extend_from_slice(&buf);
|
||||
stream_received = true;
|
||||
break;
|
||||
} else {
|
||||
debug!("bad pre-tls stanza: {}", to_str(&buf));
|
||||
break;
|
||||
}
|
||||
}
|
||||
if !stream_received {
|
||||
debug!("bad server response, going to next record");
|
||||
continue;
|
||||
}
|
||||
|
||||
return Ok((Box::new(out_wr), out_rd, server_response));
|
||||
}
|
||||
|
@ -205,5 +205,5 @@ async fn handle_tls_connection(mut stream: tokio::net::TcpStream, client_addr: S
|
||||
|
||||
let (in_rd, in_wr) = tokio::io::split(stream);
|
||||
|
||||
shuffle_rd_wr_filter(in_rd, in_wr, config, local_addr, client_addr, AllowedType::Any, in_filter).await
|
||||
shuffle_rd_wr_filter(in_rd, in_wr, config, local_addr, client_addr, in_filter).await
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user