Completely reworked request handling to be able to serve SCGI

Multi ~~track~~ protocol ~~drifting~~ abstraction!!
This commit is contained in:
Emii Tatsuo 2020-12-01 02:31:08 -05:00
parent 86ed240761
commit 244fd25112
Signed by: Emi
GPG key ID: 68FAB2E2E6DFC98B
11 changed files with 371 additions and 94 deletions

View file

@ -12,24 +12,26 @@ readme = "README.md"
include = ["src/**", "Cargo.*", "CHANGELOG.md", "LICENSE*", "README.md"] include = ["src/**", "Cargo.*", "CHANGELOG.md", "LICENSE*", "README.md"]
[features] [features]
default = ["certgen"] default = ["scgi_srv"]
user_management = ["sled", "bincode", "serde/derive", "crc32fast", "lazy_static"] user_management = ["sled", "bincode", "serde/derive", "crc32fast", "lazy_static"]
user_management_advanced = ["rust-argon2", "ring", "user_management"] user_management_advanced = ["rust-argon2", "ring", "user_management"]
user_management_routes = ["user_management"] user_management_routes = ["user_management"]
serve_dir = ["mime_guess", "tokio/fs"] serve_dir = ["mime_guess", "tokio/fs"]
ratelimiting = ["dashmap"] ratelimiting = ["dashmap"]
certgen = ["rcgen"] certgen = ["rcgen", "gemini_srv"]
gemini_srv = ["tokio-rustls", "webpki", "rustls"]
scgi_srv = []
[dependencies] [dependencies]
anyhow = "1.0.33" anyhow = "1.0.33"
rustls = { version = "0.18.1", features = ["dangerous_configuration"] }
tokio-rustls = "0.20.0"
tokio = { version = "0.3.1", features = ["io-util","net","time", "rt"] } tokio = { version = "0.3.1", features = ["io-util","net","time", "rt"] }
uriparse = "0.6.3" uriparse = "0.6.3"
percent-encoding = "2.1.0" percent-encoding = "2.1.0"
log = "0.4.11" log = "0.4.11"
webpki = "0.21.0"
lazy_static = { version = "1.4.0", optional = true } lazy_static = { version = "1.4.0", optional = true }
rustls = { version = "0.18.1", features = ["dangerous_configuration"], optional = true}
webpki = { version = "0.21.0", optional = true}
tokio-rustls = { version = "0.20.0", optional = true}
mime_guess = { version = "2.0.3", optional = true } mime_guess = { version = "2.0.3", optional = true }
dashmap = { version = "3.11.10", optional = true } dashmap = { version = "3.11.10", optional = true }
sled = { version = "0.34.6", optional = true } sled = { version = "0.34.6", optional = true }
@ -39,6 +41,7 @@ rust-argon2 = { version = "0.8.2", optional = true }
crc32fast = { version = "1.2.1", optional = true } crc32fast = { version = "1.2.1", optional = true }
ring = { version = "0.16.15", optional = true } ring = { version = "0.16.15", optional = true }
rcgen = { version = "0.8.5", optional = true } rcgen = { version = "0.8.5", optional = true }
squeegee = { git = "https://gitlab.com/Alch_Emi/squeegee.git", branch = "main", optional = true }
[dev-dependencies] [dev-dependencies]
env_logger = "0.8.1" env_logger = "0.8.1"

View file

@ -1,7 +1,7 @@
use anyhow::*; use anyhow::*;
use log::LevelFilter; use log::LevelFilter;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use kochab::{Certificate, GEMINI_PORT, Request, Response, Server}; use kochab::{Certificate, Request, Response, Server};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
@ -16,9 +16,9 @@ async fn main() -> Result<()> {
let users = Arc::<RwLock::<HashMap<CertBytes, String>>>::default(); let users = Arc::<RwLock::<HashMap<CertBytes, String>>>::default();
Server::bind(("0.0.0.0", GEMINI_PORT)) Server::new()
.add_route("/", move|req| handle_request(users.clone(), req)) .add_route("/", move|req| handle_request(users.clone(), req))
.serve() .serve_unix("kochab.sock")
.await .await
} }

View file

