mirror of https://github.com/rapiz1/rathole.git
feat(transport): add http2 transport (#392)
This commit is contained in:
parent
be14d124a2
commit
ddc97cff78
|
@ -128,7 +128,7 @@ jobs:
|
||||||
version: v4.0.2
|
version: v4.0.2
|
||||||
files: target/${{ matrix.target }}/release/${{ matrix.exe }}
|
files: target/${{ matrix.target }}/release/${{ matrix.exe }}
|
||||||
args: -q --best --lzma
|
args: -q --best --lzma
|
||||||
- uses: actions/upload-artifact@v2
|
- uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: rathole-${{ matrix.target }}
|
name: rathole-${{ matrix.target }}
|
||||||
path: target/${{ matrix.target }}/release/${{ matrix.exe }}
|
path: target/${{ matrix.target }}/release/${{ matrix.exe }}
|
||||||
|
|
|
@ -34,7 +34,7 @@ jobs:
|
||||||
- name: Check all features
|
- name: Check all features
|
||||||
run: >
|
run: >
|
||||||
cargo hack check --feature-powerset --no-dev-deps
|
cargo hack check --feature-powerset --no-dev-deps
|
||||||
--mutually-exclusive-features default,native-tls,websocket-native-tls,rustls,websocket-rustls
|
--mutually-exclusive-features default,native-tls,websocket-native-tls,http2-native-tls,rustls,websocket-rustls,http2-rustls
|
||||||
|
|
||||||
build:
|
build:
|
||||||
name: Build for ${{ matrix.target }}
|
name: Build for ${{ matrix.target }}
|
||||||
|
@ -67,8 +67,8 @@ jobs:
|
||||||
- name: Run tests with native-tls
|
- name: Run tests with native-tls
|
||||||
run: cargo test --verbose
|
run: cargo test --verbose
|
||||||
- name: Run tests with rustls
|
- name: Run tests with rustls
|
||||||
run: cargo test --verbose --no-default-features --features server,client,rustls,noise,websocket-rustls,hot-reload
|
run: cargo test --verbose --no-default-features --features server,client,rustls,noise,websocket-rustls,http2-rustls,hot-reload
|
||||||
- uses: actions/upload-artifact@v2
|
- uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: rathole-${{ matrix.target }}
|
name: rathole-${{ matrix.target }}
|
||||||
path: target/debug/${{ matrix.exe }}
|
path: target/debug/${{ matrix.exe }}
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -17,6 +17,7 @@ default = [
|
||||||
"native-tls",
|
"native-tls",
|
||||||
"noise",
|
"noise",
|
||||||
"websocket-native-tls",
|
"websocket-native-tls",
|
||||||
|
"http2-native-tls",
|
||||||
"hot-reload",
|
"hot-reload",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -53,6 +54,28 @@ websocket-rustls = [
|
||||||
"rustls",
|
"rustls",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# HTTP2 support
|
||||||
|
http2-native-tls = [
|
||||||
|
"hyper",
|
||||||
|
"hyper-util",
|
||||||
|
"http",
|
||||||
|
"http-body-util",
|
||||||
|
"futures-core",
|
||||||
|
"tokio-util",
|
||||||
|
"tower-service",
|
||||||
|
"native-tls",
|
||||||
|
]
|
||||||
|
http2-rustls = [
|
||||||
|
"hyper",
|
||||||
|
"hyper-util",
|
||||||
|
"http",
|
||||||
|
"http-body-util",
|
||||||
|
"futures-core",
|
||||||
|
"tokio-util",
|
||||||
|
"tower-service",
|
||||||
|
"rustls",
|
||||||
|
]
|
||||||
|
|
||||||
# Configuration hot-reload support
|
# Configuration hot-reload support
|
||||||
hot-reload = ["notify"]
|
hot-reload = ["notify"]
|
||||||
|
|
||||||
|
@ -117,6 +140,11 @@ async-http-proxy = { version = "1.2", features = [
|
||||||
async-socks5 = "0.5"
|
async-socks5 = "0.5"
|
||||||
url = { version = "2.2", features = ["serde"] }
|
url = { version = "2.2", features = ["serde"] }
|
||||||
tokio-tungstenite = { version = "0.20.1", optional = true }
|
tokio-tungstenite = { version = "0.20.1", optional = true }
|
||||||
|
http = { version = "1.1.0", optional = true }
|
||||||
|
hyper = { version = "1.4.1", optional = true , features = ["client","server","http2"] }
|
||||||
|
hyper-util = { version = "0.1.9", optional = true , features = ["full"]}
|
||||||
|
http-body-util = { version = "0.1.2", optional = true }
|
||||||
|
tower-service = { version = "0.3.3", optional = true }
|
||||||
tokio-util = { version = "0.7.9", optional = true, features = ["io"] }
|
tokio-util = { version = "0.7.9", optional = true, features = ["io"] }
|
||||||
futures-core = { version = "0.3.28", optional = true }
|
futures-core = { version = "0.3.28", optional = true }
|
||||||
futures-sink = { version = "0.3.28", optional = true }
|
futures-sink = { version = "0.3.28", optional = true }
|
||||||
|
|
|
@ -111,7 +111,7 @@ heartbeat_timeout = 40 # Optional. Set to 0 to disable the application-layer hea
|
||||||
retry_interval = 1 # Optional. The interval between retry to connect to the server. Default: 1 second
|
retry_interval = 1 # Optional. The interval between retry to connect to the server. Default: 1 second
|
||||||
|
|
||||||
[client.transport] # The whole block is optional. Specify which transport to use
|
[client.transport] # The whole block is optional. Specify which transport to use
|
||||||
type = "tcp" # Optional. Possible values: ["tcp", "tls", "noise"]. Default: "tcp"
|
type = "tcp" # Optional. Possible values: ["tcp", "tls", "noise", "websocket", "http2"]. Default: "tcp"
|
||||||
|
|
||||||
[client.transport.tcp] # Optional. Also affects `noise` and `tls`
|
[client.transport.tcp] # Optional. Also affects `noise` and `tls`
|
||||||
proxy = "socks5://user:passwd@127.0.0.1:1080" # Optional. The proxy used to connect to the server. `http` and `socks5` is supported.
|
proxy = "socks5://user:passwd@127.0.0.1:1080" # Optional. The proxy used to connect to the server. `http` and `socks5` is supported.
|
||||||
|
@ -131,6 +131,9 @@ remote_public_key = "key_encoded_in_base64" # Optional
|
||||||
[client.transport.websocket] # Necessary if `type` is "websocket"
|
[client.transport.websocket] # Necessary if `type` is "websocket"
|
||||||
tls = true # If `true` then it will use settings in `client.transport.tls`
|
tls = true # If `true` then it will use settings in `client.transport.tls`
|
||||||
|
|
||||||
|
[client.transport.http2] # Necessary if `type` is "http2"
|
||||||
|
tls = true # If `true` then it will use settings in `client.transport.tls`
|
||||||
|
|
||||||
[client.services.service1] # A service that needs forwarding. The name `service1` can change arbitrarily, as long as identical to the name in the server's configuration
|
[client.services.service1] # A service that needs forwarding. The name `service1` can change arbitrarily, as long as identical to the name in the server's configuration
|
||||||
type = "tcp" # Optional. The protocol that needs forwarding. Possible values: ["tcp", "udp"]. Default: "tcp"
|
type = "tcp" # Optional. The protocol that needs forwarding. Possible values: ["tcp", "udp"]. Default: "tcp"
|
||||||
token = "whatever" # Necessary if `client.default_token` not set
|
token = "whatever" # Necessary if `client.default_token` not set
|
||||||
|
@ -166,6 +169,9 @@ remote_public_key = "key_encoded_in_base64"
|
||||||
[server.transport.websocket] # Necessary if `type` is "websocket"
|
[server.transport.websocket] # Necessary if `type` is "websocket"
|
||||||
tls = true # If `true` then it will use settings in `server.transport.tls`
|
tls = true # If `true` then it will use settings in `server.transport.tls`
|
||||||
|
|
||||||
|
[server.transport.http2] # Necessary if `type` is "http2"
|
||||||
|
tls = true # If `true` then it will use settings in `server.transport.tls`
|
||||||
|
|
||||||
[server.services.service1] # The service name must be identical to the client side
|
[server.services.service1] # The service name must be identical to the client side
|
||||||
type = "tcp" # Optional. Same as the client `[client.services.X.type]
|
type = "tcp" # Optional. Same as the client `[client.services.X.type]
|
||||||
token = "whatever" # Necessary if `server.default_token` not set
|
token = "whatever" # Necessary if `server.default_token` not set
|
||||||
|
|
|
@ -21,6 +21,8 @@ use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
|
||||||
use tokio::time::{self, Duration, Instant};
|
use tokio::time::{self, Duration, Instant};
|
||||||
use tracing::{debug, error, info, instrument, trace, warn, Instrument, Span};
|
use tracing::{debug, error, info, instrument, trace, warn, Instrument, Span};
|
||||||
|
|
||||||
|
#[cfg(any(feature = "http2-native-tls", feature = "http2-rustls"))]
|
||||||
|
use crate::transport::HTTP2Transport;
|
||||||
#[cfg(feature = "noise")]
|
#[cfg(feature = "noise")]
|
||||||
use crate::transport::NoiseTransport;
|
use crate::transport::NoiseTransport;
|
||||||
#[cfg(any(feature = "native-tls", feature = "rustls"))]
|
#[cfg(any(feature = "native-tls", feature = "rustls"))]
|
||||||
|
@ -74,6 +76,15 @@ pub async fn run_client(
|
||||||
#[cfg(not(any(feature = "websocket-native-tls", feature = "websocket-rustls")))]
|
#[cfg(not(any(feature = "websocket-native-tls", feature = "websocket-rustls")))]
|
||||||
crate::helper::feature_neither_compile("websocket-native-tls", "websocket-rustls")
|
crate::helper::feature_neither_compile("websocket-native-tls", "websocket-rustls")
|
||||||
}
|
}
|
||||||
|
TransportType::HTTP2 => {
|
||||||
|
#[cfg(any(feature = "http2-native-tls", feature = "http2-rustls"))]
|
||||||
|
{
|
||||||
|
let mut client = Client::<HTTP2Transport>::from(config).await?;
|
||||||
|
client.run(shutdown_rx, update_rx).await
|
||||||
|
}
|
||||||
|
#[cfg(not(any(feature = "http2-native-tls", feature = "http2-rustls")))]
|
||||||
|
crate::helper::feature_neither_compile("http2-native-tls", "http2-rustls")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -51,6 +51,8 @@ pub enum TransportType {
|
||||||
Noise,
|
Noise,
|
||||||
#[serde(rename = "websocket")]
|
#[serde(rename = "websocket")]
|
||||||
Websocket,
|
Websocket,
|
||||||
|
#[serde(rename = "http2")]
|
||||||
|
HTTP2,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Per service config
|
/// Per service config
|
||||||
|
@ -141,6 +143,12 @@ pub struct WebsocketConfig {
|
||||||
pub tls: bool,
|
pub tls: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
|
#[serde(deny_unknown_fields)]
|
||||||
|
pub struct HTTP2Config {
|
||||||
|
pub tls: bool,
|
||||||
|
}
|
||||||
|
|
||||||
fn default_nodelay() -> bool {
|
fn default_nodelay() -> bool {
|
||||||
DEFAULT_NODELAY
|
DEFAULT_NODELAY
|
||||||
}
|
}
|
||||||
|
@ -186,6 +194,7 @@ pub struct TransportConfig {
|
||||||
pub tls: Option<TlsConfig>,
|
pub tls: Option<TlsConfig>,
|
||||||
pub noise: Option<NoiseConfig>,
|
pub noise: Option<NoiseConfig>,
|
||||||
pub websocket: Option<WebsocketConfig>,
|
pub websocket: Option<WebsocketConfig>,
|
||||||
|
pub http2: Option<HTTP2Config>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_heartbeat_timeout() -> u64 {
|
fn default_heartbeat_timeout() -> u64 {
|
||||||
|
@ -320,6 +329,7 @@ impl Config {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
TransportType::Websocket => Ok(()),
|
TransportType::Websocket => Ok(()),
|
||||||
|
TransportType::HTTP2 => Ok(()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,8 @@ use tokio::sync::{broadcast, mpsc, RwLock};
|
||||||
use tokio::time;
|
use tokio::time;
|
||||||
use tracing::{debug, error, info, info_span, instrument, warn, Instrument, Span};
|
use tracing::{debug, error, info, info_span, instrument, warn, Instrument, Span};
|
||||||
|
|
||||||
|
#[cfg(any(feature = "http2-native-tls", feature = "http2-rustls"))]
|
||||||
|
use crate::transport::HTTP2Transport;
|
||||||
#[cfg(feature = "noise")]
|
#[cfg(feature = "noise")]
|
||||||
use crate::transport::NoiseTransport;
|
use crate::transport::NoiseTransport;
|
||||||
#[cfg(any(feature = "native-tls", feature = "rustls"))]
|
#[cfg(any(feature = "native-tls", feature = "rustls"))]
|
||||||
|
@ -83,6 +85,15 @@ pub async fn run_server(
|
||||||
#[cfg(not(any(feature = "websocket-native-tls", feature = "websocket-rustls")))]
|
#[cfg(not(any(feature = "websocket-native-tls", feature = "websocket-rustls")))]
|
||||||
crate::helper::feature_neither_compile("websocket-native-tls", "websocket-rustls")
|
crate::helper::feature_neither_compile("websocket-native-tls", "websocket-rustls")
|
||||||
}
|
}
|
||||||
|
TransportType::HTTP2 => {
|
||||||
|
#[cfg(any(feature = "http2-native-tls", feature = "http2-rustls"))]
|
||||||
|
{
|
||||||
|
let mut server = Server::<HTTP2Transport>::from(config).await?;
|
||||||
|
server.run(shutdown_rx, update_rx).await?;
|
||||||
|
}
|
||||||
|
#[cfg(not(any(feature = "http2-native-tls", feature = "http2-rustls")))]
|
||||||
|
crate::helper::feature_neither_compile("http2-native-tls", "http2-rustls")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
|
@ -0,0 +1,398 @@
|
||||||
|
use core::result::Result;
|
||||||
|
use std::future::Future;
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::task::{self, Context, Poll};
|
||||||
|
|
||||||
|
use anyhow::anyhow;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use bytes::{Buf, Bytes};
|
||||||
|
use futures_core::Stream;
|
||||||
|
use http::{Method, Request, Response, Uri};
|
||||||
|
use http_body_util::StreamBody;
|
||||||
|
use hyper::body::{Body, Incoming};
|
||||||
|
use hyper::server::conn::http2 as Server;
|
||||||
|
use hyper::service::Service;
|
||||||
|
use hyper_util::client::legacy::connect::{Connected, Connection};
|
||||||
|
use hyper_util::client::legacy::Client;
|
||||||
|
use hyper_util::rt::tokio::TokioExecutor;
|
||||||
|
use hyper_util::rt::tokio::TokioIo;
|
||||||
|
use tokio::io::{self, AsyncRead, AsyncWrite, ReadBuf, ReadHalf, SimplexStream, WriteHalf};
|
||||||
|
use tokio::net::{TcpListener, ToSocketAddrs};
|
||||||
|
use tokio::select;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
use tokio::sync::{broadcast, mpsc};
|
||||||
|
use tokio_util::io::ReaderStream;
|
||||||
|
|
||||||
|
use super::maybe_tls::{MaybeTLSStream, MaybeTLSTransport};
|
||||||
|
use super::{AddrMaybeCached, SocketOpts, Transport};
|
||||||
|
use crate::config::TransportConfig;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct IncomingHyper {
|
||||||
|
inner: hyper::body::Incoming,
|
||||||
|
current_chunk: Option<bytes::Bytes>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncRead for IncomingHyper {
|
||||||
|
fn poll_read(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &mut ReadBuf<'_>,
|
||||||
|
) -> Poll<std::io::Result<()>> {
|
||||||
|
loop {
|
||||||
|
if let Some(chunk) = &mut self.current_chunk {
|
||||||
|
let len = std::cmp::min(chunk.len(), buf.remaining());
|
||||||
|
buf.put_slice(&chunk[..len]);
|
||||||
|
|
||||||
|
chunk.advance(len);
|
||||||
|
if !chunk.has_remaining() {
|
||||||
|
self.current_chunk = None;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Poll::Ready(Ok(()));
|
||||||
|
}
|
||||||
|
|
||||||
|
match Pin::new(&mut self.inner).poll_frame(cx) {
|
||||||
|
Poll::Pending => return Poll::Pending,
|
||||||
|
Poll::Ready(None) => return Poll::Ready(Ok(())),
|
||||||
|
Poll::Ready(Some(Err(err))) => {
|
||||||
|
return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, err)))
|
||||||
|
}
|
||||||
|
Poll::Ready(Some(Ok(frame))) => match frame.into_data() {
|
||||||
|
Err(_) => {
|
||||||
|
return Poll::Ready(Err(io::Error::new(
|
||||||
|
io::ErrorKind::Other,
|
||||||
|
"non data frame received",
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
Ok(data) => {
|
||||||
|
self.current_chunk = Some(data);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct HTTP2Stream {
|
||||||
|
send: WriteHalf<SimplexStream>,
|
||||||
|
recv: IncomingHyper,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncRead for HTTP2Stream {
|
||||||
|
fn poll_read(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &mut ReadBuf<'_>,
|
||||||
|
) -> Poll<std::io::Result<()>> {
|
||||||
|
Pin::new(&mut self.get_mut().recv).poll_read(cx, buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncWrite for HTTP2Stream {
|
||||||
|
fn poll_write(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<Result<usize, std::io::Error>> {
|
||||||
|
Pin::new(&mut self.get_mut().send).poll_write(cx, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
|
||||||
|
Pin::new(&mut self.get_mut().send).poll_flush(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
) -> Poll<Result<(), std::io::Error>> {
|
||||||
|
Pin::new(&mut self.get_mut().send).poll_shutdown(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct OutgoingSimplex {
|
||||||
|
inner: ReaderStream<ReadHalf<SimplexStream>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Stream for OutgoingSimplex {
|
||||||
|
type Item = Result<hyper::body::Frame<Bytes>, io::Error>;
|
||||||
|
|
||||||
|
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||||
|
match Pin::new(&mut self.get_mut().inner).poll_next(cx) {
|
||||||
|
Poll::Pending => Poll::Pending,
|
||||||
|
Poll::Ready(None) => Poll::Ready(None),
|
||||||
|
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
|
||||||
|
Poll::Ready(Some(Ok(data))) => Poll::Ready(Some(Ok(hyper::body::Frame::data(data)))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||||
|
self.inner.size_hint()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct Svc {
|
||||||
|
req_sender: mpsc::Sender<anyhow::Result<(SocketAddr, Incoming, mpsc::Sender<OutgoingSimplex>)>>,
|
||||||
|
addr: SocketAddr,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Service<Request<Incoming>> for Svc {
|
||||||
|
type Response = Response<StreamBody<OutgoingSimplex>>;
|
||||||
|
type Error = anyhow::Error;
|
||||||
|
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||||
|
|
||||||
|
fn call(&self, req: Request<Incoming>) -> Self::Future {
|
||||||
|
let req_sender = self.req_sender.clone();
|
||||||
|
let addr = self.addr;
|
||||||
|
|
||||||
|
let future = async move {
|
||||||
|
let (res_sender, mut res_receiver) = mpsc::channel::<OutgoingSimplex>(1);
|
||||||
|
if let Err(err) = req_sender
|
||||||
|
.send(Ok((addr, req.into_body(), res_sender)))
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
return Err(anyhow!(err.to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
match res_receiver.recv().await {
|
||||||
|
None => Err(anyhow!("Channel closed")),
|
||||||
|
Some(body) => Ok(Response::new(http_body_util::StreamBody::new(body))),
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Box::pin(future) // Return the boxed future
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn start_http_server(
|
||||||
|
listener: TcpListener,
|
||||||
|
transport: Arc<MaybeTLSTransport>,
|
||||||
|
req_sender: mpsc::Sender<anyhow::Result<(SocketAddr, Incoming, mpsc::Sender<OutgoingSimplex>)>>,
|
||||||
|
stop_receiver: broadcast::Receiver<()>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
loop {
|
||||||
|
let mut stop_receiver = stop_receiver.resubscribe();
|
||||||
|
let conn = async {
|
||||||
|
let (conn, addr) = transport.accept(&listener).await?;
|
||||||
|
let stream = transport.handshake(conn).await?;
|
||||||
|
Ok((stream, addr))
|
||||||
|
};
|
||||||
|
select! {
|
||||||
|
_ = stop_receiver.recv() => {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
conn = conn => {
|
||||||
|
if let Err(err) = conn {
|
||||||
|
if let Err(err)= req_sender.send(Err(err)).await {
|
||||||
|
eprintln!("Error sending error message: {}", err);
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
let (socket, addr) = conn.unwrap();
|
||||||
|
let io = TokioIo::new(socket);
|
||||||
|
let svc = Svc {
|
||||||
|
req_sender: req_sender.clone(),
|
||||||
|
addr,
|
||||||
|
};
|
||||||
|
let req_sender= req_sender.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut conn = Server::Builder::new(TokioExecutor::new())
|
||||||
|
.serve_connection(io, svc);
|
||||||
|
select! {
|
||||||
|
_ = stop_receiver.recv() => {
|
||||||
|
Pin::new(&mut conn).graceful_shutdown();
|
||||||
|
}
|
||||||
|
|
||||||
|
res = &mut conn => {
|
||||||
|
if let Err(err) = res {
|
||||||
|
if let Err(err)= req_sender.send(Err(anyhow!(err.to_string()))).await {
|
||||||
|
eprintln!("Error sending error message: {}", err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Connection for MaybeTLSStream {
|
||||||
|
fn connected(&self) -> Connected {
|
||||||
|
let connected = Connected::new();
|
||||||
|
if let (Ok(remote_addr), Ok(local_addr)) = (
|
||||||
|
self.get_tcpstream().peer_addr(),
|
||||||
|
self.get_tcpstream().local_addr(),
|
||||||
|
) {
|
||||||
|
connected.extra((remote_addr, local_addr))
|
||||||
|
} else {
|
||||||
|
connected
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct MaybeTLSConnector {
|
||||||
|
sub: Arc<MaybeTLSTransport>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl tower_service::Service<Uri> for MaybeTLSConnector {
|
||||||
|
type Response = TokioIo<MaybeTLSStream>;
|
||||||
|
type Error = anyhow::Error;
|
||||||
|
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||||
|
|
||||||
|
fn poll_ready(&mut self, _: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call(&mut self, u: Uri) -> Self::Future {
|
||||||
|
let sub = self.sub.clone();
|
||||||
|
let future = async move {
|
||||||
|
let addr = match (u.host(), u.port()) {
|
||||||
|
(Some(host), Some(port)) => format!("{}:{}", host, port),
|
||||||
|
_ => String::from(""),
|
||||||
|
};
|
||||||
|
let addr = AddrMaybeCached::new(addr.as_str());
|
||||||
|
let stream = sub.connect(&addr).await?;
|
||||||
|
Ok(TokioIo::new(stream))
|
||||||
|
};
|
||||||
|
Box::pin(future)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct HTTP2Transport {
|
||||||
|
sub: Arc<MaybeTLSTransport>,
|
||||||
|
client: Client<MaybeTLSConnector, StreamBody<OutgoingSimplex>>,
|
||||||
|
|
||||||
|
stop_sender: broadcast::Sender<()>,
|
||||||
|
stop_receiver: broadcast::Receiver<()>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Transport for HTTP2Transport {
|
||||||
|
type Acceptor = Arc<
|
||||||
|
Mutex<
|
||||||
|
mpsc::Receiver<anyhow::Result<(SocketAddr, Incoming, mpsc::Sender<OutgoingSimplex>)>>,
|
||||||
|
>,
|
||||||
|
>;
|
||||||
|
type RawStream = (Incoming, mpsc::Sender<OutgoingSimplex>);
|
||||||
|
type Stream = HTTP2Stream;
|
||||||
|
|
||||||
|
fn new(config: &TransportConfig) -> anyhow::Result<Self> {
|
||||||
|
let cfg = config
|
||||||
|
.http2
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| anyhow!("Missing http2 config"))?;
|
||||||
|
|
||||||
|
let (stop_sender, stop_receiver) = broadcast::channel(1);
|
||||||
|
let sub = Arc::new(MaybeTLSTransport::new_explicit(cfg.tls, config)?);
|
||||||
|
let client = Client::builder(TokioExecutor::new())
|
||||||
|
.http2_only(true)
|
||||||
|
.build(MaybeTLSConnector { sub: sub.clone() });
|
||||||
|
|
||||||
|
Ok(HTTP2Transport {
|
||||||
|
sub,
|
||||||
|
client,
|
||||||
|
stop_sender,
|
||||||
|
stop_receiver,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn hint(_: &Self::Stream, _: SocketOpts) {}
|
||||||
|
|
||||||
|
async fn bind<A: ToSocketAddrs + Send + Sync>(
|
||||||
|
&self,
|
||||||
|
addr: A,
|
||||||
|
) -> anyhow::Result<Self::Acceptor> {
|
||||||
|
let listener = self.sub.bind(addr).await?;
|
||||||
|
|
||||||
|
let (req_sender, req_receiver) = mpsc::channel::<
|
||||||
|
anyhow::Result<(SocketAddr, Incoming, mpsc::Sender<OutgoingSimplex>)>,
|
||||||
|
>(1);
|
||||||
|
let req_receiver = Arc::new(Mutex::new(req_receiver));
|
||||||
|
let sub_transport = self.sub.clone();
|
||||||
|
let stop_receiver = self.stop_receiver.resubscribe();
|
||||||
|
tokio::spawn(start_http_server(
|
||||||
|
listener,
|
||||||
|
sub_transport,
|
||||||
|
req_sender,
|
||||||
|
stop_receiver,
|
||||||
|
));
|
||||||
|
Ok(req_receiver)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn accept(&self, a: &Self::Acceptor) -> anyhow::Result<(Self::RawStream, SocketAddr)> {
|
||||||
|
let mut receiver = a.lock().await;
|
||||||
|
match receiver.recv().await {
|
||||||
|
None => Err(anyhow!("Channel closed")),
|
||||||
|
Some(Err(err)) => Err(err),
|
||||||
|
Some(Ok((addr, incoming, res_sender))) => Ok(((incoming, res_sender), addr)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handshake(&self, conn: Self::RawStream) -> anyhow::Result<Self::Stream> {
|
||||||
|
let (incoming, res_sender) = conn;
|
||||||
|
|
||||||
|
let (sread, swrite) = io::simplex(4096);
|
||||||
|
if let Err(err) = res_sender
|
||||||
|
.send(OutgoingSimplex {
|
||||||
|
inner: ReaderStream::with_capacity(sread, 4096),
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
return Err(anyhow!(err.to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(HTTP2Stream {
|
||||||
|
recv: IncomingHyper {
|
||||||
|
inner: incoming,
|
||||||
|
current_chunk: None,
|
||||||
|
},
|
||||||
|
send: swrite,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn connect(&self, addr: &AddrMaybeCached) -> anyhow::Result<Self::Stream> {
|
||||||
|
let client = self.client.clone();
|
||||||
|
let (sread, swrite) = io::simplex(4096);
|
||||||
|
let body = http_body_util::StreamBody::new(OutgoingSimplex {
|
||||||
|
inner: ReaderStream::with_capacity(sread, 4096),
|
||||||
|
});
|
||||||
|
let req = Request::builder()
|
||||||
|
.method(Method::POST)
|
||||||
|
.uri(format!("http://{}", &addr.addr.as_str()))
|
||||||
|
.body(body)
|
||||||
|
.expect("request builder");
|
||||||
|
let res = client.request(req).await;
|
||||||
|
if let Err(err) = res {
|
||||||
|
return Err(anyhow!("Error: {}", err));
|
||||||
|
}
|
||||||
|
|
||||||
|
let res = res.unwrap();
|
||||||
|
if !res.status().is_success() {
|
||||||
|
return Err(anyhow!("Bad status code: {}", res.status()));
|
||||||
|
}
|
||||||
|
Ok(HTTP2Stream {
|
||||||
|
recv: IncomingHyper {
|
||||||
|
inner: res.into_body(),
|
||||||
|
current_chunk: None,
|
||||||
|
},
|
||||||
|
send: swrite,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for HTTP2Transport {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
let _ = self.stop_sender.send(());
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,147 @@
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::task::{Context, Poll};
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||||
|
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
|
||||||
|
|
||||||
|
use super::tls::get_tcpstream;
|
||||||
|
use super::tls::TlsStream;
|
||||||
|
use super::{AddrMaybeCached, SocketOpts, TlsTransport};
|
||||||
|
use super::{TcpTransport, Transport};
|
||||||
|
use crate::config::TransportConfig;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(super) enum MaybeTLSStream {
|
||||||
|
No(TcpStream),
|
||||||
|
Yes(TlsStream<TcpStream>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MaybeTLSStream {
|
||||||
|
pub(super) fn get_tcpstream(&self) -> &TcpStream {
|
||||||
|
match self {
|
||||||
|
MaybeTLSStream::No(s) => s,
|
||||||
|
MaybeTLSStream::Yes(s) => get_tcpstream(s),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncRead for MaybeTLSStream {
|
||||||
|
fn poll_read(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &mut ReadBuf<'_>,
|
||||||
|
) -> Poll<std::io::Result<()>> {
|
||||||
|
match self.get_mut() {
|
||||||
|
MaybeTLSStream::No(s) => Pin::new(s).poll_read(cx, buf),
|
||||||
|
MaybeTLSStream::Yes(s) => Pin::new(s).poll_read(cx, buf),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncWrite for MaybeTLSStream {
|
||||||
|
fn poll_write(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<Result<usize, std::io::Error>> {
|
||||||
|
match self.get_mut() {
|
||||||
|
MaybeTLSStream::No(s) => Pin::new(s).poll_write(cx, buf),
|
||||||
|
MaybeTLSStream::Yes(s) => Pin::new(s).poll_write(cx, buf),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
|
||||||
|
match self.get_mut() {
|
||||||
|
MaybeTLSStream::No(s) => Pin::new(s).poll_flush(cx),
|
||||||
|
MaybeTLSStream::Yes(s) => Pin::new(s).poll_flush(cx),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut Context<'_>,
|
||||||
|
) -> Poll<Result<(), std::io::Error>> {
|
||||||
|
match self.get_mut() {
|
||||||
|
MaybeTLSStream::No(s) => Pin::new(s).poll_shutdown(cx),
|
||||||
|
MaybeTLSStream::Yes(s) => Pin::new(s).poll_shutdown(cx),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(super) enum MaybeTLSTransport {
|
||||||
|
Yes(TlsTransport),
|
||||||
|
No(TcpTransport),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MaybeTLSTransport {
|
||||||
|
pub(super) fn new_explicit(tls: bool, tconfig: &TransportConfig) -> anyhow::Result<Self> {
|
||||||
|
match tls {
|
||||||
|
true => Ok(MaybeTLSTransport::Yes(TlsTransport::new(tconfig)?)),
|
||||||
|
false => Ok(MaybeTLSTransport::No(TcpTransport::new(tconfig)?)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Transport for MaybeTLSTransport {
|
||||||
|
type Acceptor = TcpListener;
|
||||||
|
type RawStream = TcpStream;
|
||||||
|
type Stream = MaybeTLSStream;
|
||||||
|
|
||||||
|
fn new(config: &TransportConfig) -> anyhow::Result<Self> {
|
||||||
|
MaybeTLSTransport::new_explicit(config.tls.is_some(), config)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn bind<A: ToSocketAddrs + Send + Sync>(
|
||||||
|
&self,
|
||||||
|
addr: A,
|
||||||
|
) -> anyhow::Result<Self::Acceptor> {
|
||||||
|
match self {
|
||||||
|
MaybeTLSTransport::Yes(t) => t.bind(addr).await,
|
||||||
|
MaybeTLSTransport::No(t) => t.bind(addr).await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn hint(conn: &Self::Stream, opt: SocketOpts) {
|
||||||
|
match conn {
|
||||||
|
MaybeTLSStream::Yes(t) => TlsTransport::hint(t, opt),
|
||||||
|
MaybeTLSStream::No(t) => TcpTransport::hint(t, opt),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn accept(&self, a: &Self::Acceptor) -> anyhow::Result<(Self::RawStream, SocketAddr)> {
|
||||||
|
match self {
|
||||||
|
MaybeTLSTransport::Yes(t) => t.accept(a).await,
|
||||||
|
MaybeTLSTransport::No(t) => t.accept(a).await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handshake(&self, conn: Self::RawStream) -> anyhow::Result<Self::Stream> {
|
||||||
|
match self {
|
||||||
|
MaybeTLSTransport::Yes(t) => {
|
||||||
|
let stream = t.handshake(conn).await?;
|
||||||
|
Ok(MaybeTLSStream::Yes(stream))
|
||||||
|
}
|
||||||
|
MaybeTLSTransport::No(t) => {
|
||||||
|
let stream = t.handshake(conn).await?;
|
||||||
|
Ok(MaybeTLSStream::No(stream))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn connect(&self, addr: &AddrMaybeCached) -> anyhow::Result<Self::Stream> {
|
||||||
|
match self {
|
||||||
|
MaybeTLSTransport::Yes(t) => {
|
||||||
|
let stream = t.connect(addr).await?;
|
||||||
|
Ok(MaybeTLSStream::Yes(stream))
|
||||||
|
}
|
||||||
|
MaybeTLSTransport::No(t) => {
|
||||||
|
let stream = t.connect(addr).await?;
|
||||||
|
Ok(MaybeTLSStream::No(stream))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -85,6 +85,14 @@ use rustls as tls;
|
||||||
#[cfg(any(feature = "native-tls", feature = "rustls"))]
|
#[cfg(any(feature = "native-tls", feature = "rustls"))]
|
||||||
pub(crate) use tls::TlsTransport;
|
pub(crate) use tls::TlsTransport;
|
||||||
|
|
||||||
|
#[cfg(any(
|
||||||
|
feature = "websocket-native-tls",
|
||||||
|
feature = "http2-native-tls",
|
||||||
|
feature = "websocket-rustls",
|
||||||
|
feature = "http2-rustls"
|
||||||
|
))]
|
||||||
|
mod maybe_tls;
|
||||||
|
|
||||||
#[cfg(feature = "noise")]
|
#[cfg(feature = "noise")]
|
||||||
mod noise;
|
mod noise;
|
||||||
#[cfg(feature = "noise")]
|
#[cfg(feature = "noise")]
|
||||||
|
@ -95,6 +103,11 @@ mod websocket;
|
||||||
#[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
|
#[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
|
||||||
pub use websocket::WebsocketTransport;
|
pub use websocket::WebsocketTransport;
|
||||||
|
|
||||||
|
#[cfg(any(feature = "http2-native-tls", feature = "http2-rustls"))]
|
||||||
|
mod http2;
|
||||||
|
#[cfg(any(feature = "http2-native-tls", feature = "http2-rustls"))]
|
||||||
|
pub use http2::HTTP2Transport;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
struct Keepalive {
|
struct Keepalive {
|
||||||
// tcp_keepalive_time if the underlying protocol is TCP
|
// tcp_keepalive_time if the underlying protocol is TCP
|
||||||
|
|
|
@ -110,7 +110,7 @@ impl Transport for TlsTransport {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "websocket-native-tls")]
|
#[cfg(any(feature = "websocket-native-tls", feature = "http2-native-tls"))]
|
||||||
pub(crate) fn get_tcpstream(s: &TlsStream<TcpStream>) -> &TcpStream {
|
pub(crate) fn get_tcpstream(s: &TlsStream<TcpStream>) -> &TcpStream {
|
||||||
s.get_ref().get_ref().get_ref()
|
s.get_ref().get_ref().get_ref()
|
||||||
}
|
}
|
||||||
|
|
|
@ -151,6 +151,7 @@ impl Transport for TlsTransport {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(any(feature = "websocket-rustls", feature = "http2-rustls"))]
|
||||||
pub(crate) fn get_tcpstream(s: &TlsStream<TcpStream>) -> &TcpStream {
|
pub(crate) fn get_tcpstream(s: &TlsStream<TcpStream>) -> &TcpStream {
|
||||||
&s.get_ref().0
|
&s.get_ref().0
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,8 +4,6 @@ use std::net::SocketAddr;
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::task::{ready, Context, Poll};
|
use std::task::{ready, Context, Poll};
|
||||||
|
|
||||||
use super::{AddrMaybeCached, SocketOpts, TcpTransport, TlsTransport, Transport};
|
|
||||||
use crate::config::TransportConfig;
|
|
||||||
use anyhow::anyhow;
|
use anyhow::anyhow;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
|
@ -13,78 +11,18 @@ use futures_core::stream::Stream;
|
||||||
use futures_sink::Sink;
|
use futures_sink::Sink;
|
||||||
use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
|
use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
|
||||||
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
|
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
|
||||||
|
|
||||||
#[cfg(any(feature = "native-tls", feature = "rustls"))]
|
|
||||||
use super::tls::get_tcpstream;
|
|
||||||
#[cfg(any(feature = "native-tls", feature = "rustls"))]
|
|
||||||
use super::tls::TlsStream;
|
|
||||||
|
|
||||||
use tokio_tungstenite::tungstenite::protocol::{Message, WebSocketConfig};
|
use tokio_tungstenite::tungstenite::protocol::{Message, WebSocketConfig};
|
||||||
use tokio_tungstenite::{accept_async_with_config, client_async_with_config, WebSocketStream};
|
use tokio_tungstenite::{accept_async_with_config, client_async_with_config, WebSocketStream};
|
||||||
use tokio_util::io::StreamReader;
|
use tokio_util::io::StreamReader;
|
||||||
use url::Url;
|
use url::Url;
|
||||||
|
|
||||||
#[derive(Debug)]
|
use super::maybe_tls::{MaybeTLSStream, MaybeTLSTransport};
|
||||||
enum TransportStream {
|
use super::{AddrMaybeCached, SocketOpts, Transport};
|
||||||
Insecure(TcpStream),
|
use crate::config::TransportConfig;
|
||||||
Secure(TlsStream<TcpStream>),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TransportStream {
|
|
||||||
fn get_tcpstream(&self) -> &TcpStream {
|
|
||||||
match self {
|
|
||||||
TransportStream::Insecure(s) => s,
|
|
||||||
TransportStream::Secure(s) => get_tcpstream(s),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AsyncRead for TransportStream {
|
|
||||||
fn poll_read(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
buf: &mut ReadBuf<'_>,
|
|
||||||
) -> Poll<std::io::Result<()>> {
|
|
||||||
match self.get_mut() {
|
|
||||||
TransportStream::Insecure(s) => Pin::new(s).poll_read(cx, buf),
|
|
||||||
TransportStream::Secure(s) => Pin::new(s).poll_read(cx, buf),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl AsyncWrite for TransportStream {
|
|
||||||
fn poll_write(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
buf: &[u8],
|
|
||||||
) -> Poll<Result<usize, std::io::Error>> {
|
|
||||||
match self.get_mut() {
|
|
||||||
TransportStream::Insecure(s) => Pin::new(s).poll_write(cx, buf),
|
|
||||||
TransportStream::Secure(s) => Pin::new(s).poll_write(cx, buf),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
|
|
||||||
match self.get_mut() {
|
|
||||||
TransportStream::Insecure(s) => Pin::new(s).poll_flush(cx),
|
|
||||||
TransportStream::Secure(s) => Pin::new(s).poll_flush(cx),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn poll_shutdown(
|
|
||||||
self: Pin<&mut Self>,
|
|
||||||
cx: &mut Context<'_>,
|
|
||||||
) -> Poll<Result<(), std::io::Error>> {
|
|
||||||
match self.get_mut() {
|
|
||||||
TransportStream::Insecure(s) => Pin::new(s).poll_shutdown(cx),
|
|
||||||
TransportStream::Secure(s) => Pin::new(s).poll_shutdown(cx),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct StreamWrapper {
|
struct StreamWrapper {
|
||||||
inner: WebSocketStream<TransportStream>,
|
inner: WebSocketStream<MaybeTLSStream>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Stream for StreamWrapper {
|
impl Stream for StreamWrapper {
|
||||||
|
@ -170,15 +108,9 @@ impl AsyncWrite for WebsocketTunnel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
enum SubTransport {
|
|
||||||
Secure(TlsTransport),
|
|
||||||
Insecure(TcpTransport),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct WebsocketTransport {
|
pub struct WebsocketTransport {
|
||||||
sub: SubTransport,
|
sub: MaybeTLSTransport,
|
||||||
conf: WebSocketConfig,
|
conf: WebSocketConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -198,10 +130,7 @@ impl Transport for WebsocketTransport {
|
||||||
write_buffer_size: 0,
|
write_buffer_size: 0,
|
||||||
..WebSocketConfig::default()
|
..WebSocketConfig::default()
|
||||||
};
|
};
|
||||||
let sub = match wsconfig.tls {
|
let sub = MaybeTLSTransport::new_explicit(wsconfig.tls, config)?;
|
||||||
true => SubTransport::Secure(TlsTransport::new(config)?),
|
|
||||||
false => SubTransport::Insecure(TcpTransport::new(config)?),
|
|
||||||
};
|
|
||||||
Ok(WebsocketTransport { sub, conf })
|
Ok(WebsocketTransport { sub, conf })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -213,22 +142,15 @@ impl Transport for WebsocketTransport {
|
||||||
&self,
|
&self,
|
||||||
addr: A,
|
addr: A,
|
||||||
) -> anyhow::Result<Self::Acceptor> {
|
) -> anyhow::Result<Self::Acceptor> {
|
||||||
TcpListener::bind(addr).await.map_err(Into::into)
|
self.sub.bind(addr).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn accept(&self, a: &Self::Acceptor) -> anyhow::Result<(Self::RawStream, SocketAddr)> {
|
async fn accept(&self, a: &Self::Acceptor) -> anyhow::Result<(Self::RawStream, SocketAddr)> {
|
||||||
let (s, addr) = match &self.sub {
|
self.sub.accept(a).await
|
||||||
SubTransport::Insecure(t) => t.accept(a).await?,
|
|
||||||
SubTransport::Secure(t) => t.accept(a).await?,
|
|
||||||
};
|
|
||||||
Ok((s, addr))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handshake(&self, conn: Self::RawStream) -> anyhow::Result<Self::Stream> {
|
async fn handshake(&self, conn: Self::RawStream) -> anyhow::Result<Self::Stream> {
|
||||||
let tsream = match &self.sub {
|
let tsream = self.sub.handshake(conn).await?;
|
||||||
SubTransport::Insecure(t) => TransportStream::Insecure(t.handshake(conn).await?),
|
|
||||||
SubTransport::Secure(t) => TransportStream::Secure(t.handshake(conn).await?),
|
|
||||||
};
|
|
||||||
let wsstream = accept_async_with_config(tsream, Some(self.conf)).await?;
|
let wsstream = accept_async_with_config(tsream, Some(self.conf)).await?;
|
||||||
let tun = WebsocketTunnel {
|
let tun = WebsocketTunnel {
|
||||||
inner: StreamReader::new(StreamWrapper { inner: wsstream }),
|
inner: StreamReader::new(StreamWrapper { inner: wsstream }),
|
||||||
|
@ -239,10 +161,7 @@ impl Transport for WebsocketTransport {
|
||||||
async fn connect(&self, addr: &AddrMaybeCached) -> anyhow::Result<Self::Stream> {
|
async fn connect(&self, addr: &AddrMaybeCached) -> anyhow::Result<Self::Stream> {
|
||||||
let u = format!("ws://{}", &addr.addr.as_str());
|
let u = format!("ws://{}", &addr.addr.as_str());
|
||||||
let url = Url::parse(&u).unwrap();
|
let url = Url::parse(&u).unwrap();
|
||||||
let tstream = match &self.sub {
|
let tstream = self.sub.connect(addr).await?;
|
||||||
SubTransport::Insecure(t) => TransportStream::Insecure(t.connect(addr).await?),
|
|
||||||
SubTransport::Secure(t) => TransportStream::Secure(t.connect(addr).await?),
|
|
||||||
};
|
|
||||||
let (wsstream, _) = client_async_with_config(url, tstream, Some(self.conf))
|
let (wsstream, _) = client_async_with_config(url, tstream, Some(self.conf))
|
||||||
.await
|
.await
|
||||||
.expect("failed to connect");
|
.expect("failed to connect");
|
||||||
|
|
|
@ -0,0 +1,33 @@
|
||||||
|
[client]
|
||||||
|
remote_addr = "127.0.0.1:2333"
|
||||||
|
default_token = "default_token_if_not_specify"
|
||||||
|
|
||||||
|
[client.transport]
|
||||||
|
type = "http2"
|
||||||
|
[client.transport.tls]
|
||||||
|
trusted_root = "examples/tls/rootCA.crt"
|
||||||
|
hostname = "localhost"
|
||||||
|
[client.transport.http2]
|
||||||
|
tls = true
|
||||||
|
|
||||||
|
[client.services.echo]
|
||||||
|
local_addr = "127.0.0.1:8080"
|
||||||
|
[client.services.pingpong]
|
||||||
|
local_addr = "127.0.0.1:8081"
|
||||||
|
|
||||||
|
[server]
|
||||||
|
bind_addr = "0.0.0.0:2333"
|
||||||
|
default_token = "default_token_if_not_specify"
|
||||||
|
|
||||||
|
[server.transport]
|
||||||
|
type = "http2"
|
||||||
|
[server.transport.tls]
|
||||||
|
pkcs12 = "examples/tls/identity.pfx"
|
||||||
|
pkcs12_password = "1234"
|
||||||
|
[server.transport.http2]
|
||||||
|
tls = true
|
||||||
|
|
||||||
|
[server.services.echo]
|
||||||
|
bind_addr = "0.0.0.0:2334"
|
||||||
|
[server.services.pingpong]
|
||||||
|
bind_addr = "0.0.0.0:2335"
|
|
@ -0,0 +1,27 @@
|
||||||
|
[client]
|
||||||
|
remote_addr = "127.0.0.1:2333"
|
||||||
|
default_token = "default_token_if_not_specify"
|
||||||
|
|
||||||
|
[client.transport]
|
||||||
|
type = "http2"
|
||||||
|
[client.transport.http2]
|
||||||
|
tls = false
|
||||||
|
|
||||||
|
[client.services.echo]
|
||||||
|
local_addr = "127.0.0.1:8080"
|
||||||
|
[client.services.pingpong]
|
||||||
|
local_addr = "127.0.0.1:8081"
|
||||||
|
|
||||||
|
[server]
|
||||||
|
bind_addr = "0.0.0.0:2333"
|
||||||
|
default_token = "default_token_if_not_specify"
|
||||||
|
|
||||||
|
[server.transport]
|
||||||
|
type = "http2"
|
||||||
|
[server.transport.http2]
|
||||||
|
tls = false
|
||||||
|
|
||||||
|
[server.services.echo]
|
||||||
|
bind_addr = "0.0.0.0:2334"
|
||||||
|
[server.services.pingpong]
|
||||||
|
bind_addr = "0.0.0.0:2335"
|
|
@ -0,0 +1,37 @@
|
||||||
|
[client]
|
||||||
|
remote_addr = "127.0.0.1:2332"
|
||||||
|
default_token = "default_token_if_not_specify"
|
||||||
|
|
||||||
|
[client.transport]
|
||||||
|
type = "http2"
|
||||||
|
[client.transport.tls]
|
||||||
|
trusted_root = "examples/tls/rootCA.crt"
|
||||||
|
hostname = "localhost"
|
||||||
|
[client.transport.http2]
|
||||||
|
tls = true
|
||||||
|
|
||||||
|
[client.services.echo]
|
||||||
|
type = "udp"
|
||||||
|
local_addr = "127.0.0.1:8080"
|
||||||
|
[client.services.pingpong]
|
||||||
|
type = "udp"
|
||||||
|
local_addr = "127.0.0.1:8081"
|
||||||
|
|
||||||
|
[server]
|
||||||
|
bind_addr = "0.0.0.0:2332"
|
||||||
|
default_token = "default_token_if_not_specify"
|
||||||
|
|
||||||
|
[server.transport]
|
||||||
|
type = "http2"
|
||||||
|
[server.transport.tls]
|
||||||
|
pkcs12 = "examples/tls/identity.pfx"
|
||||||
|
pkcs12_password = "1234"
|
||||||
|
[server.transport.http2]
|
||||||
|
tls = true
|
||||||
|
|
||||||
|
[server.services.echo]
|
||||||
|
type = "udp"
|
||||||
|
bind_addr = "0.0.0.0:2334"
|
||||||
|
[server.services.pingpong]
|
||||||
|
type = "udp"
|
||||||
|
bind_addr = "0.0.0.0:2335"
|
|
@ -0,0 +1,31 @@
|
||||||
|
[client]
|
||||||
|
remote_addr = "127.0.0.1:2332"
|
||||||
|
default_token = "default_token_if_not_specify"
|
||||||
|
|
||||||
|
[client.transport]
|
||||||
|
type = "http2"
|
||||||
|
[client.transport.http2]
|
||||||
|
tls = false
|
||||||
|
|
||||||
|
[client.services.echo]
|
||||||
|
type = "udp"
|
||||||
|
local_addr = "127.0.0.1:8080"
|
||||||
|
[client.services.pingpong]
|
||||||
|
type = "udp"
|
||||||
|
local_addr = "127.0.0.1:8081"
|
||||||
|
|
||||||
|
[server]
|
||||||
|
bind_addr = "0.0.0.0:2332"
|
||||||
|
default_token = "default_token_if_not_specify"
|
||||||
|
|
||||||
|
[server.transport]
|
||||||
|
type = "http2"
|
||||||
|
[server.transport.http2]
|
||||||
|
tls = false
|
||||||
|
|
||||||
|
[server.services.echo]
|
||||||
|
type = "udp"
|
||||||
|
bind_addr = "0.0.0.0:2334"
|
||||||
|
[server.services.pingpong]
|
||||||
|
type = "udp"
|
||||||
|
bind_addr = "0.0.0.0:2335"
|
|
@ -65,11 +65,16 @@ async fn tcp() -> Result<()> {
|
||||||
|
|
||||||
#[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
|
#[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
|
||||||
test("tests/for_tcp/websocket_transport.toml", Type::Tcp).await?;
|
test("tests/for_tcp/websocket_transport.toml", Type::Tcp).await?;
|
||||||
|
|
||||||
#[cfg(not(target_os = "macos"))]
|
#[cfg(not(target_os = "macos"))]
|
||||||
#[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
|
#[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
|
||||||
test("tests/for_tcp/websocket_tls_transport.toml", Type::Tcp).await?;
|
test("tests/for_tcp/websocket_tls_transport.toml", Type::Tcp).await?;
|
||||||
|
|
||||||
|
#[cfg(any(feature = "http2-native-tls", feature = "http2-rustls"))]
|
||||||
|
test("tests/for_tcp/http2_transport.toml", Type::Tcp).await?;
|
||||||
|
#[cfg(not(target_os = "macos"))]
|
||||||
|
#[cfg(any(feature = "http2-native-tls", feature = "http2-rustls"))]
|
||||||
|
test("tests/for_tcp/http2_tls_transport.toml", Type::Tcp).await?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -102,11 +107,16 @@ async fn udp() -> Result<()> {
|
||||||
|
|
||||||
#[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
|
#[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
|
||||||
test("tests/for_udp/websocket_transport.toml", Type::Udp).await?;
|
test("tests/for_udp/websocket_transport.toml", Type::Udp).await?;
|
||||||
|
|
||||||
#[cfg(not(target_os = "macos"))]
|
#[cfg(not(target_os = "macos"))]
|
||||||
#[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
|
#[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
|
||||||
test("tests/for_udp/websocket_tls_transport.toml", Type::Udp).await?;
|
test("tests/for_udp/websocket_tls_transport.toml", Type::Udp).await?;
|
||||||
|
|
||||||
|
#[cfg(any(feature = "http2-native-tls", feature = "http2-rustls"))]
|
||||||
|
test("tests/for_udp/http2_transport.toml", Type::Udp).await?;
|
||||||
|
#[cfg(not(target_os = "macos"))]
|
||||||
|
#[cfg(any(feature = "http2-native-tls", feature = "http2-rustls"))]
|
||||||
|
test("tests/for_udp/http2_tls_transport.toml", Type::Udp).await?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue