Skip to content

Commit

Permalink
RequestBuilder::query to add query parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
algesten committed Oct 8, 2024
1 parent 9334df1 commit 908b45c
Show file tree
Hide file tree
Showing 5 changed files with 274 additions and 4 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ http = "1.1.0"
log = "0.4.22"
once_cell = "1.19.0"
utf-8 = "0.7.6"
percent-encoding = "2.3.1"

# These are used regardless of TLS implementation.
rustls-pemfile = { version = "2.1.2", optional = true, default-features = false, features = ["std"] }
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ mod config;
mod error;
mod pool;
mod proxy;
mod query;
mod request;
mod run;
mod send_body;
Expand Down
147 changes: 147 additions & 0 deletions src/query.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
use std::borrow::Cow;
use std::fmt;
use std::iter::Enumerate;
use std::ops::Deref;
use std::str::Chars;

use percent_encoding::utf8_percent_encode;

#[derive(Clone)]
pub(crate) struct QueryParam<'a> {
source: Source<'a>,
}

#[derive(Clone)]
enum Source<'a> {
Borrowed(&'a str),
Owned(String),
}

fn enc(i: &str) -> Cow<str> {
utf8_percent_encode(i, percent_encoding::NON_ALPHANUMERIC).into()
}

impl<'a> QueryParam<'a> {
pub fn new_key_value(param: &str, value: &str) -> QueryParam<'static> {
let s = format!("{}={}", enc(param), enc(value));
QueryParam {
source: Source::Owned(s),
}
}

fn as_str(&self) -> &str {
match &self.source {
Source::Borrowed(v) => v,
Source::Owned(v) => v.as_str(),
}
}
}

pub(crate) fn parse_query_params(query_string: &str) -> impl Iterator<Item = QueryParam<'_>> {
assert!(query_string.is_ascii());
QueryParamIterator(query_string, query_string.chars().enumerate())
}

struct QueryParamIterator<'a>(&'a str, Enumerate<Chars<'a>>);

impl<'a> Iterator for QueryParamIterator<'a> {
type Item = QueryParam<'a>;

fn next(&mut self) -> Option<Self::Item> {
let mut first = None;
let mut value = None;
let mut separator = None;

while let Some((n, c)) = self.1.next() {

Check failure on line 55 in src/query.rs

View workflow job for this annotation

GitHub Actions / Lint

this loop could be written as a `for` loop

Check failure on line 55 in src/query.rs

View workflow job for this annotation

GitHub Actions / Lint

this loop could be written as a `for` loop
if first.is_none() {
first = Some(n);
}
if value.is_none() && c == '=' {
value = Some(n + 1);
}
if c == '&' {
separator = Some(n);
break;
}
}

if let Some(start) = first {
let end = separator.unwrap_or(self.0.len());
let chunk = &self.0[start..end];
return Some(QueryParam {
source: Source::Borrowed(chunk),
});
}

None
}
}

impl<'a> fmt::Debug for QueryParam<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("QueryParam").field(&self.as_str()).finish()
}
}

impl<'a> fmt::Display for QueryParam<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.source {
Source::Borrowed(v) => write!(f, "{}", v),
Source::Owned(v) => write!(f, "{}", v),
}
}
}

impl<'a> Deref for QueryParam<'a> {
type Target = str;

fn deref(&self) -> &Self::Target {
self.as_str()
}
}

impl<'a> PartialEq for QueryParam<'a> {
fn eq(&self, other: &Self) -> bool {
self.as_str() == other.as_str()
}
}

#[cfg(test)]
mod test {
use super::*;

use http::Uri;

#[test]
fn query_string_does_not_start_with_question_mark() {
let u: Uri = "https://foo.com/qwe?abc=qwe".parse().unwrap();
assert_eq!(u.query(), Some("abc=qwe"));
}

#[test]
fn percent_encoding_is_not_decoded() {
let u: Uri = "https://foo.com/qwe?abc=%20123".parse().unwrap();
assert_eq!(u.query(), Some("abc=%20123"));
}

#[test]
fn fragments_are_not_a_thing() {
let u: Uri = "https://foo.com/qwe?abc=qwe#yaz".parse().unwrap();
assert_eq!(u.to_string(), "https://foo.com/qwe?abc=qwe");
}

fn p(s: &str) -> Vec<String> {
parse_query_params(s).map(|q| q.to_string()).collect()
}

#[test]
fn parse_query_string() {
assert_eq!(parse_query_params("").next(), None);
assert_eq!(p("&"), vec![""]);
assert_eq!(p("="), vec!["="]);
assert_eq!(p("&="), vec!["", "="]);
assert_eq!(p("foo=bar"), vec!["foo=bar"]);
assert_eq!(p("foo=bar&"), vec!["foo=bar"]);
assert_eq!(p("foo=bar&foo2=bar2"), vec!["foo=bar", "foo2=bar2"]);
}
}
128 changes: 124 additions & 4 deletions src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ use http::{HeaderName, HeaderValue, Method, Request, Response, Uri, Version};

use crate::body::Body;
use crate::config::RequestLevelConfig;
use crate::query::{parse_query_params, QueryParam};
use crate::send_body::AsSendBody;
use crate::util::private::Private;
use crate::util::UriExt;
use crate::{Agent, Config, Error, SendBody, Timeouts};

