feat: UDP support

This commit is contained in:
Yujia Qiao 2021-12-21 21:11:46 +08:00 committed by Yujia Qiao
parent 65c75da633
commit 443f763800
11 changed files with 575 additions and 127 deletions

3
Cargo.lock generated
View File

@ -88,6 +88,9 @@ name = "bytes"
version = "1.1.0" version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4872d67bab6358e59559027aa3b9157c53d9358c51423c17554809a8858e0f8" checksum = "c4872d67bab6358e59559027aa3b9157c53d9358c51423c17554809a8858e0f8"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "cc" name = "cc"

View File

@ -24,7 +24,7 @@ opt-level = "s"
[dependencies] [dependencies]
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }
bytes = { version = "1"} bytes = { version = "1", features = ["serde"] }
clap = { version = "3.0.0-rc.7", features = ["derive"] } clap = { version = "3.0.0-rc.7", features = ["derive"] }
toml = "0.5" toml = "0.5"
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }

7
examples/udp/client.toml Normal file
View File

@ -0,0 +1,7 @@
[client]
remote_addr = "localhost:2333"
default_token = "123"
[client.services.foo1]
type = "udp"
local_addr = "127.0.0.1:80"

7
examples/udp/server.toml Normal file
View File

@ -0,0 +1,7 @@
[server]
bind_addr = "0.0.0.0:2333"
default_token = "123"
[server.services.foo1]
type = "udp"
bind_addr = "0.0.0.0:5202"

View File

