From d127bd10a6caa3b7d5eced3693ec1ef3cb9623fb Mon Sep 17 00:00:00 2001 From: Leo Nash Date: Sat, 6 Dec 2025 01:47:59 +0000 Subject: [PATCH 1/3] Add option to verify JWT tokens in the HTTP Authorization header --- rust/Cargo.lock | 1 + rust/auth-impls/src/lib.rs | 4 ++- rust/server/Cargo.toml | 1 + rust/server/src/main.rs | 55 +++++++++++++++++++++--------- rust/server/src/util/config.rs | 3 +- rust/server/vss-server-config.toml | 1 + 6 files changed, 46 insertions(+), 19 deletions(-) diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 6c37e11..4e32722 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1981,6 +1981,7 @@ name = "vss-server" version = "0.1.0" dependencies = [ "api", + "auth-impls", "bytes", "http-body-util", "hyper 1.4.1", diff --git a/rust/auth-impls/src/lib.rs b/rust/auth-impls/src/lib.rs index 86c6853..2bc0fb3 100644 --- a/rust/auth-impls/src/lib.rs +++ b/rust/auth-impls/src/lib.rs @@ -14,10 +14,12 @@ use api::auth::{AuthResponse, Authorizer}; use api::error::VssError; use async_trait::async_trait; -use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation}; +use jsonwebtoken::{decode, Algorithm, Validation}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; +pub use jsonwebtoken::DecodingKey; + /// A JWT based authorizer, only allows requests with verified 'JsonWebToken' signed by the given /// issuer key. /// diff --git a/rust/server/Cargo.toml b/rust/server/Cargo.toml index 2a0e6f1..6c66812 100644 --- a/rust/server/Cargo.toml +++ b/rust/server/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [dependencies] api = { path = "../api" } +auth-impls = { path = "../auth-impls" } impls = { path = "../impls" } hyper = { version = "1", default-features = false, features = ["server", "http1"] } diff --git a/rust/server/src/main.rs b/rust/server/src/main.rs index 38fdccd..88a5dc0 100644 --- a/rust/server/src/main.rs +++ b/rust/server/src/main.rs @@ -20,11 +20,14 @@ use hyper_util::rt::TokioIo; use crate::vss_service::VssService; use api::auth::{Authorizer, NoopAuthorizer}; use api::kv_store::KvStore; +use auth_impls::{DecodingKey, JWTAuthorizer}; use impls::postgres_store::{Certificate, PostgresPlaintextBackend, PostgresTlsBackend}; use std::sync::Arc; -pub(crate) mod util; -pub(crate) mod vss_service; +mod util; +mod vss_service; + +use util::config::{Config, ServerConfig}; fn main() { let args: Vec = std::env::args().collect(); @@ -33,22 +36,21 @@ fn main() { std::process::exit(1); } - let config = match util::config::load_config(&args[1]) { - Ok(cfg) => cfg, - Err(e) => { - eprintln!("Failed to load configuration: {}", e); - std::process::exit(1); - }, - }; - - let addr: SocketAddr = - match format!("{}:{}", config.server_config.host, config.server_config.port).parse() { - Ok(addr) => addr, + let Config { server_config: ServerConfig { host, port, rsa_pub_file_path }, postgresql_config } = + match util::config::load_config(&args[1]) { + Ok(cfg) => cfg, Err(e) => { - eprintln!("Invalid host/port configuration: {}", e); + eprintln!("Failed to load configuration: {}", e); std::process::exit(1); }, }; + let addr: SocketAddr = match format!("{}:{}", host, port).parse() { + Ok(addr) => addr, + Err(e) => { + eprintln!("Invalid host/port configuration: {}", e); + std::process::exit(1); + }, + }; let runtime = match tokio::runtime::Builder::new_multi_thread().enable_all().build() { Ok(runtime) => Arc::new(runtime), @@ -66,9 +68,27 @@ fn main() { std::process::exit(-1); }, }; - let authorizer: Arc = Arc::new(NoopAuthorizer {}); - let postgresql_config = - config.postgresql_config.expect("PostgreSQLConfig must be defined in config file."); + + let authorizer: Arc = if let Some(file_path) = rsa_pub_file_path { + let rsa_pub_file = match std::fs::read(file_path) { + Ok(pem) => pem, + Err(e) => { + println!("Failed to read RSA public key file: {}", e); + std::process::exit(-1); + }, + }; + let rsa_public_key = match DecodingKey::from_rsa_pem(&rsa_pub_file) { + Ok(pem) => pem, + Err(e) => { + println!("Failed to parse RSA public key file: {}", e); + std::process::exit(-1); + }, + }; + Arc::new(JWTAuthorizer::new(rsa_public_key).await) + } else { + Arc::new(NoopAuthorizer {}) + }; + let endpoint = postgresql_config.to_postgresql_endpoint(); let db_name = postgresql_config.database; let store: Arc = if let Some(tls_config) = postgresql_config.tls { @@ -109,6 +129,7 @@ fn main() { Arc::new(postgres_plaintext_backend) }; println!("Connected to PostgreSQL backend with DSN: {}/{}", endpoint, db_name); + let rest_svc_listener = TcpListener::bind(&addr).await.expect("Failed to bind listening port"); println!("Listening for incoming connections on {}", addr); diff --git a/rust/server/src/util/config.rs b/rust/server/src/util/config.rs index 801d1bd..9a79248 100644 --- a/rust/server/src/util/config.rs +++ b/rust/server/src/util/config.rs @@ -3,13 +3,14 @@ use serde::Deserialize; #[derive(Deserialize)] pub(crate) struct Config { pub(crate) server_config: ServerConfig, - pub(crate) postgresql_config: Option, + pub(crate) postgresql_config: PostgreSQLConfig, } #[derive(Deserialize)] pub(crate) struct ServerConfig { pub(crate) host: String, pub(crate) port: u16, + pub(crate) rsa_pub_file_path: Option, } #[derive(Deserialize)] diff --git a/rust/server/vss-server-config.toml b/rust/server/vss-server-config.toml index 8c3d9c0..76a4da4 100644 --- a/rust/server/vss-server-config.toml +++ b/rust/server/vss-server-config.toml @@ -1,6 +1,7 @@ [server_config] host = "127.0.0.1" port = 8080 +# rsa_pub_file_path = "rsa_public_key.pem" # Uncomment to verify JWT tokens in the HTTP Authorization header [postgresql_config] username = "postgres" # Optional in TOML, can be overridden by env var `VSS_POSTGRESQL_USERNAME` From 55b7f543d2e0a0a5cd0d4177f847a51724612191 Mon Sep 17 00:00:00 2001 From: Leo Nash Date: Sat, 6 Dec 2025 03:10:59 +0000 Subject: [PATCH 2/3] Add option to specify default postgres db name Not all postgres hosted services use `postgres` Also rename `database` config parameter to `vss_database` --- rust/impls/src/postgres_store.rs | 88 +++++++++++++++++++----------- rust/server/src/main.rs | 13 +++-- rust/server/src/util/config.rs | 3 +- rust/server/vss-server-config.toml | 3 +- 4 files changed, 68 insertions(+), 39 deletions(-) diff --git a/rust/impls/src/postgres_store.rs b/rust/impls/src/postgres_store.rs index 4424a00..f27f929 100644 --- a/rust/impls/src/postgres_store.rs +++ b/rust/impls/src/postgres_store.rs @@ -64,14 +64,16 @@ pub type PostgresPlaintextBackend = PostgresBackend; /// A postgres backend with TLS connections to the database pub type PostgresTlsBackend = PostgresBackend; -async fn make_postgres_db_connection(postgres_endpoint: &str, tls: T) -> Result +async fn make_db_connection( + postgres_endpoint: &str, db_name: &str, tls: T, +) -> Result where T: MakeTlsConnect + Clone + Send + Sync + 'static, T::Stream: Send + Sync, T::TlsConnect: Send, <>::TlsConnect as TlsConnect>::Future: Send, { - let dsn = format!("{}/{}", postgres_endpoint, "postgres"); + let dsn = format!("{}/{}", postgres_endpoint, db_name); let (client, connection) = tokio_postgres::connect(&dsn, tls) .await .map_err(|e| Error::new(ErrorKind::Other, format!("Connection error: {}", e)))?; @@ -84,8 +86,8 @@ where Ok(client) } -async fn initialize_vss_database( - postgres_endpoint: &str, db_name: &str, tls: T, +async fn create_database( + postgres_endpoint: &str, default_db: &str, db_name: &str, tls: T, ) -> Result<(), Error> where T: MakeTlsConnect + Clone + Send + Sync + 'static, @@ -93,7 +95,7 @@ where T::TlsConnect: Send, <>::TlsConnect as TlsConnect>::Future: Send, { - let client = make_postgres_db_connection(&postgres_endpoint, tls).await?; + let client = make_db_connection(postgres_endpoint, default_db, tls).await?; let num_rows = client.execute(CHECK_DB_STMT, &[&db_name]).await.map_err(|e| { Error::new( @@ -113,14 +115,16 @@ where } #[cfg(test)] -async fn drop_database(postgres_endpoint: &str, db_name: &str, tls: T) -> Result<(), Error> +async fn drop_database( + postgres_endpoint: &str, default_db: &str, db_name: &str, tls: T, +) -> Result<(), Error> where T: MakeTlsConnect + Clone + Send + Sync + 'static, T::Stream: Send + Sync, T::TlsConnect: Send, <>::TlsConnect as TlsConnect>::Future: Send, { - let client = make_postgres_db_connection(&postgres_endpoint, tls).await?; + let client = make_db_connection(postgres_endpoint, default_db, tls).await?; let drop_database_statement = format!("{} {};", DROP_DB_CMD, db_name); let num_rows = client.execute(&drop_database_statement, &[]).await.map_err(|e| { @@ -133,15 +137,18 @@ where impl PostgresPlaintextBackend { /// Constructs a [`PostgresPlaintextBackend`] using `postgres_endpoint` for PostgreSQL connection information. - pub async fn new(postgres_endpoint: &str, db_name: &str) -> Result { - PostgresBackend::new_internal(postgres_endpoint, db_name, NoTls).await + pub async fn new( + postgres_endpoint: &str, default_db: &str, vss_db: &str, + ) -> Result { + PostgresBackend::new_internal(postgres_endpoint, default_db, vss_db, NoTls).await } } impl PostgresTlsBackend { /// Constructs a [`PostgresTlsBackend`] using `postgres_endpoint` for PostgreSQL connection information. pub async fn new( - postgres_endpoint: &str, db_name: &str, additional_certificate: Option, + postgres_endpoint: &str, default_db: &str, vss_db: &str, + additional_certificate: Option, ) -> Result { let mut builder = TlsConnector::builder(); if let Some(cert) = additional_certificate { @@ -150,8 +157,13 @@ impl PostgresTlsBackend { let connector = builder.build().map_err(|e| { Error::new(ErrorKind::Other, format!("Error building tls connector: {}", e)) })?; - PostgresBackend::new_internal(postgres_endpoint, db_name, MakeTlsConnector::new(connector)) - .await + PostgresBackend::new_internal( + postgres_endpoint, + default_db, + vss_db, + MakeTlsConnector::new(connector), + ) + .await } } @@ -162,9 +174,11 @@ where T::TlsConnect: Send, <>::TlsConnect as TlsConnect>::Future: Send, { - async fn new_internal(postgres_endpoint: &str, db_name: &str, tls: T) -> Result { - initialize_vss_database(postgres_endpoint, db_name, tls.clone()).await?; - let vss_dsn = format!("{}/{}", postgres_endpoint, db_name); + async fn new_internal( + postgres_endpoint: &str, default_db: &str, vss_db: &str, tls: T, + ) -> Result { + create_database(postgres_endpoint, default_db, vss_db, tls.clone()).await?; + let vss_dsn = format!("{}/{}", postgres_endpoint, vss_db); let manager = PostgresConnectionManager::new_from_stringlike(vss_dsn, tls).map_err(|e| { Error::new( @@ -649,24 +663,27 @@ mod tests { use tokio_postgres::NoTls; const POSTGRES_ENDPOINT: &str = "postgresql://postgres:postgres@localhost:5432"; + const DEFAULT_DB: &str = "postgres"; const MIGRATIONS_START: usize = 0; const MIGRATIONS_END: usize = MIGRATIONS.len(); static START: OnceCell<()> = OnceCell::const_new(); define_kv_store_tests!(PostgresKvStoreTest, PostgresPlaintextBackend, { - let db_name = "postgres_kv_store_tests"; + let vss_db = "postgres_kv_store_tests"; START .get_or_init(|| async { - let _ = drop_database(POSTGRES_ENDPOINT, db_name, NoTls).await; - let store = - PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap(); + let _ = drop_database(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db, NoTls).await; + let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db) + .await + .unwrap(); let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap(); assert_eq!(start, MIGRATIONS_START); assert_eq!(end, MIGRATIONS_END); }) .await; - let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap(); + let store = + PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db).await.unwrap(); let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap(); assert_eq!(start, MIGRATIONS_END); assert_eq!(end, MIGRATIONS_END); @@ -678,28 +695,31 @@ mod tests { #[tokio::test] #[should_panic(expected = "We do not allow downgrades")] async fn panic_on_downgrade() { - let db_name = "panic_on_downgrade_test"; - let _ = drop_database(POSTGRES_ENDPOINT, db_name, NoTls).await; + let vss_db = "panic_on_downgrade_test"; + let _ = drop_database(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db, NoTls).await; { let mut migrations = MIGRATIONS.to_vec(); migrations.push(DUMMY_MIGRATION); - let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap(); + let store = + PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db).await.unwrap(); let (start, end) = store.migrate_vss_database(&migrations).await.unwrap(); assert_eq!(start, MIGRATIONS_START); assert_eq!(end, MIGRATIONS_END + 1); }; { - let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap(); + let store = + PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db).await.unwrap(); let _ = store.migrate_vss_database(MIGRATIONS).await.unwrap(); }; } #[tokio::test] async fn new_migrations_increments_upgrades() { - let db_name = "new_migrations_increments_upgrades_test"; - let _ = drop_database(POSTGRES_ENDPOINT, db_name, NoTls).await; + let vss_db = "new_migrations_increments_upgrades_test"; + let _ = drop_database(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db, NoTls).await; { - let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap(); + let store = + PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db).await.unwrap(); let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap(); assert_eq!(start, MIGRATIONS_START); assert_eq!(end, MIGRATIONS_END); @@ -707,7 +727,8 @@ mod tests { assert_eq!(store.get_schema_version().await, MIGRATIONS_END); }; { - let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap(); + let store = + PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db).await.unwrap(); let (start, end) = store.migrate_vss_database(MIGRATIONS).await.unwrap(); assert_eq!(start, MIGRATIONS_END); assert_eq!(end, MIGRATIONS_END); @@ -718,7 +739,8 @@ mod tests { let mut migrations = MIGRATIONS.to_vec(); migrations.push(DUMMY_MIGRATION); { - let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap(); + let store = + PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db).await.unwrap(); let (start, end) = store.migrate_vss_database(&migrations).await.unwrap(); assert_eq!(start, MIGRATIONS_END); assert_eq!(end, MIGRATIONS_END + 1); @@ -729,7 +751,8 @@ mod tests { migrations.push(DUMMY_MIGRATION); migrations.push(DUMMY_MIGRATION); { - let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap(); + let store = + PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db).await.unwrap(); let (start, end) = store.migrate_vss_database(&migrations).await.unwrap(); assert_eq!(start, MIGRATIONS_END + 1); assert_eq!(end, MIGRATIONS_END + 3); @@ -741,13 +764,14 @@ mod tests { }; { - let store = PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, db_name).await.unwrap(); + let store = + PostgresPlaintextBackend::new(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db).await.unwrap(); let list = store.get_upgrades_list().await; assert_eq!(list, [MIGRATIONS_START, MIGRATIONS_END, MIGRATIONS_END + 1]); let version = store.get_schema_version().await; assert_eq!(version, MIGRATIONS_END + 3); } - drop_database(POSTGRES_ENDPOINT, db_name, NoTls).await.unwrap(); + drop_database(POSTGRES_ENDPOINT, DEFAULT_DB, vss_db, NoTls).await.unwrap(); } } diff --git a/rust/server/src/main.rs b/rust/server/src/main.rs index 88a5dc0..cc17f76 100644 --- a/rust/server/src/main.rs +++ b/rust/server/src/main.rs @@ -90,9 +90,10 @@ fn main() { }; let endpoint = postgresql_config.to_postgresql_endpoint(); - let db_name = postgresql_config.database; + let default_db = postgresql_config.default_database; + let vss_db = postgresql_config.vss_database; let store: Arc = if let Some(tls_config) = postgresql_config.tls { - let additional_certificate = tls_config.ca_file.map(|file| { + let addl_certificate = tls_config.ca_file.map(|file| { let certificate = match std::fs::read(&file) { Ok(cert) => cert, Err(e) => { @@ -109,7 +110,9 @@ fn main() { } }); let postgres_tls_backend = - match PostgresTlsBackend::new(&endpoint, &db_name, additional_certificate).await { + match PostgresTlsBackend::new(&endpoint, &default_db, &vss_db, addl_certificate) + .await + { Ok(backend) => backend, Err(e) => { println!("Failed to start postgres tls backend: {}", e); @@ -119,7 +122,7 @@ fn main() { Arc::new(postgres_tls_backend) } else { let postgres_plaintext_backend = - match PostgresPlaintextBackend::new(&endpoint, &db_name).await { + match PostgresPlaintextBackend::new(&endpoint, &default_db, &vss_db).await { Ok(backend) => backend, Err(e) => { println!("Failed to start postgres plaintext backend: {}", e); @@ -128,7 +131,7 @@ fn main() { }; Arc::new(postgres_plaintext_backend) }; - println!("Connected to PostgreSQL backend with DSN: {}/{}", endpoint, db_name); + println!("Connected to PostgreSQL backend with DSN: {}/{}", endpoint, vss_db); let rest_svc_listener = TcpListener::bind(&addr).await.expect("Failed to bind listening port"); diff --git a/rust/server/src/util/config.rs b/rust/server/src/util/config.rs index 9a79248..13e1ee2 100644 --- a/rust/server/src/util/config.rs +++ b/rust/server/src/util/config.rs @@ -19,7 +19,8 @@ pub(crate) struct PostgreSQLConfig { pub(crate) password: Option, // Optional in TOML, can be overridden by env pub(crate) host: String, pub(crate) port: u16, - pub(crate) database: String, + pub(crate) default_database: String, + pub(crate) vss_database: String, pub(crate) tls: Option, } diff --git a/rust/server/vss-server-config.toml b/rust/server/vss-server-config.toml index 76a4da4..e549f0d 100644 --- a/rust/server/vss-server-config.toml +++ b/rust/server/vss-server-config.toml @@ -8,6 +8,7 @@ username = "postgres" # Optional in TOML, can be overridden by env var `VSS_POS password = "postgres" # Optional in TOML, can be overridden by env var `VSS_POSTGRESQL_PASSWORD` host = "localhost" port = 5432 -database = "postgres" +default_database = "postgres" +vss_database = "vss" # tls = { } # Uncomment to make TLS connections to the postgres database using your machine's PKI # tls = { ca_file = "ca.pem" } # Uncomment to make TLS connections to the postgres database with an additional root certificate From 485243df8b0d218cf203dc4213b2889b77610503 Mon Sep 17 00:00:00 2001 From: Leo Nash Date: Sat, 6 Dec 2025 05:03:35 +0000 Subject: [PATCH 3/3] Add env var override to all postgresql configuration settings Also consolidate host and port settings into single socket address settings. --- rust/impls/src/postgres_store.rs | 5 +- rust/server/src/main.rs | 130 +++++++++--------------- rust/server/src/util/config.rs | 157 ++++++++++++++++++++++------- rust/server/vss-server-config.toml | 20 ++-- 4 files changed, 182 insertions(+), 130 deletions(-) diff --git a/rust/impls/src/postgres_store.rs b/rust/impls/src/postgres_store.rs index f27f929..494c89d 100644 --- a/rust/impls/src/postgres_store.rs +++ b/rust/impls/src/postgres_store.rs @@ -147,11 +147,10 @@ impl PostgresPlaintextBackend { impl PostgresTlsBackend { /// Constructs a [`PostgresTlsBackend`] using `postgres_endpoint` for PostgreSQL connection information. pub async fn new( - postgres_endpoint: &str, default_db: &str, vss_db: &str, - additional_certificate: Option, + postgres_endpoint: &str, default_db: &str, vss_db: &str, certificate: Option, ) -> Result { let mut builder = TlsConnector::builder(); - if let Some(cert) = additional_certificate { + if let Some(cert) = certificate { builder.add_root_certificate(cert); } let connector = builder.build().map_err(|e| { diff --git a/rust/server/src/main.rs b/rust/server/src/main.rs index cc17f76..63bca0d 100644 --- a/rust/server/src/main.rs +++ b/rust/server/src/main.rs @@ -9,8 +9,6 @@ #![deny(rustdoc::private_intra_doc_links)] #![deny(missing_docs)] -use std::net::SocketAddr; - use tokio::net::TcpListener; use tokio::signal::unix::SignalKind; @@ -20,15 +18,13 @@ use hyper_util::rt::TokioIo; use crate::vss_service::VssService; use api::auth::{Authorizer, NoopAuthorizer}; use api::kv_store::KvStore; -use auth_impls::{DecodingKey, JWTAuthorizer}; -use impls::postgres_store::{Certificate, PostgresPlaintextBackend, PostgresTlsBackend}; +use auth_impls::JWTAuthorizer; +use impls::postgres_store::{PostgresPlaintextBackend, PostgresTlsBackend}; use std::sync::Arc; mod util; mod vss_service; -use util::config::{Config, ServerConfig}; - fn main() { let args: Vec = std::env::args().collect(); if args.len() != 2 { @@ -36,21 +32,10 @@ fn main() { std::process::exit(1); } - let Config { server_config: ServerConfig { host, port, rsa_pub_file_path }, postgresql_config } = - match util::config::load_config(&args[1]) { - Ok(cfg) => cfg, - Err(e) => { - eprintln!("Failed to load configuration: {}", e); - std::process::exit(1); - }, - }; - let addr: SocketAddr = match format!("{}:{}", host, port).parse() { - Ok(addr) => addr, - Err(e) => { - eprintln!("Invalid host/port configuration: {}", e); - std::process::exit(1); - }, - }; + let config = util::config::load_configuration(&args[1]).unwrap_or_else(|e| { + eprintln!("Failed to load configuration: {}", e); + std::process::exit(-1); + }); let runtime = match tokio::runtime::Builder::new_multi_thread().enable_all().build() { Ok(runtime) => Arc::new(runtime), @@ -69,73 +54,58 @@ fn main() { }, }; - let authorizer: Arc = if let Some(file_path) = rsa_pub_file_path { - let rsa_pub_file = match std::fs::read(file_path) { - Ok(pem) => pem, - Err(e) => { - println!("Failed to read RSA public key file: {}", e); - std::process::exit(-1); - }, + let authorizer: Arc = + if let Some(rsa_public_key) = config.jwt_rsa_public_key { + let jwt_authorizer = JWTAuthorizer::new(rsa_public_key).await; + println!("Configured JWT authorizer"); + Arc::new(jwt_authorizer) + } else { + let noop_authorizer = NoopAuthorizer {}; + println!("No authentication method configured"); + Arc::new(noop_authorizer) }; - let rsa_public_key = match DecodingKey::from_rsa_pem(&rsa_pub_file) { - Ok(pem) => pem, - Err(e) => { - println!("Failed to parse RSA public key file: {}", e); - std::process::exit(-1); - }, - }; - Arc::new(JWTAuthorizer::new(rsa_public_key).await) - } else { - Arc::new(NoopAuthorizer {}) - }; - let endpoint = postgresql_config.to_postgresql_endpoint(); - let default_db = postgresql_config.default_database; - let vss_db = postgresql_config.vss_database; - let store: Arc = if let Some(tls_config) = postgresql_config.tls { - let addl_certificate = tls_config.ca_file.map(|file| { - let certificate = match std::fs::read(&file) { - Ok(cert) => cert, - Err(e) => { - println!("Failed to read certificate file: {}", e); - std::process::exit(-1); - }, - }; - match Certificate::from_pem(&certificate) { - Ok(cert) => cert, - Err(e) => { - println!("Failed to parse certificate file: {}", e); - std::process::exit(-1); - }, - } + let store: Arc = if let Some(certificate) = config.tls_config { + let postgres_tls_backend = PostgresTlsBackend::new( + &config.postgresql_prefix, + &config.default_db, + &config.vss_db, + certificate, + ) + .await + .unwrap_or_else(|e| { + println!("Failed to start postgres TLS backend: {}", e); + std::process::exit(-1); }); - let postgres_tls_backend = - match PostgresTlsBackend::new(&endpoint, &default_db, &vss_db, addl_certificate) - .await - { - Ok(backend) => backend, - Err(e) => { - println!("Failed to start postgres tls backend: {}", e); - std::process::exit(-1); - }, - }; + println!( + "Connected to PostgreSQL TLS backend with DSN: {}/{}", + config.postgresql_prefix, config.vss_db + ); Arc::new(postgres_tls_backend) } else { - let postgres_plaintext_backend = - match PostgresPlaintextBackend::new(&endpoint, &default_db, &vss_db).await { - Ok(backend) => backend, - Err(e) => { - println!("Failed to start postgres plaintext backend: {}", e); - std::process::exit(-1); - }, - }; + let postgres_plaintext_backend = PostgresPlaintextBackend::new( + &config.postgresql_prefix, + &config.default_db, + &config.vss_db, + ) + .await + .unwrap_or_else(|e| { + println!("Failed to start postgres plaintext backend: {}", e); + std::process::exit(-1); + }); + println!( + "Connected to PostgreSQL plaintext backend with DSN: {}/{}", + config.postgresql_prefix, config.vss_db + ); Arc::new(postgres_plaintext_backend) }; - println!("Connected to PostgreSQL backend with DSN: {}/{}", endpoint, vss_db); - let rest_svc_listener = - TcpListener::bind(&addr).await.expect("Failed to bind listening port"); - println!("Listening for incoming connections on {}", addr); + let rest_svc_listener = TcpListener::bind(&config.bind_address).await.unwrap_or_else(|e| { + println!("Failed to bind listening port: {}", e); + std::process::exit(-1); + }); + println!("Listening for incoming connections on {}", config.bind_address); + loop { tokio::select! { res = rest_svc_listener.accept() => { diff --git a/rust/server/src/util/config.rs b/rust/server/src/util/config.rs index 13e1ee2..dba387c 100644 --- a/rust/server/src/util/config.rs +++ b/rust/server/src/util/config.rs @@ -1,53 +1,138 @@ +use auth_impls::DecodingKey; +use impls::postgres_store::Certificate; use serde::Deserialize; +use std::net::SocketAddr; #[derive(Deserialize)] -pub(crate) struct Config { - pub(crate) server_config: ServerConfig, - pub(crate) postgresql_config: PostgreSQLConfig, +struct Config { + server_config: ServerConfig, + postgresql_config: Option, } #[derive(Deserialize)] -pub(crate) struct ServerConfig { - pub(crate) host: String, - pub(crate) port: u16, - pub(crate) rsa_pub_file_path: Option, +struct ServerConfig { + bind_address: SocketAddr, + rsa_pub_file_path: Option, } +// All fields can be overriden by their corresponding environment variables #[derive(Deserialize)] -pub(crate) struct PostgreSQLConfig { - pub(crate) username: Option, // Optional in TOML, can be overridden by env - pub(crate) password: Option, // Optional in TOML, can be overridden by env - pub(crate) host: String, - pub(crate) port: u16, - pub(crate) default_database: String, - pub(crate) vss_database: String, - pub(crate) tls: Option, +struct PostgreSQLConfig { + // VSS_POSTGRESQL_USERNAME + username: Option, + // VSS_POSTGRESQL_PASSWORD + password: Option, + // VSS_POSTGRESQL_ADDRESS + address: Option, + // VSS_POSTGRESQL_DEFAULT_DATABASE + default_database: Option, + // VSS_POSTGRESQL_VSS_DATABASE + vss_database: Option, + // Set VSS_POSTGRESQL_TLS=1 for tls: Some(TlsConfig { crt_file_path: None }) + // Set VSS_POSTGRESQL_CRT_FILE_PATH=ca.crt for tls: Some(TlsConfig { crt_file_path: String::from("ca.crt") }) + tls: Option, } #[derive(Deserialize)] -pub(crate) struct TlsConfig { - pub(crate) ca_file: Option, +struct TlsConfig { + crt_file_path: Option, } -impl PostgreSQLConfig { - pub(crate) fn to_postgresql_endpoint(&self) -> String { - let username_env = std::env::var("VSS_POSTGRESQL_USERNAME"); - let username = username_env.as_ref() - .ok() - .or_else(|| self.username.as_ref()) - .expect("PostgreSQL database username must be provided in config or env var VSS_POSTGRESQL_USERNAME must be set."); - let password_env = std::env::var("VSS_POSTGRESQL_PASSWORD"); - let password = password_env.as_ref() - .ok() - .or_else(|| self.password.as_ref()) - .expect("PostgreSQL database password must be provided in config or env var VSS_POSTGRESQL_PASSWORD must be set."); - - format!("postgresql://{}:{}@{}:{}", username, password, self.host, self.port) - } +pub(crate) struct Configuration { + pub(crate) bind_address: SocketAddr, + pub(crate) jwt_rsa_public_key: Option, + pub(crate) postgresql_prefix: String, + pub(crate) default_db: String, + pub(crate) vss_db: String, + // The Some(None) variant maps to a TLS connection with no additional certificates + pub(crate) tls_config: Option>, } -pub(crate) fn load_config(config_path: &str) -> Result> { - let config_str = std::fs::read_to_string(config_path)?; - let config: Config = toml::from_str(&config_str)?; - Ok(config) +fn load_postgresql_prefix(config: Option<&PostgreSQLConfig>) -> Result { + let username_env = std::env::var("VSS_POSTGRESQL_USERNAME").ok(); + let username = username_env.as_ref() + .or(config.and_then(|c| c.username.as_ref())) + .ok_or("PostgreSQL database username must be provided in config or env var VSS_POSTGRESQL_USERNAME must be set.")?; + + let password_env = std::env::var("VSS_POSTGRESQL_PASSWORD").ok(); + let password = password_env.as_ref() + .or(config.and_then(|c| c.password.as_ref())) + .ok_or("PostgreSQL database password must be provided in config or env var VSS_POSTGRESQL_PASSWORD must be set.")?; + + let address_env: Option = + if let Some(addr) = std::env::var("VSS_POSTGRESQL_ADDRESS").ok() { + let socket_addr = addr + .parse() + .map_err(|e| format!("Unable to parse postgresql address env var: {}", e))?; + Some(socket_addr) + } else { + None + }; + let address = address_env.as_ref() + .or(config.and_then(|c| c.address.as_ref())) + .ok_or("PostgreSQL service address must be provided in config or env var VSS_POSTGRESQL_ADDRESS must be set.")?; + + Ok(format!("postgresql://{}:{}@{}", username, password, address)) +} + +pub(crate) fn load_configuration(config_file_path: &str) -> Result { + let config_file = std::fs::read_to_string(config_file_path) + .map_err(|e| format!("Failed to read configuration file: {}", e))?; + let Config { + server_config: ServerConfig { bind_address, rsa_pub_file_path }, + postgresql_config, + } = toml::from_str(&config_file) + .map_err(|e| format!("Failed to parse configuration file: {}", e))?; + + let jwt_rsa_public_key = if let Some(file_path) = rsa_pub_file_path { + let rsa_pub_file = std::fs::read(file_path) + .map_err(|e| format!("Failed to read RSA public key file: {}", e))?; + let rsa_public_key = DecodingKey::from_rsa_pem(&rsa_pub_file) + .map_err(|e| format!("Failed to parse RSA public key file: {}", e))?; + Some(rsa_public_key) + } else { + None + }; + + let postgresql_prefix = load_postgresql_prefix(postgresql_config.as_ref())?; + + let default_db_env = std::env::var("VSS_POSTGRESQL_DEFAULT_DATABASE").ok(); + let default_db = default_db_env + .or(postgresql_config.as_ref().and_then(|c| c.default_database.clone())) + .ok_or(String::from("PostgreSQL default database name must be provided in config or env var VSS_POSTGRESQL_DEFAULT_DATABASE must be set."))?; + + let vss_db_env = std::env::var("VSS_POSTGRESQL_VSS_DATABASE").ok(); + let vss_db = vss_db_env + .or(postgresql_config.as_ref().and_then(|c| c.vss_database.clone())) + .ok_or(String::from("PostgreSQL vss database name must be provided in config or env var VSS_POSTGRESQL_VSS_DATABASE must be set."))?; + + let crt_file_path_env = std::env::var("VSS_POSTGRESQL_CRT_FILE_PATH").ok(); + let crt_file_path = crt_file_path_env.or(postgresql_config + .as_ref() + .and_then(|c| c.tls.as_ref()) + .and_then(|tls| tls.crt_file_path.clone())); + let certificate = if let Some(file_path) = crt_file_path { + let crt_file = std::fs::read(&file_path) + .map_err(|e| format!("Failed to read certificate file: {}", e))?; + let certificate = Certificate::from_pem(&crt_file) + .map_err(|e| format!("Failed to parse certificate file: {}", e))?; + Some(certificate) + } else { + None + }; + + let tls_config_env = std::env::var("VSS_POSTGRESQL_TLS").ok(); + let tls_config = (certificate.is_some() + || tls_config_env.is_some() + || postgresql_config.and_then(|c| c.tls).is_some()) + .then_some(certificate); + + Ok(Configuration { + bind_address, + jwt_rsa_public_key, + postgresql_prefix, + default_db, + vss_db, + tls_config, + }) } diff --git a/rust/server/vss-server-config.toml b/rust/server/vss-server-config.toml index e549f0d..8f9820b 100644 --- a/rust/server/vss-server-config.toml +++ b/rust/server/vss-server-config.toml @@ -1,14 +1,12 @@ [server_config] -host = "127.0.0.1" -port = 8080 -# rsa_pub_file_path = "rsa_public_key.pem" # Uncomment to verify JWT tokens in the HTTP Authorization header +bind_address = "127.0.0.1:8080" +# rsa_pub_file_path = "rsa_public_key.pem" # Uncomment to verify JWT tokens in the HTTP Authorization header [postgresql_config] -username = "postgres" # Optional in TOML, can be overridden by env var `VSS_POSTGRESQL_USERNAME` -password = "postgres" # Optional in TOML, can be overridden by env var `VSS_POSTGRESQL_PASSWORD` -host = "localhost" -port = 5432 -default_database = "postgres" -vss_database = "vss" -# tls = { } # Uncomment to make TLS connections to the postgres database using your machine's PKI -# tls = { ca_file = "ca.pem" } # Uncomment to make TLS connections to the postgres database with an additional root certificate +username = "postgres" # Optional in TOML, can be overridden by env var `VSS_POSTGRESQL_USERNAME` +password = "postgres" # Optional in TOML, can be overridden by env var `VSS_POSTGRESQL_PASSWORD` +address = "127.0.0.1:5432" # Optional in TOML, can be overridden by env var `VSS_POSTGRESQL_ADDRESS` +default_database = "postgres" # Optional in TOML, can be overridden by env var `VSS_POSTGRESQL_DEFAULT_DATABASE` +vss_database = "vss" # Optional in TOML, can be overridden by env var `VSS_POSTGRESQL_VSS_DATABASE` +# tls = { } # Uncomment, or set env var `VSS_POSTGRESQL_TLS=1` to make TLS connections to the postgres database using your machine's PKI +# tls = { crt_file_path = "ca.crt" } # Uncomment, or set env var `VSS_POSTGRESQL_CRT_FILE_PATH=ca.crt` to make TLS connections to the postgres database with an additional root certificate