mirror of https://github.com/rapiz1/rathole.git
refactor: refactor handle_connection
This commit is contained in:
parent
8f3bf5c7c7
commit
f4b7e600bc
184
src/server.rs
184
src/server.rs
|
@ -129,92 +129,114 @@ async fn handle_connection(
|
|||
let hello = read_hello(&mut conn).await?;
|
||||
match hello {
|
||||
ControlChannelHello(_, service_digest) => {
|
||||
info!("New control channel incomming from {}", addr);
|
||||
|
||||
// Generate a nonce
|
||||
let mut nonce = vec![0u8; HASH_WIDTH_IN_BYTES];
|
||||
rand::thread_rng().fill_bytes(&mut nonce);
|
||||
|
||||
// Send hello
|
||||
let hello_send = Hello::ControlChannelHello(
|
||||
protocol::CURRENT_PROTO_VRESION,
|
||||
nonce.clone().try_into().unwrap(),
|
||||
);
|
||||
conn.write_all(&bincode::serialize(&hello_send).unwrap())
|
||||
do_control_channel_handshake(conn, addr, services, control_channels, service_digest)
|
||||
.await?;
|
||||
|
||||
// Lookup the service
|
||||
let services_guard = services.read().await;
|
||||
let service_config = match services_guard.get(&service_digest) {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
conn.write_all(&bincode::serialize(&Ack::ServiceNotExist).unwrap())
|
||||
.await?;
|
||||
bail!("No such a service {}", hex::encode(&service_digest));
|
||||
}
|
||||
};
|
||||
let service_name = &service_config.name;
|
||||
|
||||
// Calculate the checksum
|
||||
let mut concat = Vec::from(service_config.token.as_ref().unwrap().as_bytes());
|
||||
concat.append(&mut nonce);
|
||||
|
||||
// Read auth
|
||||
let d = match read_auth(&mut conn).await? {
|
||||
protocol::Auth(v) => v,
|
||||
};
|
||||
|
||||
// Validate
|
||||
let session_key = protocol::digest(&concat);
|
||||
if session_key != d {
|
||||
conn.write_all(&bincode::serialize(&Ack::AuthFailed).unwrap())
|
||||
.await?;
|
||||
debug!(
|
||||
"Expect {}, but got {}",
|
||||
hex::encode(session_key),
|
||||
hex::encode(d)
|
||||
);
|
||||
bail!("Service {} failed the authentication", service_name);
|
||||
} else {
|
||||
let mut h = control_channels.write().await;
|
||||
|
||||
if let Some(_) = h.remove1(&service_digest) {
|
||||
warn!(
|
||||
"Dropping previous control channel for digest {}",
|
||||
hex::encode(service_digest)
|
||||
);
|
||||
}
|
||||
|
||||
let service_config = service_config.clone();
|
||||
drop(services_guard);
|
||||
|
||||
// Send ack
|
||||
conn.write_all(&bincode::serialize(&Ack::Ok).unwrap())
|
||||
.await?;
|
||||
|
||||
info!(service = %service_config.name, "Control channel established");
|
||||
let handle = ControlChannelHandle::new(conn, service_config);
|
||||
|
||||
// Drop the old handle
|
||||
let _ = h.insert(service_digest, session_key, handle);
|
||||
}
|
||||
}
|
||||
DataChannelHello(_, nonce) => {
|
||||
// Validate
|
||||
let control_channels_guard = control_channels.read().await;
|
||||
match control_channels_guard.get2(&nonce) {
|
||||
Some(c_ch) => {
|
||||
if let Err(e) = set_tcp_keepalive(&conn) {
|
||||
error!("The connection may be unstable! {:?}", e);
|
||||
}
|
||||
do_data_channel_handshake(conn, control_channels, nonce).await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Send the data channel to the corresponding control channel
|
||||
c_ch.conn_pool.data_ch_tx.send(conn).await?;
|
||||
}
|
||||
None => {
|
||||
warn!("Data channel has incorrect nonce");
|
||||
}
|
||||
async fn do_control_channel_handshake(
|
||||
mut conn: TcpStream,
|
||||
addr: SocketAddr,
|
||||
services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
|
||||
control_channels: Arc<RwLock<ControlChannelMap>>,
|
||||
service_digest: ServiceDigest,
|
||||
) -> Result<()> {
|
||||
info!("New control channel incomming from {}", addr);
|
||||
|
||||
// Generate a nonce
|
||||
let mut nonce = vec![0u8; HASH_WIDTH_IN_BYTES];
|
||||
rand::thread_rng().fill_bytes(&mut nonce);
|
||||
|
||||
// Send hello
|
||||
let hello_send = Hello::ControlChannelHello(
|
||||
protocol::CURRENT_PROTO_VRESION,
|
||||
nonce.clone().try_into().unwrap(),
|
||||
);
|
||||
conn.write_all(&bincode::serialize(&hello_send).unwrap())
|
||||
.await?;
|
||||
|
||||
// Lookup the service
|
||||
let services_guard = services.read().await;
|
||||
let service_config = match services_guard.get(&service_digest) {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
conn.write_all(&bincode::serialize(&Ack::ServiceNotExist).unwrap())
|
||||
.await?;
|
||||
bail!("No such a service {}", hex::encode(&service_digest));
|
||||
}
|
||||
};
|
||||
let service_name = &service_config.name;
|
||||
|
||||
// Calculate the checksum
|
||||
let mut concat = Vec::from(service_config.token.as_ref().unwrap().as_bytes());
|
||||
concat.append(&mut nonce);
|
||||
|
||||
// Read auth
|
||||
let d = match read_auth(&mut conn).await? {
|
||||
protocol::Auth(v) => v,
|
||||
};
|
||||
|
||||
// Validate
|
||||
let session_key = protocol::digest(&concat);
|
||||
if session_key != d {
|
||||
conn.write_all(&bincode::serialize(&Ack::AuthFailed).unwrap())
|
||||
.await?;
|
||||
debug!(
|
||||
"Expect {}, but got {}",
|
||||
hex::encode(session_key),
|
||||
hex::encode(d)
|
||||
);
|
||||
bail!("Service {} failed the authentication", service_name);
|
||||
} else {
|
||||
let mut h = control_channels.write().await;
|
||||
|
||||
if let Some(_) = h.remove1(&service_digest) {
|
||||
warn!(
|
||||
"Dropping previous control channel for digest {}",
|
||||
hex::encode(service_digest)
|
||||
);
|
||||
}
|
||||
|
||||
let service_config = service_config.clone();
|
||||
drop(services_guard);
|
||||
|
||||
// Send ack
|
||||
conn.write_all(&bincode::serialize(&Ack::Ok).unwrap())
|
||||
.await?;
|
||||
|
||||
info!(service = %service_config.name, "Control channel established");
|
||||
let handle = ControlChannelHandle::new(conn, service_config);
|
||||
|
||||
// Drop the old handle
|
||||
let _ = h.insert(service_digest, session_key, handle);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn do_data_channel_handshake(
|
||||
conn: TcpStream,
|
||||
control_channels: Arc<RwLock<ControlChannelMap>>,
|
||||
nonce: Nonce,
|
||||
) -> Result<()> {
|
||||
// Validate
|
||||
let control_channels_guard = control_channels.read().await;
|
||||
match control_channels_guard.get2(&nonce) {
|
||||
Some(c_ch) => {
|
||||
if let Err(e) = set_tcp_keepalive(&conn) {
|
||||
error!("The connection may be unstable! {:?}", e);
|
||||
}
|
||||
|
||||
// Send the data channel to the corresponding control channel
|
||||
c_ch.conn_pool.data_ch_tx.send(conn).await?;
|
||||
}
|
||||
None => {
|
||||
warn!("Data channel has incorrect nonce");
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
|
|
Loading…
Reference in New Issue