diff --git a/Cargo.toml b/Cargo.toml index 1806dc0..87b6281 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,12 @@ edition = "2018" arrayref = "0.3.5" arrayvec = "0.7.1" blake3 = "1.0.0" +tokio = { version = "1.24.2", features = [], default-features=false, optional = true } + +[features] +# todo: remove before merge +default = ["tokio_io"] +tokio_io = ["tokio"] [dev-dependencies] lazy_static = "1.3.0" @@ -23,3 +29,4 @@ tempfile = "3.1.0" rand_chacha = "0.3.1" rand_xorshift = "0.3.0" page_size = "0.4.1" +tokio = { version = "1.24.2", features = ["full"] } diff --git a/src/decode.rs b/src/decode.rs index af0cdb5..2e4cc99 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -205,7 +205,7 @@ impl From for io::Error { // Shared between Decoder and SliceDecoder. #[derive(Clone)] -struct DecoderShared { +struct DecoderShared { input: T, outboard: Option, state: VerifyState, @@ -214,7 +214,7 @@ struct DecoderShared { buf_end: usize, } -impl DecoderShared { +impl DecoderShared { fn new(input: T, outboard: Option, hash: &Hash) -> Self { Self { input, @@ -240,7 +240,9 @@ impl DecoderShared { self.buf_start = 0; self.buf_end = 0; } +} +impl DecoderShared { // These bytes are always verified before going in the buffer. fn take_buffered_bytes(&mut self, output: &mut [u8]) -> usize { let take = cmp::min(self.buf_len(), output.len()); @@ -441,7 +443,7 @@ impl DecoderShared { } } -impl fmt::Debug for DecoderShared { +impl fmt::Debug for DecoderShared { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, @@ -454,6 +456,578 @@ impl fmt::Debug for DecoderShared { } } +#[cfg(feature = "tokio_io")] +mod tokio_io { + + use super::{DecoderShared, Hash, NextRead}; + use std::{ + cmp, + convert::TryInto, + io, + pin::Pin, + task::{ready, Context, Poll}, + }; + use tokio::io::{AsyncRead, ReadBuf}; + + // tokio flavour async io utilities, requiing AsyncRead + impl DecoderShared { + /// write from the internal buffer to the output buffer + fn write_output(&mut self, buf: &mut ReadBuf<'_>) { + let n = cmp::min(buf.remaining(), self.buf_len()); + buf.put_slice(&self.buf[self.buf_start..self.buf_start + n]); + self.buf_start += n; + } + + /// fills the internal buffer from the input or outboard + /// + /// will return Poll::Pending until we have a chunk in the internal buffer to be read, + /// or we reach EOF + fn poll_input(&mut self, cx: &mut Context) -> Poll> { + loop { + match self.state.read_next() { + NextRead::Done => { + // This is EOF. + break Poll::Ready(Ok(())); + } + NextRead::Header => { + // read 8 bytes + // header comes from outboard if we have one, otherwise from input + ready!(self.poll_fill_buffer_from_input_or_outboard(8, cx))?; + self.state.feed_header(self.buf[0..8].try_into().unwrap()); + // we don't want to write the header, so we are done with the buffer contents + self.clear_buf(); + } + NextRead::Parent => { + // read 64 bytes + // parent comes from outboard if we have one, otherwise from input + ready!(self.poll_fill_buffer_from_input_or_outboard(64, cx))?; + self.state + .feed_parent(&self.buf[0..64].try_into().unwrap())?; + // we don't want to write the parent, so we are done with the buffer contents + self.clear_buf(); + } + NextRead::Chunk { + size, + finalization, + skip, + index, + } => { + // read size bytes + // chunk never comes from outboard + ready!(self.poll_fill_buffer_from_input(size, cx))?; + + // Hash it and push its hash into the VerifyState. This + // returns an error if the hash is bad. Otherwise, the + // chunk is verified. + let read_buf = &self.buf[0..size]; + let chunk_hash = blake3::guts::ChunkState::new(index) + .update(read_buf) + .finalize(finalization.is_root()); + self.state.feed_chunk(&chunk_hash)?; + + // we go into the writing state now, starting from skip + self.buf_start = skip; + // we should have something to write, + // unless the entire chunk was empty + debug_assert!(self.buf_len() > 0 || size == 0); + + break Poll::Ready(Ok(())); + } + } + } + } + + // Returns Ok(true) to indicate the seek is finished. Note that both the + // Decoder and the SliceDecoder will use this method (which doesn't depend on + // io::Seek), but only the Decoder will call handle_seek_bookkeeping first. + // This may read a chunk, but it never leaves output bytes in the buffer, + // because the only time seeking reads a chunk it also skips the entire + // thing. + fn poll_handle_seek_read( + &mut self, + next: NextRead, + cx: &mut Context, + ) -> Poll> { + Poll::Ready(Ok(match next { + NextRead::Header => { + // read 8 bytes + // header comes from outboard if we have one, otherwise from input + ready!(self.poll_fill_buffer_from_input_or_outboard(8, cx))?; + self.state.feed_header(self.buf[0..8].try_into().unwrap()); + // we don't want to write the header, so we are done with the buffer contents + self.clear_buf(); + // not done yet + false + } + NextRead::Parent => { + // read 64 bytes + // parent comes from outboard if we have one, otherwise from input + ready!(self.poll_fill_buffer_from_input_or_outboard(64, cx))?; + self.state + .feed_parent(&self.buf[0..64].try_into().unwrap())?; + // we don't want to write the parent, so we are done with the buffer contents + self.clear_buf(); + // not done yet + false + } + NextRead::Chunk { + size, + finalization, + skip: _, + index, + } => { + // read size bytes + // chunk never comes from outboard + ready!(self.poll_fill_buffer_from_input(size, cx))?; + + // Hash it and push its hash into the VerifyState. This + // returns an error if the hash is bad. Otherwise, the + // chunk is verified. + let read_buf = &self.buf[0..size]; + let chunk_hash = blake3::guts::ChunkState::new(index) + .update(read_buf) + .finalize(finalization.is_root()); + self.state.feed_chunk(&chunk_hash)?; + // we don't want to write the chunk, so we are done with the buffer contents + self.clear_buf(); + false + } + NextRead::Done => true, // The seek is done. + })) + } + + fn poll_fill_buffer_from_input( + &mut self, + size: usize, + cx: &mut Context<'_>, + ) -> Poll> { + let mut buf = ReadBuf::new(&mut self.buf[..size]); + buf.advance(self.buf_end); + let src = &mut self.input; + while buf.remaining() > 0 { + ready!(AsyncRead::poll_read(Pin::new(src), cx, &mut buf))?; + self.buf_end = buf.filled().len(); + } + Poll::Ready(Ok(())) + } + + fn poll_fill_buffer_from_outboard( + &mut self, + size: usize, + cx: &mut Context<'_>, + ) -> Poll> { + let mut buf = ReadBuf::new(&mut self.buf[..size]); + buf.advance(self.buf_end); + let src = self.outboard.as_mut().unwrap(); + while buf.remaining() > 0 { + ready!(AsyncRead::poll_read(Pin::new(src), cx, &mut buf))?; + self.buf_end = buf.filled().len(); + } + Poll::Ready(Ok(())) + } + + fn poll_fill_buffer_from_input_or_outboard( + &mut self, + size: usize, + cx: &mut Context<'_>, + ) -> Poll> { + if self.outboard.is_some() { + self.poll_fill_buffer_from_outboard(size, cx) + } else { + self.poll_fill_buffer_from_input(size, cx) + } + } + } + + /// type alias to make clippy happy + type BoxedDecoderShared = Pin>>; + + /// state of the decoder + /// + /// This is the decoder, but it is a separate type so it can be private. + /// The public AsyncDecoder just wraps this. + enum DecoderState { + /// we are reading from the underlying reader + Reading(BoxedDecoderShared), + /// we are being polled for output + Output(BoxedDecoderShared), + /// invalid state + Invalid, + } + + impl DecoderState { + fn take(&mut self) -> Self { + std::mem::replace(self, DecoderState::Invalid) + } + + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + // on a zero length read, we do nothing whatsoever + if buf.remaining() == 0 { + return Poll::Ready(Ok(())); + } + + loop { + match self.take() { + Self::Reading(mut shared) => { + let res = shared.poll_input(cx); + if let Poll::Ready(Ok(())) = res { + // we have read a chunk from the underlying reader + // go to output state + *self = Self::Output(shared); + continue; + } + *self = Self::Reading(shared); + break res; + } + Self::Output(mut shared) => { + shared.write_output(buf); + *self = if shared.buf_len() == 0 { + // the caller has consumed all the data in the buffer + // go to reading state + shared.clear_buf(); + Self::Reading(shared) + } else { + // we still have data in the buffer + // stay in output state + Self::Output(shared) + }; + break Poll::Ready(Ok(())); + } + DecoderState::Invalid => { + break Poll::Ready(Ok(())); + } + } + } + } + } + + pub struct AsyncDecoder(DecoderState); + + impl AsyncDecoder { + pub fn new(inner: T, hash: &Hash) -> Self { + let state = DecoderShared::new(inner, None, hash); + Self(DecoderState::Reading(Box::pin(state))) + } + } + + impl AsyncDecoder { + pub fn new_outboard(inner: T, outboard: O, hash: &Hash) -> Self { + let state = DecoderShared::new(inner, Some(outboard), hash); + Self(DecoderState::Reading(Box::pin(state))) + } + } + + impl AsyncRead for AsyncDecoder { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.0).poll_read(cx, buf) + } + } + + #[derive(Clone, Debug)] + struct SliceDecoderInner { + shared: DecoderShared, + slice_start: u64, + slice_remaining: u64, + // If the caller requested no bytes, the extractor is still required to + // include a chunk. We're not required to verify it, but we want to + // aggressively check for extractor bugs. + need_fake_read: bool, + } + + impl SliceDecoderInner { + fn poll_input(&mut self, cx: &mut Context<'_>) -> Poll> { + // If we haven't done the initial seek yet, do the full seek loop + // first. Note that this will never leave any buffered output. The only + // scenario where handle_seek_read reads a chunk is if it needs to + // validate the final chunk, and then it skips the whole thing. + if self.shared.state.content_position() < self.slice_start { + loop { + // we don't keep internal state but just call seek_next again if the + // call to poll_handle_seek_read below returns pending. + let bookkeeping = self.shared.state.seek_next(self.slice_start); + // Note here, we skip to seek_bookkeeping_done without + // calling handle_seek_bookkeeping. That is, we never + // perform any underlying seeks. The slice extractor + // already took care of lining everything up for us. + let next = self.shared.state.seek_bookkeeping_done(bookkeeping); + let done = ready!(self.shared.poll_handle_seek_read(next, cx))?; + if done { + break; + } + } + debug_assert_eq!(0, self.shared.buf_len()); + } + + // We either just finished the seek (if any), or already did it during + // a previous call. Continue the read. Cap the output buffer to be at + // most the slice bytes remaining. + if self.need_fake_read { + // Read one byte and throw it away, just to verify a chunk. + ready!(self.shared.poll_input(cx))?; + self.shared.clear_buf(); + self.need_fake_read = false; + } else if self.slice_remaining > 0 { + ready!(self.shared.poll_input(cx))?; + let len = self.shared.buf_len() as u64; + if len <= self.slice_remaining { + self.slice_remaining -= len; + } else { + // We read more than we needed. Truncate the buffer. + self.shared.buf_end -= (len - self.slice_remaining) as usize; + self.slice_remaining = 0; + } + }; + Poll::Ready(Ok(())) + } + } + + #[derive(Clone, Debug)] + enum SliceDecoderState { + /// we are reading from the underlying reader + Reading(Box>), + /// we are being polled for output + Output(Box>), + /// value so we can implement take. If you see this, you've found a bug. + Taken, + } + + impl SliceDecoderState { + fn take(&mut self) -> Self { + std::mem::replace(self, SliceDecoderState::Taken) + } + } + + #[derive(Clone, Debug)] + pub struct AsyncSliceDecoder(SliceDecoderState); + + impl AsyncSliceDecoder { + pub fn new(inner: T, hash: &Hash, slice_start: u64, slice_len: u64) -> Self { + let state = SliceDecoderInner { + shared: DecoderShared::new(inner, None, hash), + slice_start, + slice_remaining: slice_len, + need_fake_read: slice_len == 0, + }; + Self(SliceDecoderState::Reading(Box::new(state))) + } + + pub fn into_inner(self) -> T { + match self.0 { + SliceDecoderState::Reading(state) => state.shared.input, + SliceDecoderState::Output(state) => state.shared.input, + SliceDecoderState::Taken => unreachable!(), + } + } + } + + impl AsyncRead for AsyncSliceDecoder { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + // on a zero length read, we do nothing whatsoever + if buf.remaining() == 0 { + return Poll::Ready(Ok(())); + } + + loop { + match self.0.take() { + SliceDecoderState::Reading(mut state) => match state.poll_input(cx) { + Poll::Ready(Ok(())) => { + self.0 = SliceDecoderState::Output(state); + continue; + } + Poll::Ready(Err(e)) => { + self.0 = SliceDecoderState::Reading(state); + break Poll::Ready(Err(e)); + } + Poll::Pending => { + self.0 = SliceDecoderState::Reading(state); + break Poll::Pending; + } + }, + SliceDecoderState::Output(mut state) => { + state.shared.write_output(buf); + if state.shared.buf_len() == 0 { + // the caller has consumed all the data in the buffer + // go to reading state + state.shared.clear_buf(); + self.0 = SliceDecoderState::Reading(state) + } else { + // we still have data in the buffer + // stay in output state + self.0 = SliceDecoderState::Output(state) + }; + break Poll::Ready(Ok(())); + } + SliceDecoderState::Taken => { + unreachable!(); + } + } + } + } + } + + #[cfg(test)] + mod tests { + use std::io::{Cursor, Read}; + use tokio::io::AsyncReadExt; + + use super::*; + use crate::{ + decode::{make_test_input, SliceDecoder}, + encode, CHUNK_SIZE, HEADER_SIZE, + }; + + #[tokio::test] + async fn test_async_decode() { + for &case in crate::test::TEST_CASES { + use tokio::io::AsyncReadExt; + println!("case {case}"); + let input = make_test_input(case); + let (encoded, hash) = { encode::encode(&input) }; + let mut output = Vec::new(); + let mut reader = AsyncDecoder::new(&encoded[..], &hash); + reader.read_to_end(&mut output).await.unwrap(); + assert_eq!(input, output); + } + } + + #[tokio::test] + async fn test_async_decode_outboard() { + for &case in crate::test::TEST_CASES { + use tokio::io::AsyncReadExt; + println!("case {case}"); + let input = make_test_input(case); + let (outboard, hash) = { encode::outboard(&input) }; + let mut output = Vec::new(); + let mut reader = AsyncDecoder::new_outboard(&input[..], &outboard[..], &hash); + reader.read_to_end(&mut output).await.unwrap(); + assert_eq!(input, output); + } + } + + #[tokio::test] + async fn test_async_slices() { + for &case in crate::test::TEST_CASES { + let input = make_test_input(case); + let (encoded, hash) = encode::encode(&input); + // Also make an outboard encoding, to test that case. + let (outboard, outboard_hash) = encode::outboard(&input); + assert_eq!(hash, outboard_hash); + for &slice_start in crate::test::TEST_CASES { + let expected_start = cmp::min(input.len(), slice_start); + let slice_lens = [0, 1, 2, CHUNK_SIZE - 1, CHUNK_SIZE, CHUNK_SIZE + 1]; + // let slice_lens = [CHUNK_SIZE - 1, CHUNK_SIZE, CHUNK_SIZE + 1]; + for &slice_len in slice_lens.iter() { + println!("\ncase {case} start {slice_start} len {slice_len}"); + let expected_end = cmp::min(input.len(), slice_start + slice_len); + let expected_output = &input[expected_start..expected_end]; + let mut slice = Vec::new(); + let mut extractor = encode::SliceExtractor::new( + Cursor::new(&encoded), + slice_start as u64, + slice_len as u64, + ); + extractor.read_to_end(&mut slice).unwrap(); + + // Make sure the outboard extractor produces the same output. + let mut slice_from_outboard = Vec::new(); + let mut extractor = encode::SliceExtractor::new_outboard( + Cursor::new(&input), + Cursor::new(&outboard), + slice_start as u64, + slice_len as u64, + ); + extractor.read_to_end(&mut slice_from_outboard).unwrap(); + assert_eq!(slice, slice_from_outboard); + + let mut output = Vec::new(); + let mut reader = AsyncSliceDecoder::new( + &*slice, + &hash, + slice_start as u64, + slice_len as u64, + ); + reader.read_to_end(&mut output).await.unwrap(); + assert_eq!(expected_output, &*output); + } + } + } + } + + #[tokio::test] + async fn test_async_corrupted_slice() { + let input = make_test_input(20_000); + let slice_start = 5_000; + let slice_len = 10_000; + let (encoded, hash) = encode::encode(&input); + + // Slice out the middle 10_000 bytes; + let mut slice = Vec::new(); + let mut extractor = encode::SliceExtractor::new( + Cursor::new(&encoded), + slice_start as u64, + slice_len as u64, + ); + extractor.read_to_end(&mut slice).unwrap(); + + // First confirm that the regular decode works. + let mut output = Vec::new(); + let mut reader = + SliceDecoder::new(&*slice, &hash, slice_start as u64, slice_len as u64); + reader.read_to_end(&mut output).unwrap(); + assert_eq!(&input[slice_start..][..slice_len], &*output); + + // Also confirm that the outboard slice extractor gives the same slice. + let (outboard, outboard_hash) = encode::outboard(&input); + assert_eq!(hash, outboard_hash); + let mut slice_from_outboard = Vec::new(); + let mut extractor = encode::SliceExtractor::new_outboard( + Cursor::new(&input), + Cursor::new(&outboard), + slice_start as u64, + slice_len as u64, + ); + extractor.read_to_end(&mut slice_from_outboard).unwrap(); + assert_eq!(slice, slice_from_outboard); + + // Now confirm that flipping bits anywhere in the slice other than the + // length header will corrupt it. Tweaking the length header doesn't + // always break slice decoding, because the only thing its guaranteed + // to break is the final chunk, and this slice doesn't include the + // final chunk. + let mut i = HEADER_SIZE; + while i < slice.len() { + let mut slice_clone = slice.clone(); + slice_clone[i] ^= 1; + let mut reader = AsyncSliceDecoder::new( + &*slice_clone, + &hash, + slice_start as u64, + slice_len as u64, + ); + output.clear(); + let err = reader.read_to_end(&mut output).await.unwrap_err(); + assert_eq!(io::ErrorKind::InvalidData, err.kind()); + i += 32; + } + } + } +} + +#[cfg(feature = "tokio_io")] +pub use tokio_io::{AsyncDecoder, AsyncSliceDecoder}; + /// An incremental decoder, which reads and verifies the output of /// [`Encoder`](../encode/struct.Encoder.html). /// @@ -861,7 +1435,7 @@ mod test { let mut output = Vec::new(); let mut decoder = Decoder::new(&*zero_encoded, &zero_hash); decoder.read_to_end(&mut output).unwrap(); - assert_eq!(&output, &[]); + assert_eq!(output.len(), 0); // Decoding the empty tree with any other hash should fail. let mut output = Vec::new(); @@ -936,7 +1510,7 @@ mod test { let mut decoder = Decoder::new(Cursor::new(&encoded), &hash); decoder.seek(SeekFrom::Start(case as u64)).unwrap(); decoder.read_to_end(&mut output).unwrap(); - assert_eq!(&output, &[]); + assert_eq!(output.len(), 0); // Seeking to EOF should fail if the root hash is wrong. let mut bad_hash_bytes = *hash.as_bytes(); @@ -1063,7 +1637,7 @@ mod test { // be exactly the same as the entire encoded tree. This can act as a cheap way to convert // an outboard tree to a combined one. for &case in crate::test::TEST_CASES { - println!("case {}", case); + println!("case {case}"); let input = make_test_input(case); let (encoded, _) = encode::encode(&input); let (outboard, _) = encode::outboard(&input); diff --git a/src/encode.rs b/src/encode.rs index e8dce8f..a719618 100644 --- a/src/encode.rs +++ b/src/encode.rs @@ -1377,7 +1377,7 @@ mod test { let mut output = Vec::new(); let mut encoder = Encoder::new(io::Cursor::new(&mut output)); encoder.write_all(input).unwrap(); - encoder.write(&[]).unwrap(); + encoder.write_all(&[]).unwrap(); let hash = encoder.finalize().unwrap(); assert_eq!((output, hash), encode(input)); assert_eq!(hash, blake3::hash(input));