From 0955dfe335ebbaf6c779e88dd47d27ead9bec14c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johan=20Bj=C3=A4reholt?= Date: Tue, 7 Nov 2023 09:16:28 +0100 Subject: [PATCH 1/4] response: Don't take ownership of unit in do_from_stream There does not seem to be any good reason to take ownership of it. Makes us able to remove a clone and a TODO comment. --- src/response.rs | 12 ++++++------ src/stream.rs | 2 +- src/unit.rs | 5 +---- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/response.rs b/src/response.rs index 1cc48e97..af1ebba7 100644 --- a/src/response.rs +++ b/src/response.rs @@ -561,7 +561,7 @@ impl Response { /// let resp = ureq::Response::do_from_read(read); /// /// assert_eq!(resp.status(), 401); - pub(crate) fn do_from_stream(stream: Stream, unit: Unit) -> Result { + pub(crate) fn do_from_stream(stream: Stream, unit: &Unit) -> Result { let remote_addr = stream.remote_addr; let local_addr = match stream.socket() { @@ -609,7 +609,7 @@ impl Response { } let reader = - Self::stream_to_reader(stream, &unit, body_type, compression, connection_option); + Self::stream_to_reader(stream, unit, body_type, compression, connection_option); let url = unit.url.clone(); @@ -766,7 +766,7 @@ impl FromStr for Response { &request_reader, None, ); - Self::do_from_stream(stream, unit) + Self::do_from_stream(stream, &unit) } } @@ -1150,7 +1150,7 @@ mod tests { &request_reader, None, ); - let resp = Response::do_from_stream(s.into(), unit).unwrap(); + let resp = Response::do_from_stream(s.into(), &unit).unwrap(); assert_eq!(resp.status(), 200); assert_eq!(resp.header("x-geo-header"), None); } @@ -1206,7 +1206,7 @@ mod tests { ); Response::do_from_stream( stream, - Unit::new( + &Unit::new( &agent, "GET", &"https://example.com/".parse().unwrap(), @@ -1238,7 +1238,7 @@ mod tests { ); let resp = Response::do_from_stream( stream, - Unit::new( + &Unit::new( &agent, "GET", &"https://example.com/".parse().unwrap(), diff --git a/src/stream.rs b/src/stream.rs index d85467af..5c28d35c 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -456,7 +456,7 @@ pub(crate) fn connect_host( let pool_key = PoolKey::from_parts(unit.url.scheme(), hostname, port); let pool_returner = PoolReturner::new(&unit.agent, pool_key); let s = Stream::new(s, remote_addr, pool_returner); - let response = Response::do_from_stream(s, unit.clone())?; + let response = Response::do_from_stream(s, unit)?; Proxy::verify_response(&response)?; } } diff --git a/src/unit.rs b/src/unit.rs index 13c3c9a6..9ba03686 100644 --- a/src/unit.rs +++ b/src/unit.rs @@ -281,10 +281,7 @@ fn connect_inner( body::send_body(body, unit.is_chunked, &mut stream)?; // start reading the response to process cookies and redirects. - // TODO: this unit.clone() bothers me. At this stage, we're not - // going to use the unit (much) anymore, and it should be possible - // to have ownership of it and pass it into the Response. - let result = Response::do_from_stream(stream, unit.clone()); + let result = Response::do_from_stream(stream, unit); // https://tools.ietf.org/html/rfc7230#section-6.3.1 // When an inbound connection is closed prematurely, a client MAY From 353d727ad7e6fd10e45aa5a155e2324764e62db7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johan=20Bj=C3=A4reholt?= Date: Tue, 7 Nov 2023 10:09:05 +0100 Subject: [PATCH 2/4] Add support for "Expect: 100-continue" header --- src/http_crate.rs | 10 ++- src/http_interop.rs | 11 ++- src/response.rs | 180 +++++++++++++++++++++++++---------------- src/stream.rs | 5 +- src/test/agent_test.rs | 43 +++++++++- src/testserver.rs | 96 +++++++++++++++++----- src/unit.rs | 19 ++++- 7 files changed, 264 insertions(+), 100 deletions(-) diff --git a/src/http_crate.rs b/src/http_crate.rs index a7e92c72..0a954a68 100644 --- a/src/http_crate.rs +++ b/src/http_crate.rs @@ -3,7 +3,9 @@ use std::{ net::{IpAddr, Ipv4Addr, SocketAddr}, }; -use crate::{header::HeaderLine, response::ResponseStatusIndex, Request, Response}; +use crate::{ + header::HeaderLine, response::PendingReader, response::ResponseStatusIndex, Request, Response, +}; /// Converts an [`http::Response`] into a [`Response`]. /// @@ -44,7 +46,7 @@ impl + Send + Sync + 'static> From> for Respons HeaderLine::from(raw_header).into_header().unwrap() }) .collect::>(), - reader: Box::new(Cursor::new(value.into_body())), + reader: PendingReader::Reader(Box::new(Cursor::new(value.into_body()))), remote_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 80), local_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0), history: vec![], @@ -328,12 +330,14 @@ mod tests { #[test] fn convert_to_http_response_bytes() { + use crate::response::PendingReader; use http::Response; use std::io::Cursor; let mut response = super::Response::new(200, "OK", "tbr").unwrap(); // b'\xFF' as invalid UTF-8 character - response.reader = Box::new(Cursor::new(vec![b'\xFF', 0xde, 0xad, 0xbe, 0xef])); + response.reader = + PendingReader::Reader(Box::new(Cursor::new(vec![b'\xFF', 0xde, 0xad, 0xbe, 0xef]))); let http_response: Response> = response.into(); assert_eq!( diff --git a/src/http_interop.rs b/src/http_interop.rs index c51041c5..0a34bda6 100644 --- a/src/http_interop.rs +++ b/src/http_interop.rs @@ -5,7 +5,9 @@ use std::{ net::{IpAddr, Ipv4Addr, SocketAddr}, }; -use crate::{header::HeaderLine, response::ResponseStatusIndex, Request, Response}; +use crate::{ + header::HeaderLine, response::PendingReader, response::ResponseStatusIndex, Request, Response, +}; /// Converts an [`http::Response`] into a [`Response`]. /// @@ -47,7 +49,7 @@ impl + Send + Sync + 'static> From> for Respons HeaderLine::from(raw_header).into_header().unwrap() }) .collect::>(), - reader: Box::new(Cursor::new(value.into_body())), + reader: PendingReader::Reader(Box::new(Cursor::new(value.into_body()))), remote_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 80), local_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 0), history: vec![], @@ -280,6 +282,7 @@ impl From for http::request::Builder { mod tests { use crate::header::{add_header, get_header_raw, HeaderLine}; use http_02 as http; + use std::io::Read; #[test] fn convert_http_response() { @@ -376,7 +379,9 @@ mod tests { let mut response = super::Response::new(200, "OK", "tbr").unwrap(); // b'\xFF' as invalid UTF-8 character - response.reader = Box::new(Cursor::new(vec![b'\xFF', 0xde, 0xad, 0xbe, 0xef])); + response.reader = super::PendingReader::Reader(Box::new(Cursor::new(vec![ + b'\xFF', 0xde, 0xad, 0xbe, 0xef, + ]))); let http_response: Response> = response.into(); assert_eq!( diff --git a/src/response.rs b/src/response.rs index af1ebba7..9b1add11 100644 --- a/src/response.rs +++ b/src/response.rs @@ -49,6 +49,11 @@ enum BodyType { CloseDelimited, } +pub(crate) enum PendingReader { + BeforeBodyStart, + Reader(Box), +} + /// Response instances are created as results of firing off requests. /// /// The `Response` is used to read response headers and decide what to do with the body. @@ -81,7 +86,7 @@ pub struct Response { pub(crate) index: ResponseStatusIndex, pub(crate) status: u16, pub(crate) headers: Vec
, - pub(crate) reader: Box, + pub(crate) reader: PendingReader, /// The socket address of the server that sent the response. pub(crate) remote_addr: SocketAddr, /// The socket address of the client that sent the request. @@ -282,7 +287,12 @@ impl Response { /// # } /// ``` pub fn into_reader(self) -> Box { - self.reader + match self.reader { + PendingReader::Reader(reader) => reader, + PendingReader::BeforeBodyStart => panic!( + "It is not valid to call into_reader before Request::stream_to_reader is called" + ), + } } // Determine what to do with the connection after we've read the body. @@ -548,39 +558,33 @@ impl Response { }) } - /// Create a response from a Read trait impl. - /// - /// This is hopefully useful for unit tests. - /// - /// Example: - /// - /// use std::io::Cursor; + /// Create a response from a DeadlineStream, reading and parsing only the status line, headers + /// and its following CRLF. /// - /// let text = "HTTP/1.1 401 Authorization Required\r\n\r\nPlease log in\n"; - /// let read = Cursor::new(text.to_string().into_bytes()); - /// let resp = ureq::Response::do_from_read(read); - /// - /// assert_eq!(resp.status(), 401); - pub(crate) fn do_from_stream(stream: Stream, unit: &Unit) -> Result { - let remote_addr = stream.remote_addr; + /// Since this function only reads the status line, header and the following CRLF, the returned + /// Response will have an empty reader and does not take ownership of DeadlineStream. + /// To read the following data, the DeadlineStream can be read again after the call to this + /// function. + pub(crate) fn read_response_head( + stream: &mut DeadlineStream, + unit: &Unit, + ) -> Result { + let mut bytes_read = 0; + let remote_addr = stream.inner_ref().remote_addr; - let local_addr = match stream.socket() { + let local_addr = match stream.inner_ref().socket() { Some(sock) => sock.local_addr().map_err(Error::from)?, None => std::net::SocketAddrV4::new(std::net::Ipv4Addr::new(127, 0, 0, 1), 0).into(), }; - // - // HTTP/1.1 200 OK\r\n - let mut stream = stream::DeadlineStream::new(stream, unit.deadline); - // The status line we can ignore non-utf8 chars and parse as_str_lossy(). - let status_line = read_next_line(&mut stream, "the status line")?.into_string_lossy(); + let status_line = + read_next_line(stream, "the status line", &mut bytes_read)?.into_string_lossy(); let (index, status) = parse_status_line(status_line.as_str())?; - let http_version = &status_line.as_str()[0..index.http_version]; let mut headers: Vec
= Vec::new(); while headers.len() <= MAX_HEADER_COUNT { - let line = read_next_line(&mut stream, "a header")?; + let line = read_next_line(stream, "a header", &mut bytes_read)?; if line.is_empty() { break; } @@ -595,23 +599,8 @@ impl Response { )); } - let compression = - get_header(&headers, "content-encoding").and_then(Compression::from_header_value); - - let connection_option = - Self::connection_option(http_version, get_header(&headers, "connection")); - - let body_type = Self::body_type(&unit.method, status, http_version, &headers); - - // remove Content-Encoding and length due to automatic decompression - if compression.is_some() { - headers.retain(|h| !h.is_name("content-encoding") && !h.is_name("content-length")); - } - - let reader = - Self::stream_to_reader(stream, unit, body_type, compression, connection_option); - let url = unit.url.clone(); + let reader = PendingReader::BeforeBodyStart; let response = Response { url, @@ -627,6 +616,54 @@ impl Response { Ok(response) } + /// Attach a stream to Response, for reading the body. + /// + /// The response reader also uncompresses the body if it is compressed. + pub(crate) fn take_body(&mut self, stream: DeadlineStream, unit: &Unit) -> Result<(), Error> { + let compression = + get_header(&self.headers, "content-encoding").and_then(Compression::from_header_value); + + let connection_option = + Self::connection_option(self.http_version(), get_header(&self.headers, "connection")); + + let body_type = Self::body_type( + &unit.method, + self.status(), + self.http_version(), + &self.headers, + ); + + // remove Content-Encoding and length due to automatic decompression + if compression.is_some() { + self.headers + .retain(|h| !h.is_name("content-encoding") && !h.is_name("content-length")); + } + + self.reader = PendingReader::Reader(Self::stream_to_reader( + stream, + unit, + body_type, + compression, + connection_option, + )); + + Ok(()) + } + + /// Create a Response from a DeadlineStream + /// + /// Parses and comsumes the header from the stream and creates a Response with + /// the stream as the reader. The response reader also uncompresses the body + /// if it is compressed. + pub(crate) fn do_from_stream( + mut stream: DeadlineStream, + unit: &Unit, + ) -> Result { + let mut response = Self::read_response_head(&mut stream, unit)?; + response.take_body(stream, unit)?; + Ok(response) + } + #[cfg(test)] pub fn set_url(&mut self, url: Url) { self.url = url; @@ -766,16 +803,23 @@ impl FromStr for Response { &request_reader, None, ); + let stream = stream::DeadlineStream::new(stream, unit.deadline); Self::do_from_stream(stream, &unit) } } -fn read_next_line(reader: &mut impl BufRead, context: &str) -> io::Result { +fn read_next_line( + reader: &mut impl BufRead, + context: &str, + running_total: &mut usize, +) -> io::Result { let mut buf = Vec::new(); let result = reader .take((MAX_HEADER_SIZE + 1) as u64) .read_until(b'\n', &mut buf); + *running_total += buf.len(); + match result { Ok(0) => Err(io::Error::new( io::ErrorKind::ConnectionAborted, @@ -1078,7 +1122,8 @@ mod tests { const LEN: usize = MAX_HEADER_SIZE + 1; let s = format!("Long-Header: {}\r\n", "A".repeat(LEN),); let mut cursor = Cursor::new(s); - let result = read_next_line(&mut cursor, "some context"); + let mut bytes_read = 0; + let result = read_next_line(&mut cursor, "some context", &mut bytes_read); let err = result.expect_err("did not error on too-large header"); assert_eq!(err.kind(), io::ErrorKind::Other); assert_eq!( @@ -1117,9 +1162,9 @@ mod tests { encoding_rs::WINDOWS_1252.encode("HTTP/1.1 302 Déplacé Temporairement\r\n"); let bytes = cow.to_vec(); let mut reader = io::BufReader::new(io::Cursor::new(bytes)); - let r = read_next_line(&mut reader, "test status line"); - let h = r.unwrap(); - assert_eq!(h.to_string(), "HTTP/1.1 302 D�plac� Temporairement"); + let mut bytes_read = 0; + let header = read_next_line(&mut reader, "test status line", &mut bytes_read).unwrap(); + assert_eq!(header.to_string(), "HTTP/1.1 302 D�plac� Temporairement"); } #[test] @@ -1150,6 +1195,7 @@ mod tests { &request_reader, None, ); + let s = stream::DeadlineStream::new(s, unit.deadline); let resp = Response::do_from_stream(s.into(), &unit).unwrap(); assert_eq!(resp.status(), 200); assert_eq!(resp.header("x-geo-header"), None); @@ -1204,18 +1250,16 @@ mod tests { "1.1.1.1:4343".parse().unwrap(), PoolReturner::new(&agent, PoolKey::from_parts("https", "example.com", 443)), ); - Response::do_from_stream( - stream, - &Unit::new( - &agent, - "GET", - &"https://example.com/".parse().unwrap(), - vec![], - &Payload::Empty.into_read(), - None, - ), - ) - .unwrap(); + let unit = &Unit::new( + &agent, + "GET", + &"https://example.com/".parse().unwrap(), + vec![], + &Payload::Empty.into_read(), + None, + ); + let stream = stream::DeadlineStream::new(stream, unit.deadline); + Response::do_from_stream(stream, unit).unwrap(); assert_eq!(agent2.state.pool.len(), 1); } @@ -1236,18 +1280,16 @@ mod tests { "1.1.1.1:4343".parse().unwrap(), PoolReturner::none(), ); - let resp = Response::do_from_stream( - stream, - &Unit::new( - &agent, - "GET", - &"https://example.com/".parse().unwrap(), - vec![], - &Payload::Empty.into_read(), - None, - ), - ) - .unwrap(); + let unit = &Unit::new( + &agent, + "GET", + &"https://example.com/".parse().unwrap(), + vec![], + &Payload::Empty.into_read(), + None, + ); + let stream = stream::DeadlineStream::new(stream, unit.deadline); + let resp = Response::do_from_stream(stream, unit).unwrap(); let body = resp.into_string().unwrap(); assert_eq!(body, "hi\n"); } diff --git a/src/stream.rs b/src/stream.rs index 5c28d35c..cbf21519 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -455,8 +455,9 @@ pub(crate) fn connect_host( let s = stream.try_clone()?; let pool_key = PoolKey::from_parts(unit.url.scheme(), hostname, port); let pool_returner = PoolReturner::new(&unit.agent, pool_key); - let s = Stream::new(s, remote_addr, pool_returner); - let response = Response::do_from_stream(s, unit)?; + let stream = Stream::new(s, remote_addr, pool_returner); + let stream = DeadlineStream::new(stream, unit.deadline); + let response = Response::do_from_stream(stream, unit)?; Proxy::verify_response(&response)?; } } diff --git a/src/test/agent_test.rs b/src/test/agent_test.rs index 1fdaedfd..4a30a6aa 100644 --- a/src/test/agent_test.rs +++ b/src/test/agent_test.rs @@ -1,7 +1,7 @@ #![allow(dead_code)] use crate::error::Error; -use crate::testserver::{read_request, TestServer}; +use crate::testserver::{read_request, read_request_body, read_request_headers, TestServer}; use std::io::{self, Read, Write}; use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream}; use std::thread; @@ -30,6 +30,30 @@ fn idle_timeout_handler_408(mut stream: TcpStream) -> io::Result<()> { Ok(()) } +// Handler that answers with 100-continue before reading body +fn expect_100_continue_handler(stream: TcpStream) -> io::Result<()> { + use std::io::BufReader; + + let mut bufreader = BufReader::new(&stream); + let request_headers = read_request_headers(&mut bufreader); + { + let stream_write = bufreader.get_mut(); + stream_write.write_all(b"HTTP/1.1 100 Continue\r\n\r\n")?; + stream_write.flush().unwrap(); + } + let request_body = read_request_body(&mut bufreader, &request_headers); + { + let stream_write = bufreader.get_mut(); + stream_write.write_all(b"HTTP/1.1 200 OK\r\n")?; + stream_write.write_all(b"Content-Length: ")?; + stream_write.write_all(request_body.len().to_string().as_bytes())?; + stream_write.write_all(b"\r\n\r\n")?; + stream_write.write_all(&request_body)?; + stream_write.flush().unwrap(); + } + Ok(()) +} + #[test] fn connection_reuse() { let testserver = TestServer::new(idle_timeout_handler); @@ -84,6 +108,23 @@ fn connection_reuse_with_408() { assert_eq!(resp.status(), 200); } +#[test] +fn expect_100_continue() { + let testserver = TestServer::new(expect_100_continue_handler); + let url = format!("http://localhost:{}", testserver.port); + let agent = Agent::new(); + let request_body = "this is a test string for the test expect_100_continue"; + let resp = agent + .post(&url) + .set("Expect", "100-continue") + .send_bytes(request_body.as_bytes()) + .unwrap(); + assert_eq!(resp.status(), 200); + let mut response_body = Vec::new(); + resp.into_reader().read_to_end(&mut response_body).unwrap(); + assert_eq!(request_body.as_bytes(), response_body); +} + #[test] fn custom_resolver() { use std::io::Read; diff --git a/src/testserver.rs b/src/testserver.rs index 293caf68..aae7bfd8 100644 --- a/src/testserver.rs +++ b/src/testserver.rs @@ -81,39 +81,93 @@ impl TestHeaders { pub fn headers(&self) -> &[String] { &self.0[1..] } + + pub(crate) fn to_headers(&self) -> Vec { + if self.0.len() <= 1 { + return Vec::new(); + } + + let mut headers = Vec::new(); + for line in &self.0[1..] { + let headerline = crate::header::HeaderLine::from(line.clone()); + let header = headerline.into_header().unwrap(); + headers.push(header); + } + headers + } } // Read a stream until reaching a blank line, in order to consume // request headers. #[cfg(test)] -pub fn read_request(stream: &TcpStream) -> TestHeaders { - use std::io::{BufRead, BufReader}; - +use std::io::{BufRead, BufReader}; +#[cfg(test)] +pub fn read_request_headers(bufreader: &mut BufReader<&TcpStream>) -> TestHeaders { let mut results = vec![]; - for line in BufReader::new(stream).lines() { - match line { - Err(e) => { - eprintln!("testserver: in read_request: {}", e); - break; - } - Ok(line) if line.is_empty() => break, - Ok(line) => results.push(line), - }; - } - // Consume rest of body. TODO maybe capture the body for inspection in the test? - // There's a risk stream is ended here, and fill_buf() would block. - stream.set_nonblocking(true).ok(); - let mut reader = BufReader::new(stream); - while let Ok(buf) = reader.fill_buf() { - let amount = buf.len(); - if amount == 0 { + loop { + let mut line = String::new(); + bufreader.read_line(&mut line).unwrap(); + // Remove \r\n + line.truncate(line.len().saturating_sub(2)); + + if line.is_empty() { break; } - reader.consume(amount); + results.push(line); } + let mut body = Vec::new(); + body.append(&mut bufreader.buffer().to_vec()); + TestHeaders(results) } +// Read whole body from a stream after reading the statusline and headers +#[cfg(test)] +pub fn read_request_body( + bufreader: &mut BufReader<&TcpStream>, + request_headers: &TestHeaders, +) -> Vec { + let headers = request_headers.to_headers(); + + // NOTE: Currently only requests with "Content-Length" is supported. + let mut content_length = 0; + for header in headers { + if header.name() == "Content-Length" { + content_length = header.value().unwrap().parse().unwrap(); + } + } + + let mut bytes_read = 0; + let mut body = Vec::new(); + + // There's possibly already some data in the BufReader, read and consume + // those first + body.append(&mut bufreader.buffer().to_vec()); + bytes_read += body.len(); + bufreader.consume(body.len()); + + while bytes_read < content_length { + let buf = bufreader.fill_buf().unwrap(); + body.append(&mut buf.to_vec()); + let amount = buf.len(); + bytes_read += amount; + bufreader.consume(amount); + } + + body +} + +// Read a stream as a request and return the headers +#[cfg(test)] +pub fn read_request(stream: &TcpStream) -> TestHeaders { + let mut bufreader = BufReader::new(stream); + + let request_headers = read_request_headers(&mut bufreader); + let _body = read_request_body(&mut bufreader, &request_headers); + + request_headers +} + impl TestServer { pub fn new(handler: fn(TcpStream) -> io::Result<()>) -> Self { let listener = TcpListener::bind("127.0.0.1:0").unwrap(); diff --git a/src/unit.rs b/src/unit.rs index 9ba03686..6c6927c7 100644 --- a/src/unit.rs +++ b/src/unit.rs @@ -261,6 +261,13 @@ fn connect_inner( debug!("sending request {} {}", method, url); } + let mut expect_100_continue = false; + for header in &unit.headers { + if header.name() == "Expect" && header.value() == Some("100-continue") { + expect_100_continue = true; + } + } + let send_result = send_prelude(unit, &mut stream); if let Err(err) = send_result { @@ -277,8 +284,18 @@ fn connect_inner( } let retryable = unit.is_retryable(&body); + let mut stream = stream::DeadlineStream::new(stream, unit.deadline); + + if expect_100_continue { + let mut response = Response::read_response_head(&mut stream, unit)?; + if response.status() != 100 { + response.take_body(stream, unit)?; + return Err(Error::Status(response.status(), response)); + } + } + // send the body (which can be empty now depending on redirects) - body::send_body(body, unit.is_chunked, &mut stream)?; + body::send_body(body, unit.is_chunked, stream.inner_mut())?; // start reading the response to process cookies and redirects. let result = Response::do_from_stream(stream, unit); From d808fa0f58f444ab1b5e291808cd94e94d2ecac5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johan=20Bj=C3=A4reholt?= Date: Wed, 10 Apr 2024 10:44:27 +0200 Subject: [PATCH 3/4] Retry on 417 for "Expect: 100-continue" --- src/test/agent_test.rs | 53 +++++++++++++++++++++++++++++++++++++++++- src/unit.rs | 34 +++++++++++++++++++++++---- 2 files changed, 82 insertions(+), 5 deletions(-) diff --git a/src/test/agent_test.rs b/src/test/agent_test.rs index 4a30a6aa..c417ff7f 100644 --- a/src/test/agent_test.rs +++ b/src/test/agent_test.rs @@ -30,7 +30,8 @@ fn idle_timeout_handler_408(mut stream: TcpStream) -> io::Result<()> { Ok(()) } -// Handler that answers with 100-continue before reading body +// Handler that answers with 100-continue before reading body, then sends +// request body as response body. fn expect_100_continue_handler(stream: TcpStream) -> io::Result<()> { use std::io::BufReader; @@ -54,6 +55,39 @@ fn expect_100_continue_handler(stream: TcpStream) -> io::Result<()> { Ok(()) } +// Handler that answers with 417 if "Expect: 100-continue" is set, otherwise +// returns 200 OK with request body as response body. +fn expect_100_continue_respond_417_handler(stream: TcpStream) -> io::Result<()> { + use std::io::BufReader; + + let mut bufreader = BufReader::new(&stream); + let request_headers = read_request_headers(&mut bufreader); + + let mut has_expect_100_continue = false; + let headers = request_headers.to_headers(); + for header in headers { + if header.name() == "Expect" { + has_expect_100_continue = true; + } + } + + if has_expect_100_continue { + let stream_write = bufreader.get_mut(); + stream_write.write_all(b"HTTP/1.1 417 Expectation Failed\r\n\r\n")?; + stream_write.flush().unwrap(); + } else { + let request_body = read_request_body(&mut bufreader, &request_headers); + let stream_write = bufreader.get_mut(); + stream_write.write_all(b"HTTP/1.1 200 OK\r\n")?; + stream_write.write_all(b"Content-Length: ")?; + stream_write.write_all(request_body.len().to_string().as_bytes())?; + stream_write.write_all(b"\r\n\r\n")?; + stream_write.write_all(&request_body)?; + stream_write.flush().unwrap(); + } + Ok(()) +} + #[test] fn connection_reuse() { let testserver = TestServer::new(idle_timeout_handler); @@ -125,6 +159,23 @@ fn expect_100_continue() { assert_eq!(request_body.as_bytes(), response_body); } +#[test] +fn retry_on_417() { + let testserver = TestServer::new(expect_100_continue_respond_417_handler); + let url = format!("http://localhost:{}", testserver.port); + let agent = Agent::new(); + let request_body = "this is a test string for the test retry_on_417"; + let resp = agent + .post(&url) + .set("Expect", "100-continue") + .send_bytes(request_body.as_bytes()) + .unwrap(); + assert_eq!(resp.status(), 200); + let mut response_body = Vec::new(); + resp.into_reader().read_to_end(&mut response_body).unwrap(); + assert_eq!(request_body.as_bytes(), response_body); +} + #[test] fn custom_resolver() { use std::io::Read; diff --git a/src/unit.rs b/src/unit.rs index 6c6927c7..c711d1a1 100644 --- a/src/unit.rs +++ b/src/unit.rs @@ -287,10 +287,36 @@ fn connect_inner( let mut stream = stream::DeadlineStream::new(stream, unit.deadline); if expect_100_continue { - let mut response = Response::read_response_head(&mut stream, unit)?; - if response.status() != 100 { - response.take_body(stream, unit)?; - return Err(Error::Status(response.status(), response)); + match Response::read_response_head(&mut stream, unit) { + Ok(mut response) => { + match response.status() { + 100 => debug!("Got 100-continue, proceeding with body"), + 200 => { + // TODO: How should we handle this case? + debug!("Got 200 OK on an expect 100-continue response, never got the chance to send the request body"); + response.take_body(stream, unit)?; + return Ok(response); + } + 417 => { + debug!("Got 417, trying again but without expect"); + response.take_body(stream, unit)?; + let mut unit_without_expect = unit.clone(); + for (idx, header) in unit.headers.iter().enumerate() { + if header.name() == "Expect" { + unit_without_expect.headers.remove(idx); + break; + } + } + return connect_inner(&unit_without_expect, use_pooled, body, history); + } + _ => { + debug!("Didn't get 100-continue, reading body"); + response.take_body(stream, unit)?; + return Err(Error::Status(response.status(), response)); + } + } + } + Err(err) => return Err(err), } } From d54a26480609648e0b88d82498da2e6d544e28af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johan=20Bj=C3=A4reholt?= Date: Thu, 11 Apr 2024 14:01:58 +0200 Subject: [PATCH 4/4] Fallback sending body when server don't understand "Expect:100-continue" If a server does not understand the "Expect: 100-continue" header, it will wait for the body indefinitely. To solve this issue, we add a shorter timeout on reading the response status+headers and if that timeout is hit we send the body anyway. --- src/error.rs | 10 ++++++++++ src/stream.rs | 4 ++++ src/test/agent_test.rs | 38 ++++++++++++++++++++++++++++++++++++++ src/unit.rs | 37 ++++++++++++++++++++++++++++++++++++- 4 files changed, 88 insertions(+), 1 deletion(-) diff --git a/src/error.rs b/src/error.rs index 805395ec..4b913723 100644 --- a/src/error.rs +++ b/src/error.rs @@ -501,3 +501,13 @@ mod tests { assert!(size < 500); // 344 on Macbook M1 } } + +pub(crate) fn error_get_root_source<'a>( + err: &'a (dyn std::error::Error + 'static), +) -> &'a (dyn std::error::Error + 'static) { + if let Some(err) = err.source() { + error_get_root_source(err) as &(dyn std::error::Error + 'static) + } else { + err as &(dyn std::error::Error + 'static) + } +} diff --git a/src/stream.rs b/src/stream.rs index cbf21519..06832814 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -72,6 +72,10 @@ impl DeadlineStream { pub(crate) fn inner_mut(&mut self) -> &mut Stream { &mut self.stream } + + pub(crate) fn into_inner(self) -> Stream { + self.stream + } } impl From for Stream { diff --git a/src/test/agent_test.rs b/src/test/agent_test.rs index c417ff7f..a9f33b99 100644 --- a/src/test/agent_test.rs +++ b/src/test/agent_test.rs @@ -88,6 +88,27 @@ fn expect_100_continue_respond_417_handler(stream: TcpStream) -> io::Result<()> Ok(()) } +// Handler that does not support "Expect: 100-continue", which ureq should +// handle gracefully by timing out on reading the response headers and send +// the request body after a short timeout. +fn expect_100_continue_not_supported_handler(stream: TcpStream) -> io::Result<()> { + use std::io::BufReader; + + let mut bufreader = BufReader::new(&stream); + let request_headers = read_request_headers(&mut bufreader); + let request_body = read_request_body(&mut bufreader, &request_headers); + { + let stream_write = bufreader.get_mut(); + stream_write.write_all(b"HTTP/1.1 200 OK\r\n")?; + stream_write.write_all(b"Content-Length: ")?; + stream_write.write_all(request_body.len().to_string().as_bytes())?; + stream_write.write_all(b"\r\n\r\n")?; + stream_write.write_all(&request_body)?; + stream_write.flush().unwrap(); + } + Ok(()) +} + #[test] fn connection_reuse() { let testserver = TestServer::new(idle_timeout_handler); @@ -176,6 +197,23 @@ fn retry_on_417() { assert_eq!(request_body.as_bytes(), response_body); } +#[test] +fn expect_100_continue_not_supported() { + let testserver = TestServer::new(expect_100_continue_not_supported_handler); + let url = format!("http://localhost:{}", testserver.port); + let agent = Agent::new(); + let request_body = "this is a test string for the test retry_on_417"; + let resp = agent + .post(&url) + .set("Expect", "100-continue") + .send_bytes(request_body.as_bytes()) + .unwrap(); + assert_eq!(resp.status(), 200); + let mut response_body = Vec::new(); + resp.into_reader().read_to_end(&mut response_body).unwrap(); + assert_eq!(request_body.as_bytes(), response_body); +} + #[test] fn custom_resolver() { use std::io::Read; diff --git a/src/unit.rs b/src/unit.rs index c711d1a1..beb6eb1c 100644 --- a/src/unit.rs +++ b/src/unit.rs @@ -12,7 +12,7 @@ use cookie::Cookie; use crate::agent::RedirectAuthHeaders; use crate::body::{self, BodySize, Payload, SizedReader}; -use crate::error::{Error, ErrorKind}; +use crate::error::{error_get_root_source, Error, ErrorKind}; use crate::header; use crate::header::{get_header, Header}; use crate::proxy::Proto; @@ -287,6 +287,22 @@ fn connect_inner( let mut stream = stream::DeadlineStream::new(stream, unit.deadline); if expect_100_continue { + // Set lower timeout on reading status+headers. + // We do this in case we make a request to a server without "Expect: 100-continue", then we + // should continue sending the body instead of waiting for a "100 Continue" response. + let now = time::Instant::now(); + let timeout = std::time::Duration::from_secs(1); + let deadline = match now.checked_add(timeout) { + Some(dl) => Some(dl), + None => { + return Err(Error::new( + ErrorKind::Io, + Some("Request deadline overflowed".to_string()), + )) + } + }; + stream = stream::DeadlineStream::new(stream.into_inner(), deadline); + match Response::read_response_head(&mut stream, unit) { Ok(mut response) => { match response.status() { @@ -316,8 +332,27 @@ fn connect_inner( } } } + Err(crate::Error::Transport(err)) => { + // We need to fetch the root error here, because it is not guaranteed that the + // error itself has kind TimedOut even if the root cause was TimedOut (to preserve + // the stack-trace). + // See rejected GitHub PR #744. + let root_err = error_get_root_source(&err); + if let Some(io_err) = root_err.downcast_ref::() { + if io_err.kind() == std::io::ErrorKind::TimedOut { + debug!("Got timeout on reading response status+header for 'Expect: 100-continue' request, sending body even if we didn't get a '100 Continue' response"); + } else { + return Err(crate::Error::Transport(err)); + } + } else { + return Err(crate::Error::Transport(err)); + } + } Err(err) => return Err(err), } + + // reset DeadlineStream to be for complete request/response + stream = stream::DeadlineStream::new(stream.into_inner(), unit.deadline); } // send the body (which can be empty now depending on redirects)