Merge pull request #17 from Alch-Emi/timeout

Add timeouts to response handling
This commit is contained in:
panicbit 2020-11-18 21:04:48 +01:00 committed by GitHub
commit 4eae63ac4e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1,11 +1,18 @@
#[macro_use] extern crate log; #[macro_use] extern crate log;
use std::{panic::AssertUnwindSafe, convert::TryFrom, io::BufReader, sync::Arc}; use std::{
panic::AssertUnwindSafe,
convert::TryFrom,
io::BufReader,
sync::Arc,
time::Duration,
};
use futures::{future::BoxFuture, FutureExt}; use futures::{future::BoxFuture, FutureExt};
use tokio::{ use tokio::{
prelude::*, prelude::*,
io::{self, BufStream}, io::{self, BufStream},
net::{TcpStream, ToSocketAddrs}, net::{TcpStream, ToSocketAddrs},
time::timeout,
}; };
use tokio::net::TcpListener; use tokio::net::TcpListener;
use rustls::ClientCertVerifier; use rustls::ClientCertVerifier;
@ -32,6 +39,7 @@ pub struct Server {
tls_acceptor: TlsAcceptor, tls_acceptor: TlsAcceptor,
listener: Arc<TcpListener>, listener: Arc<TcpListener>,
handler: Handler, handler: Handler,
timeout: Duration,
} }
impl Server { impl Server {
@ -54,12 +62,22 @@ impl Server {
} }
async fn serve_client(self, stream: TcpStream) -> Result<()> { async fn serve_client(self, stream: TcpStream) -> Result<()> {
let fut_accept_request = async {
let stream = self.tls_acceptor.accept(stream).await let stream = self.tls_acceptor.accept(stream).await
.context("Failed to establish TLS session")?; .context("Failed to establish TLS session")?;
let mut stream = BufStream::new(stream); let mut stream = BufStream::new(stream);
let mut request = receive_request(&mut stream).await let request = receive_request(&mut stream).await
.context("Failed to receive request")?; .context("Failed to receive request")?;
Result::<_, anyhow::Error>::Ok((request, stream))
};
// Use a timeout for interacting with the client
let fut_accept_request = timeout(self.timeout, fut_accept_request);
let (mut request, mut stream) = fut_accept_request.await
.context("Client timed out while waiting for response")??;
debug!("Client requested: {}", request.uri()); debug!("Client requested: {}", request.uri());
// Identify the client certificate from the tls stream. This is the first // Identify the client certificate from the tls stream. This is the first
@ -83,11 +101,18 @@ impl Server {
}) })
.context("Request handler failed")?; .context("Request handler failed")?;
// Use a timeout for sending the response
let fut_send_and_flush = async {
send_response(response, &mut stream).await send_response(response, &mut stream).await
.context("Failed to send response")?; .context("Failed to send response")?;
stream.flush().await stream.flush()
.context("Failed to flush response data")?; .await
.context("Failed to flush response data")
};
timeout(self.timeout, fut_send_and_flush)
.await
.context("Client timed out receiving response data")??;
Ok(()) Ok(())
} }
@ -95,11 +120,29 @@ impl Server {
pub struct Builder<A> { pub struct Builder<A> {
addr: A, addr: A,
timeout: Duration,
} }
impl<A: ToSocketAddrs> Builder<A> { impl<A: ToSocketAddrs> Builder<A> {
fn bind(addr: A) -> Self { fn bind(addr: A) -> Self {
Self { addr } Self { addr, timeout: Duration::from_secs(1) }
}
/// Set the timeout on incoming requests
///
/// Note that this timeout is applied twice, once for the delivery of the request, and
/// once for sending the client's response. This means that for a 1 second timeout,
/// the client will have 1 second to complete the TLS handshake and deliver a request
/// header, then your API will have as much time as it needs to handle the request,
/// before the client has another second to receive the response.
///
/// If you would like a timeout for your code itself, please use
/// [`tokio::time::Timeout`] to implement it internally.
///
/// The default timeout is 1 second.
pub fn set_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
} }
pub async fn serve<F>(self, handler: F) -> Result<()> pub async fn serve<F>(self, handler: F) -> Result<()>
@ -116,6 +159,7 @@ impl<A: ToSocketAddrs> Builder<A> {
tls_acceptor: TlsAcceptor::from(config), tls_acceptor: TlsAcceptor::from(config),
listener: Arc::new(listener), listener: Arc::new(listener),
handler: Arc::new(handler), handler: Arc::new(handler),
timeout: self.timeout,
}; };
server.serve().await server.serve().await