diff --git a/Cargo.lock b/Cargo.lock index c9bb7d9..5706c36 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -91,6 +91,8 @@ dependencies = [ "dashmap", "hex", "hmac", + "lazy_static", + "rstest", "serde", "serde_json", "sha2", @@ -461,12 +463,34 @@ dependencies = [ "bitflags", ] +[[package]] +name = "rstest" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d912f35156a3f99a66ee3e11ac2e0b3f34ac85a07e05263d05a7e2c8810d616f" +dependencies = [ + "cfg-if", + "proc-macro2", + "quote", + "rustc_version", + "syn", +] + [[package]] name = "rustc-demangle" version = "0.1.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ef03e0a2b150c7a90d01faf6254c9c48a41e95fb2a8c2ac1c6f0d2b9aefc342" +[[package]] +name = "rustc_version" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +dependencies = [ + "semver", +] + [[package]] name = "ryu" version = "1.0.9" @@ -479,6 +503,12 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +[[package]] +name = "semver" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d65bd28f48be7196d222d95b9243287f48d27aca604e08497513019ff0502cc4" + [[package]] name = "serde" version = "1.0.136" diff --git a/Cargo.toml b/Cargo.toml index 4587f14..0051283 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,3 +28,7 @@ tokio = { version = "1.17.0", features = ["full"] } tracing = "0.1.32" tracing-subscriber = "0.3.10" uuid = { version = "0.8.2", features = ["serde", "v4"] } + +[dev-dependencies] +lazy_static = "1.4.0" +rstest = "0.12.0" diff --git a/tests/auth_test.rs b/tests/auth_test.rs new file mode 100644 index 0000000..e535b01 --- /dev/null +++ b/tests/auth_test.rs @@ -0,0 +1,35 @@ +use anyhow::Result; +use bore_cli::auth::Authenticator; +use tokio::io::{self, BufReader}; + +#[tokio::test] +async fn auth_handshake() -> Result<()> { + let auth = Authenticator::new("some secret string"); + + let (client, server) = io::duplex(8); // Ensure correctness with limited capacity. + let mut client = BufReader::new(client); + let mut server = BufReader::new(server); + + tokio::try_join!( + auth.client_handshake(&mut client), + auth.server_handshake(&mut server), + )?; + + Ok(()) +} + +#[tokio::test] +async fn auth_handshake_fail() { + let auth = Authenticator::new("client secret"); + let auth2 = Authenticator::new("different server secret"); + + let (client, server) = io::duplex(8); // Ensure correctness with limited capacity. + let mut client = BufReader::new(client); + let mut server = BufReader::new(server); + + let result = tokio::try_join!( + auth.client_handshake(&mut client), + auth2.server_handshake(&mut server), + ); + assert!(result.is_err()); +} diff --git a/tests/e2e_test.rs b/tests/e2e_test.rs new file mode 100644 index 0000000..50fa9d9 --- /dev/null +++ b/tests/e2e_test.rs @@ -0,0 +1,100 @@ +use std::net::SocketAddr; +use std::time::Duration; + +use anyhow::{anyhow, Result}; +use bore_cli::{client::Client, server::Server}; +use lazy_static::lazy_static; +use rstest::*; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::Mutex; +use tokio::time; + +lazy_static! { + /// Guard to make sure that tests are run serially, not concurrently. + static ref SERIAL_GUARD: Mutex<()> = Mutex::new(()); +} + +/// Spawn the server, giving some time for the control port TcpListener to start. +async fn spawn_server(secret: Option<&str>) { + tokio::spawn(Server::new(1024, secret).listen()); + time::sleep(Duration::from_millis(50)).await; +} + +/// Spawns a client with randomly assigned ports, returning the listener and remote address. +async fn spawn_client(secret: Option<&str>) -> Result<(TcpListener, SocketAddr)> { + let listener = TcpListener::bind("localhost:0").await?; + let client = Client::new(listener.local_addr()?.port(), "localhost", 0, secret).await?; + let remote_addr = ([0, 0, 0, 0], client.remote_port()).into(); + tokio::spawn(client.listen()); + Ok((listener, remote_addr)) +} + +#[rstest] +#[tokio::test] +async fn basic_proxy(#[values(None, Some(""), Some("abc"))] secret: Option<&str>) -> Result<()> { + let _guard = SERIAL_GUARD.lock().await; + + spawn_server(secret).await; + let (listener, addr) = spawn_client(secret).await?; + + tokio::spawn(async move { + let (mut stream, _) = listener.accept().await?; + let mut buf = [0u8; 11]; + stream.read_exact(&mut buf).await?; + assert_eq!(&buf, b"hello world"); + + stream.write_all(b"I can send a message too!").await?; + anyhow::Ok(()) + }); + + let mut stream = TcpStream::connect(addr).await?; + stream.write_all(b"hello world").await?; + + let mut buf = [0u8; 25]; + stream.read_exact(&mut buf).await?; + assert_eq!(&buf, b"I can send a message too!"); + + // Ensure that the client end of the stream is closed now. + assert_eq!(stream.read(&mut buf).await?, 0); + + // Also ensure that additional connections do not produce any data. + let mut stream = TcpStream::connect(addr).await?; + assert_eq!(stream.read(&mut buf).await?, 0); + + Ok(()) +} + +#[rstest] +#[case(None, Some("my secret"))] +#[case(Some("my secret"), None)] +#[tokio::test] +async fn mismatched_secret( + #[case] server_secret: Option<&str>, + #[case] client_secret: Option<&str>, +) { + let _guard = SERIAL_GUARD.lock().await; + + spawn_server(server_secret).await; + assert!(spawn_client(client_secret).await.is_err()); +} + +#[tokio::test] +async fn invalid_address() -> Result<()> { + // We don't need the serial guard for this test because it doesn't create a server. + async fn check_address(to: &str, use_secret: bool) -> Result<()> { + match Client::new(5000, to, 0, use_secret.then(|| "a secret")).await { + Ok(_) => Err(anyhow!("expected error for {to}, use_secret={use_secret}")), + Err(_) => Ok(()), + } + } + tokio::try_join!( + check_address("google.com", false), + check_address("google.com", true), + check_address("nonexistent.domain.for.demonstration", false), + check_address("nonexistent.domain.for.demonstration", true), + check_address("malformed !$uri$%", false), + check_address("malformed !$uri$%", true), + )?; + Ok(()) +}