mirror of https://github.com/ekzhang/bore.git
Improve stability by exiting immediately on common errors (#2)
* Kill connections immediately on missing or close * Add timeout to initial protocol messages * Add low-level tracing for JSON messages * Add timeout to initial TCP connections
This commit is contained in:
parent
c1efefeddf
commit
2d0dcf9889
|
@ -6,7 +6,7 @@ use sha2::{Digest, Sha256};
|
||||||
use tokio::io::{AsyncBufRead, AsyncWrite};
|
use tokio::io::{AsyncBufRead, AsyncWrite};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::shared::{recv_json, send_json, ClientMessage, ServerMessage};
|
use crate::shared::{recv_json_timeout, send_json, ClientMessage, ServerMessage};
|
||||||
|
|
||||||
/// Wrapper around a MAC used for authenticating clients that have a secret.
|
/// Wrapper around a MAC used for authenticating clients that have a secret.
|
||||||
pub struct Authenticator(Hmac<Sha256>);
|
pub struct Authenticator(Hmac<Sha256>);
|
||||||
|
@ -54,7 +54,7 @@ impl Authenticator {
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let challenge = Uuid::new_v4();
|
let challenge = Uuid::new_v4();
|
||||||
send_json(stream, ServerMessage::Challenge(challenge)).await?;
|
send_json(stream, ServerMessage::Challenge(challenge)).await?;
|
||||||
match recv_json(stream, &mut Vec::new()).await? {
|
match recv_json_timeout(stream).await? {
|
||||||
Some(ClientMessage::Authenticate(tag)) => {
|
Some(ClientMessage::Authenticate(tag)) => {
|
||||||
ensure!(self.validate(&challenge, &tag), "invalid secret");
|
ensure!(self.validate(&challenge, &tag), "invalid secret");
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -68,7 +68,7 @@ impl Authenticator {
|
||||||
&self,
|
&self,
|
||||||
stream: &mut (impl AsyncBufRead + AsyncWrite + Unpin),
|
stream: &mut (impl AsyncBufRead + AsyncWrite + Unpin),
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let challenge = match recv_json(stream, &mut Vec::new()).await? {
|
let challenge = match recv_json_timeout(stream).await? {
|
||||||
Some(ServerMessage::Challenge(challenge)) => challenge,
|
Some(ServerMessage::Challenge(challenge)) => challenge,
|
||||||
_ => bail!("expected authentication challenge, but no secret was required"),
|
_ => bail!("expected authentication challenge, but no secret was required"),
|
||||||
};
|
};
|
||||||
|
|
|
@ -3,12 +3,15 @@
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use anyhow::{bail, Context, Result};
|
use anyhow::{bail, Context, Result};
|
||||||
use tokio::{io::BufReader, net::TcpStream};
|
use tokio::{io::BufReader, net::TcpStream, time::timeout};
|
||||||
use tracing::{error, info, info_span, warn, Instrument};
|
use tracing::{error, info, info_span, warn, Instrument};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::auth::Authenticator;
|
use crate::auth::Authenticator;
|
||||||
use crate::shared::{proxy, recv_json, send_json, ClientMessage, ServerMessage, CONTROL_PORT};
|
use crate::shared::{
|
||||||
|
proxy, recv_json, recv_json_timeout, send_json, ClientMessage, ServerMessage, CONTROL_PORT,
|
||||||
|
NETWORK_TIMEOUT,
|
||||||
|
};
|
||||||
|
|
||||||
/// State structure for the client.
|
/// State structure for the client.
|
||||||
pub struct Client {
|
pub struct Client {
|
||||||
|
@ -31,10 +34,7 @@ pub struct Client {
|
||||||
impl Client {
|
impl Client {
|
||||||
/// Create a new client.
|
/// Create a new client.
|
||||||
pub async fn new(local_port: u16, to: &str, port: u16, secret: Option<&str>) -> Result<Self> {
|
pub async fn new(local_port: u16, to: &str, port: u16, secret: Option<&str>) -> Result<Self> {
|
||||||
let stream = TcpStream::connect((to, CONTROL_PORT))
|
let mut stream = BufReader::new(connect_with_timeout(to, CONTROL_PORT).await?);
|
||||||
.await
|
|
||||||
.with_context(|| format!("could not connect to {to}:{CONTROL_PORT}"))?;
|
|
||||||
let mut stream = BufReader::new(stream);
|
|
||||||
|
|
||||||
let auth = secret.map(Authenticator::new);
|
let auth = secret.map(Authenticator::new);
|
||||||
if let Some(auth) = &auth {
|
if let Some(auth) = &auth {
|
||||||
|
@ -42,7 +42,7 @@ impl Client {
|
||||||
}
|
}
|
||||||
|
|
||||||
send_json(&mut stream, ClientMessage::Hello(port)).await?;
|
send_json(&mut stream, ClientMessage::Hello(port)).await?;
|
||||||
let remote_port = match recv_json(&mut stream, &mut Vec::new()).await? {
|
let remote_port = match recv_json_timeout(&mut stream).await? {
|
||||||
Some(ServerMessage::Hello(remote_port)) => remote_port,
|
Some(ServerMessage::Hello(remote_port)) => remote_port,
|
||||||
Some(ServerMessage::Error(message)) => bail!("server error: {message}"),
|
Some(ServerMessage::Error(message)) => bail!("server error: {message}"),
|
||||||
Some(ServerMessage::Challenge(_)) => {
|
Some(ServerMessage::Challenge(_)) => {
|
||||||
|
@ -99,21 +99,23 @@ impl Client {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_connection(&self, id: Uuid) -> Result<()> {
|
async fn handle_connection(&self, id: Uuid) -> Result<()> {
|
||||||
let local_conn = TcpStream::connect(("localhost", self.local_port))
|
let mut remote_conn =
|
||||||
.await
|
BufReader::new(connect_with_timeout(&self.to[..], CONTROL_PORT).await?);
|
||||||
.context("failed TCP connection to local port")?;
|
|
||||||
let mut remote_conn = BufReader::new(
|
|
||||||
TcpStream::connect((&self.to[..], CONTROL_PORT))
|
|
||||||
.await
|
|
||||||
.context("failed TCP connection to remote port")?,
|
|
||||||
);
|
|
||||||
|
|
||||||
if let Some(auth) = &self.auth {
|
if let Some(auth) = &self.auth {
|
||||||
auth.client_handshake(&mut remote_conn).await?;
|
auth.client_handshake(&mut remote_conn).await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
send_json(&mut remote_conn, ClientMessage::Accept(id)).await?;
|
send_json(&mut remote_conn, ClientMessage::Accept(id)).await?;
|
||||||
|
|
||||||
|
let local_conn = connect_with_timeout("localhost", self.local_port).await?;
|
||||||
proxy(local_conn, remote_conn).await?;
|
proxy(local_conn, remote_conn).await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn connect_with_timeout(to: &str, port: u16) -> Result<TcpStream> {
|
||||||
|
match timeout(NETWORK_TIMEOUT, TcpStream::connect((to, port))).await {
|
||||||
|
Ok(res) => res,
|
||||||
|
Err(err) => Err(err.into()),
|
||||||
|
}
|
||||||
|
.with_context(|| format!("could not connect to {to}:{port}"))
|
||||||
|
}
|
||||||
|
|
|
@ -13,7 +13,9 @@ use tracing::{info, info_span, warn, Instrument};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::auth::Authenticator;
|
use crate::auth::Authenticator;
|
||||||
use crate::shared::{proxy, recv_json, send_json, ClientMessage, ServerMessage, CONTROL_PORT};
|
use crate::shared::{
|
||||||
|
proxy, recv_json_timeout, send_json, ClientMessage, ServerMessage, CONTROL_PORT,
|
||||||
|
};
|
||||||
|
|
||||||
/// State structure for the server.
|
/// State structure for the server.
|
||||||
pub struct Server {
|
pub struct Server {
|
||||||
|
@ -71,10 +73,7 @@ impl Server {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut buf = Vec::new();
|
match recv_json_timeout(&mut stream).await? {
|
||||||
let msg = recv_json(&mut stream, &mut buf).await?;
|
|
||||||
|
|
||||||
match msg {
|
|
||||||
Some(ClientMessage::Authenticate(_)) => {
|
Some(ClientMessage::Authenticate(_)) => {
|
||||||
warn!("unexpected authenticate");
|
warn!("unexpected authenticate");
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
|
@ -1,14 +1,21 @@
|
||||||
//! Shared data structures, utilities, and protocol definitions.
|
//! Shared data structures, utilities, and protocol definitions.
|
||||||
|
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use serde::de::DeserializeOwned;
|
use serde::de::DeserializeOwned;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tokio::io::{self, AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt};
|
use tokio::io::{self, AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||||
|
use tokio::time::timeout;
|
||||||
|
use tracing::trace;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
/// TCP port used for control connections with the server.
|
/// TCP port used for control connections with the server.
|
||||||
pub const CONTROL_PORT: u16 = 7835;
|
pub const CONTROL_PORT: u16 = 7835;
|
||||||
|
|
||||||
|
/// Timeout for network connections and initial protocol messages.
|
||||||
|
pub const NETWORK_TIMEOUT: Duration = Duration::from_secs(3);
|
||||||
|
|
||||||
/// A message from the client on the control connection.
|
/// A message from the client on the control connection.
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub enum ClientMessage {
|
pub enum ClientMessage {
|
||||||
|
@ -49,10 +56,10 @@ where
|
||||||
{
|
{
|
||||||
let (mut s1_read, mut s1_write) = io::split(stream1);
|
let (mut s1_read, mut s1_write) = io::split(stream1);
|
||||||
let (mut s2_read, mut s2_write) = io::split(stream2);
|
let (mut s2_read, mut s2_write) = io::split(stream2);
|
||||||
tokio::try_join!(
|
tokio::select! {
|
||||||
io::copy(&mut s1_read, &mut s2_write),
|
res = io::copy(&mut s1_read, &mut s2_write) => res,
|
||||||
io::copy(&mut s2_read, &mut s1_write),
|
res = io::copy(&mut s2_read, &mut s1_write) => res,
|
||||||
)?;
|
}?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -61,6 +68,7 @@ pub async fn recv_json<T: DeserializeOwned>(
|
||||||
reader: &mut (impl AsyncBufRead + Unpin),
|
reader: &mut (impl AsyncBufRead + Unpin),
|
||||||
buf: &mut Vec<u8>,
|
buf: &mut Vec<u8>,
|
||||||
) -> Result<Option<T>> {
|
) -> Result<Option<T>> {
|
||||||
|
trace!("waiting to receive json message");
|
||||||
buf.clear();
|
buf.clear();
|
||||||
reader.read_until(0, buf).await?;
|
reader.read_until(0, buf).await?;
|
||||||
if buf.is_empty() {
|
if buf.is_empty() {
|
||||||
|
@ -72,8 +80,21 @@ pub async fn recv_json<T: DeserializeOwned>(
|
||||||
Ok(serde_json::from_slice(buf).context("failed to parse JSON")?)
|
Ok(serde_json::from_slice(buf).context("failed to parse JSON")?)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Read the next null-delimited JSON instruction, with a default timeout.
|
||||||
|
///
|
||||||
|
/// This is useful for parsing the initial message of a stream for handshake or
|
||||||
|
/// other protocol purposes, where we do not want to wait indefinitely.
|
||||||
|
pub async fn recv_json_timeout<T: DeserializeOwned>(
|
||||||
|
reader: &mut (impl AsyncBufRead + Unpin),
|
||||||
|
) -> Result<Option<T>> {
|
||||||
|
timeout(NETWORK_TIMEOUT, recv_json(reader, &mut Vec::new()))
|
||||||
|
.await
|
||||||
|
.context("timed out waiting for initial message")?
|
||||||
|
}
|
||||||
|
|
||||||
/// Send a null-terminated JSON instruction on a stream.
|
/// Send a null-terminated JSON instruction on a stream.
|
||||||
pub async fn send_json<T: Serialize>(writer: &mut (impl AsyncWrite + Unpin), msg: T) -> Result<()> {
|
pub async fn send_json<T: Serialize>(writer: &mut (impl AsyncWrite + Unpin), msg: T) -> Result<()> {
|
||||||
|
trace!("sending json message");
|
||||||
let msg = serde_json::to_vec(&msg)?;
|
let msg = serde_json::to_vec(&msg)?;
|
||||||
writer.write_all(&msg).await?;
|
writer.write_all(&msg).await?;
|
||||||
writer.write_all(&[0]).await?;
|
writer.write_all(&[0]).await?;
|
||||||
|
|
Loading…
Reference in New Issue