diff --git a/eta/core/storage.py b/eta/core/storage.py index afc6d860..047cfb69 100644 --- a/eta/core/storage.py +++ b/eta/core/storage.py @@ -32,23 +32,26 @@ import configparser import datetime +import dateutil.parser import io import logging import os import re +import requests try: import urllib.parse as urlparse # Python 3 except ImportError: import urlparse # Python 2 -import dateutil.parser -import requests +import urllib3 try: import boto3 import botocore import botocore.config as bcc + import botocore.credentials as bcr + import botocore.session as bcs import google.api_core.exceptions as gae import google.api_core.retry as gar import google.cloud.storage as gcs @@ -636,8 +639,13 @@ def __init__( """Creates a _BotoStorageClient instance. Args: - credentials: a credentials dictionary to be passed to `boto3` via - `boto3.client("s3", **credentials)` + credentials: a credentials dictionary, which must contain one of + the following: + + - credentials to be passed to `boto3` via + `boto3.client("s3", **credentials)` + - a dict with `"role_arn"` and `"web_identity_token_file"` + keys to use to generate credentials alias: a prefix for all cloud path strings, e.g., "s3" endpoint_url: the storage endpoint, if different from the default AWS service endpoint @@ -646,9 +654,6 @@ def __init__( **kwargs: optional configuration options for `botocore.config.Config(**kwargs)` """ - self.alias = alias - self.endpoint_url = endpoint_url - prefixes = [] if alias is not None: @@ -662,8 +667,6 @@ def __init__( "At least one of `alias` and `endpoint_url` must be provided" ) - self._prefixes = tuple(prefixes) - if "retries" not in kwargs: kwargs["retries"] = {"max_attempts": 10, "mode": "standard"} @@ -676,26 +679,91 @@ def __init__( ): kwargs["max_pool_connections"] = max_pool_connections + self.alias = alias + self.endpoint_url = endpoint_url + + self._prefixes = tuple(prefixes) + self._role_arn = None + self._web_identity_token = None + self._duration_seconds = None + self._sts_client = None + + session = self._make_session(credentials) + config = bcc.Config(**kwargs) + client = session.client("s3", endpoint_url=endpoint_url, config=config) + + self._session = session + self._client = client + + def _make_session(self, credentials): # - # We allow users to provide credentials via numerous options, which - # have already been parsed from above. However, the AWS libraries seem - # to complain when a profile is specified whose credentials aren't - # configured via environment variables or `~/.aws/credentials`. + # We allow users to provide credentials via a dictionary. However, the + # AWS libraries seem to complain when a profile is specified whose + # credentials aren't configured via environment variables or + # `~/.aws/credentials`. # - # Temporarily hiding the `AWS_PROFILE` env var when creating the client - # with manually provided credentials as kwargs bypasses the issue + # Temporarily hiding the `AWS_PROFILE` environment varible when + # creating the session below seems to bypass the issue # aws_profile = os.environ.pop("AWS_PROFILE", None) try: - config = bcc.Config(**kwargs) - self._client = boto3.client( - "s3", endpoint_url=endpoint_url, config=config, **credentials, + # Create session from permanent credentials + if "role_arn" not in credentials: + return boto3.Session(**credentials) + + # Create session with autorefreshing temporary credentials + role_arn = credentials["role_arn"] + web_identity_token_file = credentials["web_identity_token_file"] + region_name = credentials.get("region_name", None) + + web_identity_token = etau.read_file(web_identity_token_file) + + sts_client = boto3.client("sts", region_name=region_name) + + try: + response = sts_client.get_role(RoleArn=role_arn) + duration_seconds = response["Role"]["MaxSessionDuration"] + except: + duration_seconds = 3600 + + self._role_arn = role_arn + self._web_identity_token = web_identity_token + self._duration_seconds = duration_seconds + self._sts_client = sts_client + + _credentials = bcr.RefreshableCredentials.create_from_metadata( + metadata=self._refresh_temporary_credentials(), + refresh_using=self._refresh_temporary_credentials, + method="assume-role-with-web-identity", + ) + + session = bcs.get_session() + session._credentials = _credentials + session.set_config_variable("region", region_name) + + return boto3.Session( + botocore_session=session, region_name=region_name ) finally: if aws_profile is not None: os.environ["AWS_PROFILE"] = aws_profile + def _refresh_temporary_credentials(self): + response = self._sts_client.assume_role_with_web_identity( + RoleArn=self._role_arn, + RoleSessionName="voxel51", + WebIdentityToken=self._web_identity_token, + DurationSeconds=self._duration_seconds, + ) + + return { + "access_key": response["Credentials"]["AccessKeyId"], + "secret_key": response["Credentials"]["SecretAccessKey"], + "token": response["Credentials"]["SessionToken"], + "expiry_time": response["Credentials"]["Expiration"].isoformat(), + } + def upload(self, local_path, cloud_path, content_type=None, metadata=None): """Uploads the file to the cloud. @@ -1160,6 +1228,12 @@ class NeedsAWSCredentials(object): (4) setting the `AWS_CONFIG_FILE` environment variable to point to a valid credentials `.ini` file + (4) generating auto-refreshing temporary credentials from an IAM role + configured via the following environment variables: + + - `AWS_ROLE_ARN` + - `AWS_WEB_IDENTITY_TOKEN_FILE` + (5) loading credentials from `~/.eta/aws-credentials.ini` that have been activated via `cls.activate_credentials()` @@ -1294,6 +1368,30 @@ def load_credentials(cls, credentials_path=None, profile=None): "AWS_CONFIG_FILE='%s'", credentials_path, ) + elif ( + "AWS_ROLE_ARN" in os.environ + and "AWS_WEB_IDENTITY_TOKEN_FILE" in os.environ + ): + logger.debug( + "Loading role ARN and web identity token file from " + "'AWS_ROLE_ARN' and 'AWS_WEB_IDENTITY_TOKEN_FILE' environment " + "variables" + ) + credentials = { + "role_arn": os.environ["AWS_ROLE_ARN"], + "web_identity_token_file": os.environ[ + "AWS_WEB_IDENTITY_TOKEN_FILE" + ], + } + + if "AWS_DEFAULT_REGION" in os.environ: + logger.debug( + "Loading region from 'AWS_DEFAULT_REGION' environment " + "variable" + ) + credentials["region"] = os.environ["AWS_DEFAULT_REGION"] + + return credentials, None elif cls.has_active_credentials(): credentials_path = cls.CREDENTIALS_PATH logger.debug( @@ -3125,6 +3223,7 @@ def __init__( set_content_type=False, chunk_size=None, max_pool_connections=None, + retry=None, ): """Creates an HTTPStorageClient instance. @@ -3134,16 +3233,28 @@ def __init__( chunk_size: an optional chunk size (in bytes) to use for downloads. By default, `DEFAULT_CHUNK_SIZE` is used max_pool_connections: an optional maximum number of connections to - keep in the connection pool - """ + keep in the connection pool. The default is 10 + retry: an optional value for the ``max_retries`` parameter of + `requests.adapters.HTTPAdapter`. By default, a good general + purpose exponential backoff strategy is used + """ + if max_pool_connections is None: + max_pool_connections = 10 + + if retry is None: + retry = urllib3.util.retry.Retry( + total=10, + status_forcelist=[408, 429, 500, 502, 503, 504, 509], + backoff_factor=0.1, + ) + session = requests.Session() - if max_pool_connections is not None: - adapter = requests.adapters.HTTPAdapter( - pool_maxsize=max_pool_connections - ) - session.mount("http://", adapter) - session.mount("https://", adapter) + adapter = requests.adapters.HTTPAdapter( + pool_maxsize=max_pool_connections, max_retries=retry + ) + session.mount("http://", adapter) + session.mount("https://", adapter) self.set_content_type = set_content_type self.chunk_size = chunk_size or self.DEFAULT_CHUNK_SIZE diff --git a/eta/core/utils.py b/eta/core/utils.py index 708db3fa..ebc61b95 100644 --- a/eta/core/utils.py +++ b/eta/core/utils.py @@ -1466,7 +1466,7 @@ class ProgressBar(object): `start()` to start the task, call `pause()` before any `print` statements, and call `close()` when the task is finalized. - `ProgressBar`s can optionally be configured to print any of the following + `ProgressBar` can optionally be configured to print any of the following statistics about the task: - the elapsed time since the task was started @@ -2784,7 +2784,8 @@ def ensure_dir(dirname): Args: dirname: the directory path """ - os.makedirs(dirname, exist_ok=True) + if dirname: + os.makedirs(dirname, exist_ok=True) def has_extension(filename, *args): diff --git a/eta/core/video.py b/eta/core/video.py index 58e0ea8f..22ae0d0b 100644 --- a/eta/core/video.py +++ b/eta/core/video.py @@ -2526,7 +2526,6 @@ def extract_clip( # Slower, more accurate option ffmpeg -ss -i -t - When fast is True, the following two-step ffmpeg process is used:: # Faster, less accurate option diff --git a/requirements/common.txt b/requirements/common.txt index 5b945dee..88b3dd72 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -26,3 +26,4 @@ six==1.11.0 sortedcontainers==2.1.0 tabulate==0.8.5 tzlocal==2.0.0 +urllib3==1.25.11 diff --git a/setup.py b/setup.py index e5456c49..93edbc0c 100644 --- a/setup.py +++ b/setup.py @@ -42,6 +42,7 @@ def finalize_options(self): "sortedcontainers", "tabulate", "tzlocal", + "urllib3", ]