rathole/src/config.rs

357 lines
9.8 KiB
Rust
Raw Normal View History

2021-12-11 13:30:42 +01:00
use anyhow::{anyhow, bail, Context, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
2021-12-11 13:30:42 +01:00
use tokio::fs;
#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
2021-12-15 08:12:36 +01:00
pub enum TransportType {
#[serde(rename = "tcp")]
Tcp,
#[serde(rename = "tls")]
Tls,
2021-12-23 16:40:20 +01:00
#[serde(rename = "noise")]
Noise,
2021-12-11 13:30:42 +01:00
}
2021-12-15 08:12:36 +01:00
impl Default for TransportType {
fn default() -> TransportType {
TransportType::Tcp
}
2021-12-11 13:30:42 +01:00
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ClientServiceConfig {
2021-12-21 14:11:46 +01:00
#[serde(rename = "type", default = "default_service_type")]
pub service_type: ServiceType,
2021-12-11 13:30:42 +01:00
#[serde(skip)]
pub name: String,
pub local_addr: String,
pub token: Option<String>,
}
2021-12-21 14:11:46 +01:00
#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
pub enum ServiceType {
#[serde(rename = "tcp")]
Tcp,
#[serde(rename = "udp")]
Udp,
}
fn default_service_type() -> ServiceType {
ServiceType::Tcp
}
2021-12-11 13:30:42 +01:00
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ServerServiceConfig {
2021-12-21 14:11:46 +01:00
#[serde(rename = "type", default = "default_service_type")]
pub service_type: ServiceType,
2021-12-11 13:30:42 +01:00
#[serde(skip)]
pub name: String,
pub bind_addr: String,
pub token: Option<String>,
2021-12-15 08:12:36 +01:00
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TlsConfig {
pub hostname: Option<String>,
pub trusted_root: Option<String>,
pub pkcs12: Option<String>,
pub pkcs12_password: Option<String>,
}
2021-12-23 16:40:20 +01:00
fn default_noise_pattern() -> String {
String::from("Noise_NK_25519_ChaChaPoly_BLAKE2s")
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct NoiseConfig {
#[serde(default = "default_noise_pattern")]
pub pattern: String,
pub local_private_key: Option<String>,
pub remote_public_key: Option<String>,
// TODO: Maybe psk can be added
}
2021-12-15 08:12:36 +01:00
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct TransportConfig {
#[serde(rename = "type")]
pub transport_type: TransportType,
pub tls: Option<TlsConfig>,
2021-12-23 16:40:20 +01:00
pub noise: Option<NoiseConfig>,
2021-12-15 08:12:36 +01:00
}
fn default_transport() -> TransportConfig {
Default::default()
2021-12-11 13:30:42 +01:00
}
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct ClientConfig {
pub remote_addr: String,
pub default_token: Option<String>,
pub services: HashMap<String, ClientServiceConfig>,
2021-12-15 08:12:36 +01:00
#[serde(default = "default_transport")]
pub transport: TransportConfig,
2021-12-11 13:30:42 +01:00
}
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct ServerConfig {
pub bind_addr: String,
pub default_token: Option<String>,
pub services: HashMap<String, ServerServiceConfig>,
2021-12-15 08:12:36 +01:00
#[serde(default = "default_transport")]
pub transport: TransportConfig,
2021-12-11 13:30:42 +01:00
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct Config {
pub server: Option<ServerConfig>,
pub client: Option<ClientConfig>,
}
impl Config {
fn from_str(s: &str) -> Result<Config> {
let mut config: Config = toml::from_str(s).with_context(|| "Failed to parse the config")?;
2021-12-15 08:12:36 +01:00
2021-12-11 13:30:42 +01:00
if let Some(server) = config.server.as_mut() {
2021-12-15 08:12:36 +01:00
Config::validate_server_config(server)?;
}
if let Some(client) = config.client.as_mut() {
Config::validate_client_config(client)?;
}
if config.server.is_none() && config.client.is_none() {
Err(anyhow!("Neither of `[server]` or `[client]` is defined"))
} else {
Ok(config)
}
}
fn validate_server_config(server: &mut ServerConfig) -> Result<()> {
// Validate services
for (name, s) in &mut server.services {
s.name = name.clone();
if s.token.is_none() {
s.token = server.default_token.clone();
2021-12-11 13:30:42 +01:00
if s.token.is_none() {
2021-12-15 08:12:36 +01:00
bail!("The token of service {} is not set", name);
2021-12-11 13:30:42 +01:00
}
}
}
2021-12-15 08:12:36 +01:00
Config::validate_transport_config(&server.transport, true)?;
Ok(())
}
fn validate_client_config(client: &mut ClientConfig) -> Result<()> {
// Validate services
for (name, s) in &mut client.services {
s.name = name.clone();
if s.token.is_none() {
s.token = client.default_token.clone();
2021-12-11 13:30:42 +01:00
if s.token.is_none() {
2021-12-15 08:12:36 +01:00
bail!("The token of service {} is not set", name);
2021-12-11 13:30:42 +01:00
}
}
}
2021-12-15 08:12:36 +01:00
Config::validate_transport_config(&client.transport, false)?;
Ok(())
}
fn validate_transport_config(config: &TransportConfig, is_server: bool) -> Result<()> {
match config.transport_type {
TransportType::Tcp => Ok(()),
TransportType::Tls => {
let tls_config = config
.tls
.as_ref()
.ok_or(anyhow!("Missing TLS configuration"))?;
if is_server {
tls_config
.pkcs12
.as_ref()
.and(tls_config.pkcs12_password.as_ref())
.ok_or(anyhow!("Missing `pkcs12` or `pkcs12_password`"))?;
} else {
tls_config
.trusted_root
.as_ref()
.ok_or(anyhow!("Missing `trusted_root`"))?;
}
Ok(())
}
2021-12-23 16:40:20 +01:00
TransportType::Noise => {
// The check is done in transport
Ok(())
}
2021-12-11 13:30:42 +01:00
}
}
pub async fn from_file(path: &Path) -> Result<Config> {
2021-12-11 13:30:42 +01:00
let s: String = fs::read_to_string(path)
.await
.with_context(|| format!("Failed to read the config {:?}", path))?;
Config::from_str(&s).with_context(|| {
"Configuration is invalid. Please refer to the configuration specification."
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::{fs, path::PathBuf};
2021-12-11 13:30:42 +01:00
use anyhow::Result;
fn list_config_files<T: AsRef<Path>>(root: T) -> Result<Vec<PathBuf>> {
let mut files = Vec::new();
for entry in fs::read_dir(root)? {
let entry = entry?;
let path = entry.path();
if path.is_file() {
files.push(path);
} else if path.is_dir() {
files.append(&mut list_config_files(path)?);
}
}
Ok(files)
}
fn get_all_example_config() -> Result<Vec<PathBuf>> {
Ok(list_config_files("./examples")?
.into_iter()
.filter(|x| x.ends_with(".toml"))
.collect())
}
2021-12-11 13:30:42 +01:00
#[test]
fn test_example_config() -> Result<()> {
let paths = get_all_example_config()?;
for p in paths {
let s = fs::read_to_string(p)?;
Config::from_str(&s)?;
}
2021-12-11 13:30:42 +01:00
Ok(())
}
#[test]
fn test_valid_config() -> Result<()> {
let paths = list_config_files("tests/config_test/valid_config")?;
for p in paths {
let s = fs::read_to_string(p)?;
Config::from_str(&s)?;
}
Ok(())
}
#[test]
fn test_invalid_config() -> Result<()> {
let paths = list_config_files("tests/config_test/invalid_config")?;
for p in paths {
let s = fs::read_to_string(p)?;
assert!(Config::from_str(&s).is_err());
}
Ok(())
}
#[test]
fn test_validate_server_config() -> Result<()> {
let mut cfg = ServerConfig::default();
cfg.services.insert(
"foo1".into(),
ServerServiceConfig {
2021-12-21 14:11:46 +01:00
service_type: ServiceType::Tcp,
name: "foo1".into(),
bind_addr: "127.0.0.1:80".into(),
token: None,
},
);
// Missing the token
assert!(Config::validate_server_config(&mut cfg).is_err());
// Use the default token
cfg.default_token = Some("123".into());
assert!(Config::validate_server_config(&mut cfg).is_ok());
assert_eq!(
cfg.services
.get("foo1")
.as_ref()
.unwrap()
.token
.as_ref()
.unwrap(),
"123"
);
// The default token won't override the service token
cfg.services.get_mut("foo1").unwrap().token = Some("4".into());
assert!(Config::validate_server_config(&mut cfg).is_ok());
assert_eq!(
cfg.services
.get("foo1")
.as_ref()
.unwrap()
.token
.as_ref()
.unwrap(),
"4"
);
Ok(())
}
#[test]
fn test_validate_client_config() -> Result<()> {
let mut cfg = ClientConfig::default();
cfg.services.insert(
"foo1".into(),
ClientServiceConfig {
2021-12-21 14:11:46 +01:00
service_type: ServiceType::Tcp,
name: "foo1".into(),
local_addr: "127.0.0.1:80".into(),
token: None,
},
);
// Missing the token
assert!(Config::validate_client_config(&mut cfg).is_err());
// Use the default token
cfg.default_token = Some("123".into());
assert!(Config::validate_client_config(&mut cfg).is_ok());
assert_eq!(
cfg.services
.get("foo1")
.as_ref()
.unwrap()
.token
.as_ref()
.unwrap(),
"123"
);
// The default token won't override the service token
cfg.services.get_mut("foo1").unwrap().token = Some("4".into());
assert!(Config::validate_client_config(&mut cfg).is_ok());
assert_eq!(
cfg.services
.get("foo1")
.as_ref()
.unwrap()
.token
.as_ref()
.unwrap(),
"4"
);
2021-12-11 13:30:42 +01:00
Ok(())
}
}