mirror of https://github.com/ekzhang/bore.git
130 lines
4.1 KiB
Rust
130 lines
4.1 KiB
Rust
#![allow(clippy::items_after_test_module)]
|
|
|
|
use std::net::SocketAddr;
|
|
use std::time::Duration;
|
|
|
|
use anyhow::{anyhow, Result};
|
|
use bore_cli::{client::Client, server::Server, shared::CONTROL_PORT};
|
|
use lazy_static::lazy_static;
|
|
use rstest::*;
|
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
|
use tokio::net::{TcpListener, TcpStream};
|
|
use tokio::sync::Mutex;
|
|
use tokio::time;
|
|
|
|
lazy_static! {
|
|
/// Guard to make sure that tests are run serially, not concurrently.
|
|
static ref SERIAL_GUARD: Mutex<()> = Mutex::new(());
|
|
}
|
|
|
|
/// Spawn the server, giving some time for the control port TcpListener to start.
|
|
async fn spawn_server(secret: Option<&str>) {
|
|
tokio::spawn(Server::new(1024..=65535, secret).listen());
|
|
time::sleep(Duration::from_millis(50)).await;
|
|
}
|
|
|
|
/// Spawns a client with randomly assigned ports, returning the listener and remote address.
|
|
async fn spawn_client(secret: Option<&str>) -> Result<(TcpListener, SocketAddr)> {
|
|
let listener = TcpListener::bind("localhost:0").await?;
|
|
let local_port = listener.local_addr()?.port();
|
|
let client = Client::new("localhost", local_port, "localhost", 0, secret).await?;
|
|
let remote_addr = ([127, 0, 0, 1], client.remote_port()).into();
|
|
tokio::spawn(client.listen());
|
|
Ok((listener, remote_addr))
|
|
}
|
|
|
|
#[rstest]
|
|
#[tokio::test]
|
|
async fn basic_proxy(#[values(None, Some(""), Some("abc"))] secret: Option<&str>) -> Result<()> {
|
|
let _guard = SERIAL_GUARD.lock().await;
|
|
|
|
spawn_server(secret).await;
|
|
let (listener, addr) = spawn_client(secret).await?;
|
|
|
|
tokio::spawn(async move {
|
|
let (mut stream, _) = listener.accept().await?;
|
|
let mut buf = [0u8; 11];
|
|
stream.read_exact(&mut buf).await?;
|
|
assert_eq!(&buf, b"hello world");
|
|
|
|
stream.write_all(b"I can send a message too!").await?;
|
|
anyhow::Ok(())
|
|
});
|
|
|
|
let mut stream = TcpStream::connect(addr).await?;
|
|
stream.write_all(b"hello world").await?;
|
|
|
|
let mut buf = [0u8; 25];
|
|
stream.read_exact(&mut buf).await?;
|
|
assert_eq!(&buf, b"I can send a message too!");
|
|
|
|
// Ensure that the client end of the stream is closed now.
|
|
assert_eq!(stream.read(&mut buf).await?, 0);
|
|
|
|
// Also ensure that additional connections do not produce any data.
|
|
let mut stream = TcpStream::connect(addr).await?;
|
|
assert_eq!(stream.read(&mut buf).await?, 0);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[rstest]
|
|
#[case(None, Some("my secret"))]
|
|
#[case(Some("my secret"), None)]
|
|
#[tokio::test]
|
|
async fn mismatched_secret(
|
|
#[case] server_secret: Option<&str>,
|
|
#[case] client_secret: Option<&str>,
|
|
) {
|
|
let _guard = SERIAL_GUARD.lock().await;
|
|
|
|
spawn_server(server_secret).await;
|
|
assert!(spawn_client(client_secret).await.is_err());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn invalid_address() -> Result<()> {
|
|
// We don't need the serial guard for this test because it doesn't create a server.
|
|
async fn check_address(to: &str, use_secret: bool) -> Result<()> {
|
|
match Client::new("localhost", 5000, to, 0, use_secret.then_some("a secret")).await {
|
|
Ok(_) => Err(anyhow!("expected error for {to}, use_secret={use_secret}")),
|
|
Err(_) => Ok(()),
|
|
}
|
|
}
|
|
tokio::try_join!(
|
|
check_address("google.com", false),
|
|
check_address("google.com", true),
|
|
check_address("nonexistent.domain.for.demonstration", false),
|
|
check_address("nonexistent.domain.for.demonstration", true),
|
|
check_address("malformed !$uri$%", false),
|
|
check_address("malformed !$uri$%", true),
|
|
)?;
|
|
Ok(())
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn very_long_frame() -> Result<()> {
|
|
let _guard = SERIAL_GUARD.lock().await;
|
|
|
|
spawn_server(None).await;
|
|
let mut attacker = TcpStream::connect(("localhost", CONTROL_PORT)).await?;
|
|
|
|
// Slowly send a very long frame.
|
|
for _ in 0..10 {
|
|
let result = attacker.write_all(&[42u8; 100000]).await;
|
|
if result.is_err() {
|
|
return Ok(());
|
|
}
|
|
time::sleep(Duration::from_millis(10)).await;
|
|
}
|
|
panic!("did not exit after a 1 MB frame");
|
|
}
|
|
|
|
#[test]
|
|
#[should_panic]
|
|
fn empty_port_range() {
|
|
let min_port = 5000;
|
|
let max_port = 3000;
|
|
let _ = Server::new(min_port..=max_port, None);
|
|
}
|