From 9b2866c5b46b5238fe3c7d2e94dfd60c6c7d26a9 Mon Sep 17 00:00:00 2001 From: moparisthebest Date: Mon, 16 Dec 2019 11:52:17 -0500 Subject: [PATCH] Implement TLS server side, refactor to allow conditional TLS support at compile time --- src/bin/wireguard-proxy.rs | 28 +++++-- src/error.rs | 42 ++++++++++ src/lib.rs | 160 +++++++++++++++++++------------------ src/notls.rs | 49 ++++++++++++ src/openssl.rs | 102 +++++++++++++++++++++++ test.sh | 64 +++++++++++---- 6 files changed, 347 insertions(+), 98 deletions(-) create mode 100644 src/error.rs create mode 100644 src/notls.rs create mode 100644 src/openssl.rs diff --git a/src/bin/wireguard-proxy.rs b/src/bin/wireguard-proxy.rs index d070124..cb27c57 100644 --- a/src/bin/wireguard-proxy.rs +++ b/src/bin/wireguard-proxy.rs @@ -36,6 +36,10 @@ fn main() { listen on for UDP packets to send back over the TCP connection, default: 127.0.0.1:30000-40000 + -tk, --tls-key TLS key to listen with, + requires --tls-cert also + -tc, --tls-cert TLS cert to listen with, + requires --tls-key also Common Options: -h, --help print this usage text @@ -61,15 +65,18 @@ fn client(tcp_target: &str, socket_timeout: u64, args: Args) { socket_timeout, ); + let tls = args.flag("--tls"); + println!( - "udp_host: {}, tcp_target: {}, socket_timeout: {:?}", + "udp_host: {}, tcp_target: {}, socket_timeout: {:?}, tls: {}", proxy_client.udp_host, proxy_client.tcp_target, proxy_client.socket_timeout, + tls, ); - if args.flag("--tls") { - proxy_client.start_tls().expect("error running tls proxy_client"); + if tls { + proxy_client.start_tls(tcp_target.split(":").next().expect("cannot extract hostname from --tcp-target")).expect("error running tls proxy_client"); } else { proxy_client.start().expect("error running proxy_client"); } @@ -107,12 +114,23 @@ fn server(tcp_host: &str, socket_timeout: u64, args: Args) { socket_timeout, ); + let tls_key = args.get_option(&["-tk", "--tls-key"]); + let tls_cert = args.get_option(&["-tc", "--tls-cert"]); + println!( - "udp_target: {}, udp_bind_host_range: {}, socket_timeout: {:?}", + "udp_target: {}, udp_bind_host_range: {}, socket_timeout: {:?}, tls_key: {:?}, tls_cert: {:?}", proxy_server.client_handler.udp_target, udp_bind_host_range_str, proxy_server.client_handler.socket_timeout, + tls_key, + tls_cert, ); - proxy_server.start().expect("error running proxy_server"); + if tls_key.is_some() && tls_cert.is_some() { + proxy_server.start_tls(tls_key.unwrap(), tls_cert.unwrap()).expect("error running TLS proxy_server"); + } else if tls_key.is_none() && tls_cert.is_none() { + proxy_server.start().expect("error running proxy_server"); + } else { + println!("Error: if one of --tls-key or --tls-cert is specified both must be!"); + } } diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..8c97de8 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,42 @@ +use core::result; + +use std::error::Error as StdError; + +pub type IoResult = result::Result; + +pub type Result = result::Result; + +#[derive(Debug)] +pub struct Error(String); + +impl Error { + pub fn new(msg: &str) -> Error { + Error(msg.to_owned()) + } +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.write_str(&self.0) + } +} + +impl std::error::Error for Error { + fn description(&self) -> &str { + &self.0 + } +} + +impl From for Error { + fn from(value: std::io::Error) -> Self { + Error::new(value.description()) + } +} + +/* +impl From for Error { + fn from(value: std::option::NoneError) -> Self { + Error::new(value.description()) + } +} +*/ \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index c6ebef7..86d8ae3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,8 +4,25 @@ use std::str::FromStr; use std::sync::Arc; use std::thread; use std::time::Duration; -use std::cell::{UnsafeCell}; -use openssl::ssl::{SslConnector, SslMethod, SslStream, SslVerifyMode}; + +mod error; +use error::Result; + +#[cfg(feature = "tls")] +#[path = ""] +mod tls { + pub mod openssl; + pub use crate::tls::openssl::{TlsStream, TlsListener}; +} + +#[cfg(not(feature = "tls"))] +#[path = ""] +mod tls { + pub mod notls; + pub use crate::tls::notls::{TlsStream, TlsListener}; +} + +use tls::{TlsStream, TlsListener}; pub struct Args<'a> { args: &'a Vec, @@ -70,14 +87,14 @@ impl + Send + 'static> TcpUdpPipe { } } - pub fn try_clone(&self) -> std::io::Result> { + pub fn try_clone(&self) -> Result> { Ok(TcpUdpPipe::new( self.tcp_stream.try_clone()?, self.udp_socket.try_clone()?, )) } - pub fn shuffle_after_first_udp(&mut self) -> std::io::Result { + pub fn shuffle_after_first_udp(&mut self) -> Result { let (len, src_addr) = self.udp_socket.recv_from(&mut self.buf[2..])?; println!("first packet from {}, connecting to that", src_addr); @@ -88,12 +105,12 @@ impl + Send + 'static> TcpUdpPipe { self.shuffle() } - pub fn udp_to_tcp(&mut self) -> std::io::Result<()> { + pub fn udp_to_tcp(&mut self) -> Result<()> { let len = self.udp_socket.recv(&mut self.buf[2..])?; self.send_udp(len) } - fn send_udp(&mut self, len: usize) -> std::io::Result<()> { + fn send_udp(&mut self, len: usize) -> Result<()> { println!("udp got len: {}", len); self.buf[0] = ((len >> 8) & 0xFF) as u8; @@ -102,23 +119,23 @@ impl + Send + 'static> TcpUdpPipe { //let test_len = ((self.buf[0] as usize) << 8) + self.buf[1] as usize; //println!("tcp sending test_len: {}", test_len); - self.tcp_stream.write_all(&self.buf[..len + 2]) + Ok(self.tcp_stream.write_all(&self.buf[..len + 2])?) // todo: do this? self.tcp_stream.flush() } - pub fn tcp_to_udp(&mut self) -> std::io::Result { + pub fn tcp_to_udp(&mut self) -> Result { self.tcp_stream.read_exact(&mut self.buf[..2])?; let len = ((self.buf[0] as usize) << 8) + self.buf[1] as usize; println!("tcp expecting len: {}", len); self.tcp_stream.read_exact(&mut self.buf[..len])?; println!("tcp got len: {}", len); - self.udp_socket.send(&self.buf[..len]) + Ok(self.udp_socket.send(&self.buf[..len])?) //let sent = udp_socket.send_to(&buf[..len], &self.udp_target)?; //assert_eq!(sent, len); } - pub fn shuffle(&mut self) -> std::io::Result { + pub fn shuffle(&mut self) -> Result { let mut udp_pipe_clone = self.try_clone()?; thread::spawn(move || loop { udp_pipe_clone @@ -150,19 +167,19 @@ impl ProxyClient { } } - fn tcp_connect(&self) -> std::io::Result { + fn tcp_connect(&self) -> Result { let tcp_stream = TcpStream::connect(&self.tcp_target)?; tcp_stream.set_read_timeout(self.socket_timeout)?; Ok(tcp_stream) } - fn udp_connect(&self) -> std::io::Result { + fn udp_connect(&self) -> Result { let udp_socket = UdpSocket::bind(&self.udp_host)?; udp_socket.set_read_timeout(self.socket_timeout)?; Ok(udp_socket) } - pub fn start(&self) -> std::io::Result { + pub fn start(&self) -> Result { let tcp_stream = self.tcp_connect()?; let udp_socket = self.udp_connect()?; @@ -171,14 +188,10 @@ impl ProxyClient { TcpUdpPipe::new(tcp_stream, udp_socket).shuffle_after_first_udp() } - pub fn start_tls(&self) -> std::io::Result { + pub fn start_tls(&self, hostname: &str) -> Result { let tcp_stream = self.tcp_connect()?; - let mut connector = SslConnector::builder(SslMethod::tls()).unwrap().build().configure().unwrap(); - connector.set_verify_hostname(false); - connector.set_verify(SslVerifyMode::NONE); - let tcp_stream = connector.connect(self.tcp_target.split(":").next().unwrap(), tcp_stream).unwrap(); - let tcp_stream = OpensslCell { sess: Arc::new(UnsafeCell::new(tcp_stream)) }; + let tcp_stream = TlsStream::client(hostname, tcp_stream)?; let udp_socket = self.udp_connect()?; @@ -189,68 +202,18 @@ impl ProxyClient { pub trait TryClone { - fn try_clone(&self) -> std::io::Result; + fn try_clone(&self) -> Result; } impl TryClone for UdpSocket { - fn try_clone(&self) -> std::io::Result { - self.try_clone() + fn try_clone(&self) -> Result { + Ok(self.try_clone()?) } } impl TryClone for TcpStream { - fn try_clone(&self) -> std::io::Result { - self.try_clone() - } -} - -impl TryClone for OpensslCell { - fn try_clone(&self) -> std::io::Result { - Ok(self.clone()) - } -} - -pub struct OpensslCell { - sess: Arc>>, -} - -unsafe impl Sync for OpensslCell {} -unsafe impl Send for OpensslCell {} - -impl Clone for OpensslCell { - fn clone(&self) -> Self { - OpensslCell { - sess: self.sess.clone(), - } - } -} - -impl OpensslCell { - pub fn borrow(&self) -> &SslStream { - unsafe { - &*self.sess.get() - } - } - pub fn borrow_mut(&self) -> &mut SslStream { - unsafe { - &mut *self.sess.get() - } - } -} - -impl Read for OpensslCell { - fn read(&mut self, buf: &mut [u8]) -> Result { - self.borrow_mut().read(buf) - } -} - -impl Write for OpensslCell { - fn write(&mut self, buf: &[u8]) -> Result { - self.borrow_mut().write(buf) - } - - fn flush(&mut self) -> Result<(), std::io::Error> { - self.borrow_mut().flush() + fn try_clone(&self) -> Result { + Ok(self.try_clone()?) } } @@ -284,7 +247,7 @@ impl ProxyServer { } } - pub fn start(&self) -> std::io::Result<()> { + pub fn start(&self) -> Result<()> { let listener = TcpListener::bind(&self.tcp_host)?; println!("Listening for connections on {}", &self.tcp_host); @@ -292,6 +255,8 @@ impl ProxyServer { match stream { Ok(stream) => { let client_handler = self.client_handler.clone(); + client_handler.set_tcp_options(&stream).expect("cannot set tcp options"); + thread::spawn(move || { client_handler .handle_client(stream) @@ -305,6 +270,34 @@ impl ProxyServer { } Ok(()) } + + pub fn start_tls(&self, tls_key: &str, tls_cert: &str) -> Result<()> { + let tls_listener = Arc::new(TlsListener::new(tls_key, tls_cert)?); + + let listener = TcpListener::bind(&self.tcp_host)?; + println!("Listening for TLS connections on {}", &self.tcp_host); + + for stream in listener.incoming() { + match stream { + Ok(stream) => { + let client_handler = self.client_handler.clone(); + client_handler.set_tcp_options(&stream).expect("cannot set tcp options"); + + let tls_listener = tls_listener.clone(); + thread::spawn(move || { + let stream = tls_listener.wrap(stream).expect("cannot wrap with tls"); + client_handler + .handle_client_tls(stream) + .expect("error handling connection") + }); + } + Err(e) => { + println!("Unable to connect: {}", e); + } + } + } + Ok(()) + } } pub struct ProxyServerClientHandler { @@ -316,9 +309,7 @@ pub struct ProxyServerClientHandler { } impl ProxyServerClientHandler { - pub fn handle_client(&self, tcp_stream: TcpStream) -> std::io::Result { - tcp_stream.set_read_timeout(self.socket_timeout)?; - + fn udp_bind(&self) -> Result { let mut port = self.udp_low_port; let udp_socket = loop { match UdpSocket::bind((&self.udp_host[..], port)) { @@ -333,7 +324,18 @@ impl ProxyServerClientHandler { }; udp_socket.set_read_timeout(self.socket_timeout)?; udp_socket.connect(&self.udp_target)?; + Ok(udp_socket) + } - TcpUdpPipe::new(tcp_stream, udp_socket).shuffle() + pub fn set_tcp_options(&self, tcp_stream: &TcpStream) -> Result<()> { + Ok(tcp_stream.set_read_timeout(self.socket_timeout)?) + } + + pub fn handle_client(&self, tcp_stream: TcpStream) -> Result { + TcpUdpPipe::new(tcp_stream, self.udp_bind()?).shuffle() + } + + pub fn handle_client_tls(&self, tcp_stream: TlsStream) -> Result { + TcpUdpPipe::new(tcp_stream, self.udp_bind()?).shuffle() } } diff --git a/src/notls.rs b/src/notls.rs new file mode 100644 index 0000000..76bec44 --- /dev/null +++ b/src/notls.rs @@ -0,0 +1,49 @@ +use std::net::TcpStream; +use crate::TryClone; +use std::io::{Read, Write}; +use crate::error::*; + +fn err() -> Error { + Error::new("Error: compiled without TLS support") +} + +pub struct TlsStream; + +impl TlsStream { + pub fn client(_host_name: &str, _tcp_stream: TcpStream) -> Result { + Err(err()) + } +} + +impl TryClone for TlsStream { + fn try_clone(&self) -> Result { + Err(err()) + } +} + +impl Read for TlsStream { + fn read(&mut self, _buf: &mut [u8]) -> IoResult { + unimplemented!() + } +} + +impl Write for TlsStream { + fn write(&mut self, _buf: &[u8]) -> IoResult { + unimplemented!() + } + + fn flush(&mut self) -> IoResult<()> { + unimplemented!() + } +} + +pub struct TlsListener; + +impl TlsListener { + pub fn new(_tls_key: &str, _tls_cert: &str) -> Result { + Err(err()) + } + pub fn wrap(&self, _tcp_stream: TcpStream) -> Result { + Err(err()) + } +} \ No newline at end of file diff --git a/src/openssl.rs b/src/openssl.rs new file mode 100644 index 0000000..69a9d25 --- /dev/null +++ b/src/openssl.rs @@ -0,0 +1,102 @@ + +use openssl::ssl::{SslConnector, SslMethod, SslStream, SslVerifyMode, SslAcceptor, SslFiletype, HandshakeError}; +use std::sync::Arc; +use std::cell::UnsafeCell; +use std::net::TcpStream; +use crate::TryClone; +use std::io::{Read, Write}; + +use crate::error::*; +use std::error::Error as StdError; + +impl TryClone for TlsStream { + fn try_clone(&self) -> Result { + Ok(self.clone()) + } +} + +pub struct TlsStream { + sess: Arc>>, +} + +impl TlsStream { + fn new(stream: SslStream) -> TlsStream { + TlsStream { + sess: Arc::new(UnsafeCell::new(stream)) + } + } + pub fn client(host_name: &str, tcp_stream: TcpStream) -> Result { + let mut connector = SslConnector::builder(SslMethod::tls())?.build().configure()?; + connector.set_verify_hostname(false); + connector.set_verify(SslVerifyMode::NONE); + let tcp_stream = connector.connect(host_name, tcp_stream)?; + Ok(TlsStream::new(tcp_stream)) + } +} + +unsafe impl Sync for TlsStream {} +unsafe impl Send for TlsStream {} + +impl Clone for TlsStream { + fn clone(&self) -> Self { + TlsStream { + sess: self.sess.clone(), + } + } +} + +impl TlsStream { + pub fn borrow_mut(&self) -> &mut SslStream { + unsafe { + &mut *self.sess.get() + } + } +} + +impl Read for TlsStream { + fn read(&mut self, buf: &mut [u8]) -> IoResult { + self.borrow_mut().read(buf) + } +} + +impl Write for TlsStream { + fn write(&mut self, buf: &[u8]) -> IoResult { + self.borrow_mut().write(buf) + } + + fn flush(&mut self) -> IoResult<()> { + self.borrow_mut().flush() + } +} + +pub struct TlsListener { + acceptor: SslAcceptor, +} + +impl TlsListener { + pub fn new(tls_key: &str, tls_cert: &str) -> Result { + let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls())?; + acceptor.set_private_key_file(tls_key, SslFiletype::PEM)?; + acceptor.set_certificate_chain_file(tls_cert)?; + acceptor.check_private_key()?; + let acceptor = acceptor.build(); + Ok(TlsListener { + acceptor + }) + } + pub fn wrap(&self, tcp_stream: TcpStream) -> Result { + Ok(TlsStream::new(self.acceptor.accept(tcp_stream)?)) + } +} + +impl From for Error { + fn from(value: openssl::error::ErrorStack) -> Self { + Error::new(value.description()) + } +} + +impl From> for Error { + fn from(value: HandshakeError) -> Self { + Error::new(value.description()) + } +} diff --git a/test.sh b/test.sh index d3904c2..c7473d9 100755 --- a/test.sh +++ b/test.sh @@ -1,33 +1,67 @@ #!/bin/sh set -x -# always run this clean +# first run without TLS cargo clean +cargo build --release --no-default-features + +export PATH="$(pwd)/target/release:$PATH" # first make sure udp-test succeeds running against itself -cargo run --release --bin udp-test || exit 1 +udp-test || exit 1 # now run udp-test without spawning other processes -cargo run --release --bin udp-test -- -is || exit 1 +udp-test -is || exit 1 # now run proxyd pointing to udp-test -cargo run --release --bin wireguard-proxy -- -th 127.0.0.1:5555 -ut 127.0.0.1:51822 & +wireguard-proxy -th 127.0.0.1:5555 -ut 127.0.0.1:51822 & proxyd_pid=$! # wait for ports to be set up, this is fragile... -sleep 1 +sleep 5 # proxy pointing to proxyd -#cargo run --release --bin wireguard-proxy -- -tt 127.0.0.1:5555 & - -echo -e '\n\n\n\n\n\n\n' | openssl req -new -x509 -days 365 -nodes -out cert.pem -keyout cert.key -socat OPENSSL-LISTEN:5554,bind=127.0.0.1,cert=./cert.pem,key=./cert.key,verify=0 tcp4-connect:127.0.0.1:5555 & - -cargo run --release --bin wireguard-proxy -- -tt 127.0.0.1:5554 --tls & - +wireguard-proxy -tt 127.0.0.1:5555 & proxy_pid=$! # wait for ports to be set up, this is fragile... sleep 1 # and udp-test pointing to proxy, which then hops to proxyd, and finally back to udp-test -cargo run --release --bin udp-test -- -uh 127.0.0.1:51822 +udp-test -uh 127.0.0.1:51822 +udp_exit=$? + +kill $proxyd_pid $proxy_pid + +[ $udp_exit -ne 0 ] && exit $udp_exit + +# now run udp-test essentially just like the script above, but all in rust +udp-test -s || exit 1 + +echo "non-tls tests passed!" + +echo -e '\n\n\n\n\n\n\n' | openssl req -new -x509 -days 365 -nodes -out cert.pem -keyout cert.key + +# first run without TLS +cargo clean +cargo build --release + +export PATH="$(pwd)/target/release:$PATH" + +# first make sure udp-test succeeds running against itself +udp-test || exit 1 + +# now run udp-test without spawning other processes +udp-test -is || exit 1 + +# now run proxyd pointing to udp-test +wireguard-proxy -th 127.0.0.1:5555 -ut 127.0.0.1:51822 --tls-key cert.key --tls-cert cert.pem & +proxyd_pid=$! +# wait for ports to be set up, this is fragile... +sleep 5 +# proxy pointing to proxyd +wireguard-proxy -tt 127.0.0.1:5555 --tls & +proxy_pid=$! +# wait for ports to be set up, this is fragile... +sleep 1 +# and udp-test pointing to proxy, which then hops to proxyd, and finally back to udp-test +udp-test -uh 127.0.0.1:51822 udp_exit=$? kill $proxyd_pid $proxy_pid @@ -37,4 +71,6 @@ rm -f cert.pem cert.key [ $udp_exit -ne 0 ] && exit $udp_exit # now run udp-test essentially just like the script above, but all in rust -cargo run --release --bin udp-test -- -s +udp-test -s + +exit $?