PronounsToday/web/src/write_vectored_all.rs

106 lines
3.0 KiB
Rust

use std::slice;
use std::future::Future;
use std::io::{Error, ErrorKind, IoSlice, Result};
use std::mem::replace;
use std::pin::Pin;
use std::task::Poll;
use futures_lite::AsyncWrite;
use futures_lite::ready;
pub trait AsyncWriteAllVectored: AsyncWrite {
fn write_all_vectored<'a>(&'a mut self, bufs: &'a mut [IoSlice<'a>]) -> WriteAllVectoredFuture<'a, Self>
where
Self: Unpin,
{
WriteAllVectoredFuture { writer: self, bufs }
}
}
impl <T: AsyncWrite> AsyncWriteAllVectored for T {}
pub struct WriteAllVectoredFuture<'a, W: Unpin + ?Sized> {
writer: &'a mut W,
bufs: &'a mut [IoSlice<'a>],
}
impl<W: Unpin + ?Sized> Unpin for WriteAllVectoredFuture<'_, W> {}
impl<W: AsyncWrite + Unpin + ?Sized> Future for WriteAllVectoredFuture<'_, W> {
type Output = Result<()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
let Self { writer, bufs } = &mut *self;
// Guarantee that bufs is empty if it contains no data,
// to avoid calling write_vectored if there is no data to be written.
advance_slices(bufs, 0);
while !bufs.is_empty() {
match ready!(Pin::new(&mut ** writer).poll_write_vectored(cx, bufs)) {
Ok(0) => {
return Poll::Ready(Err(Error::new(
ErrorKind::WriteZero,
"failed to write whole buffer",
)));
}
Ok(n) => advance_slices(bufs, n),
Err(ref e) if e.kind() == ErrorKind::Interrupted => {}
Err(e) => return Poll::Ready(Err(e)),
}
}
Poll::Ready(Ok(()))
}
}
fn advance_slices(bufs: &mut &mut [IoSlice<'_>], n: usize) {
// Number of buffers to remove.
let mut remove = 0;
// Total length of all the to be removed buffers.
let mut accumulated_len = 0;
for buf in bufs.iter() {
if accumulated_len + buf.len() > n {
break;
} else {
accumulated_len += buf.len();
remove += 1;
}
}
*bufs = &mut replace(bufs, &mut [])[remove..];
if !bufs.is_empty() {
advance(&mut bufs[0], n - accumulated_len);
}
}
fn advance<'a>(buf: &mut IoSlice<'a>, n: usize) {
if buf.len() < n {
panic!("advancing IoSlice beyond its length");
}
// This is just a hacky way of advancing the pointer inside the IoSlice
// SAFTEY: The newly constructed IoSlice has the same lifetime as the old and
// this is guaranteed not to overflow the buffer due to the previous check
unsafe {
let mut ptr = buf.as_ptr() as *mut u8;
ptr = ptr.add(n);
let len = buf.len() - n;
let new_slice: &'a [u8] = slice::from_raw_parts(ptr, len);
*buf = IoSlice::new(new_slice);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_advance() {
let expected: Vec<_> = (10..100).collect();
let buf: Vec<_> = (0..100).collect();
let mut io_slice = IoSlice::new(&buf);
advance(&mut io_slice, 10);
assert_eq!(io_slice.len(), 90);
assert_eq!(&*io_slice, &expected);
}
}