Make QUIC S2S bidirectional too
This commit is contained in:
parent
6b851312d8
commit
f127c02d45
31
src/main.rs
31
src/main.rs
@ -131,23 +131,9 @@ impl Config {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(PartialEq)]
|
async fn shuffle_rd_wr<R: AsyncRead + Unpin, W: AsyncWrite + Unpin>(in_rd: R, in_wr: W, config: CloneableConfig, local_addr: SocketAddr, client_addr: SocketAddr) -> Result<()> {
|
||||||
pub enum AllowedType {
|
|
||||||
ClientOnly,
|
|
||||||
ServerOnly,
|
|
||||||
Any,
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn shuffle_rd_wr<R: AsyncRead + Unpin, W: AsyncWrite + Unpin>(
|
|
||||||
in_rd: R,
|
|
||||||
in_wr: W,
|
|
||||||
config: CloneableConfig,
|
|
||||||
local_addr: SocketAddr,
|
|
||||||
client_addr: SocketAddr,
|
|
||||||
allowed_type: AllowedType,
|
|
||||||
) -> Result<()> {
|
|
||||||
let filter = StanzaFilter::new(config.max_stanza_size_bytes);
|
let filter = StanzaFilter::new(config.max_stanza_size_bytes);
|
||||||
shuffle_rd_wr_filter(in_rd, in_wr, config, local_addr, client_addr, allowed_type, filter).await
|
shuffle_rd_wr_filter(in_rd, in_wr, config, local_addr, client_addr, filter).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn shuffle_rd_wr_filter<R: AsyncRead + Unpin, W: AsyncWrite + Unpin>(
|
async fn shuffle_rd_wr_filter<R: AsyncRead + Unpin, W: AsyncWrite + Unpin>(
|
||||||
@ -156,7 +142,6 @@ async fn shuffle_rd_wr_filter<R: AsyncRead + Unpin, W: AsyncWrite + Unpin>(
|
|||||||
config: CloneableConfig,
|
config: CloneableConfig,
|
||||||
local_addr: SocketAddr,
|
local_addr: SocketAddr,
|
||||||
client_addr: SocketAddr,
|
client_addr: SocketAddr,
|
||||||
allowed_type: AllowedType,
|
|
||||||
in_filter: StanzaFilter,
|
in_filter: StanzaFilter,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
// we naively read 1 byte at a time, which buffering significantly speeds up
|
// we naively read 1 byte at a time, which buffering significantly speeds up
|
||||||
@ -165,17 +150,7 @@ async fn shuffle_rd_wr_filter<R: AsyncRead + Unpin, W: AsyncWrite + Unpin>(
|
|||||||
// now read to figure out client vs server
|
// now read to figure out client vs server
|
||||||
let (stream_open, is_c2s, mut in_rd, mut in_filter) = stream_preamble(StanzaReader(in_rd), client_addr, in_filter).await?;
|
let (stream_open, is_c2s, mut in_rd, mut in_filter) = stream_preamble(StanzaReader(in_rd), client_addr, in_filter).await?;
|
||||||
|
|
||||||
let target = if is_c2s {
|
let target = if is_c2s { config.c2s_target } else { config.s2s_target };
|
||||||
if allowed_type == AllowedType::ServerOnly {
|
|
||||||
bail!("c2s requested when only s2s allowed");
|
|
||||||
}
|
|
||||||
config.c2s_target
|
|
||||||
} else {
|
|
||||||
if allowed_type == AllowedType::ClientOnly {
|
|
||||||
bail!("s2s requested when only c2s allowed");
|
|
||||||
}
|
|
||||||
config.s2s_target
|
|
||||||
};
|
|
||||||
|
|
||||||
println!("INFO: {} is_c2s: {}, target: {}", client_addr, is_c2s, target);
|
println!("INFO: {} is_c2s: {}, target: {}", client_addr, is_c2s, target);
|
||||||
|
|
||||||
|
35
src/quic.rs
35
src/quic.rs
@ -18,13 +18,8 @@ pub async fn quic_connect(target: SocketAddr, server_name: &str, is_c2s: bool) -
|
|||||||
let quinn::NewConnection { connection, .. } = endpoint.connect(&target, server_name).unwrap().await?;
|
let quinn::NewConnection { connection, .. } = endpoint.connect(&target, server_name).unwrap().await?;
|
||||||
debug!("[client] connected: addr={}", connection.remote_address());
|
debug!("[client] connected: addr={}", connection.remote_address());
|
||||||
|
|
||||||
if is_c2s {
|
let (wrt, rd) = connection.open_bi().await?;
|
||||||
let (wrt, rd) = connection.open_bi().await?;
|
Ok((Box::new(wrt), Box::new(rd)))
|
||||||
Ok((Box::new(wrt), Box::new(rd)))
|
|
||||||
} else {
|
|
||||||
let wrt = connection.open_uni().await?;
|
|
||||||
Ok((Box::new(wrt), Box::new(NoopIo)))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
@ -86,25 +81,13 @@ pub fn spawn_quic_listener(local_addr: SocketAddr, config: CloneableConfig, serv
|
|||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
println!("INFO: {} quic connected", client_addr);
|
println!("INFO: {} quic connected", client_addr);
|
||||||
|
|
||||||
loop {
|
while let Some(Ok((wrt, rd))) = new_conn.bi_streams.next().await {
|
||||||
tokio::select! {
|
let config = config.clone();
|
||||||
Some(Ok((wrt, rd))) = new_conn.bi_streams.next() => {
|
tokio::spawn(async move {
|
||||||
let config = config.clone();
|
if let Err(e) = shuffle_rd_wr(rd, wrt, config, local_addr, client_addr).await {
|
||||||
tokio::spawn(async move {
|
eprintln!("ERROR: {} {}", client_addr, e);
|
||||||
if let Err(e) = shuffle_rd_wr(rd, wrt, config, local_addr, client_addr, AllowedType::ClientOnly).await {
|
}
|
||||||
eprintln!("ERROR: {} {}", client_addr, e);
|
});
|
||||||
}
|
|
||||||
});
|
|
||||||
},
|
|
||||||
Some(Ok(rd)) = new_conn.uni_streams.next() => {
|
|
||||||
let config = config.clone();
|
|
||||||
tokio::spawn(async move {
|
|
||||||
if let Err(e) = shuffle_rd_wr(rd, NoopIo, config, local_addr, client_addr, AllowedType::ServerOnly).await {
|
|
||||||
eprintln!("ERROR: {} {}", client_addr, e);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
#[allow(unreachable_code)]
|
#[allow(unreachable_code)]
|
||||||
|
36
src/srv.rs
36
src/srv.rs
@ -170,27 +170,25 @@ pub async fn srv_connect(
|
|||||||
out_wr.flush().await?;
|
out_wr.flush().await?;
|
||||||
|
|
||||||
let mut server_response = Vec::new();
|
let mut server_response = Vec::new();
|
||||||
if is_c2s {
|
// let's read to first <stream:stream to make sure we are successfully connected to a real XMPP server
|
||||||
// let's read to first <stream:stream to make sure we are successfully connected to a real XMPP server
|
let mut stream_received = false;
|
||||||
let mut stream_received = false;
|
while let Ok(Some(buf)) = out_rd.next(&mut in_filter).await {
|
||||||
while let Ok(Some(buf)) = out_rd.next(&mut in_filter).await {
|
debug!("received pre-tls stanza: {} '{}'", domain, to_str(&buf));
|
||||||
debug!("received pre-tls stanza: {} '{}'", domain, to_str(&buf));
|
if buf.starts_with(b"<?xml ") {
|
||||||
if buf.starts_with(b"<?xml ") {
|
server_response.extend_from_slice(&buf);
|
||||||
server_response.extend_from_slice(&buf);
|
} else if buf.starts_with(b"<stream:stream ") {
|
||||||
} else if buf.starts_with(b"<stream:stream ") {
|
server_response.extend_from_slice(&buf);
|
||||||
server_response.extend_from_slice(&buf);
|
stream_received = true;
|
||||||
stream_received = true;
|
break;
|
||||||
break;
|
} else {
|
||||||
} else {
|
debug!("bad pre-tls stanza: {}", to_str(&buf));
|
||||||
debug!("bad pre-tls stanza: {}", to_str(&buf));
|
break;
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !stream_received {
|
|
||||||
debug!("bad server response, going to next record");
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if !stream_received {
|
||||||
|
debug!("bad server response, going to next record");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
return Ok((Box::new(out_wr), out_rd, server_response));
|
return Ok((Box::new(out_wr), out_rd, server_response));
|
||||||
}
|
}
|
||||||
|
@ -205,5 +205,5 @@ async fn handle_tls_connection(mut stream: tokio::net::TcpStream, client_addr: S
|
|||||||
|
|
||||||
let (in_rd, in_wr) = tokio::io::split(stream);
|
let (in_rd, in_wr) = tokio::io::split(stream);
|
||||||
|
|
||||||
shuffle_rd_wr_filter(in_rd, in_wr, config, local_addr, client_addr, AllowedType::Any, in_filter).await
|
shuffle_rd_wr_filter(in_rd, in_wr, config, local_addr, client_addr, in_filter).await
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user