[web] use actual vectored write

Since IoSlice::advance isn't stable I had to hack my own. I will need to
confirm that my implementation is safe.

Signed-off-by: Ben Aaron Goldberg <ben@benaaron.dev>
This commit is contained in:
Ben Aaron Goldberg 2021-11-03 19:43:46 -04:00
parent 4ed748cf1e
commit 63cfa014d5
2 changed files with 95 additions and 9 deletions

View file

@ -2,6 +2,7 @@ pub mod ogp_images;
pub mod contrast;
pub mod statics;
pub mod configuration;
mod write_vectored_all;
use crate::configuration::Conf;
use std::io::IoSlice;
@ -16,6 +17,7 @@ use percent_encoding::{percent_decode_str, percent_encode, NON_ALPHANUMERIC};
use pronouns_today::user_preferences::ParseError;
use pronouns_today::UserPreferences;
use configuration::ConfigError;
use write_vectored_all::AsyncWriteAllVectored;
use std::net::SocketAddr;
use std::process::exit;
@ -135,16 +137,11 @@ async fn handle_request(
let response = route_request(&req)
.generate_response(conf);
let io_slices = response.into_io_slices();
let mut io_slices = response.into_io_slices();
for slice in io_slices {
if let Err(e) = stream.write_all(&slice).await {
log::warn!(
"Encountered an IO error while sending response: {}", e
);
break;
}
}
if let Err(e) = stream.write_all_vectored(&mut io_slices).await {
log::warn!("Encountered an IO error while sending response: {}", e);
}
if let Err(e) = stream.close().await {
log::warn!(
"Encountered an IO error while closing connection to server: {}", e

View file

@ -0,0 +1,89 @@
use std::slice;
use std::future::Future;
use std::io::{Error, ErrorKind, IoSlice, Result};
use std::mem::replace;
use std::pin::Pin;
use std::task::Poll;
use futures_lite::AsyncWrite;
use futures_lite::ready;
pub trait AsyncWriteAllVectored: AsyncWrite {
fn write_all_vectored<'a>(&'a mut self, bufs: &'a mut [IoSlice<'a>]) -> WriteAllVectoredFuture<'a, Self>
where
Self: Unpin,
{
WriteAllVectoredFuture { writer: self, bufs }
}
}
impl <T: AsyncWrite> AsyncWriteAllVectored for T {}
pub struct WriteAllVectoredFuture<'a, W: Unpin + ?Sized> {
writer: &'a mut W,
bufs: &'a mut [IoSlice<'a>],
}
impl<W: Unpin + ?Sized> Unpin for WriteAllVectoredFuture<'_, W> {}
impl<W: AsyncWrite + Unpin + ?Sized> Future for WriteAllVectoredFuture<'_, W> {
type Output = Result<()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
let Self { writer, bufs } = &mut *self;
// Guarantee that bufs is empty if it contains no data,
// to avoid calling write_vectored if there is no data to be written.
advance_slices(bufs, 0);
while !bufs.is_empty() {
match ready!(Pin::new(&mut ** writer).poll_write_vectored(cx, bufs)) {
Ok(0) => {
return Poll::Ready(Err(Error::new(
ErrorKind::WriteZero,
"failed to write whole buffer",
)));
}
Ok(n) => advance_slices(bufs, n),
Err(ref e) if e.kind() == ErrorKind::Interrupted => {}
Err(e) => return Poll::Ready(Err(e)),
}
}
Poll::Ready(Ok(()))
}
}
fn advance_slices(bufs: &mut &mut [IoSlice<'_>], n: usize) {
// Number of buffers to remove.
let mut remove = 0;
// Total length of all the to be removed buffers.
let mut accumulated_len = 0;
for buf in bufs.iter() {
if accumulated_len + buf.len() > n {
break;
} else {
accumulated_len += buf.len();
remove += 1;
}
}
*bufs = &mut replace(bufs, &mut [])[remove..];
if !bufs.is_empty() {
advance(&mut bufs[0], n - accumulated_len);
}
}
fn advance<'a>(buf: &mut IoSlice<'a>, n: usize) {
if buf.len() < n {
panic!("advancing IoSlice beyond its length");
}
// SAFTEY: hopefully
unsafe {
let mut ptr = buf.as_ptr() as *mut u8;
ptr = ptr.add(n);
let len = buf.len() - n;
let new_slice: &'a [u8] = slice::from_raw_parts(ptr, len);
*buf = IoSlice::new(new_slice);
}
}