mirror of https://github.com/ekzhang/bore.git
Improve stability by exiting immediately on common errors (#2)
* Kill connections immediately on missing or close * Add timeout to initial protocol messages * Add low-level tracing for JSON messages * Add timeout to initial TCP connections
This commit is contained in:
parent
c1efefeddf
commit
2d0dcf9889
|
@ -6,7 +6,7 @@ use sha2::{Digest, Sha256};
|
|||
use tokio::io::{AsyncBufRead, AsyncWrite};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::shared::{recv_json, send_json, ClientMessage, ServerMessage};
|
||||
use crate::shared::{recv_json_timeout, send_json, ClientMessage, ServerMessage};
|
||||
|
||||
/// Wrapper around a MAC used for authenticating clients that have a secret.
|
||||
pub struct Authenticator(Hmac<Sha256>);
|
||||
|
@ -54,7 +54,7 @@ impl Authenticator {
|
|||
) -> Result<()> {
|
||||
let challenge = Uuid::new_v4();
|
||||
send_json(stream, ServerMessage::Challenge(challenge)).await?;
|
||||
match recv_json(stream, &mut Vec::new()).await? {
|
||||
match recv_json_timeout(stream).await? {
|
||||
Some(ClientMessage::Authenticate(tag)) => {
|
||||
ensure!(self.validate(&challenge, &tag), "invalid secret");
|
||||
Ok(())
|
||||
|
@ -68,7 +68,7 @@ impl Authenticator {
|
|||
&self,
|
||||
stream: &mut (impl AsyncBufRead + AsyncWrite + Unpin),
|
||||
) -> Result<()> {
|
||||
let challenge = match recv_json(stream, &mut Vec::new()).await? {
|
||||
let challenge = match recv_json_timeout(stream).await? {
|
||||
Some(ServerMessage::Challenge(challenge)) => challenge,
|
||||
_ => bail!("expected authentication challenge, but no secret was required"),
|
||||
};
|
||||
|
|
|
@ -3,12 +3,15 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{bail, Context, Result};
|
||||
use tokio::{io::BufReader, net::TcpStream};
|
||||
use tokio::{io::BufReader, net::TcpStream, time::timeout};
|
||||
use tracing::{error, info, info_span, warn, Instrument};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::auth::Authenticator;
|
||||
use crate::shared::{proxy, recv_json, send_json, ClientMessage, ServerMessage, CONTROL_PORT};
|
||||
use crate::shared::{
|
||||
proxy, recv_json, recv_json_timeout, send_json, ClientMessage, ServerMessage, CONTROL_PORT,
|
||||
NETWORK_TIMEOUT,
|
||||
};
|
||||
|
||||
/// State structure for the client.
|
||||
pub struct Client {
|
||||
|
@ -31,10 +34,7 @@ pub struct Client {
|
|||
impl Client {
|
||||
/// Create a new client.
|
||||
pub async fn new(local_port: u16, to: &str, port: u16, secret: Option<&str>) -> Result<Self> {
|
||||
let stream = TcpStream::connect((to, CONTROL_PORT))
|
||||
.await
|
||||
.with_context(|| format!("could not connect to {to}:{CONTROL_PORT}"))?;
|
||||
let mut stream = BufReader::new(stream);
|
||||
let mut stream = BufReader::new(connect_with_timeout(to, CONTROL_PORT).await?);
|
||||
|
||||
let auth = secret.map(Authenticator::new);
|
||||
if let Some(auth) = &auth {
|
||||
|
@ -42,7 +42,7 @@ impl Client {
|
|||
}
|
||||
|
||||
send_json(&mut stream, ClientMessage::Hello(port)).await?;
|
||||
let remote_port = match recv_json(&mut stream, &mut Vec::new()).await? {
|
||||
let remote_port = match recv_json_timeout(&mut stream).await? {
|
||||
Some(ServerMessage::Hello(remote_port)) => remote_port,
|
||||
Some(ServerMessage::Error(message)) => bail!("server error: {message}"),
|
||||
Some(ServerMessage::Challenge(_)) => {
|
||||
|
@ -99,21 +99,23 @@ impl Client {
|
|||
}
|
||||
|
||||
async fn handle_connection(&self, id: Uuid) -> Result<()> {
|
||||
let local_conn = TcpStream::connect(("localhost", self.local_port))
|
||||
.await
|
||||
.context("failed TCP connection to local port")?;
|
||||
let mut remote_conn = BufReader::new(
|
||||
TcpStream::connect((&self.to[..], CONTROL_PORT))
|
||||
.await
|
||||
.context("failed TCP connection to remote port")?,
|
||||
);
|
||||
|
||||
let mut remote_conn =
|
||||
BufReader::new(connect_with_timeout(&self.to[..], CONTROL_PORT).await?);
|
||||
if let Some(auth) = &self.auth {
|
||||
auth.client_handshake(&mut remote_conn).await?;
|
||||
}
|
||||
|
||||
send_json(&mut remote_conn, ClientMessage::Accept(id)).await?;
|
||||
|
||||
let local_conn = connect_with_timeout("localhost", self.local_port).await?;
|
||||
proxy(local_conn, remote_conn).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn connect_with_timeout(to: &str, port: u16) -> Result<TcpStream> {
|
||||
match timeout(NETWORK_TIMEOUT, TcpStream::connect((to, port))).await {
|
||||
Ok(res) => res,
|
||||
Err(err) => Err(err.into()),
|
||||
}
|
||||
.with_context(|| format!("could not connect to {to}:{port}"))
|
||||
}
|
||||
|
|
|
@ -13,7 +13,9 @@ use tracing::{info, info_span, warn, Instrument};
|
|||
use uuid::Uuid;
|
||||
|
||||
use crate::auth::Authenticator;
|
||||
use crate::shared::{proxy, recv_json, send_json, ClientMessage, ServerMessage, CONTROL_PORT};
|
||||
use crate::shared::{
|
||||
proxy, recv_json_timeout, send_json, ClientMessage, ServerMessage, CONTROL_PORT,
|
||||
};
|
||||
|
||||
/// State structure for the server.
|
||||
pub struct Server {
|
||||
|
@ -71,10 +73,7 @@ impl Server {
|
|||
}
|
||||
}
|
||||
|
||||
let mut buf = Vec::new();
|
||||
let msg = recv_json(&mut stream, &mut buf).await?;
|
||||
|
||||
match msg {
|
||||
match recv_json_timeout(&mut stream).await? {
|
||||
Some(ClientMessage::Authenticate(_)) => {
|
||||
warn!("unexpected authenticate");
|
||||
Ok(())
|
||||
|
|
|
@ -1,14 +1,21 @@
|
|||
//! Shared data structures, utilities, and protocol definitions.
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::io::{self, AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::time::timeout;
|
||||
use tracing::trace;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// TCP port used for control connections with the server.
|
||||
pub const CONTROL_PORT: u16 = 7835;
|
||||
|
||||
/// Timeout for network connections and initial protocol messages.
|
||||
pub const NETWORK_TIMEOUT: Duration = Duration::from_secs(3);
|
||||
|
||||
/// A message from the client on the control connection.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub enum ClientMessage {
|
||||
|
@ -49,10 +56,10 @@ where
|
|||
{
|
||||
let (mut s1_read, mut s1_write) = io::split(stream1);
|
||||
let (mut s2_read, mut s2_write) = io::split(stream2);
|
||||
tokio::try_join!(
|
||||
io::copy(&mut s1_read, &mut s2_write),
|
||||
io::copy(&mut s2_read, &mut s1_write),
|
||||
)?;
|
||||
tokio::select! {
|
||||
res = io::copy(&mut s1_read, &mut s2_write) => res,
|
||||
res = io::copy(&mut s2_read, &mut s1_write) => res,
|
||||
}?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -61,6 +68,7 @@ pub async fn recv_json<T: DeserializeOwned>(
|
|||
reader: &mut (impl AsyncBufRead + Unpin),
|
||||
buf: &mut Vec<u8>,
|
||||
) -> Result<Option<T>> {
|
||||
trace!("waiting to receive json message");
|
||||
buf.clear();
|
||||
reader.read_until(0, buf).await?;
|
||||
if buf.is_empty() {
|
||||
|
@ -72,8 +80,21 @@ pub async fn recv_json<T: DeserializeOwned>(
|
|||
Ok(serde_json::from_slice(buf).context("failed to parse JSON")?)
|
||||
}
|
||||
|
||||
/// Read the next null-delimited JSON instruction, with a default timeout.
|
||||
///
|
||||
/// This is useful for parsing the initial message of a stream for handshake or
|
||||
/// other protocol purposes, where we do not want to wait indefinitely.
|
||||
pub async fn recv_json_timeout<T: DeserializeOwned>(
|
||||
reader: &mut (impl AsyncBufRead + Unpin),
|
||||
) -> Result<Option<T>> {
|
||||
timeout(NETWORK_TIMEOUT, recv_json(reader, &mut Vec::new()))
|
||||
.await
|
||||
.context("timed out waiting for initial message")?
|
||||
}
|
||||
|
||||
/// Send a null-terminated JSON instruction on a stream.
|
||||
pub async fn send_json<T: Serialize>(writer: &mut (impl AsyncWrite + Unpin), msg: T) -> Result<()> {
|
||||
trace!("sending json message");
|
||||
let msg = serde_json::to_vec(&msg)?;
|
||||
writer.write_all(&msg).await?;
|
||||
writer.write_all(&[0]).await?;
|
||||
|
|
Loading…
Reference in New Issue