diff --git a/Cargo.lock b/Cargo.lock index a29fa12..732a45b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -54,9 +54,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "backtrace" -version = "0.3.64" +version = "0.3.65" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e121dee8023ce33ab248d9ce1493df03c3b38a659b240096fcbd7048ff9c31f" +checksum = "11a17d453482a265fd5f8479f2a3f405566e6ca627837aaddb85af8b1ab8ef61" dependencies = [ "addr2line", "cc", @@ -89,6 +89,7 @@ dependencies = [ "anyhow", "clap", "dashmap", + "futures-util", "hex", "hmac", "lazy_static", @@ -97,6 +98,7 @@ dependencies = [ "serde_json", "sha2", "tokio", + "tokio-util", "tracing", "tracing-subscriber", "uuid", @@ -122,16 +124,16 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "clap" -version = "3.1.8" +version = "3.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71c47df61d9e16dc010b55dba1952a57d8c215dbb533fd13cdd13369aac73b1c" +checksum = "6aad2534fad53df1cc12519c5cda696dd3e20e6118a027e24054aea14a0bdcbe" dependencies = [ "atty", "bitflags", "clap_derive", + "clap_lex", "indexmap", "lazy_static", - "os_str_bytes", "strsim", "termcolor", "textwrap", @@ -150,6 +152,15 @@ dependencies = [ "syn", ] +[[package]] +name = "clap_lex" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "189ddd3b5d32a70b35e7686054371742a937b0d99128e76dde6340210e966669" +dependencies = [ + "os_str_bytes", +] + [[package]] name = "cpufeatures" version = "0.2.2" @@ -191,6 +202,50 @@ dependencies = [ "subtle", ] +[[package]] +name = "futures-core" +version = "0.3.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c09fd04b7e4073ac7156a9539b57a484a8ea920f79c7c675d05d289ab6110d3" + +[[package]] +name = "futures-macro" +version = "0.3.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33c1e13800337f4d4d7a316bf45a567dbcb6ffe087f16424852d97e97a91f512" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-sink" +version = "0.3.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21163e139fa306126e6eedaf49ecdb4588f939600f0b1e770f4205ee4b7fa868" + +[[package]] +name = "futures-task" +version = "0.3.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c66a976bf5909d801bbef33416c41372779507e7a6b3a5e25e4749c58f776a" + +[[package]] +name = "futures-util" +version = "0.3.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8b7abd5d659d9b90c8cba917f6ec750a74e2dc23902ef9cd4cc8c8b22e6036a" +dependencies = [ + "futures-core", + "futures-macro", + "futures-sink", + "futures-task", + "pin-project-lite", + "pin-utils", + "slab", +] + [[package]] name = "generic-array" version = "0.14.5" @@ -278,9 +333,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.123" +version = "0.2.124" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb691a747a7ab48abc15c5b42066eaafde10dc427e3b6ee2a1cf43db04c763bd" +checksum = "21a41fed9d98f27ab1c6d161da622a4fa35e8a54a8adc24bbf3ddd0ef70b0e50" [[package]] name = "lock_api" @@ -309,12 +364,11 @@ checksum = "308cc39be01b73d0d18f82a0e7b2a3df85245f84af96fdddc5d202d27e47b86a" [[package]] name = "miniz_oxide" -version = "0.4.4" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a92518e98c078586bc6c934028adcca4c92a53d6a958196de835170a01d84e4b" +checksum = "d2b29bd4bc3f33391105ebee3589c19197c4271e3e5a9ec9bfe8127eeff8f082" dependencies = [ "adler", - "autocfg", ] [[package]] @@ -361,9 +415,9 @@ dependencies = [ [[package]] name = "object" -version = "0.27.1" +version = "0.28.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67ac1d3f9a1d3616fd9a60c8d74296f22406a238b6a72f5cc1e6f314df4ffbf9" +checksum = "40bec70ba014595f99f7aa110b84331ffe1ee9aece7fe6f387cc7e3ecda4d456" dependencies = [ "memchr", ] @@ -379,9 +433,6 @@ name = "os_str_bytes" version = "6.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e22443d1643a904602595ba1cd8f7d896afe56d26712531c5ff73a15b2fbf64" -dependencies = [ - "memchr", -] [[package]] name = "parking_lot" @@ -412,6 +463,12 @@ version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e280fbe77cc62c91527259e9442153f4688736748d24660126286329742b4c6c" +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + [[package]] name = "proc-macro-error" version = "1.0.4" @@ -560,6 +617,12 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "slab" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb703cfe953bccee95685111adeedb76fabe4e97549a58d16f03ea7b9367bb32" + [[package]] name = "smallvec" version = "1.8.0" @@ -652,10 +715,24 @@ dependencies = [ ] [[package]] -name = "tracing" -version = "0.1.33" +name = "tokio-util" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80b9fa4360528139bc96100c160b7ae879f5567f49f1782b0b02035b0358ebf3" +checksum = "0edfdeb067411dba2044da6d1cb2df793dd35add7888d73c16e3381ded401764" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", + "tracing", +] + +[[package]] +name = "tracing" +version = "0.1.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d0ecdcb44a79f0fe9844f0c4f33a342cbcbb5117de8001e6ba0dc2351327d09" dependencies = [ "cfg-if", "pin-project-lite", @@ -676,9 +753,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.25" +version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6dfce9f3241b150f36e8e54bb561a742d5daa1a47b5dd9a5ce369fd4a4db2210" +checksum = "f54c8ca710e81886d498c2fd3331b56c93aa248d49de2222ad2742247c60072f" dependencies = [ "lazy_static", "valuable", diff --git a/Cargo.toml b/Cargo.toml index d80ccf0..ad07371 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,12 +19,14 @@ path = "src/main.rs" anyhow = { version = "1.0.56", features = ["backtrace"] } clap = { version = "3.1.8", features = ["derive", "env"] } dashmap = "5.2.0" +futures-util = { version = "0.3.21", features = ["sink"] } hex = "0.4.3" hmac = "0.12.1" serde = { version = "1.0.136", features = ["derive"] } serde_json = "1.0.79" sha2 = "0.10.2" tokio = { version = "1.17.0", features = ["rt-multi-thread", "io-util", "macros", "net", "time"] } +tokio-util = { version = "0.7.1", features = ["codec"] } tracing = "0.1.32" tracing-subscriber = "0.3.10" uuid = { version = "0.8.2", features = ["serde", "v4"] } diff --git a/src/auth.rs b/src/auth.rs index 95e19a4..ce8237c 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -3,10 +3,10 @@ use anyhow::{bail, ensure, Result}; use hmac::{Hmac, Mac}; use sha2::{Digest, Sha256}; -use tokio::io::{AsyncBufRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite}; use uuid::Uuid; -use crate::shared::{recv_json_timeout, send_json, ClientMessage, ServerMessage}; +use crate::shared::{ClientMessage, Delimited, ServerMessage}; /// Wrapper around a MAC used for authenticating clients that have a secret. pub struct Authenticator(Hmac); @@ -48,13 +48,13 @@ impl Authenticator { } /// As the server, send a challenge to the client and validate their response. - pub async fn server_handshake( + pub async fn server_handshake( &self, - stream: &mut (impl AsyncBufRead + AsyncWrite + Unpin), + stream: &mut Delimited, ) -> Result<()> { let challenge = Uuid::new_v4(); - send_json(stream, ServerMessage::Challenge(challenge)).await?; - match recv_json_timeout(stream).await? { + stream.send(ServerMessage::Challenge(challenge)).await?; + match stream.recv_timeout().await? { Some(ClientMessage::Authenticate(tag)) => { ensure!(self.validate(&challenge, &tag), "invalid secret"); Ok(()) @@ -64,16 +64,16 @@ impl Authenticator { } /// As the client, answer a challenge to attempt to authenticate with the server. - pub async fn client_handshake( + pub async fn client_handshake( &self, - stream: &mut (impl AsyncBufRead + AsyncWrite + Unpin), + stream: &mut Delimited, ) -> Result<()> { - let challenge = match recv_json_timeout(stream).await? { + let challenge = match stream.recv_timeout().await? { Some(ServerMessage::Challenge(challenge)) => challenge, _ => bail!("expected authentication challenge, but no secret was required"), }; let tag = self.answer(&challenge); - send_json(stream, ClientMessage::Authenticate(tag)).await?; + stream.send(ClientMessage::Authenticate(tag)).await?; Ok(()) } } diff --git a/src/client.rs b/src/client.rs index 2dde7d8..9ea31a4 100644 --- a/src/client.rs +++ b/src/client.rs @@ -3,20 +3,21 @@ use std::sync::Arc; use anyhow::{bail, Context, Result}; -use tokio::{io::BufReader, net::TcpStream, time::timeout}; + +use tokio::io::AsyncWriteExt; +use tokio::{net::TcpStream, time::timeout}; use tracing::{error, info, info_span, warn, Instrument}; use uuid::Uuid; use crate::auth::Authenticator; use crate::shared::{ - proxy, recv_json, recv_json_timeout, send_json, ClientMessage, ServerMessage, CONTROL_PORT, - NETWORK_TIMEOUT, + proxy, ClientMessage, Delimited, ServerMessage, CONTROL_PORT, NETWORK_TIMEOUT, }; /// State structure for the client. pub struct Client { /// Control connection to the server. - conn: Option>, + conn: Option>, /// Destination address of the server. to: String, @@ -43,15 +44,14 @@ impl Client { port: u16, secret: Option<&str>, ) -> Result { - let mut stream = BufReader::new(connect_with_timeout(to, CONTROL_PORT).await?); - + let mut stream = Delimited::new(connect_with_timeout(to, CONTROL_PORT).await?); let auth = secret.map(Authenticator::new); if let Some(auth) = &auth { auth.client_handshake(&mut stream).await?; } - send_json(&mut stream, ClientMessage::Hello(port)).await?; - let remote_port = match recv_json_timeout(&mut stream).await? { + stream.send(ClientMessage::Hello(port)).await?; + let remote_port = match stream.recv_timeout().await? { Some(ServerMessage::Hello(remote_port)) => remote_port, Some(ServerMessage::Error(message)) => bail!("server error: {message}"), Some(ServerMessage::Challenge(_)) => { @@ -82,10 +82,8 @@ impl Client { pub async fn listen(mut self) -> Result<()> { let mut conn = self.conn.take().unwrap(); let this = Arc::new(self); - let mut buf = Vec::new(); loop { - let msg = recv_json(&mut conn, &mut buf).await?; - match msg { + match conn.recv().await? { Some(ServerMessage::Hello(_)) => warn!("unexpected hello"), Some(ServerMessage::Challenge(_)) => warn!("unexpected challenge"), Some(ServerMessage::Heartbeat) => (), @@ -110,14 +108,16 @@ impl Client { async fn handle_connection(&self, id: Uuid) -> Result<()> { let mut remote_conn = - BufReader::new(connect_with_timeout(&self.to[..], CONTROL_PORT).await?); + Delimited::new(connect_with_timeout(&self.to[..], CONTROL_PORT).await?); if let Some(auth) = &self.auth { auth.client_handshake(&mut remote_conn).await?; } - send_json(&mut remote_conn, ClientMessage::Accept(id)).await?; - - let local_conn = connect_with_timeout(&self.local_host, self.local_port).await?; - proxy(local_conn, remote_conn).await?; + remote_conn.send(ClientMessage::Accept(id)).await?; + let mut local_conn = connect_with_timeout(&self.local_host, self.local_port).await?; + let parts = remote_conn.into_parts(); + debug_assert!(parts.write_buf.is_empty(), "framed write buffer not empty"); + local_conn.write_all(&parts.read_buf).await?; // mostly of the cases, this will be empty + proxy(local_conn, parts.io).await?; Ok(()) } } diff --git a/src/server.rs b/src/server.rs index f4a8b4f..ab71278 100644 --- a/src/server.rs +++ b/src/server.rs @@ -6,16 +6,14 @@ use std::time::Duration; use anyhow::Result; use dashmap::DashMap; -use tokio::io::BufReader; +use tokio::io::AsyncWriteExt; use tokio::net::{TcpListener, TcpStream}; use tokio::time::{sleep, timeout}; use tracing::{info, info_span, warn, Instrument}; use uuid::Uuid; use crate::auth::Authenticator; -use crate::shared::{ - proxy, recv_json_timeout, send_json, ClientMessage, ServerMessage, CONTROL_PORT, -}; +use crate::shared::{proxy, ClientMessage, Delimited, ServerMessage, CONTROL_PORT}; /// State structure for the server. pub struct Server { @@ -64,16 +62,16 @@ impl Server { } async fn handle_connection(&self, stream: TcpStream) -> Result<()> { - let mut stream = BufReader::new(stream); + let mut stream = Delimited::new(stream); if let Some(auth) = &self.auth { if let Err(err) = auth.server_handshake(&mut stream).await { warn!(%err, "server handshake failed"); - send_json(&mut stream, ServerMessage::Error(err.to_string())).await?; + stream.send(ServerMessage::Error(err.to_string())).await?; return Ok(()); } } - match recv_json_timeout(&mut stream).await? { + match stream.recv_timeout().await? { Some(ClientMessage::Authenticate(_)) => { warn!("unexpected authenticate"); Ok(()) @@ -88,22 +86,17 @@ impl Server { Ok(listener) => listener, Err(_) => { warn!(?port, "could not bind to local port"); - send_json( - &mut stream, - ServerMessage::Error("port already in use".into()), - ) - .await?; + stream + .send(ServerMessage::Error("port already in use".into())) + .await?; return Ok(()); } }; let port = listener.local_addr()?.port(); - send_json(&mut stream, ServerMessage::Hello(port)).await?; + stream.send(ServerMessage::Hello(port)).await?; loop { - if send_json(&mut stream, ServerMessage::Heartbeat) - .await - .is_err() - { + if stream.send(ServerMessage::Heartbeat).await.is_err() { // Assume that the TCP connection has been dropped. return Ok(()); } @@ -123,14 +116,19 @@ impl Server { warn!(%id, "removed stale connection"); } }); - send_json(&mut stream, ServerMessage::Connection(id)).await?; + stream.send(ServerMessage::Connection(id)).await?; } } } Some(ClientMessage::Accept(id)) => { info!(%id, "forwarding connection"); match self.conns.remove(&id) { - Some((_, stream2)) => proxy(stream, stream2).await?, + Some((_, mut stream2)) => { + let parts = stream.into_parts(); + debug_assert!(parts.write_buf.is_empty(), "framed write buffer not empty"); + stream2.write_all(&parts.read_buf).await?; + proxy(parts.io, stream2).await? + } None => warn!(%id, "missing connection"), } Ok(()) diff --git a/src/shared.rs b/src/shared.rs index 8a8fce0..6acbe1e 100644 --- a/src/shared.rs +++ b/src/shared.rs @@ -3,16 +3,22 @@ use std::time::Duration; use anyhow::{Context, Result}; +use futures_util::{SinkExt, StreamExt}; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; -use tokio::io::{self, AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt}; +use tokio::io::{self, AsyncRead, AsyncWrite}; + use tokio::time::timeout; +use tokio_util::codec::{AnyDelimiterCodec, Framed, FramedParts}; use tracing::trace; use uuid::Uuid; /// TCP port used for control connections with the server. pub const CONTROL_PORT: u16 = 7835; +/// Maxmium byte length for a JSON frame in the stream. +pub const MAX_FRAME_LENGTH: usize = 256; + /// Timeout for network connections and initial protocol messages. pub const NETWORK_TIMEOUT: Duration = Duration::from_secs(3); @@ -48,6 +54,52 @@ pub enum ServerMessage { Error(String), } +/// Transport stream with JSON frames delimited by null characters. +pub struct Delimited(Framed); + +impl Delimited { + /// Construct a new delimited stream. + pub fn new(stream: U) -> Self { + let codec = AnyDelimiterCodec::new_with_max_length(vec![0], vec![0], MAX_FRAME_LENGTH); + Self(Framed::new(stream, codec)) + } + + /// Read the next null-delimited JSON instruction from a stream. + pub async fn recv(&mut self) -> Result> { + trace!("waiting to receive json message"); + if let Some(next_message) = self.0.next().await { + let byte_message = next_message.context("frame error, invalid byte length")?; + let serialized_obj = serde_json::from_slice(&byte_message.to_vec()) + .context("unable to parse message")?; + Ok(serialized_obj) + } else { + Ok(None) + } + } + + /// Read the next null-delimited JSON instruction, with a default timeout. + /// + /// This is useful for parsing the initial message of a stream for handshake or + /// other protocol purposes, where we do not want to wait indefinitely. + pub async fn recv_timeout(&mut self) -> Result> { + timeout(NETWORK_TIMEOUT, self.recv()) + .await + .context("timed out waiting for initial message")? + } + + /// Send a null-terminated JSON instruction on a stream. + pub async fn send(&mut self, msg: T) -> Result<()> { + trace!("sending json message"); + self.0.send(serde_json::to_string(&msg)?).await?; + Ok(()) + } + + /// Consume this object, returning current buffers and the inner transport. + pub fn into_parts(self) -> FramedParts { + self.0.into_parts() + } +} + /// Copy data mutually between two read/write streams. pub async fn proxy(stream1: S1, stream2: S2) -> io::Result<()> where @@ -62,41 +114,3 @@ where }?; Ok(()) } - -/// Read the next null-delimited JSON instruction from a stream. -pub async fn recv_json( - reader: &mut (impl AsyncBufRead + Unpin), - buf: &mut Vec, -) -> Result> { - trace!("waiting to receive json message"); - buf.clear(); - reader.read_until(0, buf).await?; - if buf.is_empty() { - return Ok(None); - } - if buf.last() == Some(&0) { - buf.pop(); - } - Ok(serde_json::from_slice(buf).context("failed to parse JSON")?) -} - -/// Read the next null-delimited JSON instruction, with a default timeout. -/// -/// This is useful for parsing the initial message of a stream for handshake or -/// other protocol purposes, where we do not want to wait indefinitely. -pub async fn recv_json_timeout( - reader: &mut (impl AsyncBufRead + Unpin), -) -> Result> { - timeout(NETWORK_TIMEOUT, recv_json(reader, &mut Vec::new())) - .await - .context("timed out waiting for initial message")? -} - -/// Send a null-terminated JSON instruction on a stream. -pub async fn send_json(writer: &mut (impl AsyncWrite + Unpin), msg: T) -> Result<()> { - trace!("sending json message"); - let msg = serde_json::to_vec(&msg)?; - writer.write_all(&msg).await?; - writer.write_all(&[0]).await?; - Ok(()) -} diff --git a/tests/auth_test.rs b/tests/auth_test.rs index e535b01..d80a345 100644 --- a/tests/auth_test.rs +++ b/tests/auth_test.rs @@ -1,14 +1,14 @@ use anyhow::Result; -use bore_cli::auth::Authenticator; -use tokio::io::{self, BufReader}; +use bore_cli::{auth::Authenticator, shared::Delimited}; +use tokio::io::{self}; #[tokio::test] async fn auth_handshake() -> Result<()> { let auth = Authenticator::new("some secret string"); let (client, server) = io::duplex(8); // Ensure correctness with limited capacity. - let mut client = BufReader::new(client); - let mut server = BufReader::new(server); + let mut client = Delimited::new(client); + let mut server = Delimited::new(server); tokio::try_join!( auth.client_handshake(&mut client), @@ -24,8 +24,8 @@ async fn auth_handshake_fail() { let auth2 = Authenticator::new("different server secret"); let (client, server) = io::duplex(8); // Ensure correctness with limited capacity. - let mut client = BufReader::new(client); - let mut server = BufReader::new(server); + let mut client = Delimited::new(client); + let mut server = Delimited::new(server); let result = tokio::try_join!( auth.client_handshake(&mut client), diff --git a/tests/e2e_test.rs b/tests/e2e_test.rs index 958b154..bc77364 100644 --- a/tests/e2e_test.rs +++ b/tests/e2e_test.rs @@ -2,7 +2,7 @@ use std::net::SocketAddr; use std::time::Duration; use anyhow::{anyhow, Result}; -use bore_cli::{client::Client, server::Server}; +use bore_cli::{client::Client, server::Server, shared::CONTROL_PORT}; use lazy_static::lazy_static; use rstest::*; use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -99,3 +99,21 @@ async fn invalid_address() -> Result<()> { )?; Ok(()) } + +#[tokio::test] +async fn very_long_frame() -> Result<()> { + let _guard = SERIAL_GUARD.lock().await; + + spawn_server(None).await; + let mut attacker = TcpStream::connect(("localhost", CONTROL_PORT)).await?; + + // Slowly send a very long frame. + for _ in 0..10 { + let result = attacker.write_all(&[42u8; 100000]).await; + if result.is_err() { + return Ok(()); + } + time::sleep(Duration::from_millis(10)).await; + } + panic!("did not exit after a 1 MB frame"); +}