diff --git a/src/helper.rs b/src/helper.rs index 8be2f4b..4ba8cab 100644 --- a/src/helper.rs +++ b/src/helper.rs @@ -8,17 +8,27 @@ use std::{ use anyhow::{Context, Result}; use socket2::{SockRef, TcpKeepalive}; use tokio::net::{TcpStream, ToSocketAddrs, UdpSocket}; +use tracing::error; // Tokio hesitates to expose this option...So we have to do it on our own :( // The good news is that using socket2 it can be easily done, without losing portablity. // See https://github.com/tokio-rs/tokio/issues/3082 -pub fn set_tcp_keepalive(conn: &TcpStream) -> Result<()> { +pub fn try_set_tcp_keepalive(conn: &TcpStream) -> Result<()> { let s = SockRef::from(conn); let keepalive = TcpKeepalive::new().with_time(Duration::from_secs(60)); s.set_tcp_keepalive(&keepalive) .with_context(|| "Failed to set keepalive") } +pub fn set_tcp_keepalive(conn: &TcpStream) { + if let Err(e) = try_set_tcp_keepalive(conn) { + error!( + "Failed to set TCP keepalive. The connection maybe unstable: {:?}", + e + ); + } +} + #[allow(dead_code)] pub fn feature_not_compile(feature: &str) -> ! { panic!( diff --git a/src/transport/noise.rs b/src/transport/noise.rs index 30d5b0e..4825b44 100644 --- a/src/transport/noise.rs +++ b/src/transport/noise.rs @@ -9,7 +9,6 @@ use anyhow::{anyhow, Context, Result}; use async_trait::async_trait; use snowstorm::{Builder, NoiseParams, NoiseStream}; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; -use tracing::error; pub struct NoiseTransport { config: NoiseConfig, @@ -74,18 +73,16 @@ impl Transport for NoiseTransport { async fn accept(&self, a: &Self::Acceptor) -> Result<(Self::Stream, SocketAddr)> { let (conn, addr) = a.accept().await?; + set_tcp_keepalive(&conn); + let conn = NoiseStream::handshake(conn, self.builder().build_responder()?).await?; Ok((conn, addr)) } async fn connect(&self, addr: &str) -> Result { let conn = TcpStream::connect(addr).await?; - if let Err(e) = set_tcp_keepalive(&conn) { - error!( - "Failed to set TCP keepalive. The connection maybe unstable: {:?}", - e - ); - } + set_tcp_keepalive(&conn); + let conn = NoiseStream::handshake(conn, self.builder().build_initiator()?).await?; return Ok(conn); } diff --git a/src/transport/tcp.rs b/src/transport/tcp.rs index c44210e..81d098a 100644 --- a/src/transport/tcp.rs +++ b/src/transport/tcp.rs @@ -6,7 +6,6 @@ use anyhow::Result; use async_trait::async_trait; use std::net::SocketAddr; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; -use tracing::error; #[derive(Debug)] pub struct TcpTransport {} @@ -26,17 +25,13 @@ impl Transport for TcpTransport { async fn accept(&self, a: &Self::Acceptor) -> Result<(Self::Stream, SocketAddr)> { let (s, addr) = a.accept().await?; + set_tcp_keepalive(&s); Ok((s, addr)) } async fn connect(&self, addr: &str) -> Result { let s = TcpStream::connect(addr).await?; - if let Err(e) = set_tcp_keepalive(&s) { - error!( - "Failed to set TCP keepalive. The connection maybe unstable: {:?}", - e - ); - } + set_tcp_keepalive(&s); Ok(s) } } diff --git a/src/transport/tls.rs b/src/transport/tls.rs index 517122a..652b5ca 100644 --- a/src/transport/tls.rs +++ b/src/transport/tls.rs @@ -9,7 +9,6 @@ use tokio::fs; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; use tokio_native_tls::native_tls::{self, Certificate, Identity}; use tokio_native_tls::{TlsAcceptor, TlsConnector, TlsStream}; -use tracing::error; #[derive(Debug)] pub struct TlsTransport { @@ -66,6 +65,8 @@ impl Transport for TlsTransport { async fn accept(&self, a: &Self::Acceptor) -> Result<(Self::Stream, SocketAddr)> { let (conn, addr) = a.0.accept().await?; + set_tcp_keepalive(&conn); + let conn = a.1.accept(conn).await?; Ok((conn, addr)) @@ -73,12 +74,8 @@ impl Transport for TlsTransport { async fn connect(&self, addr: &str) -> Result { let conn = TcpStream::connect(&addr).await?; - if let Err(e) = set_tcp_keepalive(&conn) { - error!( - "Failed to set TCP keepalive. The connection maybe unstable: {:?}", - e - ); - } + set_tcp_keepalive(&conn); + let connector = self.connector.as_ref().unwrap(); Ok(connector .connect(