106 lines
3.0 KiB
Rust
106 lines
3.0 KiB
Rust
use std::slice;
|
|
use std::future::Future;
|
|
use std::io::{Error, ErrorKind, IoSlice, Result};
|
|
use std::mem::replace;
|
|
use std::pin::Pin;
|
|
use std::task::Poll;
|
|
|
|
use futures_lite::AsyncWrite;
|
|
use futures_lite::ready;
|
|
|
|
pub trait AsyncWriteAllVectored: AsyncWrite {
|
|
fn write_all_vectored<'a>(&'a mut self, bufs: &'a mut [IoSlice<'a>]) -> WriteAllVectoredFuture<'a, Self>
|
|
where
|
|
Self: Unpin,
|
|
{
|
|
WriteAllVectoredFuture { writer: self, bufs }
|
|
}
|
|
}
|
|
|
|
impl <T: AsyncWrite> AsyncWriteAllVectored for T {}
|
|
|
|
pub struct WriteAllVectoredFuture<'a, W: Unpin + ?Sized> {
|
|
writer: &'a mut W,
|
|
bufs: &'a mut [IoSlice<'a>],
|
|
}
|
|
|
|
impl<W: Unpin + ?Sized> Unpin for WriteAllVectoredFuture<'_, W> {}
|
|
|
|
impl<W: AsyncWrite + Unpin + ?Sized> Future for WriteAllVectoredFuture<'_, W> {
|
|
type Output = Result<()>;
|
|
|
|
fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
|
|
let Self { writer, bufs } = &mut *self;
|
|
|
|
// Guarantee that bufs is empty if it contains no data,
|
|
// to avoid calling write_vectored if there is no data to be written.
|
|
advance_slices(bufs, 0);
|
|
while !bufs.is_empty() {
|
|
match ready!(Pin::new(&mut ** writer).poll_write_vectored(cx, bufs)) {
|
|
Ok(0) => {
|
|
return Poll::Ready(Err(Error::new(
|
|
ErrorKind::WriteZero,
|
|
"failed to write whole buffer",
|
|
)));
|
|
}
|
|
Ok(n) => advance_slices(bufs, n),
|
|
Err(ref e) if e.kind() == ErrorKind::Interrupted => {}
|
|
Err(e) => return Poll::Ready(Err(e)),
|
|
}
|
|
}
|
|
|
|
Poll::Ready(Ok(()))
|
|
}
|
|
}
|
|
|
|
fn advance_slices(bufs: &mut &mut [IoSlice<'_>], n: usize) {
|
|
// Number of buffers to remove.
|
|
let mut remove = 0;
|
|
// Total length of all the to be removed buffers.
|
|
let mut accumulated_len = 0;
|
|
for buf in bufs.iter() {
|
|
if accumulated_len + buf.len() > n {
|
|
break;
|
|
} else {
|
|
accumulated_len += buf.len();
|
|
remove += 1;
|
|
}
|
|
}
|
|
|
|
*bufs = &mut replace(bufs, &mut [])[remove..];
|
|
if !bufs.is_empty() {
|
|
advance(&mut bufs[0], n - accumulated_len);
|
|
}
|
|
}
|
|
|
|
fn advance<'a>(buf: &mut IoSlice<'a>, n: usize) {
|
|
if buf.len() < n {
|
|
panic!("advancing IoSlice beyond its length");
|
|
}
|
|
// This is just a hacky way of advancing the pointer inside the IoSlice
|
|
// SAFTEY: The newly constructed IoSlice has the same lifetime as the old and
|
|
// this is guaranteed not to overflow the buffer due to the previous check
|
|
unsafe {
|
|
let mut ptr = buf.as_ptr() as *mut u8;
|
|
ptr = ptr.add(n);
|
|
let len = buf.len() - n;
|
|
let new_slice: &'a [u8] = slice::from_raw_parts(ptr, len);
|
|
*buf = IoSlice::new(new_slice);
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_advance() {
|
|
let expected: Vec<_> = (10..100).collect();
|
|
let buf: Vec<_> = (0..100).collect();
|
|
let mut io_slice = IoSlice::new(&buf);
|
|
advance(&mut io_slice, 10);
|
|
assert_eq!(io_slice.len(), 90);
|
|
assert_eq!(&*io_slice, &expected);
|
|
}
|
|
}
|