[web] use actual vectored write
Since IoSlice::advance isn't stable I had to hack my own. I will need to confirm that my implementation is safe. Signed-off-by: Ben Aaron Goldberg <ben@benaaron.dev>
This commit is contained in:
parent
4ed748cf1e
commit
63cfa014d5
|
@ -2,6 +2,7 @@ pub mod ogp_images;
|
|||
pub mod contrast;
|
||||
pub mod statics;
|
||||
pub mod configuration;
|
||||
mod write_vectored_all;
|
||||
|
||||
use crate::configuration::Conf;
|
||||
use std::io::IoSlice;
|
||||
|
@ -16,6 +17,7 @@ use percent_encoding::{percent_decode_str, percent_encode, NON_ALPHANUMERIC};
|
|||
use pronouns_today::user_preferences::ParseError;
|
||||
use pronouns_today::UserPreferences;
|
||||
use configuration::ConfigError;
|
||||
use write_vectored_all::AsyncWriteAllVectored;
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::process::exit;
|
||||
|
@ -135,16 +137,11 @@ async fn handle_request(
|
|||
|
||||
let response = route_request(&req)
|
||||
.generate_response(conf);
|
||||
let io_slices = response.into_io_slices();
|
||||
let mut io_slices = response.into_io_slices();
|
||||
|
||||
for slice in io_slices {
|
||||
if let Err(e) = stream.write_all(&slice).await {
|
||||
log::warn!(
|
||||
"Encountered an IO error while sending response: {}", e
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if let Err(e) = stream.write_all_vectored(&mut io_slices).await {
|
||||
log::warn!("Encountered an IO error while sending response: {}", e);
|
||||
}
|
||||
if let Err(e) = stream.close().await {
|
||||
log::warn!(
|
||||
"Encountered an IO error while closing connection to server: {}", e
|
||||
|
|
89
web/src/write_vectored_all.rs
Normal file
89
web/src/write_vectored_all.rs
Normal file
|
@ -0,0 +1,89 @@
|
|||
use std::slice;
|
||||
use std::future::Future;
|
||||
use std::io::{Error, ErrorKind, IoSlice, Result};
|
||||
use std::mem::replace;
|
||||
use std::pin::Pin;
|
||||
use std::task::Poll;
|
||||
|
||||
use futures_lite::AsyncWrite;
|
||||
use futures_lite::ready;
|
||||
|
||||
pub trait AsyncWriteAllVectored: AsyncWrite {
|
||||
fn write_all_vectored<'a>(&'a mut self, bufs: &'a mut [IoSlice<'a>]) -> WriteAllVectoredFuture<'a, Self>
|
||||
where
|
||||
Self: Unpin,
|
||||
{
|
||||
WriteAllVectoredFuture { writer: self, bufs }
|
||||
}
|
||||
}
|
||||
|
||||
impl <T: AsyncWrite> AsyncWriteAllVectored for T {}
|
||||
|
||||
pub struct WriteAllVectoredFuture<'a, W: Unpin + ?Sized> {
|
||||
writer: &'a mut W,
|
||||
bufs: &'a mut [IoSlice<'a>],
|
||||
}
|
||||
|
||||
impl<W: Unpin + ?Sized> Unpin for WriteAllVectoredFuture<'_, W> {}
|
||||
|
||||
impl<W: AsyncWrite + Unpin + ?Sized> Future for WriteAllVectoredFuture<'_, W> {
|
||||
type Output = Result<()>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
|
||||
let Self { writer, bufs } = &mut *self;
|
||||
|
||||
// Guarantee that bufs is empty if it contains no data,
|
||||
// to avoid calling write_vectored if there is no data to be written.
|
||||
advance_slices(bufs, 0);
|
||||
while !bufs.is_empty() {
|
||||
match ready!(Pin::new(&mut ** writer).poll_write_vectored(cx, bufs)) {
|
||||
Ok(0) => {
|
||||
return Poll::Ready(Err(Error::new(
|
||||
ErrorKind::WriteZero,
|
||||
"failed to write whole buffer",
|
||||
)));
|
||||
}
|
||||
Ok(n) => advance_slices(bufs, n),
|
||||
Err(ref e) if e.kind() == ErrorKind::Interrupted => {}
|
||||
Err(e) => return Poll::Ready(Err(e)),
|
||||
}
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
fn advance_slices(bufs: &mut &mut [IoSlice<'_>], n: usize) {
|
||||
// Number of buffers to remove.
|
||||
let mut remove = 0;
|
||||
// Total length of all the to be removed buffers.
|
||||
let mut accumulated_len = 0;
|
||||
for buf in bufs.iter() {
|
||||
if accumulated_len + buf.len() > n {
|
||||
break;
|
||||
} else {
|
||||
accumulated_len += buf.len();
|
||||
remove += 1;
|
||||
}
|
||||
}
|
||||
|
||||
*bufs = &mut replace(bufs, &mut [])[remove..];
|
||||
if !bufs.is_empty() {
|
||||
advance(&mut bufs[0], n - accumulated_len);
|
||||
}
|
||||
}
|
||||
|
||||
fn advance<'a>(buf: &mut IoSlice<'a>, n: usize) {
|
||||
if buf.len() < n {
|
||||
panic!("advancing IoSlice beyond its length");
|
||||
}
|
||||
// SAFTEY: hopefully
|
||||
unsafe {
|
||||
let mut ptr = buf.as_ptr() as *mut u8;
|
||||
ptr = ptr.add(n);
|
||||
let len = buf.len() - n;
|
||||
let new_slice: &'a [u8] = slice::from_raw_parts(ptr, len);
|
||||
*buf = IoSlice::new(new_slice);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in a new issue