From c8e731b8ccbff8cd18d035f58d11a2e8c33bab3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johan=20Bj=C3=A4reholt?= Date: Tue, 7 Nov 2023 10:09:05 +0100 Subject: [PATCH] Add support for "Expect: 100-continue" header Change-Id: I7cc1035e9d36e4b2c94edc5ad571b74c11f74c94 --- src/http_crate.rs | 10 ++- src/http_interop.rs | 11 ++- src/response.rs | 176 +++++++++++++++++++++++++++----------------- src/stream.rs | 5 +- src/unit.rs | 18 ++++- 5 files changed, 142 insertions(+), 78 deletions(-) diff --git a/src/http_crate.rs b/src/http_crate.rs index a7e92c72..0a954a68 100644 --- a/src/http_crate.rs +++ b/src/http_crate.rs @@ -3,7 +3,9 @@ use std::{ net::{IpAddr, Ipv4Addr, SocketAddr}, }; -use crate::{header::HeaderLine, response::ResponseStatusIndex, Request, Response}; +use crate::{ + header::HeaderLine, response::PendingReader, response::ResponseStatusIndex, Request, Response, +}; /// Converts an [`http::Response`] into a [`Response`]. /// @@ -44,7 +46,7 @@ impl + Send + Sync + 'static> From> for Respons HeaderLine::from(raw_header).into_header().unwrap() }) .collect::>(), - reader: Box::new(Cursor::new(value.into_body())), + reader: PendingReader::Reader(Box::new(Cursor::new(value.into_body()))), remote_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 80), local_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0), history: vec![], @@ -328,12 +330,14 @@ mod tests { #[test] fn convert_to_http_response_bytes() { + use crate::response::PendingReader; use http::Response; use std::io::Cursor; let mut response = super::Response::new(200, "OK", "tbr").unwrap(); // b'\xFF' as invalid UTF-8 character - response.reader = Box::new(Cursor::new(vec![b'\xFF', 0xde, 0xad, 0xbe, 0xef])); + response.reader = + PendingReader::Reader(Box::new(Cursor::new(vec![b'\xFF', 0xde, 0xad, 0xbe, 0xef]))); let http_response: Response> = response.into(); assert_eq!( diff --git a/src/http_interop.rs b/src/http_interop.rs index c51041c5..0a34bda6 100644 --- a/src/http_interop.rs +++ b/src/http_interop.rs @@ -5,7 +5,9 @@ use std::{ net::{IpAddr, Ipv4Addr, SocketAddr}, }; -use crate::{header::HeaderLine, response::ResponseStatusIndex, Request, Response}; +use crate::{ + header::HeaderLine, response::PendingReader, response::ResponseStatusIndex, Request, Response, +}; /// Converts an [`http::Response`] into a [`Response`]. /// @@ -47,7 +49,7 @@ impl + Send + Sync + 'static> From> for Respons HeaderLine::from(raw_header).into_header().unwrap() }) .collect::>(), - reader: Box::new(Cursor::new(value.into_body())), + reader: PendingReader::Reader(Box::new(Cursor::new(value.into_body()))), remote_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 80), local_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0), history: vec![], @@ -280,6 +282,7 @@ impl From for http::request::Builder { mod tests { use crate::header::{add_header, get_header_raw, HeaderLine}; use http_02 as http; + use std::io::Read; #[test] fn convert_http_response() { @@ -376,7 +379,9 @@ mod tests { let mut response = super::Response::new(200, "OK", "tbr").unwrap(); // b'\xFF' as invalid UTF-8 character - response.reader = Box::new(Cursor::new(vec![b'\xFF', 0xde, 0xad, 0xbe, 0xef])); + response.reader = super::PendingReader::Reader(Box::new(Cursor::new(vec![ + b'\xFF', 0xde, 0xad, 0xbe, 0xef, + ]))); let http_response: Response> = response.into(); assert_eq!( diff --git a/src/response.rs b/src/response.rs index af1ebba7..ba5a9029 100644 --- a/src/response.rs +++ b/src/response.rs @@ -49,6 +49,11 @@ enum BodyType { CloseDelimited, } +pub(crate) enum PendingReader { + BeforeBodyStart, + Reader(Box), +} + /// Response instances are created as results of firing off requests. /// /// The `Response` is used to read response headers and decide what to do with the body. @@ -81,7 +86,7 @@ pub struct Response { pub(crate) index: ResponseStatusIndex, pub(crate) status: u16, pub(crate) headers: Vec
, - pub(crate) reader: Box, + pub(crate) reader: PendingReader, /// The socket address of the server that sent the response. pub(crate) remote_addr: SocketAddr, /// The socket address of the client that sent the request. @@ -282,7 +287,12 @@ impl Response { /// # } /// ``` pub fn into_reader(self) -> Box { - self.reader + match self.reader { + PendingReader::Reader(reader) => reader, + PendingReader::BeforeBodyStart => panic!( + "It is not valid to call into_reader before Request::stream_to_reader is called" + ), + } } // Determine what to do with the connection after we've read the body. @@ -548,39 +558,33 @@ impl Response { }) } - /// Create a response from a Read trait impl. - /// - /// This is hopefully useful for unit tests. - /// - /// Example: + /// Create a response from a DeadlineStream, reading and parsing only the status line, headers + /// and its following CRLF. /// - /// use std::io::Cursor; - /// - /// let text = "HTTP/1.1 401 Authorization Required\r\n\r\nPlease log in\n"; - /// let read = Cursor::new(text.to_string().into_bytes()); - /// let resp = ureq::Response::do_from_read(read); - /// - /// assert_eq!(resp.status(), 401); - pub(crate) fn do_from_stream(stream: Stream, unit: &Unit) -> Result { - let remote_addr = stream.remote_addr; + /// Since this function only reads the status line, header and the following CRLF, the returned + /// Response will have an empty reader and does not take ownership of DeadlineStream. + /// To read the following data, the DeadlineStream can be read again after the call to this + /// function. + pub(crate) fn read_response_head( + stream: &mut DeadlineStream, + unit: &Unit, + ) -> Result { + let mut bytes_read = 0; + let remote_addr = stream.inner_ref().remote_addr; - let local_addr = match stream.socket() { + let local_addr = match stream.inner_ref().socket() { Some(sock) => sock.local_addr().map_err(Error::from)?, None => std::net::SocketAddrV4::new(std::net::Ipv4Addr::new(127, 0, 0, 1), 0).into(), }; - // - // HTTP/1.1 200 OK\r\n - let mut stream = stream::DeadlineStream::new(stream, unit.deadline); - // The status line we can ignore non-utf8 chars and parse as_str_lossy(). - let status_line = read_next_line(&mut stream, "the status line")?.into_string_lossy(); + let status_line = + read_next_line(stream, "the status line", &mut bytes_read)?.into_string_lossy(); let (index, status) = parse_status_line(status_line.as_str())?; - let http_version = &status_line.as_str()[0..index.http_version]; let mut headers: Vec
= Vec::new(); while headers.len() <= MAX_HEADER_COUNT { - let line = read_next_line(&mut stream, "a header")?; + let line = read_next_line(stream, "a header", &mut bytes_read)?; if line.is_empty() { break; } @@ -595,23 +599,8 @@ impl Response { )); } - let compression = - get_header(&headers, "content-encoding").and_then(Compression::from_header_value); - - let connection_option = - Self::connection_option(http_version, get_header(&headers, "connection")); - - let body_type = Self::body_type(&unit.method, status, http_version, &headers); - - // remove Content-Encoding and length due to automatic decompression - if compression.is_some() { - headers.retain(|h| !h.is_name("content-encoding") && !h.is_name("content-length")); - } - - let reader = - Self::stream_to_reader(stream, unit, body_type, compression, connection_option); - let url = unit.url.clone(); + let reader = PendingReader::BeforeBodyStart; let response = Response { url, @@ -627,6 +616,50 @@ impl Response { Ok(response) } + /// Create a Response from a DeadlineStream + /// + /// Parses and comsumes the header from the stream and creates a Response with + /// the stream as the reader. The response reader also uncompresses the body + /// if it is compressed. + pub(crate) fn do_from_stream( + mut stream: DeadlineStream, + unit: &Unit, + ) -> Result { + let mut response = Self::read_response_head(&mut stream, unit)?; + + let compression = get_header(&response.headers, "content-encoding") + .and_then(Compression::from_header_value); + + let connection_option = Self::connection_option( + response.http_version(), + get_header(&response.headers, "connection"), + ); + + let body_type = Self::body_type( + &unit.method, + response.status(), + response.http_version(), + &response.headers, + ); + + // remove Content-Encoding and length due to automatic decompression + if compression.is_some() { + response + .headers + .retain(|h| !h.is_name("content-encoding") && !h.is_name("content-length")); + } + + response.reader = PendingReader::Reader(Self::stream_to_reader( + stream, + unit, + body_type, + compression, + connection_option, + )); + + Ok(response) + } + #[cfg(test)] pub fn set_url(&mut self, url: Url) { self.url = url; @@ -766,16 +799,23 @@ impl FromStr for Response { &request_reader, None, ); + let stream = stream::DeadlineStream::new(stream, unit.deadline); Self::do_from_stream(stream, &unit) } } -fn read_next_line(reader: &mut impl BufRead, context: &str) -> io::Result { +fn read_next_line( + reader: &mut impl BufRead, + context: &str, + running_total: &mut usize, +) -> io::Result { let mut buf = Vec::new(); let result = reader .take((MAX_HEADER_SIZE + 1) as u64) .read_until(b'\n', &mut buf); + *running_total += buf.len(); + match result { Ok(0) => Err(io::Error::new( io::ErrorKind::ConnectionAborted, @@ -1078,7 +1118,8 @@ mod tests { const LEN: usize = MAX_HEADER_SIZE + 1; let s = format!("Long-Header: {}\r\n", "A".repeat(LEN),); let mut cursor = Cursor::new(s); - let result = read_next_line(&mut cursor, "some context"); + let mut bytes_read = 0; + let result = read_next_line(&mut cursor, "some context", &mut bytes_read); let err = result.expect_err("did not error on too-large header"); assert_eq!(err.kind(), io::ErrorKind::Other); assert_eq!( @@ -1117,9 +1158,9 @@ mod tests { encoding_rs::WINDOWS_1252.encode("HTTP/1.1 302 Déplacé Temporairement\r\n"); let bytes = cow.to_vec(); let mut reader = io::BufReader::new(io::Cursor::new(bytes)); - let r = read_next_line(&mut reader, "test status line"); - let h = r.unwrap(); - assert_eq!(h.to_string(), "HTTP/1.1 302 D�plac� Temporairement"); + let mut bytes_read = 0; + let header = read_next_line(&mut reader, "test status line", &mut bytes_read).unwrap(); + assert_eq!(header.to_string(), "HTTP/1.1 302 D�plac� Temporairement"); } #[test] @@ -1150,6 +1191,7 @@ mod tests { &request_reader, None, ); + let s = stream::DeadlineStream::new(s, unit.deadline); let resp = Response::do_from_stream(s.into(), &unit).unwrap(); assert_eq!(resp.status(), 200); assert_eq!(resp.header("x-geo-header"), None); @@ -1204,18 +1246,16 @@ mod tests { "1.1.1.1:4343".parse().unwrap(), PoolReturner::new(&agent, PoolKey::from_parts("https", "example.com", 443)), ); - Response::do_from_stream( - stream, - &Unit::new( - &agent, - "GET", - &"https://example.com/".parse().unwrap(), - vec![], - &Payload::Empty.into_read(), - None, - ), - ) - .unwrap(); + let unit = &Unit::new( + &agent, + "GET", + &"https://example.com/".parse().unwrap(), + vec![], + &Payload::Empty.into_read(), + None, + ); + let stream = stream::DeadlineStream::new(stream, unit.deadline); + Response::do_from_stream(stream, unit).unwrap(); assert_eq!(agent2.state.pool.len(), 1); } @@ -1236,18 +1276,16 @@ mod tests { "1.1.1.1:4343".parse().unwrap(), PoolReturner::none(), ); - let resp = Response::do_from_stream( - stream, - &Unit::new( - &agent, - "GET", - &"https://example.com/".parse().unwrap(), - vec![], - &Payload::Empty.into_read(), - None, - ), - ) - .unwrap(); + let unit = &Unit::new( + &agent, + "GET", + &"https://example.com/".parse().unwrap(), + vec![], + &Payload::Empty.into_read(), + None, + ); + let stream = stream::DeadlineStream::new(stream, unit.deadline); + let resp = Response::do_from_stream(stream, unit).unwrap(); let body = resp.into_string().unwrap(); assert_eq!(body, "hi\n"); } diff --git a/src/stream.rs b/src/stream.rs index 5c28d35c..cbf21519 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -455,8 +455,9 @@ pub(crate) fn connect_host( let s = stream.try_clone()?; let pool_key = PoolKey::from_parts(unit.url.scheme(), hostname, port); let pool_returner = PoolReturner::new(&unit.agent, pool_key); - let s = Stream::new(s, remote_addr, pool_returner); - let response = Response::do_from_stream(s, unit)?; + let stream = Stream::new(s, remote_addr, pool_returner); + let stream = DeadlineStream::new(stream, unit.deadline); + let response = Response::do_from_stream(stream, unit)?; Proxy::verify_response(&response)?; } } diff --git a/src/unit.rs b/src/unit.rs index 9ba03686..bae16a5e 100644 --- a/src/unit.rs +++ b/src/unit.rs @@ -261,6 +261,13 @@ fn connect_inner( debug!("sending request {} {}", method, url); } + let mut expect_100_continue = false; + for header in &unit.headers { + if header.name() == "Expect" && header.value() == Some("100-continue") { + expect_100_continue = true; + } + } + let send_result = send_prelude(unit, &mut stream); if let Err(err) = send_result { @@ -277,8 +284,17 @@ fn connect_inner( } let retryable = unit.is_retryable(&body); + let mut stream = stream::DeadlineStream::new(stream, unit.deadline); + + if expect_100_continue { + let response = Response::read_response_head(&mut stream, unit)?; + if response.status() != 100 { + return Err(Error::Status(response.status(), response)); + } + } + // send the body (which can be empty now depending on redirects) - body::send_body(body, unit.is_chunked, &mut stream)?; + body::send_body(body, unit.is_chunked, stream.inner_mut())?; // start reading the response to process cookies and redirects. let result = Response::do_from_stream(stream, unit);