@ -1,23 +1,28 @@
use crate::config::{ClientConfig, ClientServiceConfig, Config, TransportType}; use crate::config::{ClientConfig, ClientServiceConfig, Config, TransportType};
use crate::helper::udp_connect;
use crate::protocol::Hello::{self, *}; use crate::protocol::Hello::{self, *};
use crate::protocol::{ use crate::protocol::{
self, read_ack, read_control_cmd, read_data_cmd, read_hello, Ack, Auth, ControlChannelCmd, self, read_ack, read_control_cmd, read_data_cmd, read_hello, Ack, Auth, ControlChannelCmd,
DataChannelCmd, CURRENT_PROTO_VRESION, HASH_WIDTH_IN_BYTES, DataChannelCmd, UdpTraffic, CURRENT_PROTO_VRESION, HASH_WIDTH_IN_BYTES,
}; };
use crate::transport::{TcpTransport, Transport}; use crate::transport::{TcpTransport, Transport};
use anyhow::{anyhow, bail, Context, Result}; use anyhow::{anyhow, bail, Context, Result};
use backoff::ExponentialBackoff; use backoff::ExponentialBackoff;
use bytes::{Bytes, BytesMut};
use std::collections::HashMap; use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use tokio::io::{copy_bidirectional, AsyncWriteExt}; use tokio::io::{self, copy_bidirectional, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::{TcpStream, UdpSocket};
use tokio::sync::{broadcast, oneshot}; use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
use tokio::time::{self, Duration}; use tokio::time::{self, Duration};
use tracing::{debug, error, info, instrument, Instrument, Span}; use tracing::{debug, error, info, instrument, Instrument, Span};
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
use crate::transport::TlsTransport; use crate::transport::TlsTransport;
use crate::constants::{UDP_BUFFER_SIZE, UDP_SENDQ_SIZE, UDP_TIMEOUT};
// The entrypoint of running a client // The entrypoint of running a client
pub async fn run_client(config: &Config, shutdown_rx: broadcast::Receiver<bool>) -> Result<()> { pub async fn run_client(config: &Config, shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
let config = match &config.client { let config = match &config.client {
@ -112,7 +117,9 @@ struct RunDataChannelArgs<T: Transport> {
connector: Arc<T>, connector: Arc<T>,
} }
async fn run_data_channel<T: Transport>(args: Arc<RunDataChannelArgs<T>>) -> Result<()> { async fn do_data_channel_handshake<T: Transport>(
args: Arc<RunDataChannelArgs<T>>,
) -> Result<T::Stream> {
// Retry at least every 100ms, at most for 10 seconds // Retry at least every 100ms, at most for 10 seconds
let backoff = ExponentialBackoff { let backoff = ExponentialBackoff {
max_interval: Duration::from_millis(100), max_interval: Duration::from_millis(100),
@ -135,15 +142,162 @@ async fn run_data_channel<T: Transport>(args: Arc<RunDataChannelArgs<T>>) -> Res
let hello = Hello::DataChannelHello(CURRENT_PROTO_VRESION, v.to_owned()); let hello = Hello::DataChannelHello(CURRENT_PROTO_VRESION, v.to_owned());
conn.write_all(&bincode::serialize(&hello).unwrap()).await?; conn.write_all(&bincode::serialize(&hello).unwrap()).await?;
Ok(conn)
}
async fn run_data_channel<T: Transport>(args: Arc<RunDataChannelArgs<T>>) -> Result<()> {
// Do the handshake
let mut conn = do_data_channel_handshake(args.clone()).await?;
// Forward // Forward
match read_data_cmd(&mut conn).await? { match read_data_cmd(&mut conn).await? {
DataChannelCmd::StartForward => { DataChannelCmd::StartForwardTcp => {
let mut local = TcpStream::connect(&args.local_addr) run_data_channel_for_tcp::<T>(conn, &args.local_addr).await?;
}
DataChannelCmd::StartForwardUdp => {
run_data_channel_for_udp::<T>(conn, &args.local_addr).await?;
}
}
Ok(())
}
// Simply copying back and forth for TCP
#[instrument(skip(conn))]
async fn run_data_channel_for_tcp<T: Transport>(
mut conn: T::Stream,
local_addr: &str,
) -> Result<()> {
debug!("New data channel starts forwarding");
let mut local = TcpStream::connect(local_addr)
.await .await
.with_context(|| "Failed to conenct to local_addr")?; .with_context(|| "Failed to conenct to local_addr")?;
let _ = copy_bidirectional(&mut conn, &mut local).await; let _ = copy_bidirectional(&mut conn, &mut local).await;
Ok(())
}
// Things get a little tricker when it gets to UDP because it's connectionless.
// A UdpPortMap must be maintained for recent seen incoming address, giving them
// each a local port, which is associated with a socket. So just the sender
// to the socket will work fine for the map's value.
type UdpPortMap = Arc<RwLock<HashMap<SocketAddr, mpsc::Sender<Bytes>>>>;
#[instrument(skip(conn))]
async fn run_data_channel_for_udp<T: Transport>(conn: T::Stream, local_addr: &str) -> Result<()> {
debug!("New data channel starts forwarding");
let port_map: UdpPortMap = Arc::new(RwLock::new(HashMap::new()));
// The channel stores UdpTraffic that needs to be sent to the server
let (outbound_tx, mut outbound_rx) = mpsc::channel::<UdpTraffic>(UDP_SENDQ_SIZE);
// FIXME: https://github.com/tokio-rs/tls/issues/40
// Maybe this is our concern
let (mut rd, mut wr) = io::split(conn);
// Keep sending items from the outbound channel to the server
tokio::spawn(async move {
while let Some(t) = outbound_rx.recv().await {
debug!("outbound {:?}", t);
if t.write(&mut wr).await.is_err() {
break;
} }
} }
});
loop {
// Read a packet from the server
let packet = UdpTraffic::read(&mut rd).await?;
let m = port_map.read().await;
if m.get(&packet.from).is_none() {
// This packet is from a address we don't see for a while,
// which is not in the UdpPortMap.
// So set up a mapping (and a forwarder) for it
// Drop the reader lock
drop(m);
// Grab the writer lock
// This is the only thread that will try to grab the writer lock
// So no need to worry about some other thread has already set up
// the mapping between the gap of dropping the reader lock and
// grabbing the writer lock
let mut m = port_map.write().await;
match udp_connect(local_addr).await {
Ok(s) => {
let (inbound_tx, inbound_rx) = mpsc::channel(UDP_SENDQ_SIZE);
m.insert(packet.from, inbound_tx);
tokio::spawn(run_udp_forwarder(
s,
inbound_rx,
outbound_tx.clone(),
packet.from,
port_map.clone(),
));
}
Err(e) => {
error!("{:?}", e);
}
}
}
// Now there should be a udp forwarder that can receive the packet
let m = port_map.read().await;
if let Some(tx) = m.get(&packet.from) {
let _ = tx.send(packet.data).await;
}
}
}
// Run a UdpSocket for the visitor `from`
async fn run_udp_forwarder(
s: UdpSocket,
mut inbound_rx: mpsc::Receiver<Bytes>,
outbount_tx: mpsc::Sender<UdpTraffic>,
from: SocketAddr,
port_map: UdpPortMap,
) -> Result<()> {
let mut buf = BytesMut::new();
buf.resize(UDP_BUFFER_SIZE, 0);
loop {
tokio::select! {
// Receive from the server
data = inbound_rx.recv() => {
if let Some(data) = data {
s.send(&data).await?;
} else {
break;
}
},
// Receive from the service
val = s.recv(&mut buf) => {
let len = match val {
Ok(v) => v,
Err(_) => {break;}
};
let t = UdpTraffic{
from,
data: Bytes::copy_from_slice(&buf[..len])
};
outbount_tx.send(t).await?;
},
// No traffic for the duration of UDP_TIMEOUT, clean up the state
_ = time::sleep(Duration::from_secs(UDP_TIMEOUT)) => {
break;
}
}
}
let mut port_map = port_map.write().await;
port_map.remove(&from);
Ok(()) Ok(())
} }
@ -163,7 +317,7 @@ struct ControlChannelHandle {
} }
impl<T: 'static + Transport> ControlChannel<T> { impl<T: 'static + Transport> ControlChannel<T> {
#[instrument(skip(self), fields(service=%self.service.name))] #[instrument(skip_all)]
async fn run(&mut self) -> Result<()> { async fn run(&mut self) -> Result<()> {
let mut conn = self let mut conn = self
.transport .transport

View File

@ -20,14 +20,30 @@ impl Default for TransportType {
#[derive(Debug, Serialize, Deserialize, Clone)] #[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ClientServiceConfig { pub struct ClientServiceConfig {
#[serde(rename = "type", default = "default_service_type")]
pub service_type: ServiceType,
#[serde(skip)] #[serde(skip)]
pub name: String, pub name: String,
pub local_addr: String, pub local_addr: String,
pub token: Option<String>, pub token: Option<String>,
} }
#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
pub enum ServiceType {
#[serde(rename = "tcp")]
Tcp,
#[serde(rename = "udp")]
Udp,
}
fn default_service_type() -> ServiceType {
ServiceType::Tcp
}
#[derive(Debug, Serialize, Deserialize, Clone)] #[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ServerServiceConfig { pub struct ServerServiceConfig {
#[serde(rename = "type", default = "default_service_type")]
pub service_type: ServiceType,
#[serde(skip)] #[serde(skip)]
pub name: String, pub name: String,
pub bind_addr: String, pub bind_addr: String,
@ -231,6 +247,7 @@ mod tests {
cfg.services.insert( cfg.services.insert(
"foo1".into(), "foo1".into(),
ServerServiceConfig { ServerServiceConfig {
service_type: ServiceType::Tcp,
name: "foo1".into(), name: "foo1".into(),
bind_addr: "127.0.0.1:80".into(), bind_addr: "127.0.0.1:80".into(),
token: None, token: None,
@ -277,6 +294,7 @@ mod tests {
cfg.services.insert( cfg.services.insert(
"foo1".into(), "foo1".into(),
ClientServiceConfig { ClientServiceConfig {
service_type: ServiceType::Tcp,
name: "foo1".into(), name: "foo1".into(),
local_addr: "127.0.0.1:80".into(), local_addr: "127.0.0.1:80".into(),
token: None, token: None,

15
src/constants.rs Normal file
View File

@ -0,0 +1,15 @@
use backoff::ExponentialBackoff;
use std::time::Duration;
// FIXME: Determine reasonable size
pub const UDP_BUFFER_SIZE: usize = 2048;
pub const UDP_SENDQ_SIZE: usize = 1024;
pub const UDP_TIMEOUT: u64 = 60;
pub fn listen_backoff() -> ExponentialBackoff {
ExponentialBackoff {
max_elapsed_time: None,
max_interval: Duration::from_secs(1),
..Default::default()
}
}

View File

@ -1,8 +1,13 @@
use std::time::Duration; use std::{
collections::hash_map::DefaultHasher,
hash::{Hash, Hasher},
net::SocketAddr,
time::Duration,
};
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use socket2::{SockRef, TcpKeepalive}; use socket2::{SockRef, TcpKeepalive};
use tokio::net::TcpStream; use tokio::net::{TcpStream, ToSocketAddrs, UdpSocket};
// Tokio hesitates to expose this option...So we have to do it on our own :( // Tokio hesitates to expose this option...So we have to do it on our own :(
// The good news is that using socket2 it can be easily done, without losing portablity. // The good news is that using socket2 it can be easily done, without losing portablity.
@ -21,3 +26,78 @@ pub fn feature_not_compile(feature: &str) -> ! {
feature feature
) )
} }
/// Create a UDP socket and connect to `addr`
pub async fn udp_connect<A: ToSocketAddrs>(addr: A) -> Result<UdpSocket> {
// FIXME: This only works for IPv4
let s = UdpSocket::bind("0.0.0.0:0").await?;
s.connect(addr).await?;
Ok(s)
}
#[allow(dead_code)]
pub fn hash_socket_addr(a: &SocketAddr) -> u64 {
let mut hasher = DefaultHasher::new();
a.hash(&mut hasher);
hasher.finish()
}
// Wait for the stablization of https://doc.rust-lang.org/std/primitive.i64.html#method.log2
#[allow(dead_code)]
fn log2_floor(x: usize) -> u8 {
(x as f64).log2().floor() as u8
}
#[allow(dead_code)]
pub fn floor_to_pow_of_2(x: usize) -> usize {
if x == 1 {
return 1;
}
let w = log2_floor(x);
1 << w
}
#[cfg(test)]
mod test {
use crate::helper::{floor_to_pow_of_2, log2_floor};
#[test]
fn test_log2_floor() {
let t = [
(2, 1),
(3, 1),
(4, 2),
(8, 3),
(9, 3),
(15, 3),
(16, 4),
(1023, 9),
(1024, 10),
(2000, 10),
(2048, 11),
];
for t in t {
assert_eq!(log2_floor(t.0), t.1);
}
}
#[test]
fn test_floor_to_pow_of_2() {
let t = [
(1 as usize, 1 as usize),
(2, 2),
(3, 2),
(4, 4),
(5, 4),
(15, 8),
(31, 16),
(33, 32),
(1000, 512),
(1500, 1024),
(2300, 2048),
];
for t in t {
assert_eq!(floor_to_pow_of_2(t.0), t.1);
}
}
}

View File

@ -1,5 +1,6 @@
mod cli; mod cli;
mod config; mod config;
mod constants;
mod helper; mod helper;
mod multi_map; mod multi_map;
mod protocol; mod protocol;

View File

@ -1,9 +1,11 @@
pub const HASH_WIDTH_IN_BYTES: usize = 32; pub const HASH_WIDTH_IN_BYTES: usize = 32;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use bytes::{Bytes, BytesMut};
use lazy_static::lazy_static; use lazy_static::lazy_static;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; use std::net::SocketAddr;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
type ProtocolVersion = u8; type ProtocolVersion = u8;
const PROTO_V0: u8 = 0u8; const PROTO_V0: u8 = 0u8;
@ -49,7 +51,80 @@ pub enum ControlChannelCmd {
#[derive(Deserialize, Serialize, Debug)] #[derive(Deserialize, Serialize, Debug)]
pub enum DataChannelCmd { pub enum DataChannelCmd {
StartForward, StartForwardTcp,
StartForwardUdp,
}
type UdpPacketLen = u16; // `u16` should be enough for any practical UDP traffic on the Internet
#[derive(Deserialize, Serialize, Debug)]
struct UdpHeader {
from: SocketAddr,
len: UdpPacketLen,
}
#[derive(Debug)]
pub struct UdpTraffic {
pub from: SocketAddr,
pub data: Bytes,
}
impl UdpTraffic {
pub async fn write<T: AsyncWrite + Unpin>(&self, writer: &mut T) -> Result<()> {
let v = bincode::serialize(&UdpHeader {
from: self.from,
len: self.data.len() as UdpPacketLen,
})
.unwrap();
writer.write_u16(v.len() as u16).await?;
writer.write_all(&v).await?;
writer.write_all(&self.data).await?;
Ok(())
}
#[allow(dead_code)]
pub async fn write_slice<T: AsyncWrite + Unpin>(
writer: &mut T,
from: SocketAddr,
data: &[u8],
) -> Result<()> {
let v = bincode::serialize(&UdpHeader {
from,
len: data.len() as UdpPacketLen,
})
.unwrap();
writer.write_u16(v.len() as u16).await?;
writer.write_all(&v).await?;
writer.write_all(data).await?;
Ok(())
}
pub async fn read<T: AsyncRead + Unpin>(reader: &mut T) -> Result<UdpTraffic> {
let len = reader.read_u16().await? as usize;
let mut buf = Vec::new();
buf.resize(len, 0);
reader
.read_exact(&mut buf)
.await
.with_context(|| "Failed to read udp header")?;
let header: UdpHeader =
bincode::deserialize(&buf).with_context(|| "Failed to deserialize udp header")?;
let mut data = BytesMut::new();
data.resize(header.len as usize, 0);
reader.read_exact(&mut data).await?;
Ok(UdpTraffic {
from: header.from,
data: data.freeze(),
})
}
} }
pub fn digest(data: &[u8]) -> Digest { pub fn digest(data: &[u8]) -> Digest {
@ -74,7 +149,7 @@ impl PacketLength {
.unwrap() as usize; .unwrap() as usize;
let c_cmd = let c_cmd =
bincode::serialized_size(&ControlChannelCmd::CreateDataChannel).unwrap() as usize; bincode::serialized_size(&ControlChannelCmd::CreateDataChannel).unwrap() as usize;
let d_cmd = bincode::serialized_size(&DataChannelCmd::StartForward).unwrap() as usize; let d_cmd = bincode::serialized_size(&DataChannelCmd::StartForwardTcp).unwrap() as usize;
let ack = Ack::Ok; let ack = Ack::Ok;
let ack = bincode::serialized_size(&ack).unwrap() as usize; let ack = bincode::serialized_size(&ack).unwrap() as usize;

View File

@ -1,8 +1,10 @@
use crate::config::{Config, ServerConfig, ServerServiceConfig, TransportType}; use crate::config::{Config, ServerConfig, ServerServiceConfig, ServiceType, TransportType};
use crate::constants::{listen_backoff, UDP_BUFFER_SIZE};
use crate::multi_map::MultiMap; use crate::multi_map::MultiMap;
use crate::protocol::Hello::{ControlChannelHello, DataChannelHello}; use crate::protocol::Hello::{ControlChannelHello, DataChannelHello};
use crate::protocol::{ use crate::protocol::{
self, read_auth, read_hello, Ack, ControlChannelCmd, DataChannelCmd, Hello, HASH_WIDTH_IN_BYTES, self, read_auth, read_hello, Ack, ControlChannelCmd, DataChannelCmd, Hello, UdpTraffic,
HASH_WIDTH_IN_BYTES,
}; };
#[cfg(feature = "tls")] #[cfg(feature = "tls")]
use crate::transport::TlsTransport; use crate::transport::TlsTransport;
@ -10,21 +12,23 @@ use crate::transport::{TcpTransport, Transport};
use anyhow::{anyhow, bail, Context, Result}; use anyhow::{anyhow, bail, Context, Result};
use backoff::backoff::Backoff; use backoff::backoff::Backoff;
use backoff::ExponentialBackoff; use backoff::ExponentialBackoff;
use rand::RngCore; use rand::RngCore;
use std::collections::HashMap; use std::collections::HashMap;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::io::{self, copy_bidirectional, AsyncWriteExt}; use tokio::io::{self, copy_bidirectional, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream, UdpSocket};
use tokio::sync::{broadcast, mpsc, oneshot, RwLock}; use tokio::sync::{broadcast, mpsc, RwLock};
use tokio::time; use tokio::time;
use tracing::{debug, error, info, info_span, warn, Instrument}; use tracing::{debug, error, info, info_span, instrument, warn, Instrument, Span};
type ServiceDigest = protocol::Digest; // SHA256 of a service name type ServiceDigest = protocol::Digest; // SHA256 of a service name
type Nonce = protocol::Digest; // Also called `session_key` type Nonce = protocol::Digest; // Also called `session_key`
const POOL_SIZE: usize = 64; // The number of cached connections const TCP_POOL_SIZE: usize = 64; // The number of cached connections for TCP servies
const UDP_POOL_SIZE: usize = 2; // The number of cached connections for UDP services
const CHAN_SIZE: usize = 2048; // The capacity of various chans const CHAN_SIZE: usize = 2048; // The capacity of various chans
// The entrypoint of running a server // The entrypoint of running a server
@ -268,7 +272,7 @@ async fn do_control_channel_handshake<T: 'static + Transport>(
Ok(()) Ok(())
} }
async fn do_data_channel_handshake<T: Transport>( async fn do_data_channel_handshake<T: 'static + Transport>(
conn: T::Stream, conn: T::Stream,
control_channels: Arc<RwLock<ControlChannelMap<T>>>, control_channels: Arc<RwLock<ControlChannelMap<T>>>,
nonce: Nonce, nonce: Nonce,
@ -276,9 +280,9 @@ async fn do_data_channel_handshake<T: Transport>(
// Validate // Validate
let control_channels_guard = control_channels.read().await; let control_channels_guard = control_channels.read().await;
match control_channels_guard.get2(&nonce) { match control_channels_guard.get2(&nonce) {
Some(c_ch) => { Some(handle) => {
// Send the data channel to the corresponding control channel // Send the data channel to the corresponding control channel
c_ch.conn_pool.data_ch_tx.send(conn).await?; handle.data_ch_tx.send(conn).await?;
} }
None => { None => {
// TODO: Maybe print IP here // TODO: Maybe print IP here
@ -288,43 +292,74 @@ async fn do_data_channel_handshake<T: Transport>(
Ok(()) Ok(())
} }
// Control channel, using T as the transport layer pub struct ControlChannelHandle<T: Transport> {
struct ControlChannel<T: Transport> {
conn: T::Stream, // The connection of control channel
service: ServerServiceConfig, // A copy of the corresponding service config
shutdown_rx: oneshot::Receiver<bool>, // Receives the shutdown signal
visitor_tx: mpsc::Sender<TcpStream>, // Receives visitor connections
}
// The handle of a control channel, along with the handle of a connection pool
// Dropping it will drop the actual control channel, because `visitor_tx`
// and `shutdown_tx` are closed
struct ControlChannelHandle<T: Transport> {
// Shutdown the control channel. // Shutdown the control channel.
// Not used for now, but can be used for hot reloading // Not used for now, but can be used for hot reloading
_shutdown_tx: oneshot::Sender<bool>, #[allow(dead_code)]
conn_pool: ConnectionPoolHandle<T>, shutdown_tx: broadcast::Sender<bool>,
//data_ch_req_tx: mpsc::Sender<bool>,
data_ch_tx: mpsc::Sender<T::Stream>,
} }
impl<T: 'static + Transport> ControlChannelHandle<T> { impl<T> ControlChannelHandle<T>
where
T: 'static + Transport,
{
// Create a control channel handle, where the control channel handling task // Create a control channel handle, where the control channel handling task
// and the connection pool task are created. // and the connection pool task are created.
#[instrument(skip_all, fields(service = %service.name))]
fn new(conn: T::Stream, service: ServerServiceConfig) -> ControlChannelHandle<T> { fn new(conn: T::Stream, service: ServerServiceConfig) -> ControlChannelHandle<T> {
// Save the name string for logging // Save the name string for logging
let name = service.name.clone(); let name = service.name.clone();
// Create a shutdown channel. The sender is not used for now, but for future use // Create a shutdown channel. The sender is not used for now, but for future use
let (_shutdown_tx, shutdown_rx) = oneshot::channel::<bool>(); let (shutdown_tx, shutdown_rx) = broadcast::channel::<bool>(1);
// Create and run the connection pool, where the visitors and data channels meet // Store data channels
let conn_pool = ConnectionPoolHandle::new(); let (data_ch_tx, data_ch_rx) = mpsc::channel(CHAN_SIZE * 2);
// Store data channel creation requests
let (data_ch_req_tx, data_ch_req_rx) = mpsc::unbounded_channel();
match service.service_type {
ServiceType::Tcp => tokio::spawn(
run_tcp_connection_pool::<T>(
service.bind_addr.clone(),
data_ch_rx,
data_ch_req_tx.clone(),
shutdown_tx.subscribe(),
)
.instrument(Span::current()),
),
ServiceType::Udp => tokio::spawn(
run_udp_connection_pool::<T>(
service.bind_addr.clone(),
data_ch_rx,
data_ch_req_tx.clone(),
shutdown_tx.subscribe(),
)
.instrument(Span::current()),
),
};
// Cache some data channels for later use
let pool_size = match service.service_type {
ServiceType::Tcp => TCP_POOL_SIZE,
ServiceType::Udp => UDP_POOL_SIZE,
};
for _i in 0..pool_size {
if let Err(e) = data_ch_req_tx.send(true) {
error!("Failed to request data channel {}", e);
};
}
// Create the control channel // Create the control channel
let ch: ControlChannel<T> = ControlChannel { let ch = ControlChannel::<T> {
conn, conn,
shutdown_rx, shutdown_rx,
service, service,
visitor_tx: conn_pool.visitor_tx.clone(), data_ch_req_rx,
}; };
// Run the control channel // Run the control channel
@ -335,52 +370,83 @@ impl<T: 'static + Transport> ControlChannelHandle<T> {
}); });
ControlChannelHandle { ControlChannelHandle {
_shutdown_tx, shutdown_tx,
conn_pool, data_ch_tx,
} }
} }
#[allow(dead_code)]
fn shutdown(self) {
let _ = self.shutdown_tx.send(true);
}
}
// Control channel, using T as the transport layer. P is TcpStream or UdpTraffic
struct ControlChannel<T: Transport> {
conn: T::Stream, // The connection of control channel
service: ServerServiceConfig, // A copy of the corresponding service config
shutdown_rx: broadcast::Receiver<bool>, // Receives the shutdown signal
data_ch_req_rx: mpsc::UnboundedReceiver<bool>, // Receives visitor connections
} }
impl<T: Transport> ControlChannel<T> { impl<T: Transport> ControlChannel<T> {
// Run a control channel // Run a control channel
#[tracing::instrument(skip(self), fields(service = %self.service.name))] #[instrument(skip(self), fields(service = %self.service.name))]
async fn run(mut self) -> Result<()> { async fn run(mut self) -> Result<()> {
// Where the service is exposed
let l = match TcpListener::bind(&self.service.bind_addr).await {
Ok(v) => v,
Err(e) => {
let duration = Duration::from_secs(1);
error!(
"Failed to listen on service.bind_addr: {}. Retry in {:?}...",
e, duration
);
time::sleep(duration).await;
TcpListener::bind(&self.service.bind_addr).await?
}
};
info!("Listening at {}", &self.service.bind_addr);
// Each `u8` in the chan indicates a data channel creation request
let (data_req_tx, mut data_req_rx) = mpsc::unbounded_channel::<u8>();
// The control channel is moved into the task, and sends CreateDataChannel
// comamnds to the client when needed
tokio::spawn(async move {
let cmd = bincode::serialize(&ControlChannelCmd::CreateDataChannel).unwrap(); let cmd = bincode::serialize(&ControlChannelCmd::CreateDataChannel).unwrap();
while data_req_rx.recv().await.is_some() {
if self.conn.write_all(&cmd).await.is_err() { // Wait for data channel requests and the shutdown signal
loop {
tokio::select! {
val = self.data_ch_req_rx.recv() => {
match val {
Some(_) => {
if let Err(e) = self.conn.write_all(&cmd).await.with_context(||"Failed to write data cmds") {
error!("{:?}", e);
break; break;
} }
} }
}); None => {
break;
// Cache some data channels for later use
for _i in 0..POOL_SIZE {
if let Err(e) = data_req_tx.send(0) {
error!("Failed to request data channel {}", e);
};
} }
}
},
// Wait for the shutdown signal
_ = self.shutdown_rx.recv() => {
break;
}
}
}
info!("Control channel shuting down");
Ok(())
}
}
fn tcp_listen_and_send(
addr: String,
data_ch_req_tx: mpsc::UnboundedSender<bool>,
mut shutdown_rx: broadcast::Receiver<bool>,
) -> mpsc::Receiver<TcpStream> {
let (tx, rx) = mpsc::channel(CHAN_SIZE);
tokio::spawn(async move {
let l = backoff::future::retry(listen_backoff(), || async {
Ok(TcpListener::bind(&addr).await?)
})
.await
.with_context(|| "Failed to listen for the service");
let l: TcpListener = match l {
Ok(v) => v,
Err(e) => {
error!("{:?}", e);
return;
}
};
info!("Listening at {}", &addr);
// Retry at least every 1s // Retry at least every 1s
let mut backoff = ExponentialBackoff { let mut backoff = ExponentialBackoff {
@ -392,7 +458,6 @@ impl<T: Transport> ControlChannel<T> {
// Wait for visitors and the shutdown signal // Wait for visitors and the shutdown signal
loop { loop {
tokio::select! { tokio::select! {
// Wait for visitors
val = l.accept() => { val = l.accept() => {
match val { match val {
Err(e) => { Err(e) => {
@ -406,73 +471,47 @@ impl<T: Transport> ControlChannel<T> {
error!("Too many retries. Aborting..."); error!("Too many retries. Aborting...");
break; break;
} }
}, }
Ok((incoming, addr)) => { Ok((incoming, addr)) => {
// For every visitor, request to create a data channel // For every visitor, request to create a data channel
if let Err(e) = data_req_tx.send(0) { if let Err(e) = data_ch_req_tx.send(true) {
// An error indicates the control channel is broken // An error indicates the control channel is broken
// So break the loop // So break the loop
error!("{}", e); error!("{}", e);
break; break;
}; }
backoff.reset(); backoff.reset();
debug!("New visitor from {}", addr); debug!("New visitor from {}", addr);
// Send the visitor to the connection pool // Send the visitor to the connection pool
let _ = self.visitor_tx.send(incoming).await; let _ = tx.send(incoming).await;
} }
} }
}, },
// Wait for the shutdown signal _ = shutdown_rx.recv() => {
_ = &mut self.shutdown_rx => {
break; break;
} }
} }
} }
info!("Service shuting down"); });
Ok(()) rx
}
} }
#[derive(Debug)] #[instrument(skip_all)]
struct ConnectionPool<T: Transport> { async fn run_tcp_connection_pool<T: Transport>(
visitor_rx: mpsc::Receiver<TcpStream>, bind_addr: String,
data_ch_rx: mpsc::Receiver<T::Stream>, mut data_ch_rx: mpsc::Receiver<T::Stream>,
} data_ch_req_tx: mpsc::UnboundedSender<bool>,
shutdown_rx: broadcast::Receiver<bool>,
struct ConnectionPoolHandle<T: Transport> { ) -> Result<()> {
visitor_tx: mpsc::Sender<TcpStream>, let mut visitor_rx = tcp_listen_and_send(bind_addr, data_ch_req_tx, shutdown_rx);
data_ch_tx: mpsc::Sender<T::Stream>, while let Some(mut visitor) = visitor_rx.recv().await {
} if let Some(mut ch) = data_ch_rx.recv().await {
impl<T: 'static + Transport> ConnectionPoolHandle<T> {
fn new() -> ConnectionPoolHandle<T> {
let (data_ch_tx, data_ch_rx) = mpsc::channel(CHAN_SIZE * 2);
let (visitor_tx, visitor_rx) = mpsc::channel(CHAN_SIZE);
let conn_pool: ConnectionPool<T> = ConnectionPool {
data_ch_rx,
visitor_rx,
};
tokio::spawn(async move { conn_pool.run().await });
ConnectionPoolHandle {
data_ch_tx,
visitor_tx,
}
}
}
impl<T: Transport> ConnectionPool<T> {
#[tracing::instrument]
async fn run(mut self) {
while let Some(mut visitor) = self.visitor_rx.recv().await {
if let Some(mut ch) = self.data_ch_rx.recv().await {
tokio::spawn(async move { tokio::spawn(async move {
let cmd = bincode::serialize(&DataChannelCmd::StartForward).unwrap(); let cmd = bincode::serialize(&DataChannelCmd::StartForwardTcp).unwrap();
if ch.write_all(&cmd).await.is_ok() { if ch.write_all(&cmd).await.is_ok() {
let _ = copy_bidirectional(&mut ch, &mut visitor).await; let _ = copy_bidirectional(&mut ch, &mut visitor).await;
} }
@ -481,5 +520,54 @@ impl<T: Transport> ConnectionPool<T> {
break; break;
} }
} }
Ok(())
}
#[instrument(skip_all)]
async fn run_udp_connection_pool<T: Transport>(
bind_addr: String,
mut data_ch_rx: mpsc::Receiver<T::Stream>,
_data_ch_req_tx: mpsc::UnboundedSender<bool>,
mut shutdown_rx: broadcast::Receiver<bool>,
) -> Result<()> {
// TODO: Load balance
let l: UdpSocket = backoff::future::retry(listen_backoff(), || async {
Ok(UdpSocket::bind(&bind_addr).await?)
})
.await
.with_context(|| "Failed to listen for the service")?;
info!("Listening at {}", &bind_addr);
let cmd = bincode::serialize(&DataChannelCmd::StartForwardUdp).unwrap();
let mut conn = data_ch_rx
.recv()
.await
.ok_or(anyhow!("No available data channels"))?;
conn.write_all(&cmd).await?;
let mut buf = [0u8; UDP_BUFFER_SIZE];
loop {
tokio::select! {
// Forward inbound traffic to the client
val = l.recv_from(&mut buf) => {
let (n, from) = val?;
UdpTraffic::write_slice(&mut conn, from, &buf[..n]).await?;
},
// Forward outbound traffic from the client to the visitor
t = UdpTraffic::read(&mut conn) => {
let t = t?;
l.send_to(&t.data, t.from).await?;
},
_ = shutdown_rx.recv() => {
break;
} }
} }
}
Ok(())
}