Improve stability by exiting immediately on common errors (#2)

* Kill connections immediately on missing or close

* Add timeout to initial protocol messages

* Add low-level tracing for JSON messages

* Add timeout to initial TCP connections
This commit is contained in:
Eric Zhang 2022-04-08 15:55:54 -04:00 committed by GitHub
parent c1efefeddf
commit 2d0dcf9889
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 51 additions and 29 deletions

View File

@ -6,7 +6,7 @@ use sha2::{Digest, Sha256};
use tokio::io::{AsyncBufRead, AsyncWrite}; use tokio::io::{AsyncBufRead, AsyncWrite};
use uuid::Uuid; use uuid::Uuid;
use crate::shared::{recv_json, send_json, ClientMessage, ServerMessage}; use crate::shared::{recv_json_timeout, send_json, ClientMessage, ServerMessage};
/// Wrapper around a MAC used for authenticating clients that have a secret. /// Wrapper around a MAC used for authenticating clients that have a secret.
pub struct Authenticator(Hmac<Sha256>); pub struct Authenticator(Hmac<Sha256>);
@ -54,7 +54,7 @@ impl Authenticator {
) -> Result<()> { ) -> Result<()> {
let challenge = Uuid::new_v4(); let challenge = Uuid::new_v4();
send_json(stream, ServerMessage::Challenge(challenge)).await?; send_json(stream, ServerMessage::Challenge(challenge)).await?;
match recv_json(stream, &mut Vec::new()).await? { match recv_json_timeout(stream).await? {
Some(ClientMessage::Authenticate(tag)) => { Some(ClientMessage::Authenticate(tag)) => {
ensure!(self.validate(&challenge, &tag), "invalid secret"); ensure!(self.validate(&challenge, &tag), "invalid secret");
Ok(()) Ok(())
@ -68,7 +68,7 @@ impl Authenticator {
&self, &self,
stream: &mut (impl AsyncBufRead + AsyncWrite + Unpin), stream: &mut (impl AsyncBufRead + AsyncWrite + Unpin),
) -> Result<()> { ) -> Result<()> {
let challenge = match recv_json(stream, &mut Vec::new()).await? { let challenge = match recv_json_timeout(stream).await? {
Some(ServerMessage::Challenge(challenge)) => challenge, Some(ServerMessage::Challenge(challenge)) => challenge,
_ => bail!("expected authentication challenge, but no secret was required"), _ => bail!("expected authentication challenge, but no secret was required"),
}; };

View File

@ -3,12 +3,15 @@
use std::sync::Arc; use std::sync::Arc;
use anyhow::{bail, Context, Result}; use anyhow::{bail, Context, Result};
use tokio::{io::BufReader, net::TcpStream}; use tokio::{io::BufReader, net::TcpStream, time::timeout};
use tracing::{error, info, info_span, warn, Instrument}; use tracing::{error, info, info_span, warn, Instrument};
use uuid::Uuid; use uuid::Uuid;
use crate::auth::Authenticator; use crate::auth::Authenticator;
use crate::shared::{proxy, recv_json, send_json, ClientMessage, ServerMessage, CONTROL_PORT}; use crate::shared::{
proxy, recv_json, recv_json_timeout, send_json, ClientMessage, ServerMessage, CONTROL_PORT,
NETWORK_TIMEOUT,
};
/// State structure for the client. /// State structure for the client.
pub struct Client { pub struct Client {
@ -31,10 +34,7 @@ pub struct Client {
impl Client { impl Client {
/// Create a new client. /// Create a new client.
pub async fn new(local_port: u16, to: &str, port: u16, secret: Option<&str>) -> Result<Self> { pub async fn new(local_port: u16, to: &str, port: u16, secret: Option<&str>) -> Result<Self> {
let stream = TcpStream::connect((to, CONTROL_PORT)) let mut stream = BufReader::new(connect_with_timeout(to, CONTROL_PORT).await?);
.await
.with_context(|| format!("could not connect to {to}:{CONTROL_PORT}"))?;
let mut stream = BufReader::new(stream);
let auth = secret.map(Authenticator::new); let auth = secret.map(Authenticator::new);
if let Some(auth) = &auth { if let Some(auth) = &auth {
@ -42,7 +42,7 @@ impl Client {
} }
send_json(&mut stream, ClientMessage::Hello(port)).await?; send_json(&mut stream, ClientMessage::Hello(port)).await?;
let remote_port = match recv_json(&mut stream, &mut Vec::new()).await? { let remote_port = match recv_json_timeout(&mut stream).await? {
Some(ServerMessage::Hello(remote_port)) => remote_port, Some(ServerMessage::Hello(remote_port)) => remote_port,
Some(ServerMessage::Error(message)) => bail!("server error: {message}"), Some(ServerMessage::Error(message)) => bail!("server error: {message}"),
Some(ServerMessage::Challenge(_)) => { Some(ServerMessage::Challenge(_)) => {
@ -99,21 +99,23 @@ impl Client {
} }
async fn handle_connection(&self, id: Uuid) -> Result<()> { async fn handle_connection(&self, id: Uuid) -> Result<()> {
let local_conn = TcpStream::connect(("localhost", self.local_port)) let mut remote_conn =
.await BufReader::new(connect_with_timeout(&self.to[..], CONTROL_PORT).await?);
.context("failed TCP connection to local port")?;
let mut remote_conn = BufReader::new(
TcpStream::connect((&self.to[..], CONTROL_PORT))
.await
.context("failed TCP connection to remote port")?,
);
if let Some(auth) = &self.auth { if let Some(auth) = &self.auth {
auth.client_handshake(&mut remote_conn).await?; auth.client_handshake(&mut remote_conn).await?;
} }
send_json(&mut remote_conn, ClientMessage::Accept(id)).await?; send_json(&mut remote_conn, ClientMessage::Accept(id)).await?;
let local_conn = connect_with_timeout("localhost", self.local_port).await?;
proxy(local_conn, remote_conn).await?; proxy(local_conn, remote_conn).await?;
Ok(()) Ok(())
} }
} }
async fn connect_with_timeout(to: &str, port: u16) -> Result<TcpStream> {
match timeout(NETWORK_TIMEOUT, TcpStream::connect((to, port))).await {
Ok(res) => res,
Err(err) => Err(err.into()),
}
.with_context(|| format!("could not connect to {to}:{port}"))
}

View File

@ -13,7 +13,9 @@ use tracing::{info, info_span, warn, Instrument};
use uuid::Uuid; use uuid::Uuid;
use crate::auth::Authenticator; use crate::auth::Authenticator;
use crate::shared::{proxy, recv_json, send_json, ClientMessage, ServerMessage, CONTROL_PORT}; use crate::shared::{
proxy, recv_json_timeout, send_json, ClientMessage, ServerMessage, CONTROL_PORT,
};
/// State structure for the server. /// State structure for the server.
pub struct Server { pub struct Server {
@ -71,10 +73,7 @@ impl Server {
} }
} }
let mut buf = Vec::new(); match recv_json_timeout(&mut stream).await? {
let msg = recv_json(&mut stream, &mut buf).await?;
match msg {
Some(ClientMessage::Authenticate(_)) => { Some(ClientMessage::Authenticate(_)) => {
warn!("unexpected authenticate"); warn!("unexpected authenticate");
Ok(()) Ok(())

View File

@ -1,14 +1,21 @@
//! Shared data structures, utilities, and protocol definitions. //! Shared data structures, utilities, and protocol definitions.
use std::time::Duration;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::io::{self, AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::io::{self, AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::time::timeout;
use tracing::trace;
use uuid::Uuid; use uuid::Uuid;
/// TCP port used for control connections with the server. /// TCP port used for control connections with the server.
pub const CONTROL_PORT: u16 = 7835; pub const CONTROL_PORT: u16 = 7835;
/// Timeout for network connections and initial protocol messages.
pub const NETWORK_TIMEOUT: Duration = Duration::from_secs(3);
/// A message from the client on the control connection. /// A message from the client on the control connection.
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub enum ClientMessage { pub enum ClientMessage {
@ -49,10 +56,10 @@ where
{ {
let (mut s1_read, mut s1_write) = io::split(stream1); let (mut s1_read, mut s1_write) = io::split(stream1);
let (mut s2_read, mut s2_write) = io::split(stream2); let (mut s2_read, mut s2_write) = io::split(stream2);
tokio::try_join!( tokio::select! {
io::copy(&mut s1_read, &mut s2_write), res = io::copy(&mut s1_read, &mut s2_write) => res,
io::copy(&mut s2_read, &mut s1_write), res = io::copy(&mut s2_read, &mut s1_write) => res,
)?; }?;
Ok(()) Ok(())
} }
@ -61,6 +68,7 @@ pub async fn recv_json<T: DeserializeOwned>(
reader: &mut (impl AsyncBufRead + Unpin), reader: &mut (impl AsyncBufRead + Unpin),
buf: &mut Vec<u8>, buf: &mut Vec<u8>,
) -> Result<Option<T>> { ) -> Result<Option<T>> {
trace!("waiting to receive json message");
buf.clear(); buf.clear();
reader.read_until(0, buf).await?; reader.read_until(0, buf).await?;
if buf.is_empty() { if buf.is_empty() {
@ -72,8 +80,21 @@ pub async fn recv_json<T: DeserializeOwned>(
Ok(serde_json::from_slice(buf).context("failed to parse JSON")?) Ok(serde_json::from_slice(buf).context("failed to parse JSON")?)
} }
/// Read the next null-delimited JSON instruction, with a default timeout.
///
/// This is useful for parsing the initial message of a stream for handshake or
/// other protocol purposes, where we do not want to wait indefinitely.
pub async fn recv_json_timeout<T: DeserializeOwned>(
reader: &mut (impl AsyncBufRead + Unpin),
) -> Result<Option<T>> {
timeout(NETWORK_TIMEOUT, recv_json(reader, &mut Vec::new()))
.await
.context("timed out waiting for initial message")?
}
/// Send a null-terminated JSON instruction on a stream. /// Send a null-terminated JSON instruction on a stream.
pub async fn send_json<T: Serialize>(writer: &mut (impl AsyncWrite + Unpin), msg: T) -> Result<()> { pub async fn send_json<T: Serialize>(writer: &mut (impl AsyncWrite + Unpin), msg: T) -> Result<()> {
trace!("sending json message");
let msg = serde_json::to_vec(&msg)?; let msg = serde_json::to_vec(&msg)?;
writer.write_all(&msg).await?; writer.write_all(&msg).await?;
writer.write_all(&[0]).await?; writer.write_all(&[0]).await?;