mirror of https://github.com/rapiz1/rathole.git
feat: UDP support
This commit is contained in:
parent
65c75da633
commit
443f763800
|
@ -88,6 +88,9 @@ name = "bytes"
|
|||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c4872d67bab6358e59559027aa3b9157c53d9358c51423c17554809a8858e0f8"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
|
|
|
@ -24,7 +24,7 @@ opt-level = "s"
|
|||
|
||||
[dependencies]
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
bytes = { version = "1"}
|
||||
bytes = { version = "1", features = ["serde"] }
|
||||
clap = { version = "3.0.0-rc.7", features = ["derive"] }
|
||||
toml = "0.5"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
[client]
|
||||
remote_addr = "localhost:2333"
|
||||
default_token = "123"
|
||||
|
||||
[client.services.foo1]
|
||||
type = "udp"
|
||||
local_addr = "127.0.0.1:80"
|
|
@ -0,0 +1,7 @@
|
|||
[server]
|
||||
bind_addr = "0.0.0.0:2333"
|
||||
default_token = "123"
|
||||
|
||||
[server.services.foo1]
|
||||
type = "udp"
|
||||
bind_addr = "0.0.0.0:5202"
|
176
src/client.rs
176
src/client.rs
|
@ -1,23 +1,28 @@
|
|||
use crate::config::{ClientConfig, ClientServiceConfig, Config, TransportType};
|
||||
use crate::helper::udp_connect;
|
||||
use crate::protocol::Hello::{self, *};
|
||||
use crate::protocol::{
|
||||
self, read_ack, read_control_cmd, read_data_cmd, read_hello, Ack, Auth, ControlChannelCmd,
|
||||
DataChannelCmd, CURRENT_PROTO_VRESION, HASH_WIDTH_IN_BYTES,
|
||||
DataChannelCmd, UdpTraffic, CURRENT_PROTO_VRESION, HASH_WIDTH_IN_BYTES,
|
||||
};
|
||||
use crate::transport::{TcpTransport, Transport};
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use backoff::ExponentialBackoff;
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{copy_bidirectional, AsyncWriteExt};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::{broadcast, oneshot};
|
||||
use tokio::io::{self, copy_bidirectional, AsyncWriteExt};
|
||||
use tokio::net::{TcpStream, UdpSocket};
|
||||
use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
|
||||
use tokio::time::{self, Duration};
|
||||
use tracing::{debug, error, info, instrument, Instrument, Span};
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
use crate::transport::TlsTransport;
|
||||
|
||||
use crate::constants::{UDP_BUFFER_SIZE, UDP_SENDQ_SIZE, UDP_TIMEOUT};
|
||||
|
||||
// The entrypoint of running a client
|
||||
pub async fn run_client(config: &Config, shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
|
||||
let config = match &config.client {
|
||||
|
@ -112,7 +117,9 @@ struct RunDataChannelArgs<T: Transport> {
|
|||
connector: Arc<T>,
|
||||
}
|
||||
|
||||
async fn run_data_channel<T: Transport>(args: Arc<RunDataChannelArgs<T>>) -> Result<()> {
|
||||
async fn do_data_channel_handshake<T: Transport>(
|
||||
args: Arc<RunDataChannelArgs<T>>,
|
||||
) -> Result<T::Stream> {
|
||||
// Retry at least every 100ms, at most for 10 seconds
|
||||
let backoff = ExponentialBackoff {
|
||||
max_interval: Duration::from_millis(100),
|
||||
|
@ -135,18 +142,165 @@ async fn run_data_channel<T: Transport>(args: Arc<RunDataChannelArgs<T>>) -> Res
|
|||
let hello = Hello::DataChannelHello(CURRENT_PROTO_VRESION, v.to_owned());
|
||||
conn.write_all(&bincode::serialize(&hello).unwrap()).await?;
|
||||
|
||||
Ok(conn)
|
||||
}
|
||||
|
||||
async fn run_data_channel<T: Transport>(args: Arc<RunDataChannelArgs<T>>) -> Result<()> {
|
||||
// Do the handshake
|
||||
let mut conn = do_data_channel_handshake(args.clone()).await?;
|
||||
|
||||
// Forward
|
||||
match read_data_cmd(&mut conn).await? {
|
||||
DataChannelCmd::StartForward => {
|
||||
let mut local = TcpStream::connect(&args.local_addr)
|
||||
.await
|
||||
.with_context(|| "Failed to conenct to local_addr")?;
|
||||
let _ = copy_bidirectional(&mut conn, &mut local).await;
|
||||
DataChannelCmd::StartForwardTcp => {
|
||||
run_data_channel_for_tcp::<T>(conn, &args.local_addr).await?;
|
||||
}
|
||||
DataChannelCmd::StartForwardUdp => {
|
||||
run_data_channel_for_udp::<T>(conn, &args.local_addr).await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Simply copying back and forth for TCP
|
||||
#[instrument(skip(conn))]
|
||||
async fn run_data_channel_for_tcp<T: Transport>(
|
||||
mut conn: T::Stream,
|
||||
local_addr: &str,
|
||||
) -> Result<()> {
|
||||
debug!("New data channel starts forwarding");
|
||||
|
||||
let mut local = TcpStream::connect(local_addr)
|
||||
.await
|
||||
.with_context(|| "Failed to conenct to local_addr")?;
|
||||
let _ = copy_bidirectional(&mut conn, &mut local).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Things get a little tricker when it gets to UDP because it's connectionless.
|
||||
// A UdpPortMap must be maintained for recent seen incoming address, giving them
|
||||
// each a local port, which is associated with a socket. So just the sender
|
||||
// to the socket will work fine for the map's value.
|
||||
type UdpPortMap = Arc<RwLock<HashMap<SocketAddr, mpsc::Sender<Bytes>>>>;
|
||||
|
||||
#[instrument(skip(conn))]
|
||||
async fn run_data_channel_for_udp<T: Transport>(conn: T::Stream, local_addr: &str) -> Result<()> {
|
||||
debug!("New data channel starts forwarding");
|
||||
|
||||
let port_map: UdpPortMap = Arc::new(RwLock::new(HashMap::new()));
|
||||
|
||||
// The channel stores UdpTraffic that needs to be sent to the server
|
||||
let (outbound_tx, mut outbound_rx) = mpsc::channel::<UdpTraffic>(UDP_SENDQ_SIZE);
|
||||
|
||||
// FIXME: https://github.com/tokio-rs/tls/issues/40
|
||||
// Maybe this is our concern
|
||||
let (mut rd, mut wr) = io::split(conn);
|
||||
|
||||
// Keep sending items from the outbound channel to the server
|
||||
tokio::spawn(async move {
|
||||
while let Some(t) = outbound_rx.recv().await {
|
||||
debug!("outbound {:?}", t);
|
||||
if t.write(&mut wr).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
loop {
|
||||
// Read a packet from the server
|
||||
let packet = UdpTraffic::read(&mut rd).await?;
|
||||
let m = port_map.read().await;
|
||||
|
||||
if m.get(&packet.from).is_none() {
|
||||
// This packet is from a address we don't see for a while,
|
||||
// which is not in the UdpPortMap.
|
||||
// So set up a mapping (and a forwarder) for it
|
||||
|
||||
// Drop the reader lock
|
||||
drop(m);
|
||||
|
||||
// Grab the writer lock
|
||||
// This is the only thread that will try to grab the writer lock
|
||||
// So no need to worry about some other thread has already set up
|
||||
// the mapping between the gap of dropping the reader lock and
|
||||
// grabbing the writer lock
|
||||
let mut m = port_map.write().await;
|
||||
|
||||
match udp_connect(local_addr).await {
|
||||
Ok(s) => {
|
||||
let (inbound_tx, inbound_rx) = mpsc::channel(UDP_SENDQ_SIZE);
|
||||
m.insert(packet.from, inbound_tx);
|
||||
tokio::spawn(run_udp_forwarder(
|
||||
s,
|
||||
inbound_rx,
|
||||
outbound_tx.clone(),
|
||||
packet.from,
|
||||
port_map.clone(),
|
||||
));
|
||||
}
|
||||
Err(e) => {
|
||||
error!("{:?}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now there should be a udp forwarder that can receive the packet
|
||||
let m = port_map.read().await;
|
||||
if let Some(tx) = m.get(&packet.from) {
|
||||
let _ = tx.send(packet.data).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Run a UdpSocket for the visitor `from`
|
||||
async fn run_udp_forwarder(
|
||||
s: UdpSocket,
|
||||
mut inbound_rx: mpsc::Receiver<Bytes>,
|
||||
outbount_tx: mpsc::Sender<UdpTraffic>,
|
||||
from: SocketAddr,
|
||||
port_map: UdpPortMap,
|
||||
) -> Result<()> {
|
||||
let mut buf = BytesMut::new();
|
||||
buf.resize(UDP_BUFFER_SIZE, 0);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Receive from the server
|
||||
data = inbound_rx.recv() => {
|
||||
if let Some(data) = data {
|
||||
s.send(&data).await?;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
},
|
||||
|
||||
// Receive from the service
|
||||
val = s.recv(&mut buf) => {
|
||||
let len = match val {
|
||||
Ok(v) => v,
|
||||
Err(_) => {break;}
|
||||
};
|
||||
|
||||
let t = UdpTraffic{
|
||||
from,
|
||||
data: Bytes::copy_from_slice(&buf[..len])
|
||||
};
|
||||
|
||||
outbount_tx.send(t).await?;
|
||||
},
|
||||
|
||||
// No traffic for the duration of UDP_TIMEOUT, clean up the state
|
||||
_ = time::sleep(Duration::from_secs(UDP_TIMEOUT)) => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut port_map = port_map.write().await;
|
||||
port_map.remove(&from);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Control channel, using T as the transport layer
|
||||
struct ControlChannel<T: Transport> {
|
||||
digest: ServiceDigest, // SHA256 of the service name
|
||||
|
@ -163,7 +317,7 @@ struct ControlChannelHandle {
|
|||
}
|
||||
|
||||
impl<T: 'static + Transport> ControlChannel<T> {
|
||||
#[instrument(skip(self), fields(service=%self.service.name))]
|
||||
#[instrument(skip_all)]
|
||||
async fn run(&mut self) -> Result<()> {
|
||||
let mut conn = self
|
||||
.transport
|
||||
|
|
|
@ -20,14 +20,30 @@ impl Default for TransportType {
|
|||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct ClientServiceConfig {
|
||||
#[serde(rename = "type", default = "default_service_type")]
|
||||
pub service_type: ServiceType,
|
||||
#[serde(skip)]
|
||||
pub name: String,
|
||||
pub local_addr: String,
|
||||
pub token: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
|
||||
pub enum ServiceType {
|
||||
#[serde(rename = "tcp")]
|
||||
Tcp,
|
||||
#[serde(rename = "udp")]
|
||||
Udp,
|
||||
}
|
||||
|
||||
fn default_service_type() -> ServiceType {
|
||||
ServiceType::Tcp
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct ServerServiceConfig {
|
||||
#[serde(rename = "type", default = "default_service_type")]
|
||||
pub service_type: ServiceType,
|
||||
#[serde(skip)]
|
||||
pub name: String,
|
||||
pub bind_addr: String,
|
||||
|
@ -231,6 +247,7 @@ mod tests {
|
|||
cfg.services.insert(
|
||||
"foo1".into(),
|
||||
ServerServiceConfig {
|
||||
service_type: ServiceType::Tcp,
|
||||
name: "foo1".into(),
|
||||
bind_addr: "127.0.0.1:80".into(),
|
||||
token: None,
|
||||
|
@ -277,6 +294,7 @@ mod tests {
|
|||
cfg.services.insert(
|
||||
"foo1".into(),
|
||||
ClientServiceConfig {
|
||||
service_type: ServiceType::Tcp,
|
||||
name: "foo1".into(),
|
||||
local_addr: "127.0.0.1:80".into(),
|
||||
token: None,
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
use backoff::ExponentialBackoff;
|
||||
use std::time::Duration;
|
||||
|
||||
// FIXME: Determine reasonable size
|
||||
pub const UDP_BUFFER_SIZE: usize = 2048;
|
||||
pub const UDP_SENDQ_SIZE: usize = 1024;
|
||||
pub const UDP_TIMEOUT: u64 = 60;
|
||||
|
||||
pub fn listen_backoff() -> ExponentialBackoff {
|
||||
ExponentialBackoff {
|
||||
max_elapsed_time: None,
|
||||
max_interval: Duration::from_secs(1),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
|
@ -1,8 +1,13 @@
|
|||
use std::time::Duration;
|
||||
use std::{
|
||||
collections::hash_map::DefaultHasher,
|
||||
hash::{Hash, Hasher},
|
||||
net::SocketAddr,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use socket2::{SockRef, TcpKeepalive};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::net::{TcpStream, ToSocketAddrs, UdpSocket};
|
||||
|
||||
// Tokio hesitates to expose this option...So we have to do it on our own :(
|
||||
// The good news is that using socket2 it can be easily done, without losing portablity.
|
||||
|
@ -21,3 +26,78 @@ pub fn feature_not_compile(feature: &str) -> ! {
|
|||
feature
|
||||
)
|
||||
}
|
||||
|
||||
/// Create a UDP socket and connect to `addr`
|
||||
pub async fn udp_connect<A: ToSocketAddrs>(addr: A) -> Result<UdpSocket> {
|
||||
// FIXME: This only works for IPv4
|
||||
let s = UdpSocket::bind("0.0.0.0:0").await?;
|
||||
s.connect(addr).await?;
|
||||
Ok(s)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn hash_socket_addr(a: &SocketAddr) -> u64 {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
a.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
|
||||
// Wait for the stablization of https://doc.rust-lang.org/std/primitive.i64.html#method.log2
|
||||
#[allow(dead_code)]
|
||||
fn log2_floor(x: usize) -> u8 {
|
||||
(x as f64).log2().floor() as u8
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn floor_to_pow_of_2(x: usize) -> usize {
|
||||
if x == 1 {
|
||||
return 1;
|
||||
}
|
||||
let w = log2_floor(x);
|
||||
1 << w
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::helper::{floor_to_pow_of_2, log2_floor};
|
||||
|
||||
#[test]
|
||||
fn test_log2_floor() {
|
||||
let t = [
|
||||
(2, 1),
|
||||
(3, 1),
|
||||
(4, 2),
|
||||
(8, 3),
|
||||
(9, 3),
|
||||
(15, 3),
|
||||
(16, 4),
|
||||
(1023, 9),
|
||||
(1024, 10),
|
||||
(2000, 10),
|
||||
(2048, 11),
|
||||
];
|
||||
for t in t {
|
||||
assert_eq!(log2_floor(t.0), t.1);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_floor_to_pow_of_2() {
|
||||
let t = [
|
||||
(1 as usize, 1 as usize),
|
||||
(2, 2),
|
||||
(3, 2),
|
||||
(4, 4),
|
||||
(5, 4),
|
||||
(15, 8),
|
||||
(31, 16),
|
||||
(33, 32),
|
||||
(1000, 512),
|
||||
(1500, 1024),
|
||||
(2300, 2048),
|
||||
];
|
||||
for t in t {
|
||||
assert_eq!(floor_to_pow_of_2(t.0), t.1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
mod cli;
|
||||
mod config;
|
||||
mod constants;
|
||||
mod helper;
|
||||
mod multi_map;
|
||||
mod protocol;
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
pub const HASH_WIDTH_IN_BYTES: usize = 32;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use lazy_static::lazy_static;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
|
||||
use std::net::SocketAddr;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
type ProtocolVersion = u8;
|
||||
const PROTO_V0: u8 = 0u8;
|
||||
|
@ -49,7 +51,80 @@ pub enum ControlChannelCmd {
|
|||
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
pub enum DataChannelCmd {
|
||||
StartForward,
|
||||
StartForwardTcp,
|
||||
StartForwardUdp,
|
||||
}
|
||||
|
||||
type UdpPacketLen = u16; // `u16` should be enough for any practical UDP traffic on the Internet
|
||||
#[derive(Deserialize, Serialize, Debug)]
|
||||
struct UdpHeader {
|
||||
from: SocketAddr,
|
||||
len: UdpPacketLen,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct UdpTraffic {
|
||||
pub from: SocketAddr,
|
||||
pub data: Bytes,
|
||||
}
|
||||
|
||||
impl UdpTraffic {
|
||||
pub async fn write<T: AsyncWrite + Unpin>(&self, writer: &mut T) -> Result<()> {
|
||||
let v = bincode::serialize(&UdpHeader {
|
||||
from: self.from,
|
||||
len: self.data.len() as UdpPacketLen,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
writer.write_u16(v.len() as u16).await?;
|
||||
writer.write_all(&v).await?;
|
||||
|
||||
writer.write_all(&self.data).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub async fn write_slice<T: AsyncWrite + Unpin>(
|
||||
writer: &mut T,
|
||||
from: SocketAddr,
|
||||
data: &[u8],
|
||||
) -> Result<()> {
|
||||
let v = bincode::serialize(&UdpHeader {
|
||||
from,
|
||||
len: data.len() as UdpPacketLen,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
writer.write_u16(v.len() as u16).await?;
|
||||
writer.write_all(&v).await?;
|
||||
|
||||
writer.write_all(data).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn read<T: AsyncRead + Unpin>(reader: &mut T) -> Result<UdpTraffic> {
|
||||
let len = reader.read_u16().await? as usize;
|
||||
|
||||
let mut buf = Vec::new();
|
||||
buf.resize(len, 0);
|
||||
reader
|
||||
.read_exact(&mut buf)
|
||||
.await
|
||||
.with_context(|| "Failed to read udp header")?;
|
||||
let header: UdpHeader =
|
||||
bincode::deserialize(&buf).with_context(|| "Failed to deserialize udp header")?;
|
||||
|
||||
let mut data = BytesMut::new();
|
||||
data.resize(header.len as usize, 0);
|
||||
reader.read_exact(&mut data).await?;
|
||||
|
||||
Ok(UdpTraffic {
|
||||
from: header.from,
|
||||
data: data.freeze(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn digest(data: &[u8]) -> Digest {
|
||||
|
@ -74,7 +149,7 @@ impl PacketLength {
|
|||
.unwrap() as usize;
|
||||
let c_cmd =
|
||||
bincode::serialized_size(&ControlChannelCmd::CreateDataChannel).unwrap() as usize;
|
||||
let d_cmd = bincode::serialized_size(&DataChannelCmd::StartForward).unwrap() as usize;
|
||||
let d_cmd = bincode::serialized_size(&DataChannelCmd::StartForwardTcp).unwrap() as usize;
|
||||
let ack = Ack::Ok;
|
||||
let ack = bincode::serialized_size(&ack).unwrap() as usize;
|
||||
|
||||
|
|
308
src/server.rs
308
src/server.rs
|
@ -1,8 +1,10 @@
|
|||
use crate::config::{Config, ServerConfig, ServerServiceConfig, TransportType};
|
||||
use crate::config::{Config, ServerConfig, ServerServiceConfig, ServiceType, TransportType};
|
||||
use crate::constants::{listen_backoff, UDP_BUFFER_SIZE};
|
||||
use crate::multi_map::MultiMap;
|
||||
use crate::protocol::Hello::{ControlChannelHello, DataChannelHello};
|
||||
use crate::protocol::{
|
||||
self, read_auth, read_hello, Ack, ControlChannelCmd, DataChannelCmd, Hello, HASH_WIDTH_IN_BYTES,
|
||||
self, read_auth, read_hello, Ack, ControlChannelCmd, DataChannelCmd, Hello, UdpTraffic,
|
||||
HASH_WIDTH_IN_BYTES,
|
||||
};
|
||||
#[cfg(feature = "tls")]
|
||||
use crate::transport::TlsTransport;
|
||||
|
@ -10,21 +12,23 @@ use crate::transport::{TcpTransport, Transport};
|
|||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use backoff::backoff::Backoff;
|
||||
use backoff::ExponentialBackoff;
|
||||
|
||||
use rand::RngCore;
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::io::{self, copy_bidirectional, AsyncWriteExt};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
|
||||
use tokio::net::{TcpListener, TcpStream, UdpSocket};
|
||||
use tokio::sync::{broadcast, mpsc, RwLock};
|
||||
use tokio::time;
|
||||
use tracing::{debug, error, info, info_span, warn, Instrument};
|
||||
use tracing::{debug, error, info, info_span, instrument, warn, Instrument, Span};
|
||||
|
||||
type ServiceDigest = protocol::Digest; // SHA256 of a service name
|
||||
type Nonce = protocol::Digest; // Also called `session_key`
|
||||
|
||||
const POOL_SIZE: usize = 64; // The number of cached connections
|
||||
const TCP_POOL_SIZE: usize = 64; // The number of cached connections for TCP servies
|
||||
const UDP_POOL_SIZE: usize = 2; // The number of cached connections for UDP services
|
||||
const CHAN_SIZE: usize = 2048; // The capacity of various chans
|
||||
|
||||
// The entrypoint of running a server
|
||||
|
@ -268,7 +272,7 @@ async fn do_control_channel_handshake<T: 'static + Transport>(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn do_data_channel_handshake<T: Transport>(
|
||||
async fn do_data_channel_handshake<T: 'static + Transport>(
|
||||
conn: T::Stream,
|
||||
control_channels: Arc<RwLock<ControlChannelMap<T>>>,
|
||||
nonce: Nonce,
|
||||
|
@ -276,9 +280,9 @@ async fn do_data_channel_handshake<T: Transport>(
|
|||
// Validate
|
||||
let control_channels_guard = control_channels.read().await;
|
||||
match control_channels_guard.get2(&nonce) {
|
||||
Some(c_ch) => {
|
||||
Some(handle) => {
|
||||
// Send the data channel to the corresponding control channel
|
||||
c_ch.conn_pool.data_ch_tx.send(conn).await?;
|
||||
handle.data_ch_tx.send(conn).await?;
|
||||
}
|
||||
None => {
|
||||
// TODO: Maybe print IP here
|
||||
|
@ -288,43 +292,74 @@ async fn do_data_channel_handshake<T: Transport>(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
// Control channel, using T as the transport layer
|
||||
struct ControlChannel<T: Transport> {
|
||||
conn: T::Stream, // The connection of control channel
|
||||
service: ServerServiceConfig, // A copy of the corresponding service config
|
||||
shutdown_rx: oneshot::Receiver<bool>, // Receives the shutdown signal
|
||||
visitor_tx: mpsc::Sender<TcpStream>, // Receives visitor connections
|
||||
}
|
||||
|
||||
// The handle of a control channel, along with the handle of a connection pool
|
||||
// Dropping it will drop the actual control channel, because `visitor_tx`
|
||||
// and `shutdown_tx` are closed
|
||||
struct ControlChannelHandle<T: Transport> {
|
||||
pub struct ControlChannelHandle<T: Transport> {
|
||||
// Shutdown the control channel.
|
||||
// Not used for now, but can be used for hot reloading
|
||||
_shutdown_tx: oneshot::Sender<bool>,
|
||||
conn_pool: ConnectionPoolHandle<T>,
|
||||
#[allow(dead_code)]
|
||||
shutdown_tx: broadcast::Sender<bool>,
|
||||
//data_ch_req_tx: mpsc::Sender<bool>,
|
||||
data_ch_tx: mpsc::Sender<T::Stream>,
|
||||
}
|
||||
|
||||
impl<T: 'static + Transport> ControlChannelHandle<T> {
|
||||
impl<T> ControlChannelHandle<T>
|
||||
where
|
||||
T: 'static + Transport,
|
||||
{
|
||||
// Create a control channel handle, where the control channel handling task
|
||||
// and the connection pool task are created.
|
||||
#[instrument(skip_all, fields(service = %service.name))]
|
||||
fn new(conn: T::Stream, service: ServerServiceConfig) -> ControlChannelHandle<T> {
|
||||
// Save the name string for logging
|
||||
let name = service.name.clone();
|
||||
|
||||
// Create a shutdown channel. The sender is not used for now, but for future use
|
||||
let (_shutdown_tx, shutdown_rx) = oneshot::channel::<bool>();
|
||||
let (shutdown_tx, shutdown_rx) = broadcast::channel::<bool>(1);
|
||||
|
||||
// Create and run the connection pool, where the visitors and data channels meet
|
||||
let conn_pool = ConnectionPoolHandle::new();
|
||||
// Store data channels
|
||||
let (data_ch_tx, data_ch_rx) = mpsc::channel(CHAN_SIZE * 2);
|
||||
|
||||
// Store data channel creation requests
|
||||
let (data_ch_req_tx, data_ch_req_rx) = mpsc::unbounded_channel();
|
||||
|
||||
match service.service_type {
|
||||
ServiceType::Tcp => tokio::spawn(
|
||||
run_tcp_connection_pool::<T>(
|
||||
service.bind_addr.clone(),
|
||||
data_ch_rx,
|
||||
data_ch_req_tx.clone(),
|
||||
shutdown_tx.subscribe(),
|
||||
)
|
||||
.instrument(Span::current()),
|
||||
),
|
||||
ServiceType::Udp => tokio::spawn(
|
||||
run_udp_connection_pool::<T>(
|
||||
service.bind_addr.clone(),
|
||||
data_ch_rx,
|
||||
data_ch_req_tx.clone(),
|
||||
shutdown_tx.subscribe(),
|
||||
)
|
||||
.instrument(Span::current()),
|
||||
),
|
||||
};
|
||||
|
||||
// Cache some data channels for later use
|
||||
let pool_size = match service.service_type {
|
||||
ServiceType::Tcp => TCP_POOL_SIZE,
|
||||
ServiceType::Udp => UDP_POOL_SIZE,
|
||||
};
|
||||
|
||||
for _i in 0..pool_size {
|
||||
if let Err(e) = data_ch_req_tx.send(true) {
|
||||
error!("Failed to request data channel {}", e);
|
||||
};
|
||||
}
|
||||
|
||||
// Create the control channel
|
||||
let ch: ControlChannel<T> = ControlChannel {
|
||||
let ch = ControlChannel::<T> {
|
||||
conn,
|
||||
shutdown_rx,
|
||||
service,
|
||||
visitor_tx: conn_pool.visitor_tx.clone(),
|
||||
data_ch_req_rx,
|
||||
};
|
||||
|
||||
// Run the control channel
|
||||
|
@ -335,53 +370,84 @@ impl<T: 'static + Transport> ControlChannelHandle<T> {
|
|||
});
|
||||
|
||||
ControlChannelHandle {
|
||||
_shutdown_tx,
|
||||
conn_pool,
|
||||
shutdown_tx,
|
||||
data_ch_tx,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn shutdown(self) {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
}
|
||||
}
|
||||
|
||||
// Control channel, using T as the transport layer. P is TcpStream or UdpTraffic
|
||||
struct ControlChannel<T: Transport> {
|
||||
conn: T::Stream, // The connection of control channel
|
||||
service: ServerServiceConfig, // A copy of the corresponding service config
|
||||
shutdown_rx: broadcast::Receiver<bool>, // Receives the shutdown signal
|
||||
data_ch_req_rx: mpsc::UnboundedReceiver<bool>, // Receives visitor connections
|
||||
}
|
||||
|
||||
impl<T: Transport> ControlChannel<T> {
|
||||
// Run a control channel
|
||||
#[tracing::instrument(skip(self), fields(service = %self.service.name))]
|
||||
#[instrument(skip(self), fields(service = %self.service.name))]
|
||||
async fn run(mut self) -> Result<()> {
|
||||
// Where the service is exposed
|
||||
let l = match TcpListener::bind(&self.service.bind_addr).await {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
let duration = Duration::from_secs(1);
|
||||
error!(
|
||||
"Failed to listen on service.bind_addr: {}. Retry in {:?}...",
|
||||
e, duration
|
||||
);
|
||||
time::sleep(duration).await;
|
||||
TcpListener::bind(&self.service.bind_addr).await?
|
||||
}
|
||||
};
|
||||
let cmd = bincode::serialize(&ControlChannelCmd::CreateDataChannel).unwrap();
|
||||
|
||||
info!("Listening at {}", &self.service.bind_addr);
|
||||
|
||||
// Each `u8` in the chan indicates a data channel creation request
|
||||
let (data_req_tx, mut data_req_rx) = mpsc::unbounded_channel::<u8>();
|
||||
|
||||
// The control channel is moved into the task, and sends CreateDataChannel
|
||||
// comamnds to the client when needed
|
||||
tokio::spawn(async move {
|
||||
let cmd = bincode::serialize(&ControlChannelCmd::CreateDataChannel).unwrap();
|
||||
while data_req_rx.recv().await.is_some() {
|
||||
if self.conn.write_all(&cmd).await.is_err() {
|
||||
// Wait for data channel requests and the shutdown signal
|
||||
loop {
|
||||
tokio::select! {
|
||||
val = self.data_ch_req_rx.recv() => {
|
||||
match val {
|
||||
Some(_) => {
|
||||
if let Err(e) = self.conn.write_all(&cmd).await.with_context(||"Failed to write data cmds") {
|
||||
error!("{:?}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
None => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
},
|
||||
// Wait for the shutdown signal
|
||||
_ = self.shutdown_rx.recv() => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Cache some data channels for later use
|
||||
for _i in 0..POOL_SIZE {
|
||||
if let Err(e) = data_req_tx.send(0) {
|
||||
error!("Failed to request data channel {}", e);
|
||||
};
|
||||
}
|
||||
|
||||
info!("Control channel shuting down");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn tcp_listen_and_send(
|
||||
addr: String,
|
||||
data_ch_req_tx: mpsc::UnboundedSender<bool>,
|
||||
mut shutdown_rx: broadcast::Receiver<bool>,
|
||||
) -> mpsc::Receiver<TcpStream> {
|
||||
let (tx, rx) = mpsc::channel(CHAN_SIZE);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let l = backoff::future::retry(listen_backoff(), || async {
|
||||
Ok(TcpListener::bind(&addr).await?)
|
||||
})
|
||||
.await
|
||||
.with_context(|| "Failed to listen for the service");
|
||||
|
||||
let l: TcpListener = match l {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
error!("{:?}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
info!("Listening at {}", &addr);
|
||||
|
||||
// Retry at least every 1s
|
||||
let mut backoff = ExponentialBackoff {
|
||||
max_interval: Duration::from_secs(1),
|
||||
|
@ -392,7 +458,6 @@ impl<T: Transport> ControlChannel<T> {
|
|||
// Wait for visitors and the shutdown signal
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Wait for visitors
|
||||
val = l.accept() => {
|
||||
match val {
|
||||
Err(e) => {
|
||||
|
@ -406,80 +471,103 @@ impl<T: Transport> ControlChannel<T> {
|
|||
error!("Too many retries. Aborting...");
|
||||
break;
|
||||
}
|
||||
},
|
||||
}
|
||||
Ok((incoming, addr)) => {
|
||||
// For every visitor, request to create a data channel
|
||||
if let Err(e) = data_req_tx.send(0) {
|
||||
if let Err(e) = data_ch_req_tx.send(true) {
|
||||
// An error indicates the control channel is broken
|
||||
// So break the loop
|
||||
error!("{}", e);
|
||||
break;
|
||||
};
|
||||
}
|
||||
|
||||
backoff.reset();
|
||||
|
||||
debug!("New visitor from {}", addr);
|
||||
|
||||
// Send the visitor to the connection pool
|
||||
let _ = self.visitor_tx.send(incoming).await;
|
||||
let _ = tx.send(incoming).await;
|
||||
}
|
||||
}
|
||||
},
|
||||
// Wait for the shutdown signal
|
||||
_ = &mut self.shutdown_rx => {
|
||||
_ = shutdown_rx.recv() => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
info!("Service shuting down");
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
rx
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ConnectionPool<T: Transport> {
|
||||
visitor_rx: mpsc::Receiver<TcpStream>,
|
||||
data_ch_rx: mpsc::Receiver<T::Stream>,
|
||||
}
|
||||
|
||||
struct ConnectionPoolHandle<T: Transport> {
|
||||
visitor_tx: mpsc::Sender<TcpStream>,
|
||||
data_ch_tx: mpsc::Sender<T::Stream>,
|
||||
}
|
||||
|
||||
impl<T: 'static + Transport> ConnectionPoolHandle<T> {
|
||||
fn new() -> ConnectionPoolHandle<T> {
|
||||
let (data_ch_tx, data_ch_rx) = mpsc::channel(CHAN_SIZE * 2);
|
||||
let (visitor_tx, visitor_rx) = mpsc::channel(CHAN_SIZE);
|
||||
let conn_pool: ConnectionPool<T> = ConnectionPool {
|
||||
data_ch_rx,
|
||||
visitor_rx,
|
||||
};
|
||||
|
||||
tokio::spawn(async move { conn_pool.run().await });
|
||||
|
||||
ConnectionPoolHandle {
|
||||
data_ch_tx,
|
||||
visitor_tx,
|
||||
#[instrument(skip_all)]
|
||||
async fn run_tcp_connection_pool<T: Transport>(
|
||||
bind_addr: String,
|
||||
mut data_ch_rx: mpsc::Receiver<T::Stream>,
|
||||
data_ch_req_tx: mpsc::UnboundedSender<bool>,
|
||||
shutdown_rx: broadcast::Receiver<bool>,
|
||||
) -> Result<()> {
|
||||
let mut visitor_rx = tcp_listen_and_send(bind_addr, data_ch_req_tx, shutdown_rx);
|
||||
while let Some(mut visitor) = visitor_rx.recv().await {
|
||||
if let Some(mut ch) = data_ch_rx.recv().await {
|
||||
tokio::spawn(async move {
|
||||
let cmd = bincode::serialize(&DataChannelCmd::StartForwardTcp).unwrap();
|
||||
if ch.write_all(&cmd).await.is_ok() {
|
||||
let _ = copy_bidirectional(&mut ch, &mut visitor).await;
|
||||
}
|
||||
});
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl<T: Transport> ConnectionPool<T> {
|
||||
#[tracing::instrument]
|
||||
async fn run(mut self) {
|
||||
while let Some(mut visitor) = self.visitor_rx.recv().await {
|
||||
if let Some(mut ch) = self.data_ch_rx.recv().await {
|
||||
tokio::spawn(async move {
|
||||
let cmd = bincode::serialize(&DataChannelCmd::StartForward).unwrap();
|
||||
if ch.write_all(&cmd).await.is_ok() {
|
||||
let _ = copy_bidirectional(&mut ch, &mut visitor).await;
|
||||
}
|
||||
});
|
||||
} else {
|
||||
#[instrument(skip_all)]
|
||||
async fn run_udp_connection_pool<T: Transport>(
|
||||
bind_addr: String,
|
||||
mut data_ch_rx: mpsc::Receiver<T::Stream>,
|
||||
_data_ch_req_tx: mpsc::UnboundedSender<bool>,
|
||||
mut shutdown_rx: broadcast::Receiver<bool>,
|
||||
) -> Result<()> {
|
||||
// TODO: Load balance
|
||||
|
||||
let l: UdpSocket = backoff::future::retry(listen_backoff(), || async {
|
||||
Ok(UdpSocket::bind(&bind_addr).await?)
|
||||
})
|
||||
.await
|
||||
.with_context(|| "Failed to listen for the service")?;
|
||||
|
||||
info!("Listening at {}", &bind_addr);
|
||||
|
||||
let cmd = bincode::serialize(&DataChannelCmd::StartForwardUdp).unwrap();
|
||||
|
||||
let mut conn = data_ch_rx
|
||||
.recv()
|
||||
.await
|
||||
.ok_or(anyhow!("No available data channels"))?;
|
||||
conn.write_all(&cmd).await?;
|
||||
|
||||
let mut buf = [0u8; UDP_BUFFER_SIZE];
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Forward inbound traffic to the client
|
||||
val = l.recv_from(&mut buf) => {
|
||||
let (n, from) = val?;
|
||||
UdpTraffic::write_slice(&mut conn, from, &buf[..n]).await?;
|
||||
},
|
||||
|
||||
// Forward outbound traffic from the client to the visitor
|
||||
t = UdpTraffic::read(&mut conn) => {
|
||||
let t = t?;
|
||||
l.send_to(&t.data, t.from).await?;
|
||||
},
|
||||
|
||||
_ = shutdown_rx.recv() => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue