feat(transport): add http2 transport (#392)

This commit is contained in:
rucciva 2024-10-17 17:16:58 +07:00
parent be14d124a2
commit ddc97cff78
19 changed files with 1391 additions and 482 deletions

View File

@ -128,7 +128,7 @@ jobs:
version: v4.0.2
files: target/${{ matrix.target }}/release/${{ matrix.exe }}
args: -q --best --lzma
- uses: actions/upload-artifact@v2
- uses: actions/upload-artifact@v4
with:
name: rathole-${{ matrix.target }}
path: target/${{ matrix.target }}/release/${{ matrix.exe }}

View File

@ -34,7 +34,7 @@ jobs:
- name: Check all features
run: >
cargo hack check --feature-powerset --no-dev-deps
--mutually-exclusive-features default,native-tls,websocket-native-tls,rustls,websocket-rustls
--mutually-exclusive-features default,native-tls,websocket-native-tls,http2-native-tls,rustls,websocket-rustls,http2-rustls
build:
name: Build for ${{ matrix.target }}
@ -67,8 +67,8 @@ jobs:
- name: Run tests with native-tls
run: cargo test --verbose
- name: Run tests with rustls
run: cargo test --verbose --no-default-features --features server,client,rustls,noise,websocket-rustls,hot-reload
- uses: actions/upload-artifact@v2
run: cargo test --verbose --no-default-features --features server,client,rustls,noise,websocket-rustls,http2-rustls,hot-reload
- uses: actions/upload-artifact@v4
with:
name: rathole-${{ matrix.target }}
path: target/debug/${{ matrix.exe }}

993
Cargo.lock generated Normal file → Executable file

File diff suppressed because it is too large Load Diff

28
Cargo.toml Normal file → Executable file
View File

@ -17,6 +17,7 @@ default = [
"native-tls",
"noise",
"websocket-native-tls",
"http2-native-tls",
"hot-reload",
]
@ -53,6 +54,28 @@ websocket-rustls = [
"rustls",
]
# HTTP2 support
http2-native-tls = [
"hyper",
"hyper-util",
"http",
"http-body-util",
"futures-core",
"tokio-util",
"tower-service",
"native-tls",
]
http2-rustls = [
"hyper",
"hyper-util",
"http",
"http-body-util",
"futures-core",
"tokio-util",
"tower-service",
"rustls",
]
# Configuration hot-reload support
hot-reload = ["notify"]
@ -117,6 +140,11 @@ async-http-proxy = { version = "1.2", features = [
async-socks5 = "0.5"
url = { version = "2.2", features = ["serde"] }
tokio-tungstenite = { version = "0.20.1", optional = true }
http = { version = "1.1.0", optional = true }
hyper = { version = "1.4.1", optional = true , features = ["client","server","http2"] }
hyper-util = { version = "0.1.9", optional = true , features = ["full"]}
http-body-util = { version = "0.1.2", optional = true }
tower-service = { version = "0.3.3", optional = true }
tokio-util = { version = "0.7.9", optional = true, features = ["io"] }
futures-core = { version = "0.3.28", optional = true }
futures-sink = { version = "0.3.28", optional = true }

View File

@ -111,7 +111,7 @@ heartbeat_timeout = 40 # Optional. Set to 0 to disable the application-layer hea
retry_interval = 1 # Optional. The interval between retry to connect to the server. Default: 1 second
[client.transport] # The whole block is optional. Specify which transport to use
type = "tcp" # Optional. Possible values: ["tcp", "tls", "noise"]. Default: "tcp"
type = "tcp" # Optional. Possible values: ["tcp", "tls", "noise", "websocket", "http2"]. Default: "tcp"
[client.transport.tcp] # Optional. Also affects `noise` and `tls`
proxy = "socks5://user:passwd@127.0.0.1:1080" # Optional. The proxy used to connect to the server. `http` and `socks5` is supported.
@ -131,6 +131,9 @@ remote_public_key = "key_encoded_in_base64" # Optional
[client.transport.websocket] # Necessary if `type` is "websocket"
tls = true # If `true` then it will use settings in `client.transport.tls`
[client.transport.http2] # Necessary if `type` is "http2"
tls = true # If `true` then it will use settings in `client.transport.tls`
[client.services.service1] # A service that needs forwarding. The name `service1` can change arbitrarily, as long as identical to the name in the server's configuration
type = "tcp" # Optional. The protocol that needs forwarding. Possible values: ["tcp", "udp"]. Default: "tcp"
token = "whatever" # Necessary if `client.default_token` not set
@ -166,6 +169,9 @@ remote_public_key = "key_encoded_in_base64"
[server.transport.websocket] # Necessary if `type` is "websocket"
tls = true # If `true` then it will use settings in `server.transport.tls`
[server.transport.http2] # Necessary if `type` is "http2"
tls = true # If `true` then it will use settings in `server.transport.tls`
[server.services.service1] # The service name must be identical to the client side
type = "tcp" # Optional. Same as the client `[client.services.X.type]
token = "whatever" # Necessary if `server.default_token` not set

View File

@ -21,6 +21,8 @@ use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
use tokio::time::{self, Duration, Instant};
use tracing::{debug, error, info, instrument, trace, warn, Instrument, Span};
#[cfg(any(feature = "http2-native-tls", feature = "http2-rustls"))]
use crate::transport::HTTP2Transport;
#[cfg(feature = "noise")]
use crate::transport::NoiseTransport;
#[cfg(any(feature = "native-tls", feature = "rustls"))]
@ -74,6 +76,15 @@ pub async fn run_client(
#[cfg(not(any(feature = "websocket-native-tls", feature = "websocket-rustls")))]
crate::helper::feature_neither_compile("websocket-native-tls", "websocket-rustls")
}
TransportType::HTTP2 => {
#[cfg(any(feature = "http2-native-tls", feature = "http2-rustls"))]
{
let mut client = Client::<HTTP2Transport>::from(config).await?;
client.run(shutdown_rx, update_rx).await
}
#[cfg(not(any(feature = "http2-native-tls", feature = "http2-rustls")))]
crate::helper::feature_neither_compile("http2-native-tls", "http2-rustls")
}
}
}

10
src/config.rs Normal file → Executable file
View File

@ -51,6 +51,8 @@ pub enum TransportType {
Noise,
#[serde(rename = "websocket")]
Websocket,
#[serde(rename = "http2")]
HTTP2,
}
/// Per service config
@ -141,6 +143,12 @@ pub struct WebsocketConfig {
pub tls: bool,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
#[serde(deny_unknown_fields)]
pub struct HTTP2Config {
pub tls: bool,
}
fn default_nodelay() -> bool {
DEFAULT_NODELAY
}
@ -186,6 +194,7 @@ pub struct TransportConfig {
pub tls: Option<TlsConfig>,
pub noise: Option<NoiseConfig>,
pub websocket: Option<WebsocketConfig>,
pub http2: Option<HTTP2Config>,
}
fn default_heartbeat_timeout() -> u64 {
@ -320,6 +329,7 @@ impl Config {
Ok(())
}
TransportType::Websocket => Ok(()),
TransportType::HTTP2 => Ok(()),
}
}

View File

@ -23,6 +23,8 @@ use tokio::sync::{broadcast, mpsc, RwLock};
use tokio::time;
use tracing::{debug, error, info, info_span, instrument, warn, Instrument, Span};
#[cfg(any(feature = "http2-native-tls", feature = "http2-rustls"))]
use crate::transport::HTTP2Transport;
#[cfg(feature = "noise")]
use crate::transport::NoiseTransport;
#[cfg(any(feature = "native-tls", feature = "rustls"))]
@ -83,6 +85,15 @@ pub async fn run_server(
#[cfg(not(any(feature = "websocket-native-tls", feature = "websocket-rustls")))]
crate::helper::feature_neither_compile("websocket-native-tls", "websocket-rustls")
}
TransportType::HTTP2 => {
#[cfg(any(feature = "http2-native-tls", feature = "http2-rustls"))]
{
let mut server = Server::<HTTP2Transport>::from(config).await?;
server.run(shutdown_rx, update_rx).await?;
}
#[cfg(not(any(feature = "http2-native-tls", feature = "http2-rustls")))]
crate::helper::feature_neither_compile("http2-native-tls", "http2-rustls")
}
}
Ok(())

398
src/transport/http2.rs Executable file
View File

@ -0,0 +1,398 @@
use core::result::Result;
use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{self, Context, Poll};
use anyhow::anyhow;
use async_trait::async_trait;
use bytes::{Buf, Bytes};
use futures_core::Stream;
use http::{Method, Request, Response, Uri};
use http_body_util::StreamBody;
use hyper::body::{Body, Incoming};
use hyper::server::conn::http2 as Server;
use hyper::service::Service;
use hyper_util::client::legacy::connect::{Connected, Connection};
use hyper_util::client::legacy::Client;
use hyper_util::rt::tokio::TokioExecutor;
use hyper_util::rt::tokio::TokioIo;
use tokio::io::{self, AsyncRead, AsyncWrite, ReadBuf, ReadHalf, SimplexStream, WriteHalf};
use tokio::net::{TcpListener, ToSocketAddrs};
use tokio::select;
use tokio::sync::Mutex;
use tokio::sync::{broadcast, mpsc};
use tokio_util::io::ReaderStream;
use super::maybe_tls::{MaybeTLSStream, MaybeTLSTransport};
use super::{AddrMaybeCached, SocketOpts, Transport};
use crate::config::TransportConfig;
#[derive(Debug)]
struct IncomingHyper {
inner: hyper::body::Incoming,
current_chunk: Option<bytes::Bytes>,
}
impl AsyncRead for IncomingHyper {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
loop {
if let Some(chunk) = &mut self.current_chunk {
let len = std::cmp::min(chunk.len(), buf.remaining());
buf.put_slice(&chunk[..len]);
chunk.advance(len);
if !chunk.has_remaining() {
self.current_chunk = None;
}
return Poll::Ready(Ok(()));
}
match Pin::new(&mut self.inner).poll_frame(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => return Poll::Ready(Ok(())),
Poll::Ready(Some(Err(err))) => {
return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, err)))
}
Poll::Ready(Some(Ok(frame))) => match frame.into_data() {
Err(_) => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::Other,
"non data frame received",
)))
}
Ok(data) => {
self.current_chunk = Some(data);
continue;
}
},
}
}
}
}
#[derive(Debug)]
pub struct HTTP2Stream {
send: WriteHalf<SimplexStream>,
recv: IncomingHyper,
}
impl AsyncRead for HTTP2Stream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.get_mut().recv).poll_read(cx, buf)
}
}
impl AsyncWrite for HTTP2Stream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
Pin::new(&mut self.get_mut().send).poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.get_mut().send).poll_flush(cx)
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.get_mut().send).poll_shutdown(cx)
}
}
#[derive(Debug)]
pub struct OutgoingSimplex {
inner: ReaderStream<ReadHalf<SimplexStream>>,
}
impl Stream for OutgoingSimplex {
type Item = Result<hyper::body::Frame<Bytes>, io::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match Pin::new(&mut self.get_mut().inner).poll_next(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
Poll::Ready(Some(Ok(data))) => Poll::Ready(Some(Ok(hyper::body::Frame::data(data)))),
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
#[derive(Debug)]
struct Svc {
req_sender: mpsc::Sender<anyhow::Result<(SocketAddr, Incoming, mpsc::Sender<OutgoingSimplex>)>>,
addr: SocketAddr,
}
impl Service<Request<Incoming>> for Svc {
type Response = Response<StreamBody<OutgoingSimplex>>;
type Error = anyhow::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn call(&self, req: Request<Incoming>) -> Self::Future {
let req_sender = self.req_sender.clone();
let addr = self.addr;
let future = async move {
let (res_sender, mut res_receiver) = mpsc::channel::<OutgoingSimplex>(1);
if let Err(err) = req_sender
.send(Ok((addr, req.into_body(), res_sender)))
.await
{
return Err(anyhow!(err.to_string()));
}
match res_receiver.recv().await {
None => Err(anyhow!("Channel closed")),
Some(body) => Ok(Response::new(http_body_util::StreamBody::new(body))),
}
};
Box::pin(future) // Return the boxed future
}
}
async fn start_http_server(
listener: TcpListener,
transport: Arc<MaybeTLSTransport>,
req_sender: mpsc::Sender<anyhow::Result<(SocketAddr, Incoming, mpsc::Sender<OutgoingSimplex>)>>,
stop_receiver: broadcast::Receiver<()>,
) -> anyhow::Result<()> {
loop {
let mut stop_receiver = stop_receiver.resubscribe();
let conn = async {
let (conn, addr) = transport.accept(&listener).await?;
let stream = transport.handshake(conn).await?;
Ok((stream, addr))
};
select! {
_ = stop_receiver.recv() => {
return Ok(());
}
conn = conn => {
if let Err(err) = conn {
if let Err(err)= req_sender.send(Err(err)).await {
eprintln!("Error sending error message: {}", err);
}
continue
}
let (socket, addr) = conn.unwrap();
let io = TokioIo::new(socket);
let svc = Svc {
req_sender: req_sender.clone(),
addr,
};
let req_sender= req_sender.clone();
tokio::spawn(async move {
let mut conn = Server::Builder::new(TokioExecutor::new())
.serve_connection(io, svc);
select! {
_ = stop_receiver.recv() => {
Pin::new(&mut conn).graceful_shutdown();
}
res = &mut conn => {
if let Err(err) = res {
if let Err(err)= req_sender.send(Err(anyhow!(err.to_string()))).await {
eprintln!("Error sending error message: {}", err);
}
}
}
}
});
}
}
}
}
impl Connection for MaybeTLSStream {
fn connected(&self) -> Connected {
let connected = Connected::new();
if let (Ok(remote_addr), Ok(local_addr)) = (
self.get_tcpstream().peer_addr(),
self.get_tcpstream().local_addr(),
) {
connected.extra((remote_addr, local_addr))
} else {
connected
}
}
}
#[derive(Clone)]
struct MaybeTLSConnector {
sub: Arc<MaybeTLSTransport>,
}
impl tower_service::Service<Uri> for MaybeTLSConnector {
type Response = TokioIo<MaybeTLSStream>;
type Error = anyhow::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, u: Uri) -> Self::Future {
let sub = self.sub.clone();
let future = async move {
let addr = match (u.host(), u.port()) {
(Some(host), Some(port)) => format!("{}:{}", host, port),
_ => String::from(""),
};
let addr = AddrMaybeCached::new(addr.as_str());
let stream = sub.connect(&addr).await?;
Ok(TokioIo::new(stream))
};
Box::pin(future)
}
}
#[derive(Debug)]
pub struct HTTP2Transport {
sub: Arc<MaybeTLSTransport>,
client: Client<MaybeTLSConnector, StreamBody<OutgoingSimplex>>,
stop_sender: broadcast::Sender<()>,
stop_receiver: broadcast::Receiver<()>,
}
#[async_trait]
impl Transport for HTTP2Transport {
type Acceptor = Arc<
Mutex<
mpsc::Receiver<anyhow::Result<(SocketAddr, Incoming, mpsc::Sender<OutgoingSimplex>)>>,
>,
>;
type RawStream = (Incoming, mpsc::Sender<OutgoingSimplex>);
type Stream = HTTP2Stream;
fn new(config: &TransportConfig) -> anyhow::Result<Self> {
let cfg = config
.http2
.as_ref()
.ok_or_else(|| anyhow!("Missing http2 config"))?;
let (stop_sender, stop_receiver) = broadcast::channel(1);
let sub = Arc::new(MaybeTLSTransport::new_explicit(cfg.tls, config)?);
let client = Client::builder(TokioExecutor::new())
.http2_only(true)
.build(MaybeTLSConnector { sub: sub.clone() });
Ok(HTTP2Transport {
sub,
client,
stop_sender,
stop_receiver,
})
}
fn hint(_: &Self::Stream, _: SocketOpts) {}
async fn bind<A: ToSocketAddrs + Send + Sync>(
&self,
addr: A,
) -> anyhow::Result<Self::Acceptor> {
let listener = self.sub.bind(addr).await?;
let (req_sender, req_receiver) = mpsc::channel::<
anyhow::Result<(SocketAddr, Incoming, mpsc::Sender<OutgoingSimplex>)>,
>(1);
let req_receiver = Arc::new(Mutex::new(req_receiver));
let sub_transport = self.sub.clone();
let stop_receiver = self.stop_receiver.resubscribe();
tokio::spawn(start_http_server(
listener,
sub_transport,
req_sender,
stop_receiver,
));
Ok(req_receiver)
}
async fn accept(&self, a: &Self::Acceptor) -> anyhow::Result<(Self::RawStream, SocketAddr)> {
let mut receiver = a.lock().await;
match receiver.recv().await {
None => Err(anyhow!("Channel closed")),
Some(Err(err)) => Err(err),
Some(Ok((addr, incoming, res_sender))) => Ok(((incoming, res_sender), addr)),
}
}
async fn handshake(&self, conn: Self::RawStream) -> anyhow::Result<Self::Stream> {
let (incoming, res_sender) = conn;
let (sread, swrite) = io::simplex(4096);
if let Err(err) = res_sender
.send(OutgoingSimplex {
inner: ReaderStream::with_capacity(sread, 4096),
})
.await
{
return Err(anyhow!(err.to_string()));
}
Ok(HTTP2Stream {
recv: IncomingHyper {
inner: incoming,
current_chunk: None,
},
send: swrite,
})
}
async fn connect(&self, addr: &AddrMaybeCached) -> anyhow::Result<Self::Stream> {
let client = self.client.clone();
let (sread, swrite) = io::simplex(4096);
let body = http_body_util::StreamBody::new(OutgoingSimplex {
inner: ReaderStream::with_capacity(sread, 4096),
});
let req = Request::builder()
.method(Method::POST)
.uri(format!("http://{}", &addr.addr.as_str()))
.body(body)
.expect("request builder");
let res = client.request(req).await;
if let Err(err) = res {
return Err(anyhow!("Error: {}", err));
}
let res = res.unwrap();
if !res.status().is_success() {
return Err(anyhow!("Bad status code: {}", res.status()));
}
Ok(HTTP2Stream {
recv: IncomingHyper {
inner: res.into_body(),
current_chunk: None,
},
send: swrite,
})
}
}
impl Drop for HTTP2Transport {
fn drop(&mut self) {
let _ = self.stop_sender.send(());
}
}

147
src/transport/maybe_tls.rs Normal file
View File

@ -0,0 +1,147 @@
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};
use async_trait::async_trait;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
use super::tls::get_tcpstream;
use super::tls::TlsStream;
use super::{AddrMaybeCached, SocketOpts, TlsTransport};
use super::{TcpTransport, Transport};
use crate::config::TransportConfig;
#[derive(Debug)]
pub(super) enum MaybeTLSStream {
No(TcpStream),
Yes(TlsStream<TcpStream>),
}
impl MaybeTLSStream {
pub(super) fn get_tcpstream(&self) -> &TcpStream {
match self {
MaybeTLSStream::No(s) => s,
MaybeTLSStream::Yes(s) => get_tcpstream(s),
}
}
}
impl AsyncRead for MaybeTLSStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.get_mut() {
MaybeTLSStream::No(s) => Pin::new(s).poll_read(cx, buf),
MaybeTLSStream::Yes(s) => Pin::new(s).poll_read(cx, buf),
}
}
}
impl AsyncWrite for MaybeTLSStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
match self.get_mut() {
MaybeTLSStream::No(s) => Pin::new(s).poll_write(cx, buf),
MaybeTLSStream::Yes(s) => Pin::new(s).poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
match self.get_mut() {
MaybeTLSStream::No(s) => Pin::new(s).poll_flush(cx),
MaybeTLSStream::Yes(s) => Pin::new(s).poll_flush(cx),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
match self.get_mut() {
MaybeTLSStream::No(s) => Pin::new(s).poll_shutdown(cx),
MaybeTLSStream::Yes(s) => Pin::new(s).poll_shutdown(cx),
}
}
}
#[derive(Debug)]
pub(super) enum MaybeTLSTransport {
Yes(TlsTransport),
No(TcpTransport),
}
impl MaybeTLSTransport {
pub(super) fn new_explicit(tls: bool, tconfig: &TransportConfig) -> anyhow::Result<Self> {
match tls {
true => Ok(MaybeTLSTransport::Yes(TlsTransport::new(tconfig)?)),
false => Ok(MaybeTLSTransport::No(TcpTransport::new(tconfig)?)),
}
}
}
#[async_trait]
impl Transport for MaybeTLSTransport {
type Acceptor = TcpListener;
type RawStream = TcpStream;
type Stream = MaybeTLSStream;
fn new(config: &TransportConfig) -> anyhow::Result<Self> {
MaybeTLSTransport::new_explicit(config.tls.is_some(), config)
}
async fn bind<A: ToSocketAddrs + Send + Sync>(
&self,
addr: A,
) -> anyhow::Result<Self::Acceptor> {
match self {
MaybeTLSTransport::Yes(t) => t.bind(addr).await,
MaybeTLSTransport::No(t) => t.bind(addr).await,
}
}
fn hint(conn: &Self::Stream, opt: SocketOpts) {
match conn {
MaybeTLSStream::Yes(t) => TlsTransport::hint(t, opt),
MaybeTLSStream::No(t) => TcpTransport::hint(t, opt),
}
}
async fn accept(&self, a: &Self::Acceptor) -> anyhow::Result<(Self::RawStream, SocketAddr)> {
match self {
MaybeTLSTransport::Yes(t) => t.accept(a).await,
MaybeTLSTransport::No(t) => t.accept(a).await,
}
}
async fn handshake(&self, conn: Self::RawStream) -> anyhow::Result<Self::Stream> {
match self {
MaybeTLSTransport::Yes(t) => {
let stream = t.handshake(conn).await?;
Ok(MaybeTLSStream::Yes(stream))
}
MaybeTLSTransport::No(t) => {
let stream = t.handshake(conn).await?;
Ok(MaybeTLSStream::No(stream))
}
}
}
async fn connect(&self, addr: &AddrMaybeCached) -> anyhow::Result<Self::Stream> {
match self {
MaybeTLSTransport::Yes(t) => {
let stream = t.connect(addr).await?;
Ok(MaybeTLSStream::Yes(stream))
}
MaybeTLSTransport::No(t) => {
let stream = t.connect(addr).await?;
Ok(MaybeTLSStream::No(stream))
}
}
}
}

13
src/transport/mod.rs Normal file → Executable file
View File

@ -85,6 +85,14 @@ use rustls as tls;
#[cfg(any(feature = "native-tls", feature = "rustls"))]
pub(crate) use tls::TlsTransport;
#[cfg(any(
feature = "websocket-native-tls",
feature = "http2-native-tls",
feature = "websocket-rustls",
feature = "http2-rustls"
))]
mod maybe_tls;
#[cfg(feature = "noise")]
mod noise;
#[cfg(feature = "noise")]
@ -95,6 +103,11 @@ mod websocket;
#[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
pub use websocket::WebsocketTransport;
#[cfg(any(feature = "http2-native-tls", feature = "http2-rustls"))]
mod http2;
#[cfg(any(feature = "http2-native-tls", feature = "http2-rustls"))]
pub use http2::HTTP2Transport;
#[derive(Debug, Clone, Copy)]
struct Keepalive {
// tcp_keepalive_time if the underlying protocol is TCP

View File

@ -110,7 +110,7 @@ impl Transport for TlsTransport {
}
}
#[cfg(feature = "websocket-native-tls")]
#[cfg(any(feature = "websocket-native-tls", feature = "http2-native-tls"))]
pub(crate) fn get_tcpstream(s: &TlsStream<TcpStream>) -> &TcpStream {
s.get_ref().get_ref().get_ref()
}

View File

@ -151,6 +151,7 @@ impl Transport for TlsTransport {
}
}
#[cfg(any(feature = "websocket-rustls", feature = "http2-rustls"))]
pub(crate) fn get_tcpstream(s: &TlsStream<TcpStream>) -> &TcpStream {
&s.get_ref().0
}

View File

@ -4,8 +4,6 @@ use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use super::{AddrMaybeCached, SocketOpts, TcpTransport, TlsTransport, Transport};
use crate::config::TransportConfig;
use anyhow::anyhow;
use async_trait::async_trait;
use bytes::Bytes;
@ -13,78 +11,18 @@ use futures_core::stream::Stream;
use futures_sink::Sink;
use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
#[cfg(any(feature = "native-tls", feature = "rustls"))]
use super::tls::get_tcpstream;
#[cfg(any(feature = "native-tls", feature = "rustls"))]
use super::tls::TlsStream;
use tokio_tungstenite::tungstenite::protocol::{Message, WebSocketConfig};
use tokio_tungstenite::{accept_async_with_config, client_async_with_config, WebSocketStream};
use tokio_util::io::StreamReader;
use url::Url;
#[derive(Debug)]
enum TransportStream {
Insecure(TcpStream),
Secure(TlsStream<TcpStream>),
}
impl TransportStream {
fn get_tcpstream(&self) -> &TcpStream {
match self {
TransportStream::Insecure(s) => s,
TransportStream::Secure(s) => get_tcpstream(s),
}
}
}
impl AsyncRead for TransportStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.get_mut() {
TransportStream::Insecure(s) => Pin::new(s).poll_read(cx, buf),
TransportStream::Secure(s) => Pin::new(s).poll_read(cx, buf),
}
}
}
impl AsyncWrite for TransportStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
match self.get_mut() {
TransportStream::Insecure(s) => Pin::new(s).poll_write(cx, buf),
TransportStream::Secure(s) => Pin::new(s).poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
match self.get_mut() {
TransportStream::Insecure(s) => Pin::new(s).poll_flush(cx),
TransportStream::Secure(s) => Pin::new(s).poll_flush(cx),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
match self.get_mut() {
TransportStream::Insecure(s) => Pin::new(s).poll_shutdown(cx),
TransportStream::Secure(s) => Pin::new(s).poll_shutdown(cx),
}
}
}
use super::maybe_tls::{MaybeTLSStream, MaybeTLSTransport};
use super::{AddrMaybeCached, SocketOpts, Transport};
use crate::config::TransportConfig;
#[derive(Debug)]
struct StreamWrapper {
inner: WebSocketStream<TransportStream>,
inner: WebSocketStream<MaybeTLSStream>,
}
impl Stream for StreamWrapper {
@ -170,15 +108,9 @@ impl AsyncWrite for WebsocketTunnel {
}
}
#[derive(Debug)]
enum SubTransport {
Secure(TlsTransport),
Insecure(TcpTransport),
}
#[derive(Debug)]
pub struct WebsocketTransport {
sub: SubTransport,
sub: MaybeTLSTransport,
conf: WebSocketConfig,
}
@ -198,10 +130,7 @@ impl Transport for WebsocketTransport {
write_buffer_size: 0,
..WebSocketConfig::default()
};
let sub = match wsconfig.tls {
true => SubTransport::Secure(TlsTransport::new(config)?),
false => SubTransport::Insecure(TcpTransport::new(config)?),
};
let sub = MaybeTLSTransport::new_explicit(wsconfig.tls, config)?;
Ok(WebsocketTransport { sub, conf })
}
@ -213,22 +142,15 @@ impl Transport for WebsocketTransport {
&self,
addr: A,
) -> anyhow::Result<Self::Acceptor> {
TcpListener::bind(addr).await.map_err(Into::into)
self.sub.bind(addr).await
}
async fn accept(&self, a: &Self::Acceptor) -> anyhow::Result<(Self::RawStream, SocketAddr)> {
let (s, addr) = match &self.sub {
SubTransport::Insecure(t) => t.accept(a).await?,
SubTransport::Secure(t) => t.accept(a).await?,
};
Ok((s, addr))
self.sub.accept(a).await
}
async fn handshake(&self, conn: Self::RawStream) -> anyhow::Result<Self::Stream> {
let tsream = match &self.sub {
SubTransport::Insecure(t) => TransportStream::Insecure(t.handshake(conn).await?),
SubTransport::Secure(t) => TransportStream::Secure(t.handshake(conn).await?),
};
let tsream = self.sub.handshake(conn).await?;
let wsstream = accept_async_with_config(tsream, Some(self.conf)).await?;
let tun = WebsocketTunnel {
inner: StreamReader::new(StreamWrapper { inner: wsstream }),
@ -239,10 +161,7 @@ impl Transport for WebsocketTransport {
async fn connect(&self, addr: &AddrMaybeCached) -> anyhow::Result<Self::Stream> {
let u = format!("ws://{}", &addr.addr.as_str());
let url = Url::parse(&u).unwrap();
let tstream = match &self.sub {
SubTransport::Insecure(t) => TransportStream::Insecure(t.connect(addr).await?),
SubTransport::Secure(t) => TransportStream::Secure(t.connect(addr).await?),
};
let tstream = self.sub.connect(addr).await?;
let (wsstream, _) = client_async_with_config(url, tstream, Some(self.conf))
.await
.expect("failed to connect");

View File

@ -0,0 +1,33 @@
[client]
remote_addr = "127.0.0.1:2333"
default_token = "default_token_if_not_specify"
[client.transport]
type = "http2"
[client.transport.tls]
trusted_root = "examples/tls/rootCA.crt"
hostname = "localhost"
[client.transport.http2]
tls = true
[client.services.echo]
local_addr = "127.0.0.1:8080"
[client.services.pingpong]
local_addr = "127.0.0.1:8081"
[server]
bind_addr = "0.0.0.0:2333"
default_token = "default_token_if_not_specify"
[server.transport]
type = "http2"
[server.transport.tls]
pkcs12 = "examples/tls/identity.pfx"
pkcs12_password = "1234"
[server.transport.http2]
tls = true
[server.services.echo]
bind_addr = "0.0.0.0:2334"
[server.services.pingpong]
bind_addr = "0.0.0.0:2335"

View File

@ -0,0 +1,27 @@
[client]
remote_addr = "127.0.0.1:2333"
default_token = "default_token_if_not_specify"
[client.transport]
type = "http2"
[client.transport.http2]
tls = false
[client.services.echo]
local_addr = "127.0.0.1:8080"
[client.services.pingpong]
local_addr = "127.0.0.1:8081"
[server]
bind_addr = "0.0.0.0:2333"
default_token = "default_token_if_not_specify"
[server.transport]
type = "http2"
[server.transport.http2]
tls = false
[server.services.echo]
bind_addr = "0.0.0.0:2334"
[server.services.pingpong]
bind_addr = "0.0.0.0:2335"

View File

@ -0,0 +1,37 @@
[client]
remote_addr = "127.0.0.1:2332"
default_token = "default_token_if_not_specify"
[client.transport]
type = "http2"
[client.transport.tls]
trusted_root = "examples/tls/rootCA.crt"
hostname = "localhost"
[client.transport.http2]
tls = true
[client.services.echo]
type = "udp"
local_addr = "127.0.0.1:8080"
[client.services.pingpong]
type = "udp"
local_addr = "127.0.0.1:8081"
[server]
bind_addr = "0.0.0.0:2332"
default_token = "default_token_if_not_specify"
[server.transport]
type = "http2"
[server.transport.tls]
pkcs12 = "examples/tls/identity.pfx"
pkcs12_password = "1234"
[server.transport.http2]
tls = true
[server.services.echo]
type = "udp"
bind_addr = "0.0.0.0:2334"
[server.services.pingpong]
type = "udp"
bind_addr = "0.0.0.0:2335"

View File

@ -0,0 +1,31 @@
[client]
remote_addr = "127.0.0.1:2332"
default_token = "default_token_if_not_specify"
[client.transport]
type = "http2"
[client.transport.http2]
tls = false
[client.services.echo]
type = "udp"
local_addr = "127.0.0.1:8080"
[client.services.pingpong]
type = "udp"
local_addr = "127.0.0.1:8081"
[server]
bind_addr = "0.0.0.0:2332"
default_token = "default_token_if_not_specify"
[server.transport]
type = "http2"
[server.transport.http2]
tls = false
[server.services.echo]
type = "udp"
bind_addr = "0.0.0.0:2334"
[server.services.pingpong]
type = "udp"
bind_addr = "0.0.0.0:2335"

View File

@ -65,11 +65,16 @@ async fn tcp() -> Result<()> {
#[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
test("tests/for_tcp/websocket_transport.toml", Type::Tcp).await?;
#[cfg(not(target_os = "macos"))]
#[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
test("tests/for_tcp/websocket_tls_transport.toml", Type::Tcp).await?;
#[cfg(any(feature = "http2-native-tls", feature = "http2-rustls"))]
test("tests/for_tcp/http2_transport.toml", Type::Tcp).await?;
#[cfg(not(target_os = "macos"))]
#[cfg(any(feature = "http2-native-tls", feature = "http2-rustls"))]
test("tests/for_tcp/http2_tls_transport.toml", Type::Tcp).await?;
Ok(())
}
@ -102,11 +107,16 @@ async fn udp() -> Result<()> {
#[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
test("tests/for_udp/websocket_transport.toml", Type::Udp).await?;
#[cfg(not(target_os = "macos"))]
#[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
test("tests/for_udp/websocket_tls_transport.toml", Type::Udp).await?;
#[cfg(any(feature = "http2-native-tls", feature = "http2-rustls"))]
test("tests/for_udp/http2_transport.toml", Type::Udp).await?;
#[cfg(not(target_os = "macos"))]
#[cfg(any(feature = "http2-native-tls", feature = "http2-rustls"))]
test("tests/for_udp/http2_tls_transport.toml", Type::Udp).await?;
Ok(())
}