/// Transparent wrapper around [`http::request::Builder`].
Expand All @@ -18,6 +20,7 @@ use crate::{Agent, Config, Error, SendBody, Timeouts};
pub struct RequestBuilder<B> {
agent: Agent,
builder: http::request::Builder,
query_extra: Vec<QueryParam<'static>>,

// This is only used in case http::request::Builder contains an error
// (such as URL parsing error), and the user wants a `.config()`.
Expand Down Expand Up @@ -57,6 +60,17 @@ impl<Any> RequestBuilder<Any> {
self
}

/// Add a query paramter to the URL.
///
/// Always appends a new parameter, also when using the name of
/// an already existing one.
///
/// Using this feature causes an allocation (of a `Vec` holding the parameters).
pub fn query(mut self, key: &str, value: &str) -> Self {
self.query_extra.push(QueryParam::new_key_value(key, value));
self
}

/// Overrides the URI for this request.
///
/// Typically this is set via `ureq::get(<uri>)` or `Agent::get(<uri>)`. This
Expand Down Expand Up @@ -194,6 +208,7 @@ impl RequestBuilder<WithoutBody> {
Self {
agent,
builder: Request::builder().method(method).uri(uri),
query_extra: vec![],
dummy_config: None,
_ph: PhantomData,
}
Expand All @@ -210,7 +225,7 @@ impl RequestBuilder<WithoutBody> {
/// ```
pub fn call(self) -> Result<Response<Body>, Error> {
let request = self.builder.body(())?;
do_call(self.agent, request, SendBody::none())
do_call(self.agent, request, self.query_extra, SendBody::none())
}
}

Expand All @@ -223,6 +238,7 @@ impl RequestBuilder<WithBody> {
Self {
agent,
builder: Request::builder().method(method).uri(uri),
query_extra: vec![],
dummy_config: None,
_ph: PhantomData,
}
Expand Down Expand Up @@ -255,7 +271,7 @@ impl RequestBuilder<WithBody> {
pub fn send(self, data: impl AsSendBody) -> Result<Response<Body>, Error> {
let request = self.builder.body(())?;
let mut data_ref = data;
do_call(self.agent, request, data_ref.as_body())
do_call(self.agent, request, self.query_extra, data_ref.as_body())
}

/// Send body data as JSON.
Expand Down Expand Up @@ -285,15 +301,67 @@ impl RequestBuilder<WithBody> {
pub fn send_json(self, data: impl serde::ser::Serialize) -> Result<Response<Body>, Error> {
let request = self.builder.body(())?;
let body = SendBody::from_json(&data)?;
do_call(self.agent, request, body)
do_call(self.agent, request, self.query_extra, body)
}
}

fn do_call(agent: Agent, request: Request<()>, body: SendBody) -> Result<Response<Body>, Error> {
fn do_call(
agent: Agent,
mut request: Request<()>,
query_extra: Vec<QueryParam<'static>>,
body: SendBody,
) -> Result<Response<Body>, Error> {
if !query_extra.is_empty() {
request.uri().ensure_valid_url()?;
request = amend_request_query(request, query_extra.into_iter());
}
let response = agent.run_via_middleware(request, body)?;
Ok(response)
}

fn amend_request_query(
request: Request<()>,
query_extra: impl Iterator<Item = QueryParam<'static>>,
) -> Request<()> {
let (mut parts, body) = request.into_parts();
let uri = parts.uri;
let mut path = uri.path().to_string();
let query_existing = parse_query_params(uri.query().unwrap_or(""));

let mut do_first = true;

fn append<'a>(
path: &mut String,
do_first: &mut bool,
iter: impl Iterator<Item = QueryParam<'a>>,
) {
for q in iter {
if *do_first {
*do_first = false;
path.push('?');
} else {
path.push('&');
}
path.push_str(&q);
}
}

append(&mut path, &mut do_first, query_existing);
append(&mut path, &mut do_first, query_extra);

// Unwraps are OK, because we had a correct URI to begin with
let rebuild = Uri::builder()
.scheme(uri.scheme().unwrap().clone())
.authority(uri.authority().unwrap().clone())
.path_and_query(path)
.build()
.unwrap();

parts.uri = rebuild;

Request::from_parts(parts, body)
}

impl<MethodLimit> Deref for RequestBuilder<MethodLimit> {
type Target = http::request::Builder;

Expand Down Expand Up @@ -368,4 +436,56 @@ mod test {
let mut req = get("http://x.y.z/ borked url");
req.timeouts().global = Some(Duration::from_millis(1));
}

#[test]
fn add_params_to_request_without_query() {
let request = Request::builder()
.uri("https://foo.bar/path")
.body(())
.unwrap();

let amended = amend_request_query(
request,
vec![
QueryParam::new_key_value("x", "z"),
QueryParam::new_key_value("ab", "cde"),
]
.into_iter(),
);

assert_eq!(amended.uri(), "https://foo.bar/path?x=z&ab=cde");
}

#[test]
fn add_params_to_request_with_query() {
let request = Request::builder()
.uri("https://foo.bar/path?x=z")
.body(())
.unwrap();

let amended = amend_request_query(
request,
vec![QueryParam::new_key_value("ab", "cde")].into_iter(),
);

assert_eq!(amended.uri(), "https://foo.bar/path?x=z&ab=cde");
}

#[test]
fn add_params_that_need_percent_encoding() {
let request = Request::builder()
.uri("https://foo.bar/path")
.body(())
.unwrap();

let amended = amend_request_query(
request,
vec![QueryParam::new_key_value("å ", "i åa ä e ö")].into_iter(),
);

assert_eq!(
amended.uri(),
"https://foo.bar/path?%C3%A5%20=i%20%C3%A5a%20%C3%A4%20e%20%C3%B6"
);
}
}

0 comments on commit 908b45c

Please sign in to comment.