diff --git a/src/constants.rs b/src/constants.rs index 472af4d..a80c6ad 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -2,6 +2,7 @@ use backoff::ExponentialBackoff; use std::time::Duration; // FIXME: Determine reasonable size +/// UDP MTU. Currently far larger than necessary pub const UDP_BUFFER_SIZE: usize = 2048; pub const UDP_SENDQ_SIZE: usize = 1024; pub const UDP_TIMEOUT: u64 = 60; diff --git a/src/lib.rs b/src/lib.rs index 5e186a8..bb474ee 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,6 +8,7 @@ mod transport; pub use cli::Cli; pub use config::Config; +pub use constants::UDP_BUFFER_SIZE; use anyhow::{anyhow, Result}; use tokio::sync::broadcast; diff --git a/tests/common/mod.rs b/tests/common/mod.rs index feae8e1..fa0793d 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -34,42 +34,77 @@ pub async fn run_rathole_client( rathole::run(&cli, shutdown_rx).await } -pub async fn echo_server(addr: A) -> Result<()> { - let l = TcpListener::bind(addr).await?; +pub mod tcp { + use super::*; - loop { - let (conn, _addr) = l.accept().await?; - tokio::spawn(async move { - let _ = echo(conn).await; - }); + pub async fn echo_server(addr: A) -> Result<()> { + let l = TcpListener::bind(addr).await?; + + loop { + let (conn, _addr) = l.accept().await?; + tokio::spawn(async move { + let _ = echo(conn).await; + }); + } + } + + pub async fn pingpong_server(addr: A) -> Result<()> { + let l = TcpListener::bind(addr).await?; + + loop { + let (conn, _addr) = l.accept().await?; + tokio::spawn(async move { + let _ = pingpong(conn).await; + }); + } + } + + async fn echo(conn: TcpStream) -> Result<()> { + let (mut rd, mut wr) = conn.into_split(); + io::copy(&mut rd, &mut wr).await?; + + Ok(()) + } + + async fn pingpong(mut conn: TcpStream) -> Result<()> { + let mut buf = [0u8; PING.len()]; + + while conn.read_exact(&mut buf).await? != 0 { + assert_eq!(buf, PING.as_bytes()); + conn.write_all(PONG.as_bytes()).await?; + } + + Ok(()) } } -pub async fn pingpong_server(addr: A) -> Result<()> { - let l = TcpListener::bind(addr).await?; +pub mod udp { + use rathole::UDP_BUFFER_SIZE; + use tokio::net::UdpSocket; + use tracing::debug; - loop { - let (conn, _addr) = l.accept().await?; - tokio::spawn(async move { - let _ = pingpong(conn).await; - }); - } -} + use super::*; -async fn echo(conn: TcpStream) -> Result<()> { - let (mut rd, mut wr) = conn.into_split(); - io::copy(&mut rd, &mut wr).await?; + pub async fn echo_server(addr: A) -> Result<()> { + let l = UdpSocket::bind(addr).await?; + debug!("UDP echo server listening"); - Ok(()) -} - -async fn pingpong(mut conn: TcpStream) -> Result<()> { - let mut buf = [0u8; PING.len()]; - - while conn.read_exact(&mut buf).await? != 0 { - assert_eq!(buf, PING.as_bytes()); - conn.write_all(PONG.as_bytes()).await?; + let mut buf = [0u8; UDP_BUFFER_SIZE]; + loop { + let (n, addr) = l.recv_from(&mut buf).await?; + debug!("Get {:?} from {}", &buf[..n], addr); + l.send_to(&buf[..n], addr).await?; + } } - Ok(()) + pub async fn pingpong_server(addr: A) -> Result<()> { + let l = UdpSocket::bind(addr).await?; + + let mut buf = [0u8; UDP_BUFFER_SIZE]; + loop { + let (n, addr) = l.recv_from(&mut buf).await?; + assert_eq!(&buf[..n], PING.as_bytes()); + l.send_to(PONG.as_bytes(), addr).await?; + } + } } diff --git a/tests/tcp_transport.toml b/tests/for_tcp/tcp_transport.toml similarity index 100% rename from tests/tcp_transport.toml rename to tests/for_tcp/tcp_transport.toml diff --git a/tests/tls_transport.toml b/tests/for_tcp/tls_transport.toml similarity index 100% rename from tests/tls_transport.toml rename to tests/for_tcp/tls_transport.toml diff --git a/tests/for_udp/tcp_transport.toml b/tests/for_udp/tcp_transport.toml new file mode 100644 index 0000000..d045568 --- /dev/null +++ b/tests/for_udp/tcp_transport.toml @@ -0,0 +1,27 @@ +[client] +remote_addr = "localhost:2332" +default_token = "default_token_if_not_specify" + +[client.transport] +type = "tcp" + +[client.services.echo] +type = "udp" +local_addr = "localhost:8080" +[client.services.pingpong] +type = "udp" +local_addr = "localhost:8081" + +[server] +bind_addr = "0.0.0.0:2332" +default_token = "default_token_if_not_specify" + +[server.transport] +type = "tcp" + +[server.services.echo] +type = "udp" +bind_addr = "0.0.0.0:2334" +[server.services.pingpong] +type = "udp" +bind_addr = "0.0.0.0:2335" diff --git a/tests/for_udp/tls_transport.toml b/tests/for_udp/tls_transport.toml new file mode 100644 index 0000000..25fe955 --- /dev/null +++ b/tests/for_udp/tls_transport.toml @@ -0,0 +1,33 @@ +[client] +remote_addr = "localhost:2332" +default_token = "default_token_if_not_specify" + +[client.transport] +type = "tls" +[client.transport.tls] +trusted_root = "examples/tls/ca-cert.pem" +hostname = "0.0.0.0" + +[client.services.echo] +type = "udp" +local_addr = "localhost:8080" +[client.services.pingpong] +type = "udp" +local_addr = "localhost:8081" + +[server] +bind_addr = "0.0.0.0:2332" +default_token = "default_token_if_not_specify" + +[server.transport] +type = "tls" +[server.transport.tls] +pkcs12 = "examples/tls/identity.pfx" +pkcs12_password = "1234" + +[server.services.echo] +type = "udp" +bind_addr = "0.0.0.0:2334" +[server.services.pingpong] +type = "udp" +bind_addr = "0.0.0.0:2335" diff --git a/tests/integration_test.rs b/tests/integration_test.rs index e86c3bd..686b011 100644 --- a/tests/integration_test.rs +++ b/tests/integration_test.rs @@ -4,48 +4,93 @@ use rand::Rng; use std::time::Duration; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, - net::TcpStream, + net::{TcpStream, UdpSocket}, sync::broadcast, time, }; +use tracing::{debug, info, instrument}; +use tracing_subscriber::EnvFilter; use crate::common::run_rathole_server; mod common; -const ECHO_SERVER_ADDR: &str = "localhost:8080"; -const PINGPONG_SERVER_ADDR: &str = "localhost:8081"; -const ECHO_SERVER_ADDR_EXPOSED: &str = "localhost:2334"; -const PINGPONG_SERVER_ADDR_EXPOSED: &str = "localhost:2335"; +const ECHO_SERVER_ADDR: &str = "0.0.0.0:8080"; +const PINGPONG_SERVER_ADDR: &str = "0.0.0.0:8081"; +const ECHO_SERVER_ADDR_EXPOSED: &str = "0.0.0.0:2334"; +const PINGPONG_SERVER_ADDR_EXPOSED: &str = "0.0.0.0:2335"; const HITTER_NUM: usize = 4; +#[derive(Clone, Copy, Debug)] +enum Type { + Tcp, + Udp, +} + +fn init() { + let level = "info"; + let _ = tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::from(level)), + ) + .try_init(); +} + #[tokio::test] -async fn main() -> Result<()> { +async fn tcp() -> Result<()> { + init(); + // Spawn a echo server tokio::spawn(async move { - if let Err(e) = common::echo_server(ECHO_SERVER_ADDR).await { + if let Err(e) = common::tcp::echo_server(ECHO_SERVER_ADDR).await { panic!("Failed to run the echo server for testing: {:?}", e); } }); // Spawn a pingpong server tokio::spawn(async move { - if let Err(e) = common::pingpong_server(PINGPONG_SERVER_ADDR).await { + if let Err(e) = common::tcp::pingpong_server(PINGPONG_SERVER_ADDR).await { panic!("Failed to run the pingpong server for testing: {:?}", e); } }); - test("tests/tcp_transport.toml").await?; - test("tests/tls_transport.toml").await?; + test("tests/for_tcp/tcp_transport.toml", Type::Tcp).await?; + test("tests/for_tcp/tls_transport.toml", Type::Tcp).await?; Ok(()) } -async fn test(config_path: &'static str) -> Result<()> { +#[tokio::test] +async fn udp() -> Result<()> { + init(); + + // Spawn a echo server + tokio::spawn(async move { + if let Err(e) = common::udp::echo_server(ECHO_SERVER_ADDR).await { + panic!("Failed to run the echo server for testing: {:?}", e); + } + }); + + // Spawn a pingpong server + tokio::spawn(async move { + if let Err(e) = common::udp::pingpong_server(PINGPONG_SERVER_ADDR).await { + panic!("Failed to run the pingpong server for testing: {:?}", e); + } + }); + + test("tests/for_udp/tcp_transport.toml", Type::Udp).await?; + test("tests/for_udp/tls_transport.toml", Type::Udp).await?; + + Ok(()) +} + +#[instrument] +async fn test(config_path: &'static str, t: Type) -> Result<()> { let (client_shutdown_tx, client_shutdown_rx) = broadcast::channel(1); let (server_shutdown_tx, server_shutdown_rx) = broadcast::channel(1); // Start the client + info!("start the client"); tokio::spawn(async move { run_rathole_client(&config_path, client_shutdown_rx) .await @@ -56,6 +101,7 @@ async fn test(config_path: &'static str) -> Result<()> { time::sleep(Duration::from_secs(1)).await; // Start the server + info!("start the server"); tokio::spawn(async move { run_rathole_server(&config_path, server_shutdown_rx) .await @@ -63,27 +109,42 @@ async fn test(config_path: &'static str) -> Result<()> { }); time::sleep(Duration::from_secs(1)).await; // Wait for the client to retry - echo_hitter(ECHO_SERVER_ADDR_EXPOSED).await.unwrap(); - pingpong_hitter(PINGPONG_SERVER_ADDR_EXPOSED).await.unwrap(); + info!("echo"); + echo_hitter(ECHO_SERVER_ADDR_EXPOSED, t).await.unwrap(); + info!("pingpong"); + pingpong_hitter(PINGPONG_SERVER_ADDR_EXPOSED, t) + .await + .unwrap(); // Simulate the client crash and restart + info!("shutdown the client"); client_shutdown_tx.send(true)?; time::sleep(Duration::from_millis(500)).await; + + info!("restart the client"); let client_shutdown_rx = client_shutdown_tx.subscribe(); - tokio::spawn(async move { + let client = tokio::spawn(async move { run_rathole_client(&config_path, client_shutdown_rx) .await .unwrap(); }); + time::sleep(Duration::from_secs(1)).await; // Wait for the client to start - echo_hitter(ECHO_SERVER_ADDR_EXPOSED).await.unwrap(); - pingpong_hitter(PINGPONG_SERVER_ADDR_EXPOSED).await.unwrap(); + info!("echo"); + echo_hitter(ECHO_SERVER_ADDR_EXPOSED, t).await.unwrap(); + info!("pingpong"); + pingpong_hitter(PINGPONG_SERVER_ADDR_EXPOSED, t) + .await + .unwrap(); // Simulate the server crash and restart + info!("shutdown the server"); server_shutdown_tx.send(true)?; time::sleep(Duration::from_millis(500)).await; + + info!("restart the server"); let server_shutdown_rx = server_shutdown_tx.subscribe(); - tokio::spawn(async move { + let server = tokio::spawn(async move { run_rathole_server(&config_path, server_shutdown_rx) .await .unwrap(); @@ -91,24 +152,44 @@ async fn test(config_path: &'static str) -> Result<()> { time::sleep(Duration::from_secs(1)).await; // Wait for the client to retry // Simulate heavy load + info!("lots of echo and pingpong"); for _ in 0..HITTER_NUM / 2 { tokio::spawn(async move { - echo_hitter(ECHO_SERVER_ADDR_EXPOSED).await.unwrap(); + echo_hitter(ECHO_SERVER_ADDR_EXPOSED, t).await.unwrap(); }); tokio::spawn(async move { - pingpong_hitter(PINGPONG_SERVER_ADDR_EXPOSED).await.unwrap(); + pingpong_hitter(PINGPONG_SERVER_ADDR_EXPOSED, t) + .await + .unwrap(); }); } // Shutdown + info!("shutdown the server and the client"); server_shutdown_tx.send(true)?; client_shutdown_tx.send(true)?; + let _ = tokio::join!(server, client); + Ok(()) } -async fn echo_hitter(addr: &str) -> Result<()> { +async fn echo_hitter(addr: &'static str, t: Type) -> Result<()> { + match t { + Type::Tcp => tcp_echo_hitter(addr).await, + Type::Udp => udp_echo_hitter(addr).await, + } +} + +async fn pingpong_hitter(addr: &'static str, t: Type) -> Result<()> { + match t { + Type::Tcp => tcp_pingpong_hitter(addr).await, + Type::Udp => udp_pingpong_hitter(addr).await, + } +} + +async fn tcp_echo_hitter(addr: &'static str) -> Result<()> { let mut conn = TcpStream::connect(addr).await?; let mut wr = [0u8; 1024]; @@ -123,7 +204,27 @@ async fn echo_hitter(addr: &str) -> Result<()> { Ok(()) } -async fn pingpong_hitter(addr: &str) -> Result<()> { +async fn udp_echo_hitter(addr: &'static str) -> Result<()> { + let conn = UdpSocket::bind("0.0.0.0:0").await?; + conn.connect(addr).await?; + + let mut wr = [0u8; 128]; + let mut rd = [0u8; 128]; + for _ in 0..3 { + rand::thread_rng().fill(&mut wr); + + conn.send(&wr).await?; + debug!("send"); + + conn.recv(&mut rd).await?; + debug!("recv"); + + assert_eq!(wr, rd); + } + Ok(()) +} + +async fn tcp_pingpong_hitter(addr: &'static str) -> Result<()> { let mut conn = TcpStream::connect(addr).await?; let wr = PING.as_bytes(); @@ -137,3 +238,23 @@ async fn pingpong_hitter(addr: &str) -> Result<()> { Ok(()) } + +async fn udp_pingpong_hitter(addr: &'static str) -> Result<()> { + let conn = UdpSocket::bind("0.0.0.0:0").await?; + conn.connect(&addr).await?; + + let wr = PING.as_bytes(); + let mut rd = [0u8; PONG.len()]; + + for _ in 0..3 { + conn.send(wr).await?; + debug!("ping"); + + conn.recv(&mut rd).await?; + debug!("pong"); + + assert_eq!(rd, PONG.as_bytes()); + } + + Ok(()) +}