@ -1,6 +1,6 @@
use anyhow::*; use anyhow::*;
use log::LevelFilter; use log::LevelFilter;
use kochab::{Server, Response, GEMINI_PORT, Document}; use kochab::{Server, Response, Document};
use kochab::document::HeadingLevel::*; use kochab::document::HeadingLevel::*;
#[tokio::main] #[tokio::main]
@ -38,8 +38,8 @@ async fn main() -> Result<()> {
)) ))
.into(); .into();
Server::bind(("localhost", GEMINI_PORT)) Server::new()
.add_route("/", response) .add_route("/", response)
.serve() .serve_unix("kochab.sock")
.await .await
} }

View file

@ -2,7 +2,7 @@ use std::time::Duration;
use anyhow::*; use anyhow::*;
use log::LevelFilter; use log::LevelFilter;
use kochab::{Server, Request, Response, GEMINI_PORT, Document}; use kochab::{Server, Request, Response, Document};
#[tokio::main] #[tokio::main]
async fn main() -> Result<()> { async fn main() -> Result<()> {
@ -10,10 +10,10 @@ async fn main() -> Result<()> {
.filter_module("kochab", LevelFilter::Debug) .filter_module("kochab", LevelFilter::Debug)
.init(); .init();
Server::bind(("localhost", GEMINI_PORT)) Server::new()
.add_route("/", handle_request) .add_route("/", handle_request)
.ratelimit("/limit", 2, Duration::from_secs(60)) .ratelimit("/limit", 2, Duration::from_secs(60))
.serve() .serve_unix("kochab.sock")
.await .await
} }

View file

@ -1,6 +1,6 @@
use anyhow::*; use anyhow::*;
use log::LevelFilter; use log::LevelFilter;
use kochab::{Document, document::HeadingLevel, Request, Response, GEMINI_PORT}; use kochab::{Document, document::HeadingLevel, Request, Response};
#[tokio::main] #[tokio::main]
async fn main() -> Result<()> { async fn main() -> Result<()> {
@ -8,11 +8,11 @@ async fn main() -> Result<()> {
.filter_module("kochab", LevelFilter::Debug) .filter_module("kochab", LevelFilter::Debug)
.init(); .init();
kochab::Server::bind(("localhost", GEMINI_PORT)) kochab::Server::new()
.add_route("/", handle_base) .add_route("/", handle_base)
.add_route("/route", handle_short) .add_route("/route", handle_short)
.add_route("/route/long", handle_long) .add_route("/route/long", handle_long)
.serve() .serve_unix("kochab.sock")
.await .await
} }

View file

@ -2,7 +2,7 @@ use std::path::PathBuf;
use anyhow::*; use anyhow::*;
use log::LevelFilter; use log::LevelFilter;
use kochab::{Server, GEMINI_PORT}; use kochab::Server;
#[tokio::main] #[tokio::main]
async fn main() -> Result<()> { async fn main() -> Result<()> {
@ -10,9 +10,9 @@ async fn main() -> Result<()> {
.filter_module("kochab", LevelFilter::Debug) .filter_module("kochab", LevelFilter::Debug)
.init(); .init();
Server::bind(("localhost", GEMINI_PORT)) Server::new()
.add_route("/", PathBuf::from("public")) // Serve directory listings & file contents .add_route("/", PathBuf::from("public")) // Serve directory listings & file contents
.add_route("/about", PathBuf::from("README.md")) // Serve a single file .add_route("/about", PathBuf::from("README.md")) // Serve a single file
.serve() .serve_unix("kochab.sock")
.await .await
} }

View file

@ -1,7 +1,6 @@
use anyhow::*; use anyhow::*;
use log::LevelFilter; use log::LevelFilter;
use kochab::{ use kochab::{
GEMINI_PORT,
Document, Document,
Request, Request,
Response, Response,
@ -26,7 +25,7 @@ async fn main() -> Result<()> {
.filter_module("kochab", LevelFilter::Debug) .filter_module("kochab", LevelFilter::Debug)
.init(); .init();
Server::bind(("0.0.0.0", GEMINI_PORT)) Server::new()
// Add our main routes // Add our main routes
.add_authenticated_route("/", handle_main) .add_authenticated_route("/", handle_main)
@ -36,7 +35,7 @@ async fn main() -> Result<()> {
.add_um_routes::<String>() .add_um_routes::<String>()
// Start the server // Start the server
.serve() .serve_unix("kochab.sock")
.await .await
} }

20
molly-brown.conf Normal file
View file

@ -0,0 +1,20 @@
# This is a super simple molly brown config file for the purpose of testing SCGI
# applications. Although you are welcome to use this as a base for an actual webserver,
# please find somewhere better for your production sockets.
#
# You can get a copy of molly brown and more information about configuring it from:
# https://tildegit.org/solderpunk/molly-brown
#
# Once installed, run the test server using the command
# molly-brown -c molly-brown.conf
Port = 1965
Hostname = "localhost"
CertPath = "cert/cert.pem"
KeyPath = "cert/key.pem"
AccessLog = "/dev/stdout"
ErrorLog = "/dev/stderr"
[SCGIPaths]
"/" = "kochab.sock"

View file

@ -1,26 +1,43 @@
#[macro_use] extern crate log; #[macro_use] extern crate log;
use std::{ use std::{
convert::TryFrom,
io::BufReader,
sync::Arc, sync::Arc,
path::PathBuf,
time::Duration, time::Duration,
}; };
#[cfg(feature = "ratelimiting")] #[cfg(feature = "gemini_srv")]
use std::net::IpAddr; use std::{
convert::TryFrom,
path::PathBuf,
};
#[cfg(feature = "scgi_srv")]
use std::{
collections::HashMap,
net::SocketAddr,
str::FromStr,
};
use tokio::{ use tokio::{
io,
io::BufReader,
net::TcpListener,
net::ToSocketAddrs,
prelude::*, prelude::*,
io::{self, BufStream}, };
net::{TcpStream, ToSocketAddrs}, #[cfg(feature = "scgi_srv")]
use tokio::net::UnixListener;
#[cfg(feature = "gemini_srv")]
use tokio::{
time::timeout, time::timeout,
net::TcpStream,
}; };
#[cfg(feature = "ratelimiting")] #[cfg(feature = "ratelimiting")]
use tokio::time::interval; use tokio::time::interval;
use tokio::net::TcpListener; #[cfg(feature = "gemini_srv")]
use rustls::ClientCertVerifier; use rustls::ClientCertVerifier;
#[cfg(feature = "gemini_srv")]
use rustls::internal::msgs::handshake::DigitallySignedStruct; use rustls::internal::msgs::handshake::DigitallySignedStruct;
#[cfg(feature = "gemini_srv")]
use tokio_rustls::{rustls, TlsAcceptor}; use tokio_rustls::{rustls, TlsAcceptor};
#[cfg(feature = "gemini_srv")]
use rustls::*; use rustls::*;
use anyhow::*; use anyhow::*;
use crate::util::opt_timeout; use crate::util::opt_timeout;
@ -54,6 +71,7 @@ use handling::Handler;
#[derive(Clone)] #[derive(Clone)]
struct ServerInner { struct ServerInner {
#[cfg(feature = "gemini_srv")]
tls_acceptor: TlsAcceptor, tls_acceptor: TlsAcceptor,
routes: Arc<RoutingNode<Handler>>, routes: Arc<RoutingNode<Handler>>,
timeout: Duration, timeout: Duration,
@ -65,7 +83,7 @@ struct ServerInner {
} }
impl ServerInner { impl ServerInner {
async fn serve(self, listener: TcpListener) -> Result<()> { async fn serve_ip(self, listener: TcpListener) -> Result<()> {
#[cfg(feature = "ratelimiting")] #[cfg(feature = "ratelimiting")]
tokio::spawn(prune_ratelimit_log(self.rate_limits.clone())); tokio::spawn(prune_ratelimit_log(self.rate_limits.clone()));
@ -82,48 +100,110 @@ impl ServerInner {
} }
} }
async fn serve_client(self, stream: TcpStream) -> Result<()> { #[cfg(feature = "scgi_srv")]
#[cfg(feature="ratelimiting")] // Yeah it's code duplication, but I can't find a way around it, so this is what we're
let peer_addr = stream.peer_addr()?.ip(); // getting for now
async fn serve_unix(self, listener: UnixListener) -> Result<()> {
#[cfg(feature = "ratelimiting")]
tokio::spawn(prune_ratelimit_log(self.rate_limits.clone()));
loop {
let (stream, _addr) = listener.accept().await
.context("Failed to accept client")?;
let this = self.clone();
tokio::spawn(async move {
if let Err(err) = this.serve_client(stream).await {
error!("{:?}", err);
}
});
}
}
async fn serve_client(
&self,
#[cfg(feature = "gemini_srv")]
stream: TcpStream,
#[cfg(feature = "scgi_srv")]
stream: impl AsyncWrite + AsyncRead + Unpin,
) -> Result<()> {
let fut_accept_request = async { let fut_accept_request = async {
#[cfg(feature = "gemini_srv")]
let stream = self.tls_acceptor.accept(stream).await let stream = self.tls_acceptor.accept(stream).await
.context("Failed to establish TLS session")?; .context("Failed to establish TLS session")?;
let mut stream = BufStream::new(stream); let mut stream = BufReader::new(stream);
#[cfg(feature="user_management")]
let request = self.receive_request(&mut stream).await let request = self.receive_request(&mut stream).await
.context("Failed to receive request")?; .context("Failed to receive request")?;
#[cfg(not(feature="user_management"))]
let request = Self::receive_request(&mut stream).await
.context("Failed to receive request")?;
Result::<_, anyhow::Error>::Ok((request, stream)) Result::<_, anyhow::Error>::Ok((request, stream))
}; };
// Use a timeout for interacting with the client
let fut_accept_request = timeout(self.timeout, fut_accept_request);
let (mut request, mut stream) = fut_accept_request.await
.context("Client timed out while waiting for response")??;
#[cfg(feature="ratelimiting")] // Wait for the request to be parsed
let (mut request, mut stream) = {
#[cfg(feature = "gemini_srv")] {
// Use a timeout for interacting with the client
let fut_accept_request = timeout(self.timeout, fut_accept_request);
fut_accept_request.await
.context("Client timed out while waiting for response")??
}
#[cfg(feature = "scgi_srv")]
fut_accept_request.await?
};
// Determine the remote client's IP address for logging and ratelimiting
let peer_addr = {
#[cfg(feature = "gemini_srv")] {
stream.get_ref()
.get_ref()
.0
.peer_addr()?
.ip()
}
#[cfg(feature = "scgi_srv")] {
SocketAddr::from_str(
request.headers()
.get("REMOTE_ADDR")
.ok_or(ParseError::Malformed("REMOTE_ADDR header not received"))?
.as_str()
).context("Received malformed IP address from upstream")?
.ip()
}
};
#[cfg(feature = "ratelimiting")]
// Perform ratelimiting checks
if let Some(resp) = self.check_rate_limits(peer_addr, &request) { if let Some(resp) = self.check_rate_limits(peer_addr, &request) {
// Log warning
warn!(
"Client from {} requesting {} was turned away by ratelimiting",
peer_addr,
request.uri()
);
// Send error response
self.send_response(resp, &mut stream).await self.send_response(resp, &mut stream).await
.context("Failed to send response")?; .context("Failed to send response")?;
// Exit
return Ok(()) return Ok(())
} }
debug!("Client requested: {}", request.uri()); info!("{} requested: {}", peer_addr, request.uri());
// Identify the client certificate from the tls stream. This is the first // Identify the client certificate from the tls stream. This is the first
// certificate in the certificate chain. // certificate in the certificate chain.
let client_cert = stream.get_ref() #[cfg(feature = "gemini_srv")] { // This is done earlier for `scgi_srv`
.get_ref() let client_cert = stream.get_ref()
.1 .get_ref()
.get_peer_certificates() .1
.and_then(|mut v| if v.is_empty() {None} else {Some(v.remove(0))}); .get_peer_certificates()
.and_then(|mut v| if v.is_empty() {None} else {Some(v.remove(0))});
request.set_cert(client_cert); request.set_cert(client_cert);
}
let response = if let Some((trailing, handler)) = self.routes.match_request(&request) { let response = if let Some((trailing, handler)) = self.routes.match_request(&request) {
request.set_trailing(trailing); request.set_trailing(trailing);
@ -197,13 +277,13 @@ impl ServerInner {
None None
} }
#[cfg(feature = "gemini_srv")]
async fn receive_request( async fn receive_request(
#[cfg(feature="user_management")]
&self, &self,
stream: &mut (impl AsyncBufRead + Unpin) stream: &mut (impl AsyncBufRead + Unpin),
) -> Result<Request> { ) -> Result<Request> {
let limit = REQUEST_URI_MAX_LEN + "\r\n".len(); const HEADER_LIMIT: usize = REQUEST_URI_MAX_LEN + "\r\n".len();
let mut stream = stream.take(limit as u64); let mut stream = stream.take(HEADER_LIMIT as u64);
let mut uri = Vec::new(); let mut uri = Vec::new();
stream.read_until(b'\n', &mut uri).await?; stream.read_until(b'\n', &mut uri).await?;
@ -223,23 +303,123 @@ impl ServerInner {
let uri = URIReference::try_from(&*uri) let uri = URIReference::try_from(&*uri)
.context("Request URI is invalid")? .context("Request URI is invalid")?
.into_owned(); .into_owned();
let request = Request::from_uri(
Request::new(
uri, uri,
#[cfg(feature="user_management")] #[cfg(feature="user_management")]
self.manager.clone(), self.manager.clone(),
) .context("Failed to create request from URI")?; ).context("Failed to create request from URI")
}
Ok(request) #[cfg(feature = "scgi_srv")]
async fn receive_request(
&self,
stream: &mut (impl AsyncBufRead + Unpin),
) -> Result<Request> {
let mut buff = Vec::with_capacity(4);
#[allow(clippy::char_lit_as_u8)]
// Read the length of the header netstring (e.g. "120:")
stream.read_until(':' as u8, &mut buff).await?;
buff.pop(); // Remove the trailing ':'
let len = std::str::from_utf8(&*buff)
.ok()
.and_then(|s| usize::from_str(s).ok())
.ok_or(ParseError::Malformed("netstring length"))?;
// Read in the headers
buff.clear();
buff.resize(len + 1, 0);
stream.read_exact(buff.as_mut()).await?;
buff.truncate(len - 1); // Remove the final \x00,
// Parse the headers
let (maybe_trailing, headers) = buff.split(|b| *b == 0) // Headers are null delimiited
.map(|bytes| // Convert to an &str
std::str::from_utf8(bytes)
.map_err(|_| ParseError::Malformed("scgi headers"))
.map(str::trim)
)
.try_fold( // Turn the array of [header, value, header, ...] into a map
(Option::<&str>::None, HashMap::<String, String>::with_capacity(16)),
|(last_header, mut headers), s| {
s.map(|text| {
match last_header {
None => (Some(text), headers),
Some(header) => {
headers.insert(header.to_string(), text.to_string());
(None, headers)
}
}
})
}
)?;
// If there's not the same number of headers as values, that's a problem
if maybe_trailing.is_some() {
bail!(ParseError::Malformed("trailing header"));
}
// Check the content length info
let cont_len_val = headers.get("CONTENT_LENGTH")
.ok_or(ParseError::Malformed("No content length header!"))?;
let cont_len = usize::from_str(cont_len_val)
.map_err(|_| ParseError::Malformed("Malformed content length"))?;
if cont_len > 0 {
bail!(ParseError::Malformed("Gemini SCGI requests should not have a body"));
}
// Spec requires setting an SCGI header to one
if *headers.get("SCGI").ok_or(ParseError::Malformed("No SCGI header"))? != "1" {
bail!(ParseError::Malformed("SCGI header not set to \"1\""));
}
trace!("Headers received: {:?}", headers);
Ok(Request::new(headers)?)
} }
} }
pub struct Server<A> { #[derive(Debug)]
addr: A, #[cfg(feature = "scgi_srv")]
cert_path: PathBuf, enum ParseError {
key_path: PathBuf, IO(io::Error),
Malformed(&'static str),
}
#[cfg(feature = "scgi_srv")]
impl From<io::Error> for ParseError {
fn from(e: io::Error) -> Self {
Self::IO(e)
}
}
#[cfg(feature = "scgi_srv")]
impl std::fmt::Display for ParseError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Self::IO(e) => write!(f, "IO Error while parsing and responding SCGI: {}", e),
Self::Malformed(e) => write!(f, "SCGI request malformed at {}", e),
}
}
}
#[cfg(feature = "scgi_srv")]
impl std::error::Error for ParseError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
if let Self::IO(e) = self { Some(e) } else { None }
}
}
pub struct Server {
timeout: Duration, timeout: Duration,
complex_body_timeout_override: Option<Duration>, complex_body_timeout_override: Option<Duration>,
routes: RoutingNode<Handler>, routes: RoutingNode<Handler>,
#[cfg(feature = "gemini_srv")]
cert_path: PathBuf,
#[cfg(feature = "gemini_srv")]
key_path: PathBuf,
#[cfg(feature="ratelimiting")] #[cfg(feature="ratelimiting")]
rate_limits: RoutingNode<RateLimiter<IpAddr>>, rate_limits: RoutingNode<RateLimiter<IpAddr>>,
#[cfg(feature="user_management")] #[cfg(feature="user_management")]
@ -250,15 +430,16 @@ pub struct Server<A> {
certgen_mode: CertGenMode, certgen_mode: CertGenMode,
} }
impl<A: ToSocketAddrs> Server<A> { impl Server {
pub fn bind(addr: A) -> Self { pub fn new() -> Self {
Self { Self {
addr,
timeout: Duration::from_secs(1), timeout: Duration::from_secs(1),
complex_body_timeout_override: Some(Duration::from_secs(30)), complex_body_timeout_override: Some(Duration::from_secs(30)),
cert_path: PathBuf::from("cert/cert.pem"),
key_path: PathBuf::from("cert/key.pem"),
routes: RoutingNode::default(), routes: RoutingNode::default(),
#[cfg(feature = "gemini_srv")]
cert_path: PathBuf::from("cert/cert.pem"),
#[cfg(feature = "gemini_srv")]
key_path: PathBuf::from("cert/key.pem"),
#[cfg(feature="ratelimiting")] #[cfg(feature="ratelimiting")]
rate_limits: RoutingNode::default(), rate_limits: RoutingNode::default(),
#[cfg(feature="user_management")] #[cfg(feature="user_management")]
@ -309,6 +490,7 @@ impl<A: ToSocketAddrs> Server<A> {
self self
} }
#[cfg(feature = "gemini_srv")]
/// Sets the directory that kochab should look for TLS certs and keys into /// Sets the directory that kochab should look for TLS certs and keys into
/// ///
/// Northstar will look for files called `cert.pem` and `key.pem` in the provided /// Northstar will look for files called `cert.pem` and `key.pem` in the provided
@ -324,6 +506,7 @@ impl<A: ToSocketAddrs> Server<A> {
.set_key(dir.join("key.pem")) .set_key(dir.join("key.pem"))
} }
#[cfg(feature = "gemini_srv")]
/// Set the path to the TLS certificate kochab will use /// Set the path to the TLS certificate kochab will use
/// ///
/// This defaults to `cert/cert.pem`. /// This defaults to `cert/cert.pem`.
@ -335,6 +518,7 @@ impl<A: ToSocketAddrs> Server<A> {
self self
} }
#[cfg(feature = "gemini_srv")]
/// Set the path to the ertificate key kochab will use /// Set the path to the ertificate key kochab will use
/// ///
/// This defaults to `cert/key.pem`. /// This defaults to `cert/key.pem`.
@ -436,7 +620,8 @@ impl<A: ToSocketAddrs> Server<A> {
self self
} }
pub async fn serve(mut self) -> Result<()> { fn build(mut self) -> Result<ServerInner> {
#[cfg(feature = "gemini_srv")]
let config = tls_config( let config = tls_config(
&self.cert_path, &self.cert_path,
&self.key_path, &self.key_path,
@ -444,28 +629,51 @@ impl<A: ToSocketAddrs> Server<A> {
self.certgen_mode self.certgen_mode
).context("Failed to create TLS config")?; ).context("Failed to create TLS config")?;
let listener = TcpListener::bind(self.addr).await
.context("Failed to create socket")?;
self.routes.shrink(); self.routes.shrink();
#[cfg(feature="user_management")] #[cfg(feature="user_management")]
let data_dir = self.data_dir; let data_dir = self.data_dir;
let server = ServerInner { Ok(ServerInner {
tls_acceptor: TlsAcceptor::from(config),
routes: Arc::new(self.routes), routes: Arc::new(self.routes),
timeout: self.timeout, timeout: self.timeout,
complex_timeout: self.complex_body_timeout_override, complex_timeout: self.complex_body_timeout_override,
#[cfg(feature = "gemini_srv")]
tls_acceptor: TlsAcceptor::from(config),
#[cfg(feature="ratelimiting")] #[cfg(feature="ratelimiting")]
rate_limits: Arc::new(self.rate_limits), rate_limits: Arc::new(self.rate_limits),
#[cfg(feature="user_management")] #[cfg(feature="user_management")]
manager: UserManager::new( manager: UserManager::new(
self.database.unwrap_or_else(move|| sled::open(data_dir).unwrap()) self.database.unwrap_or_else(move|| sled::open(data_dir).unwrap())
)?, )?,
}; })
}
server.serve(listener).await /// Start serving requests on a given bound address & port
///
/// `addr` can be anything `tokio` can parse, including just a string like
/// "localhost:1965"
pub async fn serve_ip(self, addr: impl ToSocketAddrs) -> Result<()> {
let server = self.build()?;
let socket = TcpListener::bind(addr).await?;
server.serve_ip(socket).await
}
#[cfg(feature = "scgi_srv")]
/// Start serving requests on a given unix socket
///
/// Requires an address in the form of a path to bind to. This is only available when
/// in `scgi_srv` mode.
pub async fn serve_unix(self, addr: impl AsRef<std::path::Path>) -> Result<()> {
let server = self.build()?;
let socket = UnixListener::bind(addr)?;
server.serve_unix(socket).await
}
}
impl Default for Server {
fn default() -> Self {
Self::new()
} }
} }
@ -512,6 +720,7 @@ async fn prune_ratelimit_log(rate_limits: Arc<RoutingNode<RateLimiter<IpAddr>>>)
} }
} }
#[cfg(feature = "gemini_srv")]
fn tls_config( fn tls_config(
cert_path: &PathBuf, cert_path: &PathBuf,
key_path: &PathBuf, key_path: &PathBuf,
@ -535,20 +744,22 @@ fn tls_config(
Ok(config.into()) Ok(config.into())
} }
#[cfg(feature = "gemini_srv")]
fn load_cert_chain(cert_path: &PathBuf) -> Result<Vec<Certificate>> { fn load_cert_chain(cert_path: &PathBuf) -> Result<Vec<Certificate>> {
let certs = std::fs::File::open(cert_path) let certs = std::fs::File::open(cert_path)
.with_context(|| format!("Failed to open `{:?}`", cert_path))?; .with_context(|| format!("Failed to open `{:?}`", cert_path))?;
let mut certs = BufReader::new(certs); let mut certs = std::io::BufReader::new(certs);
let certs = rustls::internal::pemfile::certs(&mut certs) let certs = rustls::internal::pemfile::certs(&mut certs)
.map_err(|_| anyhow!("failed to load certs `{:?}`", cert_path))?; .map_err(|_| anyhow!("failed to load certs `{:?}`", cert_path))?;
Ok(certs) Ok(certs)
} }
#[cfg(feature = "gemini_srv")]
fn load_key(key_path: &PathBuf) -> Result<PrivateKey> { fn load_key(key_path: &PathBuf) -> Result<PrivateKey> {
let keys = std::fs::File::open(key_path) let keys = std::fs::File::open(key_path)
.with_context(|| format!("Failed to open `{:?}`", key_path))?; .with_context(|| format!("Failed to open `{:?}`", key_path))?;
let mut keys = BufReader::new(keys); let mut keys = std::io::BufReader::new(keys);
let mut keys = rustls::internal::pemfile::pkcs8_private_keys(&mut keys) let mut keys = rustls::internal::pemfile::pkcs8_private_keys(&mut keys)
.map_err(|_| anyhow!("failed to load key `{:?}`", key_path))?; .map_err(|_| anyhow!("failed to load key `{:?}`", key_path))?;
@ -559,11 +770,14 @@ fn load_key(key_path: &PathBuf) -> Result<PrivateKey> {
Ok(key) Ok(key)
} }
#[cfg(feature = "gemini_srv")]
/// A client cert verifier that accepts all connections /// A client cert verifier that accepts all connections
/// ///
/// Unfortunately, rustls doesn't provide a ClientCertVerifier that accepts self-signed /// Unfortunately, rustls doesn't provide a ClientCertVerifier that accepts self-signed
/// certificates, so we need to implement this ourselves. /// certificates, so we need to implement this ourselves.
struct AllowAnonOrSelfsignedClient { } struct AllowAnonOrSelfsignedClient { }
#[cfg(feature = "gemini_srv")]
impl AllowAnonOrSelfsignedClient { impl AllowAnonOrSelfsignedClient {
/// Create a new verifier /// Create a new verifier
@ -573,6 +787,7 @@ impl AllowAnonOrSelfsignedClient {
} }
#[cfg(feature = "gemini_srv")]
impl ClientCertVerifier for AllowAnonOrSelfsignedClient { impl ClientCertVerifier for AllowAnonOrSelfsignedClient {
fn client_auth_root_subjects( fn client_auth_root_subjects(

View file

@ -1,4 +1,7 @@
#[cfg(feature = "gemini_srv")]
pub use rustls::Certificate; pub use rustls::Certificate;
#[cfg(feature = "scgi_srv")]
pub type Certificate = String;
pub use uriparse::URIReference; pub use uriparse::URIReference;
mod meta; mod meta;

View file

@ -1,8 +1,13 @@
use std::ops; use std::ops;
#[cfg(feature = "scgi_srv")]
use std::{
collections::HashMap,
convert::TryFrom,
};
use anyhow::*; use anyhow::*;
use percent_encoding::percent_decode_str; use percent_encoding::percent_decode_str;
use uriparse::URIReference; use uriparse::URIReference;
use rustls::Certificate; use crate::types::Certificate;
#[cfg(feature="user_management")] #[cfg(feature="user_management")]
use serde::{Serialize, de::DeserializeOwned}; use serde::{Serialize, de::DeserializeOwned};
@ -16,28 +21,37 @@ pub struct Request {
trailing_segments: Option<Vec<String>>, trailing_segments: Option<Vec<String>>,
#[cfg(feature="user_management")] #[cfg(feature="user_management")]
manager: UserManager, manager: UserManager,
#[cfg(feature = "scgi_srv")]
headers: HashMap<String, String>,
} }
impl Request { impl Request {
pub fn from_uri( pub fn new(
uri: URIReference<'static>, #[cfg(feature = "gemini_srv")]
#[cfg(feature="user_management")]
manager: UserManager,
) -> Result<Self> {
Self::with_certificate(
uri,
None,
#[cfg(feature="user_management")]
manager
)
}
pub fn with_certificate(
mut uri: URIReference<'static>, mut uri: URIReference<'static>,
certificate: Option<Certificate>, #[cfg(feature = "scgi_srv")]
headers: HashMap<String, String>,
#[cfg(feature="user_management")] #[cfg(feature="user_management")]
manager: UserManager, manager: UserManager,
) -> Result<Self> { ) -> Result<Self> {
#[cfg(feature = "scgi_srv")]
let (mut uri, certificate) = (
URIReference::try_from(
format!(
"{}{}",
headers.get("PATH_INFO")
.context("PATH_INFO header not received from SCGI client")?
.as_str(),
headers.get("QUERY_STRING")
.map(|q| format!("?{}", q))
.unwrap_or_else(String::new),
).as_str()
)
.context("Request URI is invalid")?
.into_owned(),
headers.get("TLS_CLIENT_HASH").cloned(),
);
uri.normalize(); uri.normalize();
let input = match uri.query() { let input = match uri.query() {
@ -54,8 +68,13 @@ impl Request {
Ok(Self { Ok(Self {
uri, uri,
input, input,
#[cfg(feature = "scgi_srv")]
certificate, certificate,
#[cfg(feature = "gemini_srv")]
certificate: None,
trailing_segments: None, trailing_segments: None,
#[cfg(feature = "scgi_srv")]
headers,
#[cfg(feature="user_management")] #[cfg(feature="user_management")]
manager, manager,
}) })
@ -103,6 +122,24 @@ impl Request {
self.input.as_deref() self.input.as_deref()
} }
#[cfg(feature="scgi_srv")]
/// View any headers sent by the SCGI client
///
/// When an SCGI client delivers a request (e.g. when your gemini server sends a
/// request to this app), it includes many headers which aren't always included in
/// the request otherwise. Bear in mind that **not all SCGI clients send the same
/// headers**, and these are *never* available when operating in `gemini_srv` mode.
///
/// Some examples of headers mollybrown sets are:
/// - `REMOTE_ADDR` (The user's IP address and port)
/// - `TLS_CLIENT_SUBJECT_CN` (The CommonName on the user's certificate, when present)
/// - `SERVER_NAME` (The host name of the server the request was received on)
/// - `SERVER_SOFTWARE` (= "MOLLY_BROWN")
/// - `SCRIPT_PATH` (The prefix the script is being served on)
pub const fn headers(&self) -> &HashMap<String, String> {
&self.headers
}
pub fn set_cert(&mut self, cert: Option<Certificate>) { pub fn set_cert(&mut self, cert: Option<Certificate>) {
self.certificate = cert; self.certificate = cert;
} }