Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
Signed-off-by: Xuanwo <[email protected]>
  • Loading branch information
Xuanwo committed Jul 23, 2023
1 parent a9607f5 commit fab7195
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 40 deletions.
241 changes: 201 additions & 40 deletions src/aws/credential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@ use reqwest::Client;
use serde::Deserialize;

use super::config::Config;
use super::v4::Signer;
use crate::time::now;
use crate::time::parse_rfc3339;
use crate::time::DateTime;

pub const EMPTY_STRING_SHA256: &str =
"e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";

/// Credential that holds the access_key and secret_key.
#[derive(Default, Clone)]
#[cfg_attr(test, derive(Debug))]
Expand Down Expand Up @@ -54,7 +58,7 @@ impl Credential {

/// Loader trait will try to load credential from different sources.
#[async_trait]
pub trait CredentialLoad: 'static + Send + Sync + Debug {
pub trait CredentialLoad: 'static + Send + Sync {
/// Load credential from sources.
///
/// - If succeed, return `Ok(Some(cred))`
Expand All @@ -72,12 +76,6 @@ pub struct DefaultLoader {
credential: Arc<Mutex<Option<Credential>>>,
}

impl Debug for DefaultLoader {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DefaultLoader").finish()
}
}

impl DefaultLoader {
/// Create a new CredentialLoader
pub fn new(client: Client, config: Config) -> Self {
Expand Down Expand Up @@ -138,14 +136,6 @@ impl DefaultLoader {
return Ok(Some(cred));
}

if let Ok(Some(cred)) = self
.load_via_assume_role()
.await
.map_err(|err| debug!("load credential via assume_role failed: {err:?}"))
{
return Ok(Some(cred));
}

if let Ok(Some(cred)) = self
.load_via_imds_v2()
.await
Expand Down Expand Up @@ -247,20 +237,20 @@ impl DefaultLoader {
Ok(Some(cred))
}

async fn load_via_assume_role(&self) -> Result<Option<Credential>> {
let role_arn = match &self.config.role_arn {
Some(role_arn) => role_arn,
None => return Ok(None),
};
async fn load_via_assume_role_with_web_identity(&self) -> Result<Option<Credential>> {
let (token_file, role_arn) =
match (&self.config.web_identity_token_file, &self.config.role_arn) {
(Some(token_file), Some(role_arn)) => (token_file, role_arn),
_ => return Ok(None),
};

let token = fs::read_to_string(token_file)?;
let role_session_name = &self.config.role_session_name;

let endpoint = self.sts_endpoint()?;

// Construct request to AWS STS Service.
let mut url = format!("https://{endpoint}/?Action=AssumeRole&RoleArn={role_arn}&Version=2011-06-15&RoleSessionName={role_session_name}");
if let Some(external_id) = &self.config.external_id {
write!(url, "&ExternalId={external_id}")?;
}
let url = format!("https://{endpoint}/?Action=AssumeRoleWithWebIdentity&RoleArn={role_arn}&WebIdentityToken={token}&Version=2011-06-15&RoleSessionName={role_session_name}");
let req = self.client.get(&url).header(
http::header::CONTENT_TYPE.as_str(),
"application/x-www-form-urlencoded",
Expand All @@ -272,7 +262,7 @@ impl DefaultLoader {
return Err(anyhow!("request to AWS STS Services failed: {content}"));
}

let resp: AssumeRoleResponse = de::from_str(&resp.text().await?)?;
let resp: AssumeRoleWithWebIdentityResponse = de::from_str(&resp.text().await?)?;
let resp_cred = resp.result.credentials;

let cred = Credential {
Expand All @@ -285,32 +275,118 @@ impl DefaultLoader {
Ok(Some(cred))
}

async fn load_via_assume_role_with_web_identity(&self) -> Result<Option<Credential>> {
let (token_file, role_arn) =
match (&self.config.web_identity_token_file, &self.config.role_arn) {
(Some(token_file), Some(role_arn)) => (token_file, role_arn),
_ => return Ok(None),
};
/// Get the sts endpoint.
///
/// The returning format may look like `sts.{region}.amazonaws.com`
///
/// # Notes
///
/// AWS could have different sts endpoint based on it's region.
/// We can check them by region name.
///
/// ref: https://github.com/awslabs/aws-sdk-rust/blob/31cfae2cf23be0c68a47357070dea1aee9227e3a/sdk/sts/src/aws_endpoint.rs
fn sts_endpoint(&self) -> Result<String> {
// use regional sts if sts_regional_endpoints has been set.
if self.config.sts_regional_endpoints == "regional" {
let region = self.config.region.clone().ok_or_else(|| {
anyhow!("sts_regional_endpoints set to reginal, but region is not set")
})?;
if region.starts_with("cn-") {
Ok(format!("sts.{region}.amazonaws.com.cn"))
} else {
Ok(format!("sts.{region}.amazonaws.com"))
}
} else {
let region = self.config.region.clone().unwrap_or_default();
if region.starts_with("cn") {
// TODO: seems aws china doesn't support global sts?
Ok("sts.amazonaws.com.cn".to_string())
} else {
Ok("sts.amazonaws.com".to_string())
}
}
}
}

let token = fs::read_to_string(token_file)?;
#[async_trait]
impl CredentialLoad for DefaultLoader {
async fn load_credential(&self, _: Client) -> Result<Option<Credential>> {
self.load().await
}
}

/// AssumeRoleLoader will load credential via assume role.
pub struct AssumeRoleLoader {
client: Client,
config: Config,

source_credential: Box<dyn CredentialLoad>,
sts_signer: Signer,
}

impl AssumeRoleLoader {
/// Create a new assume role loader.
pub fn new(
client: Client,
config: Config,
source_credential: Box<dyn CredentialLoad>,
) -> Result<Self> {
let region = config.region.clone().ok_or_else(|| {
anyhow!("assume role loader requires region, but not found, please check your configuration")
})?;

Ok(Self {
client,
config,
source_credential,

sts_signer: Signer::new("sts", &region),
})
}

/// Load credential via assume role.
pub async fn load(&self) -> Result<Option<Credential>> {
let role_arn = match &self.config.role_arn {
Some(role_arn) => role_arn,
None => return Ok(None),
};
let role_session_name = &self.config.role_session_name;

let endpoint = self.sts_endpoint()?;

// Construct request to AWS STS Service.
let url = format!("https://{endpoint}/?Action=AssumeRoleWithWebIdentity&RoleArn={role_arn}&WebIdentityToken={token}&Version=2011-06-15&RoleSessionName={role_session_name}");
let req = self.client.get(&url).header(
http::header::CONTENT_TYPE.as_str(),
"application/x-www-form-urlencoded",
);
let mut url = format!("https://{endpoint}/?Action=AssumeRole&RoleArn={role_arn}&Version=2011-06-15&RoleSessionName={role_session_name}");
if let Some(external_id) = &self.config.external_id {
write!(url, "&ExternalId={external_id}")?;
}
let mut req = self
.client
.get(&url)
.header(
http::header::CONTENT_TYPE.as_str(),
"application/x-www-form-urlencoded",
)
// Set content sha to empty string.
.header("X_AMZ_CONTENT_SHA_256", EMPTY_STRING_SHA256)
.build()?;

let source_cred = self
.source_credential
.load_credential(self.client.clone())
.await?
.ok_or_else(|| {
anyhow!("source credential is required for AssumeRole, but not found, please check your configuration")
})?;

let resp = req.send().await?;
self.sts_signer.sign(&mut req, &source_cred)?;

let resp = self.client.execute(req).await?;
if resp.status() != http::StatusCode::OK {
let content = resp.text().await?;
return Err(anyhow!("request to AWS STS Services failed: {content}"));
}

let resp: AssumeRoleWithWebIdentityResponse = de::from_str(&resp.text().await?)?;
let resp: AssumeRoleResponse = de::from_str(&resp.text().await?)?;
let resp_cred = resp.result.credentials;

let cred = Credential {
Expand Down Expand Up @@ -357,7 +433,7 @@ impl DefaultLoader {
}

#[async_trait]
impl CredentialLoad for DefaultLoader {
impl CredentialLoad for AssumeRoleLoader {
async fn load_credential(&self, _: Client) -> Result<Option<Credential>> {
self.load().await
}
Expand Down Expand Up @@ -676,6 +752,91 @@ mod tests {
Ok(())
}

#[test]
fn test_signer_with_web_loader_assume_role() -> Result<()> {
let _ = env_logger::builder().is_test(true).try_init();

dotenv::from_filename(".env").ok();

if env::var("REQSIGN_AWS_S3_TEST").is_err()
|| env::var("REQSIGN_AWS_S3_TEST").unwrap() != "on"
{
return Ok(());
}

// Ignore test if role_arn not set
let role_arn = if let Ok(v) = env::var("REQSIGN_AWS_ROLE_ARN") {
v
} else {
return Ok(());
};
// Ignore test if assume_role_arn not set
let assume_role_arn = if let Ok(v) = env::var("REQSIGN_AWS_ASSUME_ROLE_ARN") {
v
} else {
return Ok(());
};

// let provider_arn = env::var("REQSIGN_AWS_PROVIDER_ARN").expect("REQSIGN_AWS_PROVIDER_ARN not exist");
let region = env::var("REQSIGN_AWS_S3_REGION").expect("REQSIGN_AWS_S3_REGION not exist");

let github_token = env::var("GITHUB_ID_TOKEN").expect("GITHUB_ID_TOKEN not exist");
let file_path = format!(
"{}/testdata/services/aws/web_identity_token_file",
env::current_dir()
.expect("current_dir must exist")
.to_string_lossy()
);
fs::write(&file_path, github_token)?;

temp_env::with_vars(
vec![
(AWS_REGION, Some(&region)),
(AWS_ROLE_ARN, Some(&role_arn)),
(AWS_WEB_IDENTITY_TOKEN_FILE, Some(&file_path)),
],
|| {
RUNTIME.block_on(async {
let client = reqwest::Client::new();
let default_loader =
DefaultLoader::new(client.clone(), Config::default().from_env())
.with_disable_ec2_metadata();

let cfg = Config {
role_arn: Some(assume_role_arn.clone()),
region: Some(region.clone()),
..Default::default()
};
let loader =
AssumeRoleLoader::new(client.clone(), cfg, Box::new(default_loader))
.expect("AssumeRoleLoader must be valid");

let signer = Signer::new("s3", &region);
let endpoint = format!("https://s3.{}.amazonaws.com/opendal-testing", region);
let mut req = Request::new("");
*req.method_mut() = http::Method::GET;
*req.uri_mut() =
http::Uri::from_str(&format!("{}/{}", endpoint, "not_exist_file")).unwrap();
let cred = loader
.load()
.await
.expect("credential must be valid")
.unwrap();
signer.sign(&mut req, &cred).expect("sign must success");
debug!("signed request url: {:?}", req.uri().to_string());
debug!("signed request: {:?}", req);
let client = Client::new();
let resp = client.execute(req.try_into().unwrap()).await.unwrap();
let status = resp.status();
debug!("got response: {:?}", resp);
debug!("got response content: {:?}", resp.text().await.unwrap());
assert_eq!(status, StatusCode::NOT_FOUND);
})
},
);
Ok(())
}

#[test]
fn test_parse_assume_role_with_web_identity_response() -> Result<()> {
let _ = env_logger::builder().is_test(true).try_init();
Expand Down
1 change: 1 addition & 0 deletions src/aws/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod config;
pub use config::Config as AwsConfig;

mod credential;
pub use credential::AssumeRoleLoader as AwsAssumeRoleLoader;
pub use credential::Credential as AwsCredential;
pub use credential::CredentialLoad as AwsCredentialLoad;
pub use credential::DefaultLoader as AwsDefaultLoader;
Expand Down

0 comments on commit fab7195

Please sign in to comment.