mirror of https://github.com/rapiz1/rathole.git
feat: UDP support
This commit is contained in:
parent
65c75da633
commit
443f763800
|
@ -88,6 +88,9 @@ name = "bytes"
|
||||||
version = "1.1.0"
|
version = "1.1.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c4872d67bab6358e59559027aa3b9157c53d9358c51423c17554809a8858e0f8"
|
checksum = "c4872d67bab6358e59559027aa3b9157c53d9358c51423c17554809a8858e0f8"
|
||||||
|
dependencies = [
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cc"
|
name = "cc"
|
||||||
|
|
|
@ -24,7 +24,7 @@ opt-level = "s"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
tokio = { version = "1", features = ["full"] }
|
tokio = { version = "1", features = ["full"] }
|
||||||
bytes = { version = "1"}
|
bytes = { version = "1", features = ["serde"] }
|
||||||
clap = { version = "3.0.0-rc.7", features = ["derive"] }
|
clap = { version = "3.0.0-rc.7", features = ["derive"] }
|
||||||
toml = "0.5"
|
toml = "0.5"
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
[client]
|
||||||
|
remote_addr = "localhost:2333"
|
||||||
|
default_token = "123"
|
||||||
|
|
||||||
|
[client.services.foo1]
|
||||||
|
type = "udp"
|
||||||
|
local_addr = "127.0.0.1:80"
|
|
@ -0,0 +1,7 @@
|
||||||
|
[server]
|
||||||
|
bind_addr = "0.0.0.0:2333"
|
||||||
|
default_token = "123"
|
||||||
|
|
||||||
|
[server.services.foo1]
|
||||||
|
type = "udp"
|
||||||
|
bind_addr = "0.0.0.0:5202"
|
170
src/client.rs
170
src/client.rs
|
@ -1,23 +1,28 @@
|
||||||
use crate::config::{ClientConfig, ClientServiceConfig, Config, TransportType};
|
use crate::config::{ClientConfig, ClientServiceConfig, Config, TransportType};
|
||||||
|
use crate::helper::udp_connect;
|
||||||
use crate::protocol::Hello::{self, *};
|
use crate::protocol::Hello::{self, *};
|
||||||
use crate::protocol::{
|
use crate::protocol::{
|
||||||
self, read_ack, read_control_cmd, read_data_cmd, read_hello, Ack, Auth, ControlChannelCmd,
|
self, read_ack, read_control_cmd, read_data_cmd, read_hello, Ack, Auth, ControlChannelCmd,
|
||||||
DataChannelCmd, CURRENT_PROTO_VRESION, HASH_WIDTH_IN_BYTES,
|
DataChannelCmd, UdpTraffic, CURRENT_PROTO_VRESION, HASH_WIDTH_IN_BYTES,
|
||||||
};
|
};
|
||||||
use crate::transport::{TcpTransport, Transport};
|
use crate::transport::{TcpTransport, Transport};
|
||||||
use anyhow::{anyhow, bail, Context, Result};
|
use anyhow::{anyhow, bail, Context, Result};
|
||||||
use backoff::ExponentialBackoff;
|
use backoff::ExponentialBackoff;
|
||||||
|
use bytes::{Bytes, BytesMut};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::net::SocketAddr;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::io::{copy_bidirectional, AsyncWriteExt};
|
use tokio::io::{self, copy_bidirectional, AsyncWriteExt};
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::{TcpStream, UdpSocket};
|
||||||
use tokio::sync::{broadcast, oneshot};
|
use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
|
||||||
use tokio::time::{self, Duration};
|
use tokio::time::{self, Duration};
|
||||||
use tracing::{debug, error, info, instrument, Instrument, Span};
|
use tracing::{debug, error, info, instrument, Instrument, Span};
|
||||||
|
|
||||||
#[cfg(feature = "tls")]
|
#[cfg(feature = "tls")]
|
||||||
use crate::transport::TlsTransport;
|
use crate::transport::TlsTransport;
|
||||||
|
|
||||||
|
use crate::constants::{UDP_BUFFER_SIZE, UDP_SENDQ_SIZE, UDP_TIMEOUT};
|
||||||
|
|
||||||
// The entrypoint of running a client
|
// The entrypoint of running a client
|
||||||
pub async fn run_client(config: &Config, shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
|
pub async fn run_client(config: &Config, shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
|
||||||
let config = match &config.client {
|
let config = match &config.client {
|
||||||
|
@ -112,7 +117,9 @@ struct RunDataChannelArgs<T: Transport> {
|
||||||
connector: Arc<T>,
|
connector: Arc<T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn run_data_channel<T: Transport>(args: Arc<RunDataChannelArgs<T>>) -> Result<()> {
|
async fn do_data_channel_handshake<T: Transport>(
|
||||||
|
args: Arc<RunDataChannelArgs<T>>,
|
||||||
|
) -> Result<T::Stream> {
|
||||||
// Retry at least every 100ms, at most for 10 seconds
|
// Retry at least every 100ms, at most for 10 seconds
|
||||||
let backoff = ExponentialBackoff {
|
let backoff = ExponentialBackoff {
|
||||||
max_interval: Duration::from_millis(100),
|
max_interval: Duration::from_millis(100),
|
||||||
|
@ -135,15 +142,162 @@ async fn run_data_channel<T: Transport>(args: Arc<RunDataChannelArgs<T>>) -> Res
|
||||||
let hello = Hello::DataChannelHello(CURRENT_PROTO_VRESION, v.to_owned());
|
let hello = Hello::DataChannelHello(CURRENT_PROTO_VRESION, v.to_owned());
|
||||||
conn.write_all(&bincode::serialize(&hello).unwrap()).await?;
|
conn.write_all(&bincode::serialize(&hello).unwrap()).await?;
|
||||||
|
|
||||||
|
Ok(conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn run_data_channel<T: Transport>(args: Arc<RunDataChannelArgs<T>>) -> Result<()> {
|
||||||
|
// Do the handshake
|
||||||
|
let mut conn = do_data_channel_handshake(args.clone()).await?;
|
||||||
|
|
||||||
// Forward
|
// Forward
|
||||||
match read_data_cmd(&mut conn).await? {
|
match read_data_cmd(&mut conn).await? {
|
||||||
DataChannelCmd::StartForward => {
|
DataChannelCmd::StartForwardTcp => {
|
||||||
let mut local = TcpStream::connect(&args.local_addr)
|
run_data_channel_for_tcp::<T>(conn, &args.local_addr).await?;
|
||||||
|
}
|
||||||
|
DataChannelCmd::StartForwardUdp => {
|
||||||
|
run_data_channel_for_udp::<T>(conn, &args.local_addr).await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simply copying back and forth for TCP
|
||||||
|
#[instrument(skip(conn))]
|
||||||
|
async fn run_data_channel_for_tcp<T: Transport>(
|
||||||
|
mut conn: T::Stream,
|
||||||
|
local_addr: &str,
|
||||||
|
) -> Result<()> {
|
||||||
|
debug!("New data channel starts forwarding");
|
||||||
|
|
||||||
|
let mut local = TcpStream::connect(local_addr)
|
||||||
.await
|
.await
|
||||||
.with_context(|| "Failed to conenct to local_addr")?;
|
.with_context(|| "Failed to conenct to local_addr")?;
|
||||||
let _ = copy_bidirectional(&mut conn, &mut local).await;
|
let _ = copy_bidirectional(&mut conn, &mut local).await;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Things get a little tricker when it gets to UDP because it's connectionless.
|
||||||
|
// A UdpPortMap must be maintained for recent seen incoming address, giving them
|
||||||
|
// each a local port, which is associated with a socket. So just the sender
|
||||||
|
// to the socket will work fine for the map's value.
|
||||||
|
type UdpPortMap = Arc<RwLock<HashMap<SocketAddr, mpsc::Sender<Bytes>>>>;
|
||||||
|
|
||||||
|
#[instrument(skip(conn))]
|
||||||
|
async fn run_data_channel_for_udp<T: Transport>(conn: T::Stream, local_addr: &str) -> Result<()> {
|
||||||
|
debug!("New data channel starts forwarding");
|
||||||
|
|
||||||
|
let port_map: UdpPortMap = Arc::new(RwLock::new(HashMap::new()));
|
||||||
|
|
||||||
|
// The channel stores UdpTraffic that needs to be sent to the server
|
||||||
|
let (outbound_tx, mut outbound_rx) = mpsc::channel::<UdpTraffic>(UDP_SENDQ_SIZE);
|
||||||
|
|
||||||
|
// FIXME: https://github.com/tokio-rs/tls/issues/40
|
||||||
|
// Maybe this is our concern
|
||||||
|
let (mut rd, mut wr) = io::split(conn);
|
||||||
|
|
||||||
|
// Keep sending items from the outbound channel to the server
|
||||||
|
tokio::spawn(async move {
|
||||||
|
while let Some(t) = outbound_rx.recv().await {
|
||||||
|
debug!("outbound {:?}", t);
|
||||||
|
if t.write(&mut wr).await.is_err() {
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
loop {
|
||||||
|
// Read a packet from the server
|
||||||
|
let packet = UdpTraffic::read(&mut rd).await?;
|
||||||
|
let m = port_map.read().await;
|
||||||
|
|
||||||
|
if m.get(&packet.from).is_none() {
|
||||||
|
// This packet is from a address we don't see for a while,
|
||||||
|
// which is not in the UdpPortMap.
|
||||||
|
// So set up a mapping (and a forwarder) for it
|
||||||
|
|
||||||
|
// Drop the reader lock
|
||||||
|
drop(m);
|
||||||
|
|
||||||
|
// Grab the writer lock
|
||||||
|
// This is the only thread that will try to grab the writer lock
|
||||||
|
// So no need to worry about some other thread has already set up
|
||||||
|
// the mapping between the gap of dropping the reader lock and
|
||||||
|
// grabbing the writer lock
|
||||||
|
let mut m = port_map.write().await;
|
||||||
|
|
||||||
|
match udp_connect(local_addr).await {
|
||||||
|
Ok(s) => {
|
||||||
|
let (inbound_tx, inbound_rx) = mpsc::channel(UDP_SENDQ_SIZE);
|
||||||
|
m.insert(packet.from, inbound_tx);
|
||||||
|
tokio::spawn(run_udp_forwarder(
|
||||||
|
s,
|
||||||
|
inbound_rx,
|
||||||
|
outbound_tx.clone(),
|
||||||
|
packet.from,
|
||||||
|
port_map.clone(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!("{:?}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now there should be a udp forwarder that can receive the packet
|
||||||
|
let m = port_map.read().await;
|
||||||
|
if let Some(tx) = m.get(&packet.from) {
|
||||||
|
let _ = tx.send(packet.data).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run a UdpSocket for the visitor `from`
|
||||||
|
async fn run_udp_forwarder(
|
||||||
|
s: UdpSocket,
|
||||||
|
mut inbound_rx: mpsc::Receiver<Bytes>,
|
||||||
|
outbount_tx: mpsc::Sender<UdpTraffic>,
|
||||||
|
from: SocketAddr,
|
||||||
|
port_map: UdpPortMap,
|
||||||
|
) -> Result<()> {
|
||||||
|
let mut buf = BytesMut::new();
|
||||||
|
buf.resize(UDP_BUFFER_SIZE, 0);
|
||||||
|
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
// Receive from the server
|
||||||
|
data = inbound_rx.recv() => {
|
||||||
|
if let Some(data) = data {
|
||||||
|
s.send(&data).await?;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
// Receive from the service
|
||||||
|
val = s.recv(&mut buf) => {
|
||||||
|
let len = match val {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(_) => {break;}
|
||||||
|
};
|
||||||
|
|
||||||
|
let t = UdpTraffic{
|
||||||
|
from,
|
||||||
|
data: Bytes::copy_from_slice(&buf[..len])
|
||||||
|
};
|
||||||
|
|
||||||
|
outbount_tx.send(t).await?;
|
||||||
|
},
|
||||||
|
|
||||||
|
// No traffic for the duration of UDP_TIMEOUT, clean up the state
|
||||||
|
_ = time::sleep(Duration::from_secs(UDP_TIMEOUT)) => {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut port_map = port_map.write().await;
|
||||||
|
port_map.remove(&from);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -163,7 +317,7 @@ struct ControlChannelHandle {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: 'static + Transport> ControlChannel<T> {
|
impl<T: 'static + Transport> ControlChannel<T> {
|
||||||
#[instrument(skip(self), fields(service=%self.service.name))]
|
#[instrument(skip_all)]
|
||||||
async fn run(&mut self) -> Result<()> {
|
async fn run(&mut self) -> Result<()> {
|
||||||
let mut conn = self
|
let mut conn = self
|
||||||
.transport
|
.transport
|
||||||
|
|
|
@ -20,14 +20,30 @@ impl Default for TransportType {
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
pub struct ClientServiceConfig {
|
pub struct ClientServiceConfig {
|
||||||
|
#[serde(rename = "type", default = "default_service_type")]
|
||||||
|
pub service_type: ServiceType,
|
||||||
#[serde(skip)]
|
#[serde(skip)]
|
||||||
pub name: String,
|
pub name: String,
|
||||||
pub local_addr: String,
|
pub local_addr: String,
|
||||||
pub token: Option<String>,
|
pub token: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
|
||||||
|
pub enum ServiceType {
|
||||||
|
#[serde(rename = "tcp")]
|
||||||
|
Tcp,
|
||||||
|
#[serde(rename = "udp")]
|
||||||
|
Udp,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_service_type() -> ServiceType {
|
||||||
|
ServiceType::Tcp
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||||
pub struct ServerServiceConfig {
|
pub struct ServerServiceConfig {
|
||||||
|
#[serde(rename = "type", default = "default_service_type")]
|
||||||
|
pub service_type: ServiceType,
|
||||||
#[serde(skip)]
|
#[serde(skip)]
|
||||||
pub name: String,
|
pub name: String,
|
||||||
pub bind_addr: String,
|
pub bind_addr: String,
|
||||||
|
@ -231,6 +247,7 @@ mod tests {
|
||||||
cfg.services.insert(
|
cfg.services.insert(
|
||||||
"foo1".into(),
|
"foo1".into(),
|
||||||
ServerServiceConfig {
|
ServerServiceConfig {
|
||||||
|
service_type: ServiceType::Tcp,
|
||||||
name: "foo1".into(),
|
name: "foo1".into(),
|
||||||
bind_addr: "127.0.0.1:80".into(),
|
bind_addr: "127.0.0.1:80".into(),
|
||||||
token: None,
|
token: None,
|
||||||
|
@ -277,6 +294,7 @@ mod tests {
|
||||||
cfg.services.insert(
|
cfg.services.insert(
|
||||||
"foo1".into(),
|
"foo1".into(),
|
||||||
ClientServiceConfig {
|
ClientServiceConfig {
|
||||||
|
service_type: ServiceType::Tcp,
|
||||||
name: "foo1".into(),
|
name: "foo1".into(),
|
||||||
local_addr: "127.0.0.1:80".into(),
|
local_addr: "127.0.0.1:80".into(),
|
||||||
token: None,
|
token: None,
|
||||||
|
|
|
@ -0,0 +1,15 @@
|
||||||
|
use backoff::ExponentialBackoff;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
// FIXME: Determine reasonable size
|
||||||
|
pub const UDP_BUFFER_SIZE: usize = 2048;
|
||||||
|
pub const UDP_SENDQ_SIZE: usize = 1024;
|
||||||
|
pub const UDP_TIMEOUT: u64 = 60;
|
||||||
|
|
||||||
|
pub fn listen_backoff() -> ExponentialBackoff {
|
||||||
|
ExponentialBackoff {
|
||||||
|
max_elapsed_time: None,
|
||||||
|
max_interval: Duration::from_secs(1),
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,8 +1,13 @@
|
||||||
use std::time::Duration;
|
use std::{
|
||||||
|
collections::hash_map::DefaultHasher,
|
||||||
|
hash::{Hash, Hasher},
|
||||||
|
net::SocketAddr,
|
||||||
|
time::Duration,
|
||||||
|
};
|
||||||
|
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use socket2::{SockRef, TcpKeepalive};
|
use socket2::{SockRef, TcpKeepalive};
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::{TcpStream, ToSocketAddrs, UdpSocket};
|
||||||
|
|
||||||
// Tokio hesitates to expose this option...So we have to do it on our own :(
|
// Tokio hesitates to expose this option...So we have to do it on our own :(
|
||||||
// The good news is that using socket2 it can be easily done, without losing portablity.
|
// The good news is that using socket2 it can be easily done, without losing portablity.
|
||||||
|
@ -21,3 +26,78 @@ pub fn feature_not_compile(feature: &str) -> ! {
|
||||||
feature
|
feature
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create a UDP socket and connect to `addr`
|
||||||
|
pub async fn udp_connect<A: ToSocketAddrs>(addr: A) -> Result<UdpSocket> {
|
||||||
|
// FIXME: This only works for IPv4
|
||||||
|
let s = UdpSocket::bind("0.0.0.0:0").await?;
|
||||||
|
s.connect(addr).await?;
|
||||||
|
Ok(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn hash_socket_addr(a: &SocketAddr) -> u64 {
|
||||||
|
let mut hasher = DefaultHasher::new();
|
||||||
|
a.hash(&mut hasher);
|
||||||
|
hasher.finish()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the stablization of https://doc.rust-lang.org/std/primitive.i64.html#method.log2
|
||||||
|
#[allow(dead_code)]
|
||||||
|
fn log2_floor(x: usize) -> u8 {
|
||||||
|
(x as f64).log2().floor() as u8
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn floor_to_pow_of_2(x: usize) -> usize {
|
||||||
|
if x == 1 {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
let w = log2_floor(x);
|
||||||
|
1 << w
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use crate::helper::{floor_to_pow_of_2, log2_floor};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_log2_floor() {
|
||||||
|
let t = [
|
||||||
|
(2, 1),
|
||||||
|
(3, 1),
|
||||||
|
(4, 2),
|
||||||
|
(8, 3),
|
||||||
|
(9, 3),
|
||||||
|
(15, 3),
|
||||||
|
(16, 4),
|
||||||
|
(1023, 9),
|
||||||
|
(1024, 10),
|
||||||
|
(2000, 10),
|
||||||
|
(2048, 11),
|
||||||
|
];
|
||||||
|
for t in t {
|
||||||
|
assert_eq!(log2_floor(t.0), t.1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_floor_to_pow_of_2() {
|
||||||
|
let t = [
|
||||||
|
(1 as usize, 1 as usize),
|
||||||
|
(2, 2),
|
||||||
|
(3, 2),
|
||||||
|
(4, 4),
|
||||||
|
(5, 4),
|
||||||
|
(15, 8),
|
||||||
|
(31, 16),
|
||||||
|
(33, 32),
|
||||||
|
(1000, 512),
|
||||||
|
(1500, 1024),
|
||||||
|
(2300, 2048),
|
||||||
|
];
|
||||||
|
for t in t {
|
||||||
|
assert_eq!(floor_to_pow_of_2(t.0), t.1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
mod cli;
|
mod cli;
|
||||||
mod config;
|
mod config;
|
||||||
|
mod constants;
|
||||||
mod helper;
|
mod helper;
|
||||||
mod multi_map;
|
mod multi_map;
|
||||||
mod protocol;
|
mod protocol;
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
pub const HASH_WIDTH_IN_BYTES: usize = 32;
|
pub const HASH_WIDTH_IN_BYTES: usize = 32;
|
||||||
|
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
|
use bytes::{Bytes, BytesMut};
|
||||||
use lazy_static::lazy_static;
|
use lazy_static::lazy_static;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
|
use std::net::SocketAddr;
|
||||||
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||||
|
|
||||||
type ProtocolVersion = u8;
|
type ProtocolVersion = u8;
|
||||||
const PROTO_V0: u8 = 0u8;
|
const PROTO_V0: u8 = 0u8;
|
||||||
|
@ -49,7 +51,80 @@ pub enum ControlChannelCmd {
|
||||||
|
|
||||||
#[derive(Deserialize, Serialize, Debug)]
|
#[derive(Deserialize, Serialize, Debug)]
|
||||||
pub enum DataChannelCmd {
|
pub enum DataChannelCmd {
|
||||||
StartForward,
|
StartForwardTcp,
|
||||||
|
StartForwardUdp,
|
||||||
|
}
|
||||||
|
|
||||||
|
type UdpPacketLen = u16; // `u16` should be enough for any practical UDP traffic on the Internet
|
||||||
|
#[derive(Deserialize, Serialize, Debug)]
|
||||||
|
struct UdpHeader {
|
||||||
|
from: SocketAddr,
|
||||||
|
len: UdpPacketLen,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct UdpTraffic {
|
||||||
|
pub from: SocketAddr,
|
||||||
|
pub data: Bytes,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UdpTraffic {
|
||||||
|
pub async fn write<T: AsyncWrite + Unpin>(&self, writer: &mut T) -> Result<()> {
|
||||||
|
let v = bincode::serialize(&UdpHeader {
|
||||||
|
from: self.from,
|
||||||
|
len: self.data.len() as UdpPacketLen,
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
writer.write_u16(v.len() as u16).await?;
|
||||||
|
writer.write_all(&v).await?;
|
||||||
|
|
||||||
|
writer.write_all(&self.data).await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub async fn write_slice<T: AsyncWrite + Unpin>(
|
||||||
|
writer: &mut T,
|
||||||
|
from: SocketAddr,
|
||||||
|
data: &[u8],
|
||||||
|
) -> Result<()> {
|
||||||
|
let v = bincode::serialize(&UdpHeader {
|
||||||
|
from,
|
||||||
|
len: data.len() as UdpPacketLen,
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
writer.write_u16(v.len() as u16).await?;
|
||||||
|
writer.write_all(&v).await?;
|
||||||
|
|
||||||
|
writer.write_all(data).await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn read<T: AsyncRead + Unpin>(reader: &mut T) -> Result<UdpTraffic> {
|
||||||
|
let len = reader.read_u16().await? as usize;
|
||||||
|
|
||||||
|
let mut buf = Vec::new();
|
||||||
|
buf.resize(len, 0);
|
||||||
|
reader
|
||||||
|
.read_exact(&mut buf)
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to read udp header")?;
|
||||||
|
let header: UdpHeader =
|
||||||
|
bincode::deserialize(&buf).with_context(|| "Failed to deserialize udp header")?;
|
||||||
|
|
||||||
|
let mut data = BytesMut::new();
|
||||||
|
data.resize(header.len as usize, 0);
|
||||||
|
reader.read_exact(&mut data).await?;
|
||||||
|
|
||||||
|
Ok(UdpTraffic {
|
||||||
|
from: header.from,
|
||||||
|
data: data.freeze(),
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn digest(data: &[u8]) -> Digest {
|
pub fn digest(data: &[u8]) -> Digest {
|
||||||
|
@ -74,7 +149,7 @@ impl PacketLength {
|
||||||
.unwrap() as usize;
|
.unwrap() as usize;
|
||||||
let c_cmd =
|
let c_cmd =
|
||||||
bincode::serialized_size(&ControlChannelCmd::CreateDataChannel).unwrap() as usize;
|
bincode::serialized_size(&ControlChannelCmd::CreateDataChannel).unwrap() as usize;
|
||||||
let d_cmd = bincode::serialized_size(&DataChannelCmd::StartForward).unwrap() as usize;
|
let d_cmd = bincode::serialized_size(&DataChannelCmd::StartForwardTcp).unwrap() as usize;
|
||||||
let ack = Ack::Ok;
|
let ack = Ack::Ok;
|
||||||
let ack = bincode::serialized_size(&ack).unwrap() as usize;
|
let ack = bincode::serialized_size(&ack).unwrap() as usize;
|
||||||
|
|
||||||
|
|
304
src/server.rs
304
src/server.rs
|
@ -1,8 +1,10 @@
|
||||||
use crate::config::{Config, ServerConfig, ServerServiceConfig, TransportType};
|
use crate::config::{Config, ServerConfig, ServerServiceConfig, ServiceType, TransportType};
|
||||||
|
use crate::constants::{listen_backoff, UDP_BUFFER_SIZE};
|
||||||
use crate::multi_map::MultiMap;
|
use crate::multi_map::MultiMap;
|
||||||
use crate::protocol::Hello::{ControlChannelHello, DataChannelHello};
|
use crate::protocol::Hello::{ControlChannelHello, DataChannelHello};
|
||||||
use crate::protocol::{
|
use crate::protocol::{
|
||||||
self, read_auth, read_hello, Ack, ControlChannelCmd, DataChannelCmd, Hello, HASH_WIDTH_IN_BYTES,
|
self, read_auth, read_hello, Ack, ControlChannelCmd, DataChannelCmd, Hello, UdpTraffic,
|
||||||
|
HASH_WIDTH_IN_BYTES,
|
||||||
};
|
};
|
||||||
#[cfg(feature = "tls")]
|
#[cfg(feature = "tls")]
|
||||||
use crate::transport::TlsTransport;
|
use crate::transport::TlsTransport;
|
||||||
|
@ -10,21 +12,23 @@ use crate::transport::{TcpTransport, Transport};
|
||||||
use anyhow::{anyhow, bail, Context, Result};
|
use anyhow::{anyhow, bail, Context, Result};
|
||||||
use backoff::backoff::Backoff;
|
use backoff::backoff::Backoff;
|
||||||
use backoff::ExponentialBackoff;
|
use backoff::ExponentialBackoff;
|
||||||
|
|
||||||
use rand::RngCore;
|
use rand::RngCore;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::io::{self, copy_bidirectional, AsyncWriteExt};
|
use tokio::io::{self, copy_bidirectional, AsyncWriteExt};
|
||||||
use tokio::net::{TcpListener, TcpStream};
|
use tokio::net::{TcpListener, TcpStream, UdpSocket};
|
||||||
use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
|
use tokio::sync::{broadcast, mpsc, RwLock};
|
||||||
use tokio::time;
|
use tokio::time;
|
||||||
use tracing::{debug, error, info, info_span, warn, Instrument};
|
use tracing::{debug, error, info, info_span, instrument, warn, Instrument, Span};
|
||||||
|
|
||||||
type ServiceDigest = protocol::Digest; // SHA256 of a service name
|
type ServiceDigest = protocol::Digest; // SHA256 of a service name
|
||||||
type Nonce = protocol::Digest; // Also called `session_key`
|
type Nonce = protocol::Digest; // Also called `session_key`
|
||||||
|
|
||||||
const POOL_SIZE: usize = 64; // The number of cached connections
|
const TCP_POOL_SIZE: usize = 64; // The number of cached connections for TCP servies
|
||||||
|
const UDP_POOL_SIZE: usize = 2; // The number of cached connections for UDP services
|
||||||
const CHAN_SIZE: usize = 2048; // The capacity of various chans
|
const CHAN_SIZE: usize = 2048; // The capacity of various chans
|
||||||
|
|
||||||
// The entrypoint of running a server
|
// The entrypoint of running a server
|
||||||
|
@ -268,7 +272,7 @@ async fn do_control_channel_handshake<T: 'static + Transport>(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn do_data_channel_handshake<T: Transport>(
|
async fn do_data_channel_handshake<T: 'static + Transport>(
|
||||||
conn: T::Stream,
|
conn: T::Stream,
|
||||||
control_channels: Arc<RwLock<ControlChannelMap<T>>>,
|
control_channels: Arc<RwLock<ControlChannelMap<T>>>,
|
||||||
nonce: Nonce,
|
nonce: Nonce,
|
||||||
|
@ -276,9 +280,9 @@ async fn do_data_channel_handshake<T: Transport>(
|
||||||
// Validate
|
// Validate
|
||||||
let control_channels_guard = control_channels.read().await;
|
let control_channels_guard = control_channels.read().await;
|
||||||
match control_channels_guard.get2(&nonce) {
|
match control_channels_guard.get2(&nonce) {
|
||||||
Some(c_ch) => {
|
Some(handle) => {
|
||||||
// Send the data channel to the corresponding control channel
|
// Send the data channel to the corresponding control channel
|
||||||
c_ch.conn_pool.data_ch_tx.send(conn).await?;
|
handle.data_ch_tx.send(conn).await?;
|
||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
// TODO: Maybe print IP here
|
// TODO: Maybe print IP here
|
||||||
|
@ -288,43 +292,74 @@ async fn do_data_channel_handshake<T: Transport>(
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Control channel, using T as the transport layer
|
pub struct ControlChannelHandle<T: Transport> {
|
||||||
struct ControlChannel<T: Transport> {
|
|
||||||
conn: T::Stream, // The connection of control channel
|
|
||||||
service: ServerServiceConfig, // A copy of the corresponding service config
|
|
||||||
shutdown_rx: oneshot::Receiver<bool>, // Receives the shutdown signal
|
|
||||||
visitor_tx: mpsc::Sender<TcpStream>, // Receives visitor connections
|
|
||||||
}
|
|
||||||
|
|
||||||
// The handle of a control channel, along with the handle of a connection pool
|
|
||||||
// Dropping it will drop the actual control channel, because `visitor_tx`
|
|
||||||
// and `shutdown_tx` are closed
|
|
||||||
struct ControlChannelHandle<T: Transport> {
|
|
||||||
// Shutdown the control channel.
|
// Shutdown the control channel.
|
||||||
// Not used for now, but can be used for hot reloading
|
// Not used for now, but can be used for hot reloading
|
||||||
_shutdown_tx: oneshot::Sender<bool>,
|
#[allow(dead_code)]
|
||||||
conn_pool: ConnectionPoolHandle<T>,
|
shutdown_tx: broadcast::Sender<bool>,
|
||||||
|
//data_ch_req_tx: mpsc::Sender<bool>,
|
||||||
|
data_ch_tx: mpsc::Sender<T::Stream>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: 'static + Transport> ControlChannelHandle<T> {
|
impl<T> ControlChannelHandle<T>
|
||||||
|
where
|
||||||
|
T: 'static + Transport,
|
||||||
|
{
|
||||||
// Create a control channel handle, where the control channel handling task
|
// Create a control channel handle, where the control channel handling task
|
||||||
// and the connection pool task are created.
|
// and the connection pool task are created.
|
||||||
|
#[instrument(skip_all, fields(service = %service.name))]
|
||||||
fn new(conn: T::Stream, service: ServerServiceConfig) -> ControlChannelHandle<T> {
|
fn new(conn: T::Stream, service: ServerServiceConfig) -> ControlChannelHandle<T> {
|
||||||
// Save the name string for logging
|
// Save the name string for logging
|
||||||
let name = service.name.clone();
|
let name = service.name.clone();
|
||||||
|
|
||||||
// Create a shutdown channel. The sender is not used for now, but for future use
|
// Create a shutdown channel. The sender is not used for now, but for future use
|
||||||
let (_shutdown_tx, shutdown_rx) = oneshot::channel::<bool>();
|
let (shutdown_tx, shutdown_rx) = broadcast::channel::<bool>(1);
|
||||||
|
|
||||||
// Create and run the connection pool, where the visitors and data channels meet
|
// Store data channels
|
||||||
let conn_pool = ConnectionPoolHandle::new();
|
let (data_ch_tx, data_ch_rx) = mpsc::channel(CHAN_SIZE * 2);
|
||||||
|
|
||||||
|
// Store data channel creation requests
|
||||||
|
let (data_ch_req_tx, data_ch_req_rx) = mpsc::unbounded_channel();
|
||||||
|
|
||||||
|
match service.service_type {
|
||||||
|
ServiceType::Tcp => tokio::spawn(
|
||||||
|
run_tcp_connection_pool::<T>(
|
||||||
|
service.bind_addr.clone(),
|
||||||
|
data_ch_rx,
|
||||||
|
data_ch_req_tx.clone(),
|
||||||
|
shutdown_tx.subscribe(),
|
||||||
|
)
|
||||||
|
.instrument(Span::current()),
|
||||||
|
),
|
||||||
|
ServiceType::Udp => tokio::spawn(
|
||||||
|
run_udp_connection_pool::<T>(
|
||||||
|
service.bind_addr.clone(),
|
||||||
|
data_ch_rx,
|
||||||
|
data_ch_req_tx.clone(),
|
||||||
|
shutdown_tx.subscribe(),
|
||||||
|
)
|
||||||
|
.instrument(Span::current()),
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Cache some data channels for later use
|
||||||
|
let pool_size = match service.service_type {
|
||||||
|
ServiceType::Tcp => TCP_POOL_SIZE,
|
||||||
|
ServiceType::Udp => UDP_POOL_SIZE,
|
||||||
|
};
|
||||||
|
|
||||||
|
for _i in 0..pool_size {
|
||||||
|
if let Err(e) = data_ch_req_tx.send(true) {
|
||||||
|
error!("Failed to request data channel {}", e);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
// Create the control channel
|
// Create the control channel
|
||||||
let ch: ControlChannel<T> = ControlChannel {
|
let ch = ControlChannel::<T> {
|
||||||
conn,
|
conn,
|
||||||
shutdown_rx,
|
shutdown_rx,
|
||||||
service,
|
service,
|
||||||
visitor_tx: conn_pool.visitor_tx.clone(),
|
data_ch_req_rx,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Run the control channel
|
// Run the control channel
|
||||||
|
@ -335,52 +370,83 @@ impl<T: 'static + Transport> ControlChannelHandle<T> {
|
||||||
});
|
});
|
||||||
|
|
||||||
ControlChannelHandle {
|
ControlChannelHandle {
|
||||||
_shutdown_tx,
|
shutdown_tx,
|
||||||
conn_pool,
|
data_ch_tx,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
fn shutdown(self) {
|
||||||
|
let _ = self.shutdown_tx.send(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Control channel, using T as the transport layer. P is TcpStream or UdpTraffic
|
||||||
|
struct ControlChannel<T: Transport> {
|
||||||
|
conn: T::Stream, // The connection of control channel
|
||||||
|
service: ServerServiceConfig, // A copy of the corresponding service config
|
||||||
|
shutdown_rx: broadcast::Receiver<bool>, // Receives the shutdown signal
|
||||||
|
data_ch_req_rx: mpsc::UnboundedReceiver<bool>, // Receives visitor connections
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Transport> ControlChannel<T> {
|
impl<T: Transport> ControlChannel<T> {
|
||||||
// Run a control channel
|
// Run a control channel
|
||||||
#[tracing::instrument(skip(self), fields(service = %self.service.name))]
|
#[instrument(skip(self), fields(service = %self.service.name))]
|
||||||
async fn run(mut self) -> Result<()> {
|
async fn run(mut self) -> Result<()> {
|
||||||
// Where the service is exposed
|
|
||||||
let l = match TcpListener::bind(&self.service.bind_addr).await {
|
|
||||||
Ok(v) => v,
|
|
||||||
Err(e) => {
|
|
||||||
let duration = Duration::from_secs(1);
|
|
||||||
error!(
|
|
||||||
"Failed to listen on service.bind_addr: {}. Retry in {:?}...",
|
|
||||||
e, duration
|
|
||||||
);
|
|
||||||
time::sleep(duration).await;
|
|
||||||
TcpListener::bind(&self.service.bind_addr).await?
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
info!("Listening at {}", &self.service.bind_addr);
|
|
||||||
|
|
||||||
// Each `u8` in the chan indicates a data channel creation request
|
|
||||||
let (data_req_tx, mut data_req_rx) = mpsc::unbounded_channel::<u8>();
|
|
||||||
|
|
||||||
// The control channel is moved into the task, and sends CreateDataChannel
|
|
||||||
// comamnds to the client when needed
|
|
||||||
tokio::spawn(async move {
|
|
||||||
let cmd = bincode::serialize(&ControlChannelCmd::CreateDataChannel).unwrap();
|
let cmd = bincode::serialize(&ControlChannelCmd::CreateDataChannel).unwrap();
|
||||||
while data_req_rx.recv().await.is_some() {
|
|
||||||
if self.conn.write_all(&cmd).await.is_err() {
|
// Wait for data channel requests and the shutdown signal
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
val = self.data_ch_req_rx.recv() => {
|
||||||
|
match val {
|
||||||
|
Some(_) => {
|
||||||
|
if let Err(e) = self.conn.write_all(&cmd).await.with_context(||"Failed to write data cmds") {
|
||||||
|
error!("{:?}", e);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
None => {
|
||||||
|
break;
|
||||||
// Cache some data channels for later use
|
|
||||||
for _i in 0..POOL_SIZE {
|
|
||||||
if let Err(e) = data_req_tx.send(0) {
|
|
||||||
error!("Failed to request data channel {}", e);
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
// Wait for the shutdown signal
|
||||||
|
_ = self.shutdown_rx.recv() => {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
info!("Control channel shuting down");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn tcp_listen_and_send(
|
||||||
|
addr: String,
|
||||||
|
data_ch_req_tx: mpsc::UnboundedSender<bool>,
|
||||||
|
mut shutdown_rx: broadcast::Receiver<bool>,
|
||||||
|
) -> mpsc::Receiver<TcpStream> {
|
||||||
|
let (tx, rx) = mpsc::channel(CHAN_SIZE);
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let l = backoff::future::retry(listen_backoff(), || async {
|
||||||
|
Ok(TcpListener::bind(&addr).await?)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to listen for the service");
|
||||||
|
|
||||||
|
let l: TcpListener = match l {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(e) => {
|
||||||
|
error!("{:?}", e);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
info!("Listening at {}", &addr);
|
||||||
|
|
||||||
// Retry at least every 1s
|
// Retry at least every 1s
|
||||||
let mut backoff = ExponentialBackoff {
|
let mut backoff = ExponentialBackoff {
|
||||||
|
@ -392,7 +458,6 @@ impl<T: Transport> ControlChannel<T> {
|
||||||
// Wait for visitors and the shutdown signal
|
// Wait for visitors and the shutdown signal
|
||||||
loop {
|
loop {
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
// Wait for visitors
|
|
||||||
val = l.accept() => {
|
val = l.accept() => {
|
||||||
match val {
|
match val {
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
@ -406,73 +471,47 @@ impl<T: Transport> ControlChannel<T> {
|
||||||
error!("Too many retries. Aborting...");
|
error!("Too many retries. Aborting...");
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
Ok((incoming, addr)) => {
|
Ok((incoming, addr)) => {
|
||||||
// For every visitor, request to create a data channel
|
// For every visitor, request to create a data channel
|
||||||
if let Err(e) = data_req_tx.send(0) {
|
if let Err(e) = data_ch_req_tx.send(true) {
|
||||||
// An error indicates the control channel is broken
|
// An error indicates the control channel is broken
|
||||||
// So break the loop
|
// So break the loop
|
||||||
error!("{}", e);
|
error!("{}", e);
|
||||||
break;
|
break;
|
||||||
};
|
}
|
||||||
|
|
||||||
backoff.reset();
|
backoff.reset();
|
||||||
|
|
||||||
debug!("New visitor from {}", addr);
|
debug!("New visitor from {}", addr);
|
||||||
|
|
||||||
// Send the visitor to the connection pool
|
// Send the visitor to the connection pool
|
||||||
let _ = self.visitor_tx.send(incoming).await;
|
let _ = tx.send(incoming).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
// Wait for the shutdown signal
|
_ = shutdown_rx.recv() => {
|
||||||
_ = &mut self.shutdown_rx => {
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
info!("Service shuting down");
|
});
|
||||||
|
|
||||||
Ok(())
|
rx
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[instrument(skip_all)]
|
||||||
struct ConnectionPool<T: Transport> {
|
async fn run_tcp_connection_pool<T: Transport>(
|
||||||
visitor_rx: mpsc::Receiver<TcpStream>,
|
bind_addr: String,
|
||||||
data_ch_rx: mpsc::Receiver<T::Stream>,
|
mut data_ch_rx: mpsc::Receiver<T::Stream>,
|
||||||
}
|
data_ch_req_tx: mpsc::UnboundedSender<bool>,
|
||||||
|
shutdown_rx: broadcast::Receiver<bool>,
|
||||||
struct ConnectionPoolHandle<T: Transport> {
|
) -> Result<()> {
|
||||||
visitor_tx: mpsc::Sender<TcpStream>,
|
let mut visitor_rx = tcp_listen_and_send(bind_addr, data_ch_req_tx, shutdown_rx);
|
||||||
data_ch_tx: mpsc::Sender<T::Stream>,
|
while let Some(mut visitor) = visitor_rx.recv().await {
|
||||||
}
|
if let Some(mut ch) = data_ch_rx.recv().await {
|
||||||
|
|
||||||
impl<T: 'static + Transport> ConnectionPoolHandle<T> {
|
|
||||||
fn new() -> ConnectionPoolHandle<T> {
|
|
||||||
let (data_ch_tx, data_ch_rx) = mpsc::channel(CHAN_SIZE * 2);
|
|
||||||
let (visitor_tx, visitor_rx) = mpsc::channel(CHAN_SIZE);
|
|
||||||
let conn_pool: ConnectionPool<T> = ConnectionPool {
|
|
||||||
data_ch_rx,
|
|
||||||
visitor_rx,
|
|
||||||
};
|
|
||||||
|
|
||||||
tokio::spawn(async move { conn_pool.run().await });
|
|
||||||
|
|
||||||
ConnectionPoolHandle {
|
|
||||||
data_ch_tx,
|
|
||||||
visitor_tx,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: Transport> ConnectionPool<T> {
|
|
||||||
#[tracing::instrument]
|
|
||||||
async fn run(mut self) {
|
|
||||||
while let Some(mut visitor) = self.visitor_rx.recv().await {
|
|
||||||
if let Some(mut ch) = self.data_ch_rx.recv().await {
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let cmd = bincode::serialize(&DataChannelCmd::StartForward).unwrap();
|
let cmd = bincode::serialize(&DataChannelCmd::StartForwardTcp).unwrap();
|
||||||
if ch.write_all(&cmd).await.is_ok() {
|
if ch.write_all(&cmd).await.is_ok() {
|
||||||
let _ = copy_bidirectional(&mut ch, &mut visitor).await;
|
let _ = copy_bidirectional(&mut ch, &mut visitor).await;
|
||||||
}
|
}
|
||||||
|
@ -481,5 +520,54 @@ impl<T: Transport> ConnectionPool<T> {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[instrument(skip_all)]
|
||||||
|
async fn run_udp_connection_pool<T: Transport>(
|
||||||
|
bind_addr: String,
|
||||||
|
mut data_ch_rx: mpsc::Receiver<T::Stream>,
|
||||||
|
_data_ch_req_tx: mpsc::UnboundedSender<bool>,
|
||||||
|
mut shutdown_rx: broadcast::Receiver<bool>,
|
||||||
|
) -> Result<()> {
|
||||||
|
// TODO: Load balance
|
||||||
|
|
||||||
|
let l: UdpSocket = backoff::future::retry(listen_backoff(), || async {
|
||||||
|
Ok(UdpSocket::bind(&bind_addr).await?)
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.with_context(|| "Failed to listen for the service")?;
|
||||||
|
|
||||||
|
info!("Listening at {}", &bind_addr);
|
||||||
|
|
||||||
|
let cmd = bincode::serialize(&DataChannelCmd::StartForwardUdp).unwrap();
|
||||||
|
|
||||||
|
let mut conn = data_ch_rx
|
||||||
|
.recv()
|
||||||
|
.await
|
||||||
|
.ok_or(anyhow!("No available data channels"))?;
|
||||||
|
conn.write_all(&cmd).await?;
|
||||||
|
|
||||||
|
let mut buf = [0u8; UDP_BUFFER_SIZE];
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
// Forward inbound traffic to the client
|
||||||
|
val = l.recv_from(&mut buf) => {
|
||||||
|
let (n, from) = val?;
|
||||||
|
UdpTraffic::write_slice(&mut conn, from, &buf[..n]).await?;
|
||||||
|
},
|
||||||
|
|
||||||
|
// Forward outbound traffic from the client to the visitor
|
||||||
|
t = UdpTraffic::read(&mut conn) => {
|
||||||
|
let t = t?;
|
||||||
|
l.send_to(&t.data, t.from).await?;
|
||||||
|
},
|
||||||
|
|
||||||
|
_ = shutdown_rx.recv() => {
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue