Implement TLS server side, refactor to allow conditional TLS support at compile time

This commit is contained in:
Travis Burtrum 2019-12-16 11:52:17 -05:00
parent 4c6f4258b4
commit 9b2866c5b4
6 changed files with 347 additions and 98 deletions

View File

@ -36,6 +36,10 @@ fn main() {
listen on for UDP packets to send listen on for UDP packets to send
back over the TCP connection, back over the TCP connection,
default: 127.0.0.1:30000-40000 default: 127.0.0.1:30000-40000
-tk, --tls-key <ip:port> TLS key to listen with,
requires --tls-cert also
-tc, --tls-cert <ip:port> TLS cert to listen with,
requires --tls-key also
Common Options: Common Options:
-h, --help print this usage text -h, --help print this usage text
@ -61,15 +65,18 @@ fn client(tcp_target: &str, socket_timeout: u64, args: Args) {
socket_timeout, socket_timeout,
); );
let tls = args.flag("--tls");
println!( println!(
"udp_host: {}, tcp_target: {}, socket_timeout: {:?}", "udp_host: {}, tcp_target: {}, socket_timeout: {:?}, tls: {}",
proxy_client.udp_host, proxy_client.udp_host,
proxy_client.tcp_target, proxy_client.tcp_target,
proxy_client.socket_timeout, proxy_client.socket_timeout,
tls,
); );
if args.flag("--tls") { if tls {
proxy_client.start_tls().expect("error running tls proxy_client"); proxy_client.start_tls(tcp_target.split(":").next().expect("cannot extract hostname from --tcp-target")).expect("error running tls proxy_client");
} else { } else {
proxy_client.start().expect("error running proxy_client"); proxy_client.start().expect("error running proxy_client");
} }
@ -107,12 +114,23 @@ fn server(tcp_host: &str, socket_timeout: u64, args: Args) {
socket_timeout, socket_timeout,
); );
let tls_key = args.get_option(&["-tk", "--tls-key"]);
let tls_cert = args.get_option(&["-tc", "--tls-cert"]);
println!( println!(
"udp_target: {}, udp_bind_host_range: {}, socket_timeout: {:?}", "udp_target: {}, udp_bind_host_range: {}, socket_timeout: {:?}, tls_key: {:?}, tls_cert: {:?}",
proxy_server.client_handler.udp_target, proxy_server.client_handler.udp_target,
udp_bind_host_range_str, udp_bind_host_range_str,
proxy_server.client_handler.socket_timeout, proxy_server.client_handler.socket_timeout,
tls_key,
tls_cert,
); );
if tls_key.is_some() && tls_cert.is_some() {
proxy_server.start_tls(tls_key.unwrap(), tls_cert.unwrap()).expect("error running TLS proxy_server");
} else if tls_key.is_none() && tls_cert.is_none() {
proxy_server.start().expect("error running proxy_server"); proxy_server.start().expect("error running proxy_server");
} else {
println!("Error: if one of --tls-key or --tls-cert is specified both must be!");
}
} }

42
src/error.rs Normal file
View File

@ -0,0 +1,42 @@
use core::result;
use std::error::Error as StdError;
pub type IoResult<T> = result::Result<T, std::io::Error>;
pub type Result<T> = result::Result<T, Error>;
#[derive(Debug)]
pub struct Error(String);
impl Error {
pub fn new(msg: &str) -> Error {
Error(msg.to_owned())
}
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.write_str(&self.0)
}
}
impl std::error::Error for Error {
fn description(&self) -> &str {
&self.0
}
}
impl From<std::io::Error> for Error {
fn from(value: std::io::Error) -> Self {
Error::new(value.description())
}
}
/*
impl From<std::option::NoneError> for Error {
fn from(value: std::option::NoneError) -> Self {
Error::new(value.description())
}
}
*/

View File

