diff --git a/src/main.rs b/src/main.rs index 29a2f48..9b2fadc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -131,23 +131,9 @@ impl Config { } } -#[derive(PartialEq)] -pub enum AllowedType { - ClientOnly, - ServerOnly, - Any, -} - -async fn shuffle_rd_wr( - in_rd: R, - in_wr: W, - config: CloneableConfig, - local_addr: SocketAddr, - client_addr: SocketAddr, - allowed_type: AllowedType, -) -> Result<()> { +async fn shuffle_rd_wr(in_rd: R, in_wr: W, config: CloneableConfig, local_addr: SocketAddr, client_addr: SocketAddr) -> Result<()> { let filter = StanzaFilter::new(config.max_stanza_size_bytes); - shuffle_rd_wr_filter(in_rd, in_wr, config, local_addr, client_addr, allowed_type, filter).await + shuffle_rd_wr_filter(in_rd, in_wr, config, local_addr, client_addr, filter).await } async fn shuffle_rd_wr_filter( @@ -156,7 +142,6 @@ async fn shuffle_rd_wr_filter( config: CloneableConfig, local_addr: SocketAddr, client_addr: SocketAddr, - allowed_type: AllowedType, in_filter: StanzaFilter, ) -> Result<()> { // we naively read 1 byte at a time, which buffering significantly speeds up @@ -165,17 +150,7 @@ async fn shuffle_rd_wr_filter( // now read to figure out client vs server let (stream_open, is_c2s, mut in_rd, mut in_filter) = stream_preamble(StanzaReader(in_rd), client_addr, in_filter).await?; - let target = if is_c2s { - if allowed_type == AllowedType::ServerOnly { - bail!("c2s requested when only s2s allowed"); - } - config.c2s_target - } else { - if allowed_type == AllowedType::ClientOnly { - bail!("s2s requested when only c2s allowed"); - } - config.s2s_target - }; + let target = if is_c2s { config.c2s_target } else { config.s2s_target }; println!("INFO: {} is_c2s: {}, target: {}", client_addr, is_c2s, target); diff --git a/src/quic.rs b/src/quic.rs index 41c2de1..ccac99d 100644 --- a/src/quic.rs +++ b/src/quic.rs @@ -18,13 +18,8 @@ pub async fn quic_connect(target: SocketAddr, server_name: &str, is_c2s: bool) - let quinn::NewConnection { connection, .. } = endpoint.connect(&target, server_name).unwrap().await?; debug!("[client] connected: addr={}", connection.remote_address()); - if is_c2s { - let (wrt, rd) = connection.open_bi().await?; - Ok((Box::new(wrt), Box::new(rd))) - } else { - let wrt = connection.open_uni().await?; - Ok((Box::new(wrt), Box::new(NoopIo))) - } + let (wrt, rd) = connection.open_bi().await?; + Ok((Box::new(wrt), Box::new(rd))) } impl Config { @@ -86,25 +81,13 @@ pub fn spawn_quic_listener(local_addr: SocketAddr, config: CloneableConfig, serv tokio::spawn(async move { println!("INFO: {} quic connected", client_addr); - loop { - tokio::select! { - Some(Ok((wrt, rd))) = new_conn.bi_streams.next() => { - let config = config.clone(); - tokio::spawn(async move { - if let Err(e) = shuffle_rd_wr(rd, wrt, config, local_addr, client_addr, AllowedType::ClientOnly).await { - eprintln!("ERROR: {} {}", client_addr, e); - } - }); - }, - Some(Ok(rd)) = new_conn.uni_streams.next() => { - let config = config.clone(); - tokio::spawn(async move { - if let Err(e) = shuffle_rd_wr(rd, NoopIo, config, local_addr, client_addr, AllowedType::ServerOnly).await { - eprintln!("ERROR: {} {}", client_addr, e); - } - }); - }, - } + while let Some(Ok((wrt, rd))) = new_conn.bi_streams.next().await { + let config = config.clone(); + tokio::spawn(async move { + if let Err(e) = shuffle_rd_wr(rd, wrt, config, local_addr, client_addr).await { + eprintln!("ERROR: {} {}", client_addr, e); + } + }); } }); #[allow(unreachable_code)] diff --git a/src/srv.rs b/src/srv.rs index b3ebc76..a75e543 100644 --- a/src/srv.rs +++ b/src/srv.rs @@ -170,27 +170,25 @@ pub async fn srv_connect( out_wr.flush().await?; let mut server_response = Vec::new(); - if is_c2s { - // let's read to first