Skip to content

Commit

Permalink
aws - cache clients by region (#9107)
Browse files Browse the repository at this point in the history
  • Loading branch information
kapilt committed May 17, 2024
1 parent 92830bf commit 9706bf6
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 26 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ install:
.PHONY: test

test:
. $(PWD)/test.env && poetry run pytest -n auto tests tools $(ARGS)
. $(PWD)/test.env && poetry run pytest -n auto $(ARGS) tests tools

test-coverage:
. $(PWD)/test.env && poetry run pytest -n auto \
Expand Down
46 changes: 44 additions & 2 deletions c7n/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
Authentication utilities
"""
import threading
import os

from botocore.credentials import RefreshableCredentials
Expand All @@ -19,6 +20,47 @@
'C7N_USE_STS_REGIONAL', '').lower() in ('yes', 'true')


class CustodianSession(Session):

# track clients and return extant ones if present
_clients = {}
lock = threading.Lock()

def client(self, service_name, region_name=None, *args, **kw):
if kw.get('config'):
return super().client(service_name, region_name, *args, **kw)

key = self._cache_key(service_name, region_name)
client = self._clients.get(key)
if client is not None:
return client

with self.lock:
client = self._clients.get(key)
if client is not None:
return client

client = super().client(service_name, region_name, *args, **kw)
self._clients[key] = client
return client

def _cache_key(self, service_name, region_name):
region_name = region_name or self.region_name
return (
# namedtuple so stable comparison
hash(self.get_credentials().get_frozen_credentials()),
service_name,
region_name
)

@classmethod
def close(cls):
with cls.lock:
for c in cls._clients.values():
c.close()
cls._clients = {}


class SessionFactory:

def __init__(self, region, profile=None, assume_role=None, external_id=None):
Expand All @@ -45,7 +87,7 @@ def __call__(self, assume=True, region=None):
self.assume_role, self.session_name, session,
region or self.region, self.external_id)
else:
session = Session(
session = CustodianSession(
region_name=region or self.region, profile_name=self.profile)

return self.update(session)
Expand Down Expand Up @@ -118,7 +160,7 @@ def refresh():
if region is None:
region = s.get_config_variable('region') or 'us-east-1'
s.set_config_variable('region', region)
return Session(botocore_session=s)
return CustodianSession(botocore_session=s)


def get_sts_client(session, region):
Expand Down
3 changes: 3 additions & 0 deletions c7n/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,9 @@ def reset_session_cache():
for k in [k for k in dir(CONN_CACHE) if not k.startswith('_')]:
setattr(CONN_CACHE, k, {})

from .credentials import CustodianSession
CustodianSession.close()


def annotation(i, k):
return i.get(k, ())
Expand Down
51 changes: 28 additions & 23 deletions tests/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@
import placebo

from c7n import credentials
from c7n.credentials import SessionFactory, assumed_session, get_sts_client
from c7n.credentials import (
CustodianSession, SessionFactory, assumed_session, get_sts_client
)
from c7n.version import version
from c7n.utils import local_session

from .common import BaseTest

import freezegun


class Credential(BaseTest):

Expand All @@ -35,30 +39,31 @@ def test_regional_sts(self):
'arn:aws:iam::644160558196:user/kapil')

def test_assumed_session(self):
factory = self.replay_flight_data("test_credential_sts")
session = assumed_session(
role_arn='arn:aws:iam::644160558196:role/CustodianGuardDuty',
session_name="custodian-dev",
session=factory(),
)
with freezegun.freeze_time('2019-09-29T22:07:46+00:00'):
factory = self.replay_flight_data("test_credential_sts")
session = assumed_session(
role_arn='arn:aws:iam::644160558196:role/CustodianGuardDuty',
session_name="custodian-dev",
session=factory(),
)

# attach the placebo flight recorder to the new session.
pill = placebo.attach(
session, os.path.join(self.placebo_dir, 'test_credential_sts'))
if self.recording:
pill.record()
else:
pill.playback()
self.addCleanup(pill.stop)
# attach the placebo flight recorder to the new session.
pill = placebo.attach(
session, os.path.join(self.placebo_dir, 'test_credential_sts'))
if self.recording:
pill.record()
else:
pill.playback()
self.addCleanup(pill.stop)

try:
identity = session.client("sts").get_caller_identity()
except ClientError as e:
self.assertEqual(e.response["Error"]["Code"], "ValidationError")
try:
identity = session.client("sts").get_caller_identity()
except ClientError as e:
self.assertEqual(e.response["Error"]["Code"], "ValidationError")

self.assertEqual(
identity['Arn'],
'arn:aws:sts::644160558196:assumed-role/CustodianGuardDuty/custodian-dev')
self.assertEqual(
identity['Arn'],
'arn:aws:sts::644160558196:assumed-role/CustodianGuardDuty/custodian-dev')

def test_policy_name_user_agent(self):
session = SessionFactory("us-east-1")
Expand All @@ -80,7 +85,7 @@ def test_local_session_agent_update(self):
client = local_session(factory).client('ec2')
self.assertTrue(
'check-ebs' in client._client_config.user_agent)

CustodianSession.close()
factory.policy_name = "check-ec2"
factory.update(local_session(factory))
client = local_session(factory).client('ec2')
Expand Down

0 comments on commit 9706bf6

Please sign in to comment.