diff --git a/src/auth.rs b/src/auth.rs index b7ce2a1..95e19a4 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -6,7 +6,7 @@ use sha2::{Digest, Sha256}; use tokio::io::{AsyncBufRead, AsyncWrite}; use uuid::Uuid; -use crate::shared::{recv_json, send_json, ClientMessage, ServerMessage}; +use crate::shared::{recv_json_timeout, send_json, ClientMessage, ServerMessage}; /// Wrapper around a MAC used for authenticating clients that have a secret. pub struct Authenticator(Hmac); @@ -54,7 +54,7 @@ impl Authenticator { ) -> Result<()> { let challenge = Uuid::new_v4(); send_json(stream, ServerMessage::Challenge(challenge)).await?; - match recv_json(stream, &mut Vec::new()).await? { + match recv_json_timeout(stream).await? { Some(ClientMessage::Authenticate(tag)) => { ensure!(self.validate(&challenge, &tag), "invalid secret"); Ok(()) @@ -68,7 +68,7 @@ impl Authenticator { &self, stream: &mut (impl AsyncBufRead + AsyncWrite + Unpin), ) -> Result<()> { - let challenge = match recv_json(stream, &mut Vec::new()).await? { + let challenge = match recv_json_timeout(stream).await? { Some(ServerMessage::Challenge(challenge)) => challenge, _ => bail!("expected authentication challenge, but no secret was required"), }; diff --git a/src/client.rs b/src/client.rs index 7ba0d0c..4190ca7 100644 --- a/src/client.rs +++ b/src/client.rs @@ -3,12 +3,15 @@ use std::sync::Arc; use anyhow::{bail, Context, Result}; -use tokio::{io::BufReader, net::TcpStream}; +use tokio::{io::BufReader, net::TcpStream, time::timeout}; use tracing::{error, info, info_span, warn, Instrument}; use uuid::Uuid; use crate::auth::Authenticator; -use crate::shared::{proxy, recv_json, send_json, ClientMessage, ServerMessage, CONTROL_PORT}; +use crate::shared::{ + proxy, recv_json, recv_json_timeout, send_json, ClientMessage, ServerMessage, CONTROL_PORT, + NETWORK_TIMEOUT, +}; /// State structure for the client. pub struct Client { @@ -31,10 +34,7 @@ pub struct Client { impl Client { /// Create a new client. pub async fn new(local_port: u16, to: &str, port: u16, secret: Option<&str>) -> Result { - let stream = TcpStream::connect((to, CONTROL_PORT)) - .await - .with_context(|| format!("could not connect to {to}:{CONTROL_PORT}"))?; - let mut stream = BufReader::new(stream); + let mut stream = BufReader::new(connect_with_timeout(to, CONTROL_PORT).await?); let auth = secret.map(Authenticator::new); if let Some(auth) = &auth { @@ -42,7 +42,7 @@ impl Client { } send_json(&mut stream, ClientMessage::Hello(port)).await?; - let remote_port = match recv_json(&mut stream, &mut Vec::new()).await? { + let remote_port = match recv_json_timeout(&mut stream).await? { Some(ServerMessage::Hello(remote_port)) => remote_port, Some(ServerMessage::Error(message)) => bail!("server error: {message}"), Some(ServerMessage::Challenge(_)) => { @@ -99,21 +99,23 @@ impl Client { } async fn handle_connection(&self, id: Uuid) -> Result<()> { - let local_conn = TcpStream::connect(("localhost", self.local_port)) - .await - .context("failed TCP connection to local port")?; - let mut remote_conn = BufReader::new( - TcpStream::connect((&self.to[..], CONTROL_PORT)) - .await - .context("failed TCP connection to remote port")?, - ); - + let mut remote_conn = + BufReader::new(connect_with_timeout(&self.to[..], CONTROL_PORT).await?); if let Some(auth) = &self.auth { auth.client_handshake(&mut remote_conn).await?; } - send_json(&mut remote_conn, ClientMessage::Accept(id)).await?; + + let local_conn = connect_with_timeout("localhost", self.local_port).await?; proxy(local_conn, remote_conn).await?; Ok(()) } } + +async fn connect_with_timeout(to: &str, port: u16) -> Result { + match timeout(NETWORK_TIMEOUT, TcpStream::connect((to, port))).await { + Ok(res) => res, + Err(err) => Err(err.into()), + } + .with_context(|| format!("could not connect to {to}:{port}")) +} diff --git a/src/server.rs b/src/server.rs index b75667a..7192bf7 100644 --- a/src/server.rs +++ b/src/server.rs @@ -13,7 +13,9 @@ use tracing::{info, info_span, warn, Instrument}; use uuid::Uuid; use crate::auth::Authenticator; -use crate::shared::{proxy, recv_json, send_json, ClientMessage, ServerMessage, CONTROL_PORT}; +use crate::shared::{ + proxy, recv_json_timeout, send_json, ClientMessage, ServerMessage, CONTROL_PORT, +}; /// State structure for the server. pub struct Server { @@ -71,10 +73,7 @@ impl Server { } } - let mut buf = Vec::new(); - let msg = recv_json(&mut stream, &mut buf).await?; - - match msg { + match recv_json_timeout(&mut stream).await? { Some(ClientMessage::Authenticate(_)) => { warn!("unexpected authenticate"); Ok(()) diff --git a/src/shared.rs b/src/shared.rs index a01fc25..8a8fce0 100644 --- a/src/shared.rs +++ b/src/shared.rs @@ -1,14 +1,21 @@ //! Shared data structures, utilities, and protocol definitions. +use std::time::Duration; + use anyhow::{Context, Result}; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use tokio::io::{self, AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt}; +use tokio::time::timeout; +use tracing::trace; use uuid::Uuid; /// TCP port used for control connections with the server. pub const CONTROL_PORT: u16 = 7835; +/// Timeout for network connections and initial protocol messages. +pub const NETWORK_TIMEOUT: Duration = Duration::from_secs(3); + /// A message from the client on the control connection. #[derive(Debug, Serialize, Deserialize)] pub enum ClientMessage { @@ -49,10 +56,10 @@ where { let (mut s1_read, mut s1_write) = io::split(stream1); let (mut s2_read, mut s2_write) = io::split(stream2); - tokio::try_join!( - io::copy(&mut s1_read, &mut s2_write), - io::copy(&mut s2_read, &mut s1_write), - )?; + tokio::select! { + res = io::copy(&mut s1_read, &mut s2_write) => res, + res = io::copy(&mut s2_read, &mut s1_write) => res, + }?; Ok(()) } @@ -61,6 +68,7 @@ pub async fn recv_json( reader: &mut (impl AsyncBufRead + Unpin), buf: &mut Vec, ) -> Result> { + trace!("waiting to receive json message"); buf.clear(); reader.read_until(0, buf).await?; if buf.is_empty() { @@ -72,8 +80,21 @@ pub async fn recv_json( Ok(serde_json::from_slice(buf).context("failed to parse JSON")?) } +/// Read the next null-delimited JSON instruction, with a default timeout. +/// +/// This is useful for parsing the initial message of a stream for handshake or +/// other protocol purposes, where we do not want to wait indefinitely. +pub async fn recv_json_timeout( + reader: &mut (impl AsyncBufRead + Unpin), +) -> Result> { + timeout(NETWORK_TIMEOUT, recv_json(reader, &mut Vec::new())) + .await + .context("timed out waiting for initial message")? +} + /// Send a null-terminated JSON instruction on a stream. pub async fn send_json(writer: &mut (impl AsyncWrite + Unpin), msg: T) -> Result<()> { + trace!("sending json message"); let msg = serde_json::to_vec(&msg)?; writer.write_all(&msg).await?; writer.write_all(&[0]).await?;