diff --git a/web/src/main.rs b/web/src/main.rs index 5e4bdeb..305d506 100644 --- a/web/src/main.rs +++ b/web/src/main.rs @@ -2,6 +2,7 @@ pub mod ogp_images; pub mod contrast; pub mod statics; pub mod configuration; +mod write_vectored_all; use crate::configuration::Conf; use std::io::IoSlice; @@ -16,6 +17,7 @@ use percent_encoding::{percent_decode_str, percent_encode, NON_ALPHANUMERIC}; use pronouns_today::user_preferences::ParseError; use pronouns_today::UserPreferences; use configuration::ConfigError; +use write_vectored_all::AsyncWriteAllVectored; use std::net::SocketAddr; use std::process::exit; @@ -135,16 +137,11 @@ async fn handle_request( let response = route_request(&req) .generate_response(conf); - let io_slices = response.into_io_slices(); + let mut io_slices = response.into_io_slices(); - for slice in io_slices { - if let Err(e) = stream.write_all(&slice).await { - log::warn!( - "Encountered an IO error while sending response: {}", e - ); - break; - } - } + if let Err(e) = stream.write_all_vectored(&mut io_slices).await { + log::warn!("Encountered an IO error while sending response: {}", e); + } if let Err(e) = stream.close().await { log::warn!( "Encountered an IO error while closing connection to server: {}", e diff --git a/web/src/write_vectored_all.rs b/web/src/write_vectored_all.rs new file mode 100644 index 0000000..cc153f6 --- /dev/null +++ b/web/src/write_vectored_all.rs @@ -0,0 +1,89 @@ +use std::slice; +use std::future::Future; +use std::io::{Error, ErrorKind, IoSlice, Result}; +use std::mem::replace; +use std::pin::Pin; +use std::task::Poll; + +use futures_lite::AsyncWrite; +use futures_lite::ready; + +pub trait AsyncWriteAllVectored: AsyncWrite { + fn write_all_vectored<'a>(&'a mut self, bufs: &'a mut [IoSlice<'a>]) -> WriteAllVectoredFuture<'a, Self> + where + Self: Unpin, + { + WriteAllVectoredFuture { writer: self, bufs } + } +} + +impl AsyncWriteAllVectored for T {} + +pub struct WriteAllVectoredFuture<'a, W: Unpin + ?Sized> { + writer: &'a mut W, + bufs: &'a mut [IoSlice<'a>], +} + +impl Unpin for WriteAllVectoredFuture<'_, W> {} + +impl Future for WriteAllVectoredFuture<'_, W> { + type Output = Result<()>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + let Self { writer, bufs } = &mut *self; + + // Guarantee that bufs is empty if it contains no data, + // to avoid calling write_vectored if there is no data to be written. + advance_slices(bufs, 0); + while !bufs.is_empty() { + match ready!(Pin::new(&mut ** writer).poll_write_vectored(cx, bufs)) { + Ok(0) => { + return Poll::Ready(Err(Error::new( + ErrorKind::WriteZero, + "failed to write whole buffer", + ))); + } + Ok(n) => advance_slices(bufs, n), + Err(ref e) if e.kind() == ErrorKind::Interrupted => {} + Err(e) => return Poll::Ready(Err(e)), + } + } + + Poll::Ready(Ok(())) + } +} + +fn advance_slices(bufs: &mut &mut [IoSlice<'_>], n: usize) { + // Number of buffers to remove. + let mut remove = 0; + // Total length of all the to be removed buffers. + let mut accumulated_len = 0; + for buf in bufs.iter() { + if accumulated_len + buf.len() > n { + break; + } else { + accumulated_len += buf.len(); + remove += 1; + } + } + + *bufs = &mut replace(bufs, &mut [])[remove..]; + if !bufs.is_empty() { + advance(&mut bufs[0], n - accumulated_len); + } +} + +fn advance<'a>(buf: &mut IoSlice<'a>, n: usize) { + if buf.len() < n { + panic!("advancing IoSlice beyond its length"); + } + // SAFTEY: hopefully + unsafe { + let mut ptr = buf.as_ptr() as *mut u8; + ptr = ptr.add(n); + let len = buf.len() - n; + let new_slice: &'a [u8] = slice::from_raw_parts(ptr, len); + *buf = IoSlice::new(new_slice); + } +} +