Skip to content

Commit

Permalink
update IamAwsProvider as per minio-go implementation (#1437)
Browse files Browse the repository at this point in the history
  • Loading branch information
setu4993 authored Aug 22, 2024
1 parent c9b4c49 commit f673f09
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 61 deletions.
107 changes: 81 additions & 26 deletions minio/credentials/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=too-many-branches

"""Credential providers."""

from __future__ import annotations
Expand All @@ -29,7 +31,7 @@
from datetime import timedelta
from pathlib import Path
from typing import Callable, cast
from urllib.parse import urlencode, urlsplit
from urllib.parse import urlencode, urlsplit, urlunsplit
from xml.etree import ElementTree as ET

import certifi
Expand All @@ -42,7 +44,7 @@

from urllib3.util import Retry, parse_url

from minio.helpers import sha256_hash
from minio.helpers import sha256_hash, url_replace
from minio.signer import sign_v4_sts
from minio.time import from_iso8601utc, to_amz_date, utcnow
from minio.xml import find, findtext
Expand Down Expand Up @@ -381,6 +383,13 @@ def __init__(
self,
custom_endpoint: str | None = None,
http_client: PoolManager | None = None,
auth_token: str | None = None,
relative_uri: str | None = None,
full_uri: str | None = None,
token_file: str | None = None,
role_arn: str | None = None,
role_session_name: str | None = None,
region: str | None = None,
):
self._custom_endpoint = custom_endpoint
self._http_client = http_client or PoolManager(
Expand All @@ -390,22 +399,41 @@ def __init__(
status_forcelist=[500, 502, 503, 504],
),
)
self._token_file = os.environ.get("AWS_WEB_IDENTITY_TOKEN_FILE")
self._aws_region = os.environ.get("AWS_REGION")
self._role_arn = os.environ.get("AWS_ROLE_ARN")
self._role_session_name = os.environ.get("AWS_ROLE_SESSION_NAME")
self._relative_uri = os.environ.get(
"AWS_CONTAINER_CREDENTIALS_RELATIVE_URI",
self._token = (
os.environ.get("AWS_CONTAINER_AUTHORIZATION_TOKEN") or
auth_token
)
self._token_file = (
os.environ.get("AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE") or
auth_token
)
self._identity_file = (
os.environ.get("AWS_WEB_IDENTITY_TOKEN_FILE") or token_file
)
self._aws_region = os.environ.get("AWS_REGION") or region
self._role_arn = os.environ.get("AWS_ROLE_ARN") or role_arn
self._role_session_name = (
os.environ.get("AWS_ROLE_SESSION_NAME") or role_session_name
)
self._relative_uri = (
os.environ.get("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") or
relative_uri
)
if self._relative_uri and not self._relative_uri.startswith("/"):
self._relative_uri = "/" + self._relative_uri
self._full_uri = os.environ.get("AWS_CONTAINER_CREDENTIALS_FULL_URI")
self._full_uri = (
os.environ.get("AWS_CONTAINER_CREDENTIALS_FULL_URI") or
full_uri
)
self._credentials: Credentials | None = None

def fetch(self, url: str) -> Credentials:
"""Fetch credentials from EC2/ECS. """

res = _urlopen(self._http_client, "GET", url)
def fetch(
self,
url: str,
headers: dict[str, str | list[str] | tuple[str]] | None = None,
) -> Credentials:
"""Fetch credentials from EC2/ECS."""
res = _urlopen(self._http_client, "GET", url, headers=headers)
data = json.loads(res.data)
if data.get("Code", "Success") != "Success":
raise ValueError(
Expand All @@ -428,14 +456,16 @@ def retrieve(self) -> Credentials:
return self._credentials

url = self._custom_endpoint
if self._token_file:
if self._identity_file:
if not url:
url = "https://sts.amazonaws.com"
if self._aws_region:
url = f"https://sts.{self._aws_region}.amazonaws.com"
if self._aws_region.startswith("cn-"):
url += ".cn"

provider = WebIdentityProvider(
lambda: _get_jwt_token(cast(str, self._token_file)),
lambda: _get_jwt_token(cast(str, self._identity_file)),
url,
role_arn=self._role_arn,
role_session_name=self._role_session_name,
Expand All @@ -444,30 +474,55 @@ def retrieve(self) -> Credentials:
self._credentials = provider.retrieve()
return cast(Credentials, self._credentials)

headers: dict[str, str | list[str] | tuple[str]] | None = None
if self._relative_uri:
if not url:
url = "http://169.254.170.2" + self._relative_uri
headers = {"Authorization": self._token} if self._token else None
elif self._full_uri:
if not url:
token = self._token
if self._token_file:
url = self._full_uri
_check_loopback_host(url)
with open(self._token_file, encoding="utf-8") as file:
token = file.read()
else:
if not url:
url = self._full_uri
_check_loopback_host(url)
headers = {"Authorization": token} if token else None
else:
if not url:
url = (
"http://169.254.169.254" +
"/latest/meta-data/iam/security-credentials/"
)

res = _urlopen(self._http_client, "GET", url)
url = "http://169.254.169.254"

# Get IMDS Token
res = _urlopen(
self._http_client,
"PUT",
url+"/latest/api/token",
headers={"X-aws-ec2-metadata-token-ttl-seconds": "21600"},
)
token = res.data.decode("utf-8")
headers = {"X-aws-ec2-metadata-token": token} if token else None

# Get role name
res = _urlopen(
self._http_client,
"GET",
urlunsplit(
url_replace(
urlsplit(url),
path="/latest/meta-data/iam/security-credentials/",
),
),
headers=headers,
)
role_names = res.data.decode("utf-8").split("\n")
if not role_names:
raise ValueError(f"no IAM roles attached to EC2 service {url}")
url += "/" + role_names[0].strip("\r")

if not url:
raise ValueError("url is empty; this should not happen")

self._credentials = self.fetch(url)
self._credentials = self.fetch(url, headers=headers)
return self._credentials


Expand Down
35 changes: 0 additions & 35 deletions tests/unit/credentials_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import unittest.mock as mock
from datetime import datetime, timedelta
from unittest import TestCase

from minio.credentials.credentials import Credentials
from minio.credentials.providers import (AWSConfigProvider, ChainedProvider,
EnvAWSProvider, EnvMinioProvider,
IamAwsProvider,
MinioClientConfigProvider,
StaticProvider)

Expand All @@ -45,36 +40,6 @@ def test_credentials_get(self):
self.assertEqual(creds.session_token, None)


class CredListResponse(object):
status = 200
data = b"test-s3-full-access-for-minio-ec2"


class CredsResponse(object):
status = 200
data = json.dumps({
"Code": "Success",
"Type": "AWS-HMAC",
"AccessKeyId": "accessKey",
"SecretAccessKey": "secret",
"Token": "token",
"Expiration": "2014-12-16T01:51:37Z",
"LastUpdated": "2009-11-23T0:00:00Z"
})


class IamAwsProviderTest(TestCase):
@mock.patch("urllib3.PoolManager.urlopen")
def test_iam(self, mock_connection):
mock_connection.side_effect = [CredListResponse(), CredsResponse()]
provider = IamAwsProvider()
creds = provider.retrieve()
self.assertEqual(creds.access_key, "accessKey")
self.assertEqual(creds.secret_key, "secret")
self.assertEqual(creds.session_token, "token")
self.assertEqual(creds._expiration, datetime(2014, 12, 16, 1, 51, 37))


class ChainedProviderTest(TestCase):
def test_chain_retrieve(self):
# clear environment
Expand Down

0 comments on commit f673f09

Please sign in to comment.