@ -4,8 +4,25 @@ use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
use std::thread; use std::thread;
use std::time::Duration; use std::time::Duration;
use std::cell::{UnsafeCell};
use openssl::ssl::{SslConnector, SslMethod, SslStream, SslVerifyMode}; mod error;
use error::Result;
#[cfg(feature = "tls")]
#[path = ""]
mod tls {
pub mod openssl;
pub use crate::tls::openssl::{TlsStream, TlsListener};
}
#[cfg(not(feature = "tls"))]
#[path = ""]
mod tls {
pub mod notls;
pub use crate::tls::notls::{TlsStream, TlsListener};
}
use tls::{TlsStream, TlsListener};
pub struct Args<'a> { pub struct Args<'a> {
args: &'a Vec<String>, args: &'a Vec<String>,
@ -70,14 +87,14 @@ impl<T: Write + Read + TryClone<T> + Send + 'static> TcpUdpPipe<T> {
} }
} }
pub fn try_clone(&self) -> std::io::Result<TcpUdpPipe<T>> { pub fn try_clone(&self) -> Result<TcpUdpPipe<T>> {
Ok(TcpUdpPipe::new( Ok(TcpUdpPipe::new(
self.tcp_stream.try_clone()?, self.tcp_stream.try_clone()?,
self.udp_socket.try_clone()?, self.udp_socket.try_clone()?,
)) ))
} }
pub fn shuffle_after_first_udp(&mut self) -> std::io::Result<usize> { pub fn shuffle_after_first_udp(&mut self) -> Result<usize> {
let (len, src_addr) = self.udp_socket.recv_from(&mut self.buf[2..])?; let (len, src_addr) = self.udp_socket.recv_from(&mut self.buf[2..])?;
println!("first packet from {}, connecting to that", src_addr); println!("first packet from {}, connecting to that", src_addr);
@ -88,12 +105,12 @@ impl<T: Write + Read + TryClone<T> + Send + 'static> TcpUdpPipe<T> {
self.shuffle() self.shuffle()
} }
pub fn udp_to_tcp(&mut self) -> std::io::Result<()> { pub fn udp_to_tcp(&mut self) -> Result<()> {
let len = self.udp_socket.recv(&mut self.buf[2..])?; let len = self.udp_socket.recv(&mut self.buf[2..])?;
self.send_udp(len) self.send_udp(len)
} }
fn send_udp(&mut self, len: usize) -> std::io::Result<()> { fn send_udp(&mut self, len: usize) -> Result<()> {
println!("udp got len: {}", len); println!("udp got len: {}", len);
self.buf[0] = ((len >> 8) & 0xFF) as u8; self.buf[0] = ((len >> 8) & 0xFF) as u8;
@ -102,23 +119,23 @@ impl<T: Write + Read + TryClone<T> + Send + 'static> TcpUdpPipe<T> {
//let test_len = ((self.buf[0] as usize) << 8) + self.buf[1] as usize; //let test_len = ((self.buf[0] as usize) << 8) + self.buf[1] as usize;
//println!("tcp sending test_len: {}", test_len); //println!("tcp sending test_len: {}", test_len);
self.tcp_stream.write_all(&self.buf[..len + 2]) Ok(self.tcp_stream.write_all(&self.buf[..len + 2])?)
// todo: do this? self.tcp_stream.flush() // todo: do this? self.tcp_stream.flush()
} }
pub fn tcp_to_udp(&mut self) -> std::io::Result<usize> { pub fn tcp_to_udp(&mut self) -> Result<usize> {
self.tcp_stream.read_exact(&mut self.buf[..2])?; self.tcp_stream.read_exact(&mut self.buf[..2])?;
let len = ((self.buf[0] as usize) << 8) + self.buf[1] as usize; let len = ((self.buf[0] as usize) << 8) + self.buf[1] as usize;
println!("tcp expecting len: {}", len); println!("tcp expecting len: {}", len);
self.tcp_stream.read_exact(&mut self.buf[..len])?; self.tcp_stream.read_exact(&mut self.buf[..len])?;
println!("tcp got len: {}", len); println!("tcp got len: {}", len);
self.udp_socket.send(&self.buf[..len]) Ok(self.udp_socket.send(&self.buf[..len])?)
//let sent = udp_socket.send_to(&buf[..len], &self.udp_target)?; //let sent = udp_socket.send_to(&buf[..len], &self.udp_target)?;
//assert_eq!(sent, len); //assert_eq!(sent, len);
} }
pub fn shuffle(&mut self) -> std::io::Result<usize> { pub fn shuffle(&mut self) -> Result<usize> {
let mut udp_pipe_clone = self.try_clone()?; let mut udp_pipe_clone = self.try_clone()?;
thread::spawn(move || loop { thread::spawn(move || loop {
udp_pipe_clone udp_pipe_clone
@ -150,19 +167,19 @@ impl ProxyClient {
} }
} }
fn tcp_connect(&self) -> std::io::Result<TcpStream> { fn tcp_connect(&self) -> Result<TcpStream> {
let tcp_stream = TcpStream::connect(&self.tcp_target)?; let tcp_stream = TcpStream::connect(&self.tcp_target)?;
tcp_stream.set_read_timeout(self.socket_timeout)?; tcp_stream.set_read_timeout(self.socket_timeout)?;
Ok(tcp_stream) Ok(tcp_stream)
} }
fn udp_connect(&self) -> std::io::Result<UdpSocket> { fn udp_connect(&self) -> Result<UdpSocket> {
let udp_socket = UdpSocket::bind(&self.udp_host)?; let udp_socket = UdpSocket::bind(&self.udp_host)?;
udp_socket.set_read_timeout(self.socket_timeout)?; udp_socket.set_read_timeout(self.socket_timeout)?;
Ok(udp_socket) Ok(udp_socket)
} }
pub fn start(&self) -> std::io::Result<usize> { pub fn start(&self) -> Result<usize> {
let tcp_stream = self.tcp_connect()?; let tcp_stream = self.tcp_connect()?;
let udp_socket = self.udp_connect()?; let udp_socket = self.udp_connect()?;
@ -171,14 +188,10 @@ impl ProxyClient {
TcpUdpPipe::new(tcp_stream, udp_socket).shuffle_after_first_udp() TcpUdpPipe::new(tcp_stream, udp_socket).shuffle_after_first_udp()
} }
pub fn start_tls(&self) -> std::io::Result<usize> { pub fn start_tls(&self, hostname: &str) -> Result<usize> {
let tcp_stream = self.tcp_connect()?; let tcp_stream = self.tcp_connect()?;
let mut connector = SslConnector::builder(SslMethod::tls()).unwrap().build().configure().unwrap(); let tcp_stream = TlsStream::client(hostname, tcp_stream)?;
connector.set_verify_hostname(false);
connector.set_verify(SslVerifyMode::NONE);
let tcp_stream = connector.connect(self.tcp_target.split(":").next().unwrap(), tcp_stream).unwrap();
let tcp_stream = OpensslCell { sess: Arc::new(UnsafeCell::new(tcp_stream)) };
let udp_socket = self.udp_connect()?; let udp_socket = self.udp_connect()?;
@ -189,68 +202,18 @@ impl ProxyClient {
pub trait TryClone<T> { pub trait TryClone<T> {
fn try_clone(&self) -> std::io::Result<T>; fn try_clone(&self) -> Result<T>;
} }
impl TryClone<UdpSocket> for UdpSocket { impl TryClone<UdpSocket> for UdpSocket {
fn try_clone(&self) -> std::io::Result<UdpSocket> { fn try_clone(&self) -> Result<UdpSocket> {
self.try_clone() Ok(self.try_clone()?)
} }
} }
impl TryClone<TcpStream> for TcpStream { impl TryClone<TcpStream> for TcpStream {
fn try_clone(&self) -> std::io::Result<TcpStream> { fn try_clone(&self) -> Result<TcpStream> {
self.try_clone() Ok(self.try_clone()?)
}
}
impl TryClone<OpensslCell> for OpensslCell {
fn try_clone(&self) -> std::io::Result<OpensslCell> {
Ok(self.clone())
}
}
pub struct OpensslCell {
sess: Arc<UnsafeCell<SslStream<TcpStream>>>,
}
unsafe impl Sync for OpensslCell {}
unsafe impl Send for OpensslCell {}
impl Clone for OpensslCell {
fn clone(&self) -> Self {
OpensslCell {
sess: self.sess.clone(),
}
}
}
impl OpensslCell {
pub fn borrow(&self) -> &SslStream<TcpStream> {
unsafe {
&*self.sess.get()
}
}
pub fn borrow_mut(&self) -> &mut SslStream<TcpStream> {
unsafe {
&mut *self.sess.get()
}
}
}
impl Read for OpensslCell {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
self.borrow_mut().read(buf)
}
}
impl Write for OpensslCell {
fn write(&mut self, buf: &[u8]) -> Result<usize, std::io::Error> {
self.borrow_mut().write(buf)
}
fn flush(&mut self) -> Result<(), std::io::Error> {
self.borrow_mut().flush()
} }
} }
@ -284,7 +247,7 @@ impl ProxyServer {
} }
} }
pub fn start(&self) -> std::io::Result<()> { pub fn start(&self) -> Result<()> {
let listener = TcpListener::bind(&self.tcp_host)?; let listener = TcpListener::bind(&self.tcp_host)?;
println!("Listening for connections on {}", &self.tcp_host); println!("Listening for connections on {}", &self.tcp_host);
@ -292,6 +255,8 @@ impl ProxyServer {
match stream { match stream {
Ok(stream) => { Ok(stream) => {
let client_handler = self.client_handler.clone(); let client_handler = self.client_handler.clone();
client_handler.set_tcp_options(&stream).expect("cannot set tcp options");
thread::spawn(move || { thread::spawn(move || {
client_handler client_handler
.handle_client(stream) .handle_client(stream)
@ -305,6 +270,34 @@ impl ProxyServer {
} }
Ok(()) Ok(())
} }
pub fn start_tls(&self, tls_key: &str, tls_cert: &str) -> Result<()> {
let tls_listener = Arc::new(TlsListener::new(tls_key, tls_cert)?);
let listener = TcpListener::bind(&self.tcp_host)?;
println!("Listening for TLS connections on {}", &self.tcp_host);
for stream in listener.incoming() {
match stream {
Ok(stream) => {
let client_handler = self.client_handler.clone();
client_handler.set_tcp_options(&stream).expect("cannot set tcp options");
let tls_listener = tls_listener.clone();
thread::spawn(move || {
let stream = tls_listener.wrap(stream).expect("cannot wrap with tls");
client_handler
.handle_client_tls(stream)
.expect("error handling connection")
});
}
Err(e) => {
println!("Unable to connect: {}", e);
}
}
}
Ok(())
}
} }
pub struct ProxyServerClientHandler { pub struct ProxyServerClientHandler {
@ -316,9 +309,7 @@ pub struct ProxyServerClientHandler {
} }
impl ProxyServerClientHandler { impl ProxyServerClientHandler {
pub fn handle_client(&self, tcp_stream: TcpStream) -> std::io::Result<usize> { fn udp_bind(&self) -> Result<UdpSocket> {
tcp_stream.set_read_timeout(self.socket_timeout)?;
let mut port = self.udp_low_port; let mut port = self.udp_low_port;
let udp_socket = loop { let udp_socket = loop {
match UdpSocket::bind((&self.udp_host[..], port)) { match UdpSocket::bind((&self.udp_host[..], port)) {
@ -333,7 +324,18 @@ impl ProxyServerClientHandler {
}; };
udp_socket.set_read_timeout(self.socket_timeout)?; udp_socket.set_read_timeout(self.socket_timeout)?;
udp_socket.connect(&self.udp_target)?; udp_socket.connect(&self.udp_target)?;
Ok(udp_socket)
}
TcpUdpPipe::new(tcp_stream, udp_socket).shuffle() pub fn set_tcp_options(&self, tcp_stream: &TcpStream) -> Result<()> {
Ok(tcp_stream.set_read_timeout(self.socket_timeout)?)
}
pub fn handle_client(&self, tcp_stream: TcpStream) -> Result<usize> {
TcpUdpPipe::new(tcp_stream, self.udp_bind()?).shuffle()
}
pub fn handle_client_tls(&self, tcp_stream: TlsStream) -> Result<usize> {
TcpUdpPipe::new(tcp_stream, self.udp_bind()?).shuffle()
} }
} }

49
src/notls.rs Normal file
View File

@ -0,0 +1,49 @@
use std::net::TcpStream;
use crate::TryClone;
use std::io::{Read, Write};
use crate::error::*;
fn err() -> Error {
Error::new("Error: compiled without TLS support")
}
pub struct TlsStream;
impl TlsStream {
pub fn client(_host_name: &str, _tcp_stream: TcpStream) -> Result<TlsStream> {
Err(err())
}
}
impl TryClone<TlsStream> for TlsStream {
fn try_clone(&self) -> Result<TlsStream> {
Err(err())
}
}
impl Read for TlsStream {
fn read(&mut self, _buf: &mut [u8]) -> IoResult<usize> {
unimplemented!()
}
}
impl Write for TlsStream {
fn write(&mut self, _buf: &[u8]) -> IoResult<usize> {
unimplemented!()
}
fn flush(&mut self) -> IoResult<()> {
unimplemented!()
}
}
pub struct TlsListener;
impl TlsListener {
pub fn new(_tls_key: &str, _tls_cert: &str) -> Result<TlsListener> {
Err(err())
}
pub fn wrap(&self, _tcp_stream: TcpStream) -> Result<TlsStream> {
Err(err())
}
}

102
src/openssl.rs Normal file
View File

@ -0,0 +1,102 @@
use openssl::ssl::{SslConnector, SslMethod, SslStream, SslVerifyMode, SslAcceptor, SslFiletype, HandshakeError};
use std::sync::Arc;
use std::cell::UnsafeCell;
use std::net::TcpStream;
use crate::TryClone;
use std::io::{Read, Write};
use crate::error::*;
use std::error::Error as StdError;
impl TryClone<TlsStream> for TlsStream {
fn try_clone(&self) -> Result<TlsStream> {
Ok(self.clone())
}
}
pub struct TlsStream {
sess: Arc<UnsafeCell<SslStream<TcpStream>>>,
}
impl TlsStream {
fn new(stream: SslStream<TcpStream>) -> TlsStream {
TlsStream {
sess: Arc::new(UnsafeCell::new(stream))
}
}
pub fn client(host_name: &str, tcp_stream: TcpStream) -> Result<TlsStream> {
let mut connector = SslConnector::builder(SslMethod::tls())?.build().configure()?;
connector.set_verify_hostname(false);
connector.set_verify(SslVerifyMode::NONE);
let tcp_stream = connector.connect(host_name, tcp_stream)?;
Ok(TlsStream::new(tcp_stream))
}
}
unsafe impl Sync for TlsStream {}
unsafe impl Send for TlsStream {}
impl Clone for TlsStream {
fn clone(&self) -> Self {
TlsStream {
sess: self.sess.clone(),
}
}
}
impl TlsStream {
pub fn borrow_mut(&self) -> &mut SslStream<TcpStream> {
unsafe {
&mut *self.sess.get()
}
}
}
impl Read for TlsStream {
fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
self.borrow_mut().read(buf)
}
}
impl Write for TlsStream {
fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
self.borrow_mut().write(buf)
}
fn flush(&mut self) -> IoResult<()> {
self.borrow_mut().flush()
}
}
pub struct TlsListener {
acceptor: SslAcceptor,
}
impl TlsListener {
pub fn new(tls_key: &str, tls_cert: &str) -> Result<TlsListener> {
let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls())?;
acceptor.set_private_key_file(tls_key, SslFiletype::PEM)?;
acceptor.set_certificate_chain_file(tls_cert)?;
acceptor.check_private_key()?;
let acceptor = acceptor.build();
Ok(TlsListener {
acceptor
})
}
pub fn wrap(&self, tcp_stream: TcpStream) -> Result<TlsStream> {
Ok(TlsStream::new(self.acceptor.accept(tcp_stream)?))
}
}
impl From<openssl::error::ErrorStack> for Error {
fn from(value: openssl::error::ErrorStack) -> Self {
Error::new(value.description())
}
}
impl From<HandshakeError<std::net::TcpStream>> for Error {
fn from(value: HandshakeError<std::net::TcpStream>) -> Self {
Error::new(value.description())
}
}

64
test.sh
View File

@ -1,33 +1,67 @@
#!/bin/sh #!/bin/sh
set -x set -x
# always run this clean # first run without TLS
cargo clean cargo clean
cargo build --release --no-default-features
export PATH="$(pwd)/target/release:$PATH"
# first make sure udp-test succeeds running against itself # first make sure udp-test succeeds running against itself
cargo run --release --bin udp-test || exit 1 udp-test || exit 1
# now run udp-test without spawning other processes # now run udp-test without spawning other processes
cargo run --release --bin udp-test -- -is || exit 1 udp-test -is || exit 1
# now run proxyd pointing to udp-test # now run proxyd pointing to udp-test
cargo run --release --bin wireguard-proxy -- -th 127.0.0.1:5555 -ut 127.0.0.1:51822 & wireguard-proxy -th 127.0.0.1:5555 -ut 127.0.0.1:51822 &
proxyd_pid=$! proxyd_pid=$!
# wait for ports to be set up, this is fragile... # wait for ports to be set up, this is fragile...
sleep 1 sleep 5
# proxy pointing to proxyd # proxy pointing to proxyd
#cargo run --release --bin wireguard-proxy -- -tt 127.0.0.1:5555 & wireguard-proxy -tt 127.0.0.1:5555 &
echo -e '\n\n\n\n\n\n\n' | openssl req -new -x509 -days 365 -nodes -out cert.pem -keyout cert.key
socat OPENSSL-LISTEN:5554,bind=127.0.0.1,cert=./cert.pem,key=./cert.key,verify=0 tcp4-connect:127.0.0.1:5555 &
cargo run --release --bin wireguard-proxy -- -tt 127.0.0.1:5554 --tls &
proxy_pid=$! proxy_pid=$!
# wait for ports to be set up, this is fragile... # wait for ports to be set up, this is fragile...
sleep 1 sleep 1
# and udp-test pointing to proxy, which then hops to proxyd, and finally back to udp-test # and udp-test pointing to proxy, which then hops to proxyd, and finally back to udp-test
cargo run --release --bin udp-test -- -uh 127.0.0.1:51822 udp-test -uh 127.0.0.1:51822
udp_exit=$?
kill $proxyd_pid $proxy_pid
[ $udp_exit -ne 0 ] && exit $udp_exit
# now run udp-test essentially just like the script above, but all in rust
udp-test -s || exit 1
echo "non-tls tests passed!"
echo -e '\n\n\n\n\n\n\n' | openssl req -new -x509 -days 365 -nodes -out cert.pem -keyout cert.key
# first run without TLS
cargo clean
cargo build --release
export PATH="$(pwd)/target/release:$PATH"
# first make sure udp-test succeeds running against itself
udp-test || exit 1
# now run udp-test without spawning other processes
udp-test -is || exit 1
# now run proxyd pointing to udp-test
wireguard-proxy -th 127.0.0.1:5555 -ut 127.0.0.1:51822 --tls-key cert.key --tls-cert cert.pem &
proxyd_pid=$!
# wait for ports to be set up, this is fragile...
sleep 5
# proxy pointing to proxyd
wireguard-proxy -tt 127.0.0.1:5555 --tls &
proxy_pid=$!
# wait for ports to be set up, this is fragile...
sleep 1
# and udp-test pointing to proxy, which then hops to proxyd, and finally back to udp-test
udp-test -uh 127.0.0.1:51822
udp_exit=$? udp_exit=$?
kill $proxyd_pid $proxy_pid kill $proxyd_pid $proxy_pid
@ -37,4 +71,6 @@ rm -f cert.pem cert.key
[ $udp_exit -ne 0 ] && exit $udp_exit [ $udp_exit -ne 0 ] && exit $udp_exit
# now run udp-test essentially just like the script above, but all in rust # now run udp-test essentially just like the script above, but all in rust
cargo run --release --bin udp-test -- -s udp-test -s
exit $?