From f2782dcaea6d06442fddf3f323654462e60ca0eb Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Fri, 27 Jan 2023 12:04:12 +0200 Subject: [PATCH 01/10] refactor: some cosmetic changes to make rust and clippy happy --- src/decode.rs | 4 ++-- src/encode.rs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/decode.rs b/src/decode.rs index af0cdb5..c48066c 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -861,7 +861,7 @@ mod test { let mut output = Vec::new(); let mut decoder = Decoder::new(&*zero_encoded, &zero_hash); decoder.read_to_end(&mut output).unwrap(); - assert_eq!(&output, &[]); + assert_eq!(output.len(), 0); // Decoding the empty tree with any other hash should fail. let mut output = Vec::new(); @@ -936,7 +936,7 @@ mod test { let mut decoder = Decoder::new(Cursor::new(&encoded), &hash); decoder.seek(SeekFrom::Start(case as u64)).unwrap(); decoder.read_to_end(&mut output).unwrap(); - assert_eq!(&output, &[]); + assert_eq!(output.len(), 0); // Seeking to EOF should fail if the root hash is wrong. let mut bad_hash_bytes = *hash.as_bytes(); diff --git a/src/encode.rs b/src/encode.rs index e8dce8f..a719618 100644 --- a/src/encode.rs +++ b/src/encode.rs @@ -1377,7 +1377,7 @@ mod test { let mut output = Vec::new(); let mut encoder = Encoder::new(io::Cursor::new(&mut output)); encoder.write_all(input).unwrap(); - encoder.write(&[]).unwrap(); + encoder.write_all(&[]).unwrap(); let hash = encoder.finalize().unwrap(); assert_eq!((output, hash), encode(input)); assert_eq!(hash, blake3::hash(input)); From c9d07006a88aa3218b3fdcf6549bdaa9b7866a41 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Fri, 27 Jan 2023 12:06:11 +0200 Subject: [PATCH 02/10] refactor: do not require Read on DecoderShared this is so DecoderShared can be used from async decoders --- src/decode.rs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/decode.rs b/src/decode.rs index c48066c..6c070a7 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -205,7 +205,7 @@ impl From for io::Error { // Shared between Decoder and SliceDecoder. #[derive(Clone)] -struct DecoderShared { +struct DecoderShared { input: T, outboard: Option, state: VerifyState, @@ -214,8 +214,8 @@ struct DecoderShared { buf_end: usize, } -impl DecoderShared { - fn new(input: T, outboard: Option, hash: &Hash) -> Self { +impl DecoderShared { + pub fn new(input: T, outboard: Option, hash: &Hash) -> Self { Self { input, outboard, @@ -240,7 +240,9 @@ impl DecoderShared { self.buf_start = 0; self.buf_end = 0; } +} +impl DecoderShared { // These bytes are always verified before going in the buffer. fn take_buffered_bytes(&mut self, output: &mut [u8]) -> usize { let take = cmp::min(self.buf_len(), output.len()); @@ -441,7 +443,7 @@ impl DecoderShared { } } -impl fmt::Debug for DecoderShared { +impl fmt::Debug for DecoderShared { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, From fa1407261ca36de3f38a260740e7eb94b666b3de Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Fri, 27 Jan 2023 13:02:55 +0200 Subject: [PATCH 03/10] feature: add minimal, optional support for tokio async for now just for the Decoder in inline and outboard mode, not yet for the SliceDecoder. --- Cargo.toml | 7 ++ src/decode.rs | 222 +++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 228 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 1806dc0..87b6281 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,12 @@ edition = "2018" arrayref = "0.3.5" arrayvec = "0.7.1" blake3 = "1.0.0" +tokio = { version = "1.24.2", features = [], default-features=false, optional = true } + +[features] +# todo: remove before merge +default = ["tokio_io"] +tokio_io = ["tokio"] [dev-dependencies] lazy_static = "1.3.0" @@ -23,3 +29,4 @@ tempfile = "3.1.0" rand_chacha = "0.3.1" rand_xorshift = "0.3.0" page_size = "0.4.1" +tokio = { version = "1.24.2", features = ["full"] } diff --git a/src/decode.rs b/src/decode.rs index 6c070a7..b4a3cde 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -215,7 +215,7 @@ struct DecoderShared { } impl DecoderShared { - pub fn new(input: T, outboard: Option, hash: &Hash) -> Self { + fn new(input: T, outboard: Option, hash: &Hash) -> Self { Self { input, outboard, @@ -456,6 +456,226 @@ impl fmt::Debug for DecoderShared { } } +#[cfg(feature = "tokio_io")] +mod tokio_io { + use super::{DecoderShared, Hash, NextRead}; + use std::{ + cmp, + convert::TryInto, + io, + pin::Pin, + task::{self, ready}, + }; + use tokio::io::{AsyncRead, ReadBuf}; + + // tokio flavour async io utilities, requiing AsyncRead + impl DecoderShared { + fn poll_read( + &mut self, + cx: &mut task::Context, + buf: &mut ReadBuf<'_>, + ) -> task::Poll> { + // Explicitly short-circuit zero-length reads. We're within our rights + // to buffer an internal chunk in this case, or to make progress if + // there's an empty chunk, but this matches the current behavior of + // SliceExtractor for zero-length slices. This might change in the + // future. + if buf.remaining() == 0 { + return task::Poll::Ready(Ok(())); + } + + // Otherwise try to verify a new chunk. + loop { + // If there are bytes in the internal buffer, just return those. + if self.buf_len() > 0 { + let n = cmp::min(buf.remaining(), self.buf_len()); + buf.put_slice(&self.buf[self.buf_start..self.buf_start + n]); + self.buf_start += n; + // if we are done with writing, go into the reading state + if self.buf_len() == 0 { + self.clear_buf(); + } + return task::Poll::Ready(Ok(())); + } + + match self.state.read_next() { + NextRead::Done => { + // This is EOF. We know the internal buffer is empty, + // because we checked it before this loop. + return task::Poll::Ready(Ok(())); + } + NextRead::Header => { + // ensure reading state, reading 8 bytes + // we might already be in the reading state, + // so we must not set buf_start to 0 + self.buf_end = 8; + // header comes from outboard if we have one, otherwise from input + ready!(self.poll_fill_buffer_from_input_or_outboard(cx))?; + self.state.feed_header(self.buf[0..8].try_into().unwrap()); + // we don't want to write the header, so we are done with the buffer contents + self.clear_buf(); + } + NextRead::Parent => { + // ensure reading state, reading 64 bytes + // we might already be in the reading state, + // so we must not set buf_start to 0 + self.buf_end = 64; + // parent comes from outboard if we have one, otherwise from input + ready!(self.poll_fill_buffer_from_input_or_outboard(cx))?; + self.state + .feed_parent(&self.buf[0..64].try_into().unwrap())?; + // we don't want to write the parent, so we are done with the buffer contents + self.clear_buf(); + } + NextRead::Chunk { + size, + finalization, + skip, + index, + } => { + // todo: add direct output optimization + + // ensure reading state, reading size bytes + // we might already be in the reading state, + // so we must not set buf_start to 0 + self.buf_end = size; + // chunk never comes from outboard + ready!(self.poll_fill_buffer_from_input(cx))?; + + // Hash it and push its hash into the VerifyState. This + // returns an error if the hash is bad. Otherwise, the + // chunk is verified. + let read_buf = &self.buf[0..size]; + let chunk_hash = blake3::guts::ChunkState::new(index) + .update(read_buf) + .finalize(finalization.is_root()); + self.state.feed_chunk(&chunk_hash)?; + + // we go into the writing state now, starting from skip + self.buf_start = skip; + // we should have something to write, + // unless the entire chunk was empty + debug_assert!(self.buf_len() > 0 || size == 0); + } + } + } + } + + fn poll_fill_buffer_from_input( + &mut self, + cx: &mut task::Context<'_>, + ) -> task::Poll> { + let mut buf = ReadBuf::new(&mut self.buf[..self.buf_end]); + buf.advance(self.buf_start); + let src = &mut self.input; + while buf.remaining() > 0 { + ready!(AsyncRead::poll_read(Pin::new(src), cx, &mut buf))?; + self.buf_start = buf.filled().len(); + } + task::Poll::Ready(Ok(())) + } + + fn poll_fill_buffer_from_outboard( + &mut self, + cx: &mut task::Context<'_>, + ) -> task::Poll> { + let mut buf = ReadBuf::new(&mut self.buf[..self.buf_end]); + buf.advance(self.buf_start); + let src = self.outboard.as_mut().unwrap(); + while buf.remaining() > 0 { + ready!(AsyncRead::poll_read(Pin::new(src), cx, &mut buf))?; + self.buf_start = buf.filled().len(); + } + task::Poll::Ready(Ok(())) + } + + fn poll_fill_buffer_from_input_or_outboard( + &mut self, + cx: &mut task::Context<'_>, + ) -> task::Poll> { + if self.outboard.is_some() { + self.poll_fill_buffer_from_outboard(cx) + } else { + self.poll_fill_buffer_from_input(cx) + } + } + } + + #[derive(Clone, Debug)] + pub struct AsyncDecoder { + shared: DecoderShared, + } + + impl AsyncDecoder { + pub fn new(inner: T, hash: &Hash) -> Self { + Self { + shared: DecoderShared::new(inner, None, hash), + } + } + } + + impl AsyncDecoder { + pub fn new_outboard(inner: T, outboard: O, hash: &Hash) -> Self { + Self { + shared: DecoderShared::new(inner, Some(outboard), hash), + } + } + + /// Return the underlying reader and the outboard reader, if any. If the `Decoder` was created + /// with `Decoder::new`, the outboard reader will be `None`. + pub fn into_inner(self) -> (T, Option) { + (self.shared.input, self.shared.outboard) + } + } + + impl AsyncRead for AsyncDecoder { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> task::Poll> { + self.shared.poll_read(cx, buf) + } + } + + #[cfg(test)] + mod tests { + use super::*; + use crate::{decode::make_test_input, encode}; + + #[tokio::test] + async fn test_async_decode() { + for &case in crate::test::TEST_CASES { + use tokio::io::AsyncReadExt; + println!("case {}", case); + let input = make_test_input(case); + let (encoded, hash) = { encode::encode(&input) }; + let mut output = Vec::new(); + let mut reader = AsyncDecoder::new(&encoded[..], &hash); + reader.read_to_end(&mut output).await.unwrap(); + assert_eq!(input, output); + } + } + + #[tokio::test] + async fn test_async_decode_outboard() { + for &case in crate::test::TEST_CASES { + use tokio::io::AsyncReadExt; + println!("case {}", case); + let input = make_test_input(case); + let (outboard, hash) = { encode::outboard(&input) }; + let mut output = Vec::new(); + let mut reader = AsyncDecoder::new_outboard(&input[..], &outboard[..], &hash); + reader.read_to_end(&mut output).await.unwrap(); + assert_eq!(input, output); + } + } + } +} + +#[cfg(feature = "tokio_io")] +pub use tokio_io::AsyncDecoder; + /// An incremental decoder, which reads and verifies the output of /// [`Encoder`](../encode/struct.Encoder.html). /// From 482f1116f532024d86d5425b71d914d122d3f409 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Fri, 27 Jan 2023 16:05:31 +0200 Subject: [PATCH 04/10] feature: add AsyncRead for AsyncSliceDecoder This requires some seek logic, but no actual seeking in the underlying inputs --- src/decode.rs | 275 +++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 272 insertions(+), 3 deletions(-) diff --git a/src/decode.rs b/src/decode.rs index b4a3cde..7d5bad3 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -458,6 +458,7 @@ impl fmt::Debug for DecoderShared { #[cfg(feature = "tokio_io")] mod tokio_io { + use super::{DecoderShared, Hash, NextRead}; use std::{ cmp, @@ -561,6 +562,74 @@ mod tokio_io { } } + // Returns Ok(true) to indicate the seek is finished. Note that both the + // Decoder and the SliceDecoder will use this method (which doesn't depend on + // io::Seek), but only the Decoder will call handle_seek_bookkeeping first. + // This may read a chunk, but it never leaves output bytes in the buffer, + // because the only time seeking reads a chunk it also skips the entire + // thing. + fn poll_handle_seek_read( + &mut self, + next: NextRead, + cx: &mut task::Context, + ) -> task::Poll> { + task::Poll::Ready(Ok(match next { + NextRead::Header => { + // ensure reading state, reading 8 bytes + // we might already be in the reading state, + // so we must not set buf_start to 0 + self.buf_end = 8; + // header comes from outboard if we have one, otherwise from input + ready!(self.poll_fill_buffer_from_input_or_outboard(cx))?; + self.state.feed_header(self.buf[0..8].try_into().unwrap()); + // we don't want to write the header, so we are done with the buffer contents + self.clear_buf(); + // not done yet + false + } + NextRead::Parent => { + // ensure reading state, reading 64 bytes + // we might already be in the reading state, + // so we must not set buf_start to 0 + self.buf_end = 64; + // parent comes from outboard if we have one, otherwise from input + ready!(self.poll_fill_buffer_from_input_or_outboard(cx))?; + self.state + .feed_parent(&self.buf[0..64].try_into().unwrap())?; + // we don't want to write the parent, so we are done with the buffer contents + self.clear_buf(); + // not done yet + false + } + NextRead::Chunk { + size, + finalization, + skip: _, + index, + } => { + // ensure reading state, reading size bytes + // we might already be in the reading state, + // so we must not set buf_start to 0 + self.buf_end = size; + // chunk never comes from outboard + ready!(self.poll_fill_buffer_from_input(cx))?; + + // Hash it and push its hash into the VerifyState. This + // returns an error if the hash is bad. Otherwise, the + // chunk is verified. + let read_buf = &self.buf[0..size]; + let chunk_hash = blake3::guts::ChunkState::new(index) + .update(read_buf) + .finalize(finalization.is_root()); + self.state.feed_chunk(&chunk_hash)?; + // we don't want to write the chunk, so we are done with the buffer contents + self.clear_buf(); + false + } + NextRead::Done => true, // The seek is done. + })) + } + fn poll_fill_buffer_from_input( &mut self, cx: &mut task::Context<'_>, @@ -638,16 +707,108 @@ mod tokio_io { } } + pub struct AsyncSliceDecoder { + shared: DecoderShared, + slice_start: u64, + slice_remaining: u64, + // If the caller requested no bytes, the extractor is still required to + // include a chunk. We're not required to verify it, but we want to + // aggressively check for extractor bugs. + need_fake_read: bool, + } + + impl AsyncSliceDecoder { + pub fn new(inner: T, hash: &Hash, slice_start: u64, slice_len: u64) -> Self { + Self { + shared: DecoderShared::new(inner, None, hash), + slice_start, + slice_remaining: slice_len, + need_fake_read: slice_len == 0, + } + } + + /// Return the underlying reader. + pub fn into_inner(self) -> T { + self.shared.input + } + } + + impl AsyncRead for AsyncSliceDecoder { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> task::Poll> { + // If we haven't done the initial seek yet, do the full seek loop + // first. Note that this will never leave any buffered output. The only + // scenario where handle_seek_read reads a chunk is if it needs to + // validate the final chunk, and then it skips the whole thing. + if self.shared.state.content_position() < self.slice_start { + loop { + // we don't keep internal state but just call seek_next again if the + // call to poll_handle_seek_read below returns pending. + let bookkeeping = self.shared.state.seek_next(self.slice_start); + // Note here, we skip to seek_bookkeeping_done without + // calling handle_seek_bookkeeping. That is, we never + // perform any underlying seeks. The slice extractor + // already took care of lining everything up for us. + let next = self.shared.state.seek_bookkeeping_done(bookkeeping); + let done = ready!(self.shared.poll_handle_seek_read(next, cx))?; + if done { + break; + } + } + debug_assert_eq!(0, self.shared.buf_len()); + } + + // We either just finished the seek (if any), or already did it during + // a previous call. Continue the read. Cap the output buffer to be at + // most the slice bytes remaining. + if self.need_fake_read { + // Read one byte and throw it away, just to verify a chunk. + let mut tmp = [0]; + let mut buf = ReadBuf::new(tmp.as_mut_slice()); + ready!(self.shared.poll_read(cx, &mut buf))?; + self.need_fake_read = false; + } else if self.slice_remaining > 0 { + // We still got bytes to read. + let filled0 = buf.filled().len(); + // This can read more than we need. But that is ok. + // Reading might need the buffer to read an entire chunk + ready!(self.shared.poll_read(cx, buf))?; + let read = (buf.filled().len() - filled0) as u64; + if read <= self.slice_remaining { + // just decrease the remaining bytes + self.slice_remaining -= read; + } else { + // We read more than we needed. + // Truncate the buffer and set remaining to 0. + let overread = (read - self.slice_remaining) as usize; + buf.set_filled(buf.filled().len() - overread); + self.slice_remaining = 0; + } + }; + task::Poll::Ready(Ok(())) + } + } + #[cfg(test)] mod tests { + use std::io::{Cursor, Read}; + + use tokio::io::AsyncReadExt; + use super::*; - use crate::{decode::make_test_input, encode}; + use crate::{ + decode::{make_test_input, SliceDecoder}, + encode, CHUNK_SIZE, HEADER_SIZE, + }; #[tokio::test] async fn test_async_decode() { for &case in crate::test::TEST_CASES { use tokio::io::AsyncReadExt; - println!("case {}", case); + println!("case {case}"); let input = make_test_input(case); let (encoded, hash) = { encode::encode(&input) }; let mut output = Vec::new(); @@ -661,7 +822,7 @@ mod tokio_io { async fn test_async_decode_outboard() { for &case in crate::test::TEST_CASES { use tokio::io::AsyncReadExt; - println!("case {}", case); + println!("case {case}"); let input = make_test_input(case); let (outboard, hash) = { encode::outboard(&input) }; let mut output = Vec::new(); @@ -670,6 +831,114 @@ mod tokio_io { assert_eq!(input, output); } } + + #[tokio::test] + async fn test_async_slices() { + for &case in crate::test::TEST_CASES { + // for case in [1025] { + let input = make_test_input(case); + let (encoded, hash) = encode::encode(&input); + // Also make an outboard encoding, to test that case. + let (outboard, outboard_hash) = encode::outboard(&input); + assert_eq!(hash, outboard_hash); + for &slice_start in crate::test::TEST_CASES { + let expected_start = cmp::min(input.len(), slice_start); + let slice_lens = [0, 1, 2, CHUNK_SIZE - 1, CHUNK_SIZE, CHUNK_SIZE + 1]; + // let slice_lens = [CHUNK_SIZE - 1, CHUNK_SIZE, CHUNK_SIZE + 1]; + for &slice_len in slice_lens.iter() { + println!("\ncase {case} start {slice_start} len {slice_len}"); + let expected_end = cmp::min(input.len(), slice_start + slice_len); + let expected_output = &input[expected_start..expected_end]; + let mut slice = Vec::new(); + let mut extractor = encode::SliceExtractor::new( + Cursor::new(&encoded), + slice_start as u64, + slice_len as u64, + ); + extractor.read_to_end(&mut slice).unwrap(); + + // Make sure the outboard extractor produces the same output. + let mut slice_from_outboard = Vec::new(); + let mut extractor = encode::SliceExtractor::new_outboard( + Cursor::new(&input), + Cursor::new(&outboard), + slice_start as u64, + slice_len as u64, + ); + extractor.read_to_end(&mut slice_from_outboard).unwrap(); + assert_eq!(slice, slice_from_outboard); + + let mut output = Vec::new(); + let mut reader = AsyncSliceDecoder::new( + &*slice, + &hash, + slice_start as u64, + slice_len as u64, + ); + reader.read_to_end(&mut output).await.unwrap(); + assert_eq!(expected_output, &*output); + } + } + } + } + + #[tokio::test] + async fn test_async_corrupted_slice() { + let input = make_test_input(20_000); + let slice_start = 5_000; + let slice_len = 10_000; + let (encoded, hash) = encode::encode(&input); + + // Slice out the middle 10_000 bytes; + let mut slice = Vec::new(); + let mut extractor = encode::SliceExtractor::new( + Cursor::new(&encoded), + slice_start as u64, + slice_len as u64, + ); + extractor.read_to_end(&mut slice).unwrap(); + + // First confirm that the regular decode works. + let mut output = Vec::new(); + let mut reader = + SliceDecoder::new(&*slice, &hash, slice_start as u64, slice_len as u64); + reader.read_to_end(&mut output).unwrap(); + assert_eq!(&input[slice_start..][..slice_len], &*output); + + // Also confirm that the outboard slice extractor gives the same slice. + let (outboard, outboard_hash) = encode::outboard(&input); + assert_eq!(hash, outboard_hash); + let mut slice_from_outboard = Vec::new(); + let mut extractor = encode::SliceExtractor::new_outboard( + Cursor::new(&input), + Cursor::new(&outboard), + slice_start as u64, + slice_len as u64, + ); + extractor.read_to_end(&mut slice_from_outboard).unwrap(); + assert_eq!(slice, slice_from_outboard); + + // Now confirm that flipping bits anywhere in the slice other than the + // length header will corrupt it. Tweaking the length header doesn't + // always break slice decoding, because the only thing its guaranteed + // to break is the final chunk, and this slice doesn't include the + // final chunk. + let mut i = HEADER_SIZE; + while i < slice.len() { + let mut slice_clone = slice.clone(); + slice_clone[i] ^= 1; + let mut reader = AsyncSliceDecoder::new( + &*slice_clone, + &hash, + slice_start as u64, + slice_len as u64, + ); + output.clear(); + let err = reader.read_to_end(&mut output).await.unwrap_err(); + assert_eq!(io::ErrorKind::InvalidData, err.kind()); + i += 32; + } + } } } From bc8af4ecb7cbea7569ddddf0e50280bddcb0af0f Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Sat, 28 Jan 2023 12:51:16 +0200 Subject: [PATCH 05/10] refactor: explicit states for reading and writing --- src/decode.rs | 266 +++++++++++++++++++++++++++++++------------------- 1 file changed, 166 insertions(+), 100 deletions(-) diff --git a/src/decode.rs b/src/decode.rs index 7d5bad3..615397d 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -465,45 +465,30 @@ mod tokio_io { convert::TryInto, io, pin::Pin, - task::{self, ready}, + task::{self, ready, Context, Poll}, }; use tokio::io::{AsyncRead, ReadBuf}; // tokio flavour async io utilities, requiing AsyncRead impl DecoderShared { - fn poll_read( - &mut self, - cx: &mut task::Context, - buf: &mut ReadBuf<'_>, - ) -> task::Poll> { - // Explicitly short-circuit zero-length reads. We're within our rights - // to buffer an internal chunk in this case, or to make progress if - // there's an empty chunk, but this matches the current behavior of - // SliceExtractor for zero-length slices. This might change in the - // future. - if buf.remaining() == 0 { - return task::Poll::Ready(Ok(())); - } + /// write from the internal buffer to the output buffer + fn write_output(&mut self, buf: &mut ReadBuf<'_>) { + let n = cmp::min(buf.remaining(), self.buf_len()); + buf.put_slice(&self.buf[self.buf_start..self.buf_start + n]); + self.buf_start += n; + } - // Otherwise try to verify a new chunk. + /// fills the internal buffer from the input or outboard + /// + /// will return Poll::Pending until we have a chunk in the internal buffer to be read, + /// or we reach EOF + fn poll_input(&mut self, cx: &mut Context) -> Poll> { loop { - // If there are bytes in the internal buffer, just return those. - if self.buf_len() > 0 { - let n = cmp::min(buf.remaining(), self.buf_len()); - buf.put_slice(&self.buf[self.buf_start..self.buf_start + n]); - self.buf_start += n; - // if we are done with writing, go into the reading state - if self.buf_len() == 0 { - self.clear_buf(); - } - return task::Poll::Ready(Ok(())); - } - match self.state.read_next() { NextRead::Done => { // This is EOF. We know the internal buffer is empty, // because we checked it before this loop. - return task::Poll::Ready(Ok(())); + break Poll::Ready(Ok(())); } NextRead::Header => { // ensure reading state, reading 8 bytes @@ -557,6 +542,8 @@ mod tokio_io { // we should have something to write, // unless the entire chunk was empty debug_assert!(self.buf_len() > 0 || size == 0); + + break Poll::Ready(Ok(())); } } } @@ -571,9 +558,9 @@ mod tokio_io { fn poll_handle_seek_read( &mut self, next: NextRead, - cx: &mut task::Context, - ) -> task::Poll> { - task::Poll::Ready(Ok(match next { + cx: &mut Context, + ) -> Poll> { + Poll::Ready(Ok(match next { NextRead::Header => { // ensure reading state, reading 8 bytes // we might already be in the reading state, @@ -632,8 +619,8 @@ mod tokio_io { fn poll_fill_buffer_from_input( &mut self, - cx: &mut task::Context<'_>, - ) -> task::Poll> { + cx: &mut Context<'_>, + ) -> Poll> { let mut buf = ReadBuf::new(&mut self.buf[..self.buf_end]); buf.advance(self.buf_start); let src = &mut self.input; @@ -641,13 +628,13 @@ mod tokio_io { ready!(AsyncRead::poll_read(Pin::new(src), cx, &mut buf))?; self.buf_start = buf.filled().len(); } - task::Poll::Ready(Ok(())) + Poll::Ready(Ok(())) } fn poll_fill_buffer_from_outboard( &mut self, - cx: &mut task::Context<'_>, - ) -> task::Poll> { + cx: &mut Context<'_>, + ) -> Poll> { let mut buf = ReadBuf::new(&mut self.buf[..self.buf_end]); buf.advance(self.buf_start); let src = self.outboard.as_mut().unwrap(); @@ -655,13 +642,13 @@ mod tokio_io { ready!(AsyncRead::poll_read(Pin::new(src), cx, &mut buf))?; self.buf_start = buf.filled().len(); } - task::Poll::Ready(Ok(())) + Poll::Ready(Ok(())) } fn poll_fill_buffer_from_input_or_outboard( &mut self, - cx: &mut task::Context<'_>, - ) -> task::Poll> { + cx: &mut Context<'_>, + ) -> Poll> { if self.outboard.is_some() { self.poll_fill_buffer_from_outboard(cx) } else { @@ -670,44 +657,86 @@ mod tokio_io { } } - #[derive(Clone, Debug)] - pub struct AsyncDecoder { - shared: DecoderShared, + #[derive(Debug)] + enum DecoderState { + /// we are reading from the underlying reader + Reading(Box>), + /// we are being polled for output + Output(Box>), } + #[derive(Debug)] + pub struct AsyncDecoder(Option>); + impl AsyncDecoder { pub fn new(inner: T, hash: &Hash) -> Self { - Self { - shared: DecoderShared::new(inner, None, hash), - } + let state = DecoderShared::new(inner, None, hash); + Self(Some(DecoderState::Reading(Box::new(state)))) } } impl AsyncDecoder { pub fn new_outboard(inner: T, outboard: O, hash: &Hash) -> Self { - Self { - shared: DecoderShared::new(inner, Some(outboard), hash), - } - } - - /// Return the underlying reader and the outboard reader, if any. If the `Decoder` was created - /// with `Decoder::new`, the outboard reader will be `None`. - pub fn into_inner(self) -> (T, Option) { - (self.shared.input, self.shared.outboard) + let state = DecoderShared::new(inner, Some(outboard), hash); + Self(Some(DecoderState::Reading(Box::new(state)))) } } impl AsyncRead for AsyncDecoder { fn poll_read( mut self: Pin<&mut Self>, - cx: &mut task::Context<'_>, + cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, - ) -> task::Poll> { - self.shared.poll_read(cx, buf) + ) -> Poll> { + // on a zero length read, we do nothing whatsoever + if buf.remaining() == 0 { + return Poll::Ready(Ok(())); + } + + loop { + match self.0.take().unwrap() { + DecoderState::Reading(mut shared) => { + match shared.poll_input(cx) { + Poll::Ready(Ok(())) => { + // we have read a chunk from the underlying reader + // go to output state + self.0 = Some(DecoderState::Output(shared)); + continue; + } + Poll::Ready(Err(e)) => { + // we got an error from the underlying io + // stay in reading state + self.0 = Some(DecoderState::Reading(shared)); + break Poll::Ready(Err(e)); + } + Poll::Pending => { + // we don't have a complete chunk yet + // stay in reading state + self.0 = Some(DecoderState::Output(shared)); + break Poll::Pending; + } + } + } + DecoderState::Output(mut shared) => { + shared.write_output(buf); + if shared.buf_len() == 0 { + // the caller has consumed all the data in the buffer + // go to reading state + shared.clear_buf(); + self.0 = Some(DecoderState::Reading(shared)) + } else { + // we still have data in the buffer + // stay in output state + self.0 = Some(DecoderState::Output(shared)) + }; + break Poll::Ready(Ok(())); + } + } + } } } - pub struct AsyncSliceDecoder { + pub struct SliceDecoderInner { shared: DecoderShared, slice_start: u64, slice_remaining: u64, @@ -717,28 +746,8 @@ mod tokio_io { need_fake_read: bool, } - impl AsyncSliceDecoder { - pub fn new(inner: T, hash: &Hash, slice_start: u64, slice_len: u64) -> Self { - Self { - shared: DecoderShared::new(inner, None, hash), - slice_start, - slice_remaining: slice_len, - need_fake_read: slice_len == 0, - } - } - - /// Return the underlying reader. - pub fn into_inner(self) -> T { - self.shared.input - } - } - - impl AsyncRead for AsyncSliceDecoder { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut task::Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> task::Poll> { + impl SliceDecoderInner { + fn poll_input(&mut self, cx: &mut Context<'_>) -> Poll> { // If we haven't done the initial seek yet, do the full seek loop // first. Note that this will never leave any buffered output. The only // scenario where handle_seek_read reads a chunk is if it needs to @@ -766,36 +775,94 @@ mod tokio_io { // most the slice bytes remaining. if self.need_fake_read { // Read one byte and throw it away, just to verify a chunk. - let mut tmp = [0]; - let mut buf = ReadBuf::new(tmp.as_mut_slice()); - ready!(self.shared.poll_read(cx, &mut buf))?; + ready!(self.shared.poll_input(cx))?; + self.shared.clear_buf(); self.need_fake_read = false; } else if self.slice_remaining > 0 { - // We still got bytes to read. - let filled0 = buf.filled().len(); - // This can read more than we need. But that is ok. - // Reading might need the buffer to read an entire chunk - ready!(self.shared.poll_read(cx, buf))?; - let read = (buf.filled().len() - filled0) as u64; - if read <= self.slice_remaining { - // just decrease the remaining bytes - self.slice_remaining -= read; + ready!(self.shared.poll_input(cx))?; + let len = self.shared.buf_len() as u64; + if len <= self.slice_remaining { + self.slice_remaining -= len; } else { - // We read more than we needed. - // Truncate the buffer and set remaining to 0. - let overread = (read - self.slice_remaining) as usize; - buf.set_filled(buf.filled().len() - overread); + // We read more than we needed. Truncate the buffer. + self.shared.buf_end -= (len - self.slice_remaining) as usize; self.slice_remaining = 0; } }; - task::Poll::Ready(Ok(())) + Poll::Ready(Ok(())) + } + } + + enum SliceDecoderState { + /// we are reading from the underlying reader + Reading(Box>), + /// we are being polled for output + Output(Box>), + } + + pub struct AsyncSliceDecoder(Option>); + + impl AsyncSliceDecoder { + pub fn new(inner: T, hash: &Hash, slice_start: u64, slice_len: u64) -> Self { + let state = SliceDecoderInner { + shared: DecoderShared::new(inner, None, hash), + slice_start, + slice_remaining: slice_len, + need_fake_read: slice_len == 0, + }; + Self(Some(SliceDecoderState::Reading(Box::new(state)))) + } + } + + impl AsyncRead for AsyncSliceDecoder { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + // on a zero length read, we do nothing whatsoever + if buf.remaining() == 0 { + return Poll::Ready(Ok(())); + } + + loop { + match self.0.take().unwrap() { + SliceDecoderState::Reading(mut state) => match state.poll_input(cx) { + Poll::Ready(Ok(())) => { + self.0 = Some(SliceDecoderState::Output(state)); + continue; + } + Poll::Ready(Err(e)) => { + self.0 = Some(SliceDecoderState::Reading(state)); + break Poll::Ready(Err(e)); + } + Poll::Pending => { + self.0 = Some(SliceDecoderState::Reading(state)); + break Poll::Pending; + } + }, + SliceDecoderState::Output(mut state) => { + state.shared.write_output(buf); + if state.shared.buf_len() == 0 { + // the caller has consumed all the data in the buffer + // go to reading state + state.shared.clear_buf(); + self.0 = Some(SliceDecoderState::Reading(state)) + } else { + // we still have data in the buffer + // stay in output state + self.0 = Some(SliceDecoderState::Output(state)) + }; + break Poll::Ready(Ok(())); + } + } + } } } #[cfg(test)] mod tests { use std::io::{Cursor, Read}; - use tokio::io::AsyncReadExt; use super::*; @@ -835,7 +902,6 @@ mod tokio_io { #[tokio::test] async fn test_async_slices() { for &case in crate::test::TEST_CASES { - // for case in [1025] { let input = make_test_input(case); let (encoded, hash) = encode::encode(&input); // Also make an outboard encoding, to test that case. @@ -1554,7 +1620,7 @@ mod test { // be exactly the same as the entire encoded tree. This can act as a cheap way to convert // an outboard tree to a combined one. for &case in crate::test::TEST_CASES { - println!("case {}", case); + println!("case {case}"); let input = make_test_input(case); let (encoded, _) = encode::encode(&input); let (outboard, _) = encode::outboard(&input); From 09ecead343933a3c4d4e335331e3fa176c64e6d5 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Sat, 28 Jan 2023 16:39:40 +0200 Subject: [PATCH 06/10] refactor: avoid the option just give the state a case Done so we can implement take --- src/decode.rs | 58 +++++++++++++++++++++++++++++++++++---------------- 1 file changed, 40 insertions(+), 18 deletions(-) diff --git a/src/decode.rs b/src/decode.rs index 615397d..195410e 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -465,7 +465,7 @@ mod tokio_io { convert::TryInto, io, pin::Pin, - task::{self, ready, Context, Poll}, + task::{ready, Context, Poll}, }; use tokio::io::{AsyncRead, ReadBuf}; @@ -663,22 +663,30 @@ mod tokio_io { Reading(Box>), /// we are being polled for output Output(Box>), + /// we are done + Done, + } + + impl DecoderState { + fn take(&mut self) -> Self { + std::mem::replace(self, DecoderState::Done) + } } #[derive(Debug)] - pub struct AsyncDecoder(Option>); + pub struct AsyncDecoder(DecoderState); impl AsyncDecoder { pub fn new(inner: T, hash: &Hash) -> Self { let state = DecoderShared::new(inner, None, hash); - Self(Some(DecoderState::Reading(Box::new(state)))) + Self(DecoderState::Reading(Box::new(state))) } } impl AsyncDecoder { pub fn new_outboard(inner: T, outboard: O, hash: &Hash) -> Self { let state = DecoderShared::new(inner, Some(outboard), hash); - Self(Some(DecoderState::Reading(Box::new(state)))) + Self(DecoderState::Reading(Box::new(state))) } } @@ -694,25 +702,25 @@ mod tokio_io { } loop { - match self.0.take().unwrap() { + match self.0.take() { DecoderState::Reading(mut shared) => { match shared.poll_input(cx) { Poll::Ready(Ok(())) => { // we have read a chunk from the underlying reader // go to output state - self.0 = Some(DecoderState::Output(shared)); + self.0 = DecoderState::Output(shared); continue; } Poll::Ready(Err(e)) => { // we got an error from the underlying io // stay in reading state - self.0 = Some(DecoderState::Reading(shared)); + self.0 = DecoderState::Reading(shared); break Poll::Ready(Err(e)); } Poll::Pending => { // we don't have a complete chunk yet // stay in reading state - self.0 = Some(DecoderState::Output(shared)); + self.0 = DecoderState::Output(shared); break Poll::Pending; } } @@ -723,14 +731,17 @@ mod tokio_io { // the caller has consumed all the data in the buffer // go to reading state shared.clear_buf(); - self.0 = Some(DecoderState::Reading(shared)) + self.0 = DecoderState::Reading(shared); } else { // we still have data in the buffer // stay in output state - self.0 = Some(DecoderState::Output(shared)) + self.0 = DecoderState::Output(shared); }; break Poll::Ready(Ok(())); } + DecoderState::Done => { + break Poll::Ready(Ok(())); + } } } } @@ -798,9 +809,17 @@ mod tokio_io { Reading(Box>), /// we are being polled for output Output(Box>), + /// we are done + Done, } - pub struct AsyncSliceDecoder(Option>); + impl SliceDecoderState { + fn take(&mut self) -> Self { + std::mem::replace(self, SliceDecoderState::Done) + } + } + + pub struct AsyncSliceDecoder(SliceDecoderState); impl AsyncSliceDecoder { pub fn new(inner: T, hash: &Hash, slice_start: u64, slice_len: u64) -> Self { @@ -810,7 +829,7 @@ mod tokio_io { slice_remaining: slice_len, need_fake_read: slice_len == 0, }; - Self(Some(SliceDecoderState::Reading(Box::new(state)))) + Self(SliceDecoderState::Reading(Box::new(state))) } } @@ -826,18 +845,18 @@ mod tokio_io { } loop { - match self.0.take().unwrap() { + match self.0.take() { SliceDecoderState::Reading(mut state) => match state.poll_input(cx) { Poll::Ready(Ok(())) => { - self.0 = Some(SliceDecoderState::Output(state)); + self.0 = SliceDecoderState::Output(state); continue; } Poll::Ready(Err(e)) => { - self.0 = Some(SliceDecoderState::Reading(state)); + self.0 = SliceDecoderState::Reading(state); break Poll::Ready(Err(e)); } Poll::Pending => { - self.0 = Some(SliceDecoderState::Reading(state)); + self.0 = SliceDecoderState::Reading(state); break Poll::Pending; } }, @@ -847,14 +866,17 @@ mod tokio_io { // the caller has consumed all the data in the buffer // go to reading state state.shared.clear_buf(); - self.0 = Some(SliceDecoderState::Reading(state)) + self.0 = SliceDecoderState::Reading(state) } else { // we still have data in the buffer // stay in output state - self.0 = Some(SliceDecoderState::Output(state)) + self.0 = SliceDecoderState::Output(state) }; break Poll::Ready(Ok(())); } + SliceDecoderState::Done => { + break Poll::Ready(Ok(())); + } } } } From d44f5c7408b6251021462d9e0906754bf7f111bd Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Sat, 28 Jan 2023 17:57:36 +0200 Subject: [PATCH 07/10] refactor: resolve weirdness with using buf_start to track end we don't have to do this anymore now that we properly track state --- src/decode.rs | 44 ++++++++++++++++++-------------------------- 1 file changed, 18 insertions(+), 26 deletions(-) diff --git a/src/decode.rs b/src/decode.rs index 195410e..cc483bf 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -493,10 +493,8 @@ mod tokio_io { NextRead::Header => { // ensure reading state, reading 8 bytes // we might already be in the reading state, - // so we must not set buf_start to 0 - self.buf_end = 8; // header comes from outboard if we have one, otherwise from input - ready!(self.poll_fill_buffer_from_input_or_outboard(cx))?; + ready!(self.poll_fill_buffer_from_input_or_outboard(8, cx))?; self.state.feed_header(self.buf[0..8].try_into().unwrap()); // we don't want to write the header, so we are done with the buffer contents self.clear_buf(); @@ -504,10 +502,8 @@ mod tokio_io { NextRead::Parent => { // ensure reading state, reading 64 bytes // we might already be in the reading state, - // so we must not set buf_start to 0 - self.buf_end = 64; // parent comes from outboard if we have one, otherwise from input - ready!(self.poll_fill_buffer_from_input_or_outboard(cx))?; + ready!(self.poll_fill_buffer_from_input_or_outboard(64, cx))?; self.state .feed_parent(&self.buf[0..64].try_into().unwrap())?; // we don't want to write the parent, so we are done with the buffer contents @@ -524,9 +520,8 @@ mod tokio_io { // ensure reading state, reading size bytes // we might already be in the reading state, // so we must not set buf_start to 0 - self.buf_end = size; // chunk never comes from outboard - ready!(self.poll_fill_buffer_from_input(cx))?; + ready!(self.poll_fill_buffer_from_input(size, cx))?; // Hash it and push its hash into the VerifyState. This // returns an error if the hash is bad. Otherwise, the @@ -564,10 +559,8 @@ mod tokio_io { NextRead::Header => { // ensure reading state, reading 8 bytes // we might already be in the reading state, - // so we must not set buf_start to 0 - self.buf_end = 8; // header comes from outboard if we have one, otherwise from input - ready!(self.poll_fill_buffer_from_input_or_outboard(cx))?; + ready!(self.poll_fill_buffer_from_input_or_outboard(8, cx))?; self.state.feed_header(self.buf[0..8].try_into().unwrap()); // we don't want to write the header, so we are done with the buffer contents self.clear_buf(); @@ -577,10 +570,8 @@ mod tokio_io { NextRead::Parent => { // ensure reading state, reading 64 bytes // we might already be in the reading state, - // so we must not set buf_start to 0 - self.buf_end = 64; // parent comes from outboard if we have one, otherwise from input - ready!(self.poll_fill_buffer_from_input_or_outboard(cx))?; + ready!(self.poll_fill_buffer_from_input_or_outboard(64, cx))?; self.state .feed_parent(&self.buf[0..64].try_into().unwrap())?; // we don't want to write the parent, so we are done with the buffer contents @@ -596,10 +587,8 @@ mod tokio_io { } => { // ensure reading state, reading size bytes // we might already be in the reading state, - // so we must not set buf_start to 0 - self.buf_end = size; // chunk never comes from outboard - ready!(self.poll_fill_buffer_from_input(cx))?; + ready!(self.poll_fill_buffer_from_input(size, cx))?; // Hash it and push its hash into the VerifyState. This // returns an error if the hash is bad. Otherwise, the @@ -619,40 +608,43 @@ mod tokio_io { fn poll_fill_buffer_from_input( &mut self, + size: usize, cx: &mut Context<'_>, ) -> Poll> { - let mut buf = ReadBuf::new(&mut self.buf[..self.buf_end]); - buf.advance(self.buf_start); + let mut buf = ReadBuf::new(&mut self.buf[..size]); + buf.advance(self.buf_end); let src = &mut self.input; while buf.remaining() > 0 { ready!(AsyncRead::poll_read(Pin::new(src), cx, &mut buf))?; - self.buf_start = buf.filled().len(); + self.buf_end = buf.filled().len(); } Poll::Ready(Ok(())) } fn poll_fill_buffer_from_outboard( &mut self, + size: usize, cx: &mut Context<'_>, ) -> Poll> { - let mut buf = ReadBuf::new(&mut self.buf[..self.buf_end]); - buf.advance(self.buf_start); + let mut buf = ReadBuf::new(&mut self.buf[..size]); + buf.advance(self.buf_end); let src = self.outboard.as_mut().unwrap(); while buf.remaining() > 0 { ready!(AsyncRead::poll_read(Pin::new(src), cx, &mut buf))?; - self.buf_start = buf.filled().len(); + self.buf_end = buf.filled().len(); } Poll::Ready(Ok(())) } fn poll_fill_buffer_from_input_or_outboard( &mut self, + size: usize, cx: &mut Context<'_>, ) -> Poll> { if self.outboard.is_some() { - self.poll_fill_buffer_from_outboard(cx) + self.poll_fill_buffer_from_outboard(size, cx) } else { - self.poll_fill_buffer_from_input(cx) + self.poll_fill_buffer_from_input(size, cx) } } } @@ -1031,7 +1023,7 @@ mod tokio_io { } #[cfg(feature = "tokio_io")] -pub use tokio_io::AsyncDecoder; +pub use tokio_io::{AsyncDecoder, AsyncSliceDecoder}; /// An incremental decoder, which reads and verifies the output of /// [`Encoder`](../encode/struct.Encoder.html). From c1f7c9c6bbdaacff4d0f8ab888057f04157b1c07 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Mon, 30 Jan 2023 13:53:54 +0200 Subject: [PATCH 08/10] refactor: some type aliases for clippy also update comments --- src/decode.rs | 128 +++++++++++++++++++++++--------------------------- 1 file changed, 60 insertions(+), 68 deletions(-) diff --git a/src/decode.rs b/src/decode.rs index cc483bf..7c77f41 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -486,13 +486,11 @@ mod tokio_io { loop { match self.state.read_next() { NextRead::Done => { - // This is EOF. We know the internal buffer is empty, - // because we checked it before this loop. + // This is EOF. break Poll::Ready(Ok(())); } NextRead::Header => { - // ensure reading state, reading 8 bytes - // we might already be in the reading state, + // read 8 bytes // header comes from outboard if we have one, otherwise from input ready!(self.poll_fill_buffer_from_input_or_outboard(8, cx))?; self.state.feed_header(self.buf[0..8].try_into().unwrap()); @@ -500,8 +498,7 @@ mod tokio_io { self.clear_buf(); } NextRead::Parent => { - // ensure reading state, reading 64 bytes - // we might already be in the reading state, + // read 64 bytes // parent comes from outboard if we have one, otherwise from input ready!(self.poll_fill_buffer_from_input_or_outboard(64, cx))?; self.state @@ -515,11 +512,7 @@ mod tokio_io { skip, index, } => { - // todo: add direct output optimization - - // ensure reading state, reading size bytes - // we might already be in the reading state, - // so we must not set buf_start to 0 + // read size bytes // chunk never comes from outboard ready!(self.poll_fill_buffer_from_input(size, cx))?; @@ -557,8 +550,7 @@ mod tokio_io { ) -> Poll> { Poll::Ready(Ok(match next { NextRead::Header => { - // ensure reading state, reading 8 bytes - // we might already be in the reading state, + // read 8 bytes // header comes from outboard if we have one, otherwise from input ready!(self.poll_fill_buffer_from_input_or_outboard(8, cx))?; self.state.feed_header(self.buf[0..8].try_into().unwrap()); @@ -568,8 +560,7 @@ mod tokio_io { false } NextRead::Parent => { - // ensure reading state, reading 64 bytes - // we might already be in the reading state, + // read 64 bytes // parent comes from outboard if we have one, otherwise from input ready!(self.poll_fill_buffer_from_input_or_outboard(64, cx))?; self.state @@ -585,8 +576,7 @@ mod tokio_io { skip: _, index, } => { - // ensure reading state, reading size bytes - // we might already be in the reading state, + // read size bytes // chunk never comes from outboard ready!(self.poll_fill_buffer_from_input(size, cx))?; @@ -649,40 +639,27 @@ mod tokio_io { } } - #[derive(Debug)] + /// type alias to make clippy happy + type BoxedDecoderShared = Pin>>; + + /// state of the decoder + /// + /// This is the decoder, but it is a separate type so it can be private. + /// The public AsyncDecoder just wraps this. enum DecoderState { /// we are reading from the underlying reader - Reading(Box>), + Reading(BoxedDecoderShared), /// we are being polled for output - Output(Box>), - /// we are done - Done, + Output(BoxedDecoderShared), + /// invalid state + Invalid, } impl DecoderState { fn take(&mut self) -> Self { - std::mem::replace(self, DecoderState::Done) + std::mem::replace(self, DecoderState::Invalid) } - } - - #[derive(Debug)] - pub struct AsyncDecoder(DecoderState); - impl AsyncDecoder { - pub fn new(inner: T, hash: &Hash) -> Self { - let state = DecoderShared::new(inner, None, hash); - Self(DecoderState::Reading(Box::new(state))) - } - } - - impl AsyncDecoder { - pub fn new_outboard(inner: T, outboard: O, hash: &Hash) -> Self { - let state = DecoderShared::new(inner, Some(outboard), hash); - Self(DecoderState::Reading(Box::new(state))) - } - } - - impl AsyncRead for AsyncDecoder { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -694,44 +671,33 @@ mod tokio_io { } loop { - match self.0.take() { - DecoderState::Reading(mut shared) => { - match shared.poll_input(cx) { - Poll::Ready(Ok(())) => { - // we have read a chunk from the underlying reader - // go to output state - self.0 = DecoderState::Output(shared); - continue; - } - Poll::Ready(Err(e)) => { - // we got an error from the underlying io - // stay in reading state - self.0 = DecoderState::Reading(shared); - break Poll::Ready(Err(e)); - } - Poll::Pending => { - // we don't have a complete chunk yet - // stay in reading state - self.0 = DecoderState::Output(shared); - break Poll::Pending; - } + match self.take() { + Self::Reading(mut shared) => { + let res = shared.poll_input(cx); + if let Poll::Ready(Ok(())) = res { + // we have read a chunk from the underlying reader + // go to output state + *self = Self::Output(shared); + continue; } + *self = Self::Reading(shared); + break res; } - DecoderState::Output(mut shared) => { + Self::Output(mut shared) => { shared.write_output(buf); - if shared.buf_len() == 0 { + *self = if shared.buf_len() == 0 { // the caller has consumed all the data in the buffer // go to reading state shared.clear_buf(); - self.0 = DecoderState::Reading(shared); + Self::Reading(shared) } else { // we still have data in the buffer // stay in output state - self.0 = DecoderState::Output(shared); + Self::Output(shared) }; break Poll::Ready(Ok(())); } - DecoderState::Done => { + DecoderState::Invalid => { break Poll::Ready(Ok(())); } } @@ -739,6 +705,32 @@ mod tokio_io { } } + pub struct AsyncDecoder(DecoderState); + + impl AsyncDecoder { + pub fn new(inner: T, hash: &Hash) -> Self { + let state = DecoderShared::new(inner, None, hash); + Self(DecoderState::Reading(Box::pin(state))) + } + } + + impl AsyncDecoder { + pub fn new_outboard(inner: T, outboard: O, hash: &Hash) -> Self { + let state = DecoderShared::new(inner, Some(outboard), hash); + Self(DecoderState::Reading(Box::pin(state))) + } + } + + impl AsyncRead for AsyncDecoder { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.0).poll_read(cx, buf) + } + } + pub struct SliceDecoderInner { shared: DecoderShared, slice_start: u64, From 44943bdc7d1e4f6cfeeb892878c62bcb2f6fb2d9 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Wed, 8 Feb 2023 10:26:01 +0200 Subject: [PATCH 09/10] refactor: add clone and debug --- src/decode.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/decode.rs b/src/decode.rs index 7c77f41..681165a 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -731,7 +731,8 @@ mod tokio_io { } } - pub struct SliceDecoderInner { + #[derive(Clone, Debug)] + struct SliceDecoderInner { shared: DecoderShared, slice_start: u64, slice_remaining: u64, @@ -788,6 +789,7 @@ mod tokio_io { } } + #[derive(Clone, Debug)] enum SliceDecoderState { /// we are reading from the underlying reader Reading(Box>), @@ -803,6 +805,7 @@ mod tokio_io { } } + #[derive(Clone, Debug)] pub struct AsyncSliceDecoder(SliceDecoderState); impl AsyncSliceDecoder { From 2329a8f946ae644bb29e4b8243952244412ca7c0 Mon Sep 17 00:00:00 2001 From: Ruediger Klaehn Date: Wed, 8 Feb 2023 10:34:49 +0200 Subject: [PATCH 10/10] feature: add into_inner This can always succeed, because the SliceDecoder will never implement Seek --- src/decode.rs | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/decode.rs b/src/decode.rs index 681165a..2e4cc99 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -795,13 +795,13 @@ mod tokio_io { Reading(Box>), /// we are being polled for output Output(Box>), - /// we are done - Done, + /// value so we can implement take. If you see this, you've found a bug. + Taken, } impl SliceDecoderState { fn take(&mut self) -> Self { - std::mem::replace(self, SliceDecoderState::Done) + std::mem::replace(self, SliceDecoderState::Taken) } } @@ -818,6 +818,14 @@ mod tokio_io { }; Self(SliceDecoderState::Reading(Box::new(state))) } + + pub fn into_inner(self) -> T { + match self.0 { + SliceDecoderState::Reading(state) => state.shared.input, + SliceDecoderState::Output(state) => state.shared.input, + SliceDecoderState::Taken => unreachable!(), + } + } } impl AsyncRead for AsyncSliceDecoder { @@ -861,8 +869,8 @@ mod tokio_io { }; break Poll::Ready(Ok(())); } - SliceDecoderState::Done => { - break Poll::Ready(Ok(())); + SliceDecoderState::Taken => { + unreachable!(); } } }