Skip to content

Commit

Permalink
Proposed fixes for token-based auth in azure fileshare service (#820)
Browse files Browse the repository at this point in the history
To use token-based authentication for `ShareClient`, I think we should
be passing in the credential object derived from `TokenCredential` (in
our case, some instance of `DefaultAzureCredential`.

Previously, we were passing specific string tokens to the `credential`
argument, which is being intepreted as a SAS token.
This leads to errors like:
`azure.core.exceptions.ClientAuthenticationError: Server failed to
authenticate the request. Make sure the value of Authorization header is
formed correctly including the signature.`


[ShareClient](https://learn.microsoft.com/en-us/python/api/azure-storage-file-share/azure.storage.fileshare.shareclient?view=azure-python)
documentation on the `credential` argument.

By passing in the whole `TokenCredential` object, I believe
`ShareClient` will manage the token lifecycle and we won't need to do so
as mentioned in #818.

---------

Co-authored-by: Eu Jing Chua <[email protected]>
Co-authored-by: Sergiy Matusevych <[email protected]>
  • Loading branch information
3 people authored Aug 3, 2024
1 parent 7fe167d commit 45528cf
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 24 deletions.
32 changes: 16 additions & 16 deletions mlos_bench/mlos_bench/services/remote/azure/azure_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Union

import azure.identity as azure_id
from azure.core.credentials import TokenCredential
from azure.identity import CertificateCredential, DefaultAzureCredential
from azure.keyvault.secrets import SecretClient
from pytz import UTC

Expand All @@ -20,7 +21,7 @@
_LOG = logging.getLogger(__name__)


class AzureAuthService(Service, SupportsAuth):
class AzureAuthService(Service, SupportsAuth[TokenCredential]):
"""Helper methods to get access to Azure services."""

_REQ_INTERVAL = 300 # = 5 min
Expand Down Expand Up @@ -56,6 +57,7 @@ def __init__(
[
self.get_access_token,
self.get_auth_headers,
self.get_credential,
],
),
)
Expand All @@ -65,10 +67,7 @@ def __init__(

self._access_token = "RENEW *NOW*"
self._token_expiration_ts = datetime.now(UTC) # Typically, some future timestamp.

# Login as the first identity available, usually ourselves or a managed identity
self._cred: Union[azure_id.DefaultAzureCredential, azure_id.CertificateCredential]
self._cred = azure_id.DefaultAzureCredential()
self._cred: Optional[TokenCredential] = None

# Verify info required for SP auth early
if "spClientId" in self.config:
Expand All @@ -82,18 +81,22 @@ def __init__(
},
)

def _init_sp(self) -> None:
def get_credential(self) -> TokenCredential:
"""Return the Azure SDK credential object."""
# Perform this initialization outside of __init__ so that environment loading tests
# don't need to specifically mock keyvault interactions out
if self._cred is not None:
return self._cred

# Already logged in as SP
if isinstance(self._cred, azure_id.CertificateCredential):
return
self._cred = DefaultAzureCredential()
if "spClientId" not in self.config:
return self._cred

sp_client_id = self.config["spClientId"]
keyvault_name = self.config["keyVaultName"]
cert_name = self.config["certName"]
tenant_id = self.config["tenant"]
_LOG.debug("Log in with Azure Service Principal %s", sp_client_id)

# Get a client for fetching cert info
keyvault_secrets_client = SecretClient(
Expand All @@ -108,23 +111,20 @@ def _init_sp(self) -> None:
cert_bytes = b64decode(secret.value)

# Reauthenticate as the service principal.
self._cred = azure_id.CertificateCredential(
self._cred = CertificateCredential(
tenant_id=tenant_id,
client_id=sp_client_id,
certificate_data=cert_bytes,
)
return self._cred

def get_access_token(self) -> str:
"""Get the access token from Azure CLI, if expired."""
# Ensure we are logged as the Service Principal, if provided
if "spClientId" in self.config:
self._init_sp()

ts_diff = (self._token_expiration_ts - datetime.now(UTC)).total_seconds()
_LOG.debug("Time to renew the token: %.2f sec.", ts_diff)
if ts_diff < self._req_interval:
_LOG.debug("Request new accessToken")
res = self._cred.get_token("https://management.azure.com/.default")
res = self.get_credential().get_token("https://management.azure.com/.default")
self._token_expiration_ts = datetime.fromtimestamp(res.expires_on, tz=UTC)
self._access_token = res.token
_LOG.info("Got new accessToken. Expiration time: %s", self._token_expiration_ts)
Expand Down
14 changes: 10 additions & 4 deletions mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
from typing import Any, Callable, Dict, List, Optional, Set, Union

from azure.core.credentials import TokenCredential
from azure.core.exceptions import ResourceNotFoundError
from azure.storage.fileshare import ShareClient

Expand Down Expand Up @@ -60,20 +61,25 @@ def __init__(
"storageFileShareName",
},
)
assert self._parent is not None and isinstance(
self._parent, SupportsAuth
), "Authorization service not provided. Include service-auth.jsonc?"
self._auth_service: SupportsAuth[TokenCredential] = self._parent
self._share_client: Optional[ShareClient] = None

def _get_share_client(self) -> ShareClient:
"""Get the Azure file share client object."""
if self._share_client is None:
assert self._parent is not None and isinstance(
self._parent, SupportsAuth
), "Authorization service not provided. Include service-auth.jsonc?"
credential = self._auth_service.get_credential()
assert isinstance(
credential, TokenCredential
), f"Expected a TokenCredential, but got {type(credential)} instead."
self._share_client = ShareClient.from_share_url(
self._SHARE_URL.format(
account_name=self.config["storageAccountName"],
fs_name=self.config["storageFileShareName"],
),
credential=self._parent.get_access_token(),
credential=credential,
token_intent="backup",
)
return self._share_client
Expand Down
16 changes: 14 additions & 2 deletions mlos_bench/mlos_bench/services/types/authenticator_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
#
"""Protocol interface for authentication for the cloud services."""

from typing import Protocol, runtime_checkable
from typing import Protocol, TypeVar, runtime_checkable

T_co = TypeVar("T_co", covariant=True)


@runtime_checkable
class SupportsAuth(Protocol):
class SupportsAuth(Protocol[T_co]):
"""Protocol interface for authentication for the cloud services."""

def get_access_token(self) -> str:
Expand All @@ -30,3 +32,13 @@ def get_auth_headers(self) -> dict:
access_header : dict
HTTP header containing the access token.
"""

def get_credential(self) -> T_co:
"""
Get the credential object for cloud services.
Returns
-------
credential : T
Cloud-specific credential object.
"""
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,28 @@ def test_load_service_config_examples(
config_path: str,
) -> None:
"""Tests loading a config example."""
parent: Service = config_loader_service
config = config_loader_service.load_config(config_path, ConfigSchema.SERVICE)
# Add other services that require a SupportsAuth parent service as necessary.
requires_auth_service_parent = {
"AzureFileShareService",
}
config_class_name = str(config.get("class", "MISSING CLASS")).rsplit(".", maxsplit=1)[-1]
if config_class_name in requires_auth_service_parent:
# AzureFileShareService requires an auth service to be loaded as well.
auth_service_config = config_loader_service.load_config(
"services/remote/mock/mock_auth_service.jsonc",
ConfigSchema.SERVICE,
)
auth_service = config_loader_service.build_service(
config=auth_service_config,
parent=config_loader_service,
)
parent = auth_service
# Make an instance of the class based on the config.
service_inst = config_loader_service.build_service(
config=config,
parent=config_loader_service,
parent=parent,
)
assert service_inst is not None
assert isinstance(service_inst, Service)
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
_LOG = logging.getLogger(__name__)


class MockAuthService(Service, SupportsAuth):
class MockAuthService(Service, SupportsAuth[str]):
"""A collection Service functions for mocking authentication ops."""

def __init__(
Expand All @@ -32,6 +32,7 @@ def __init__(
[
self.get_access_token,
self.get_auth_headers,
self.get_credential,
],
),
)
Expand All @@ -41,3 +42,6 @@ def get_access_token(self) -> str:

def get_auth_headers(self) -> dict:
return {"Authorization": "Bearer " + self.get_access_token()}

def get_credential(self) -> str:
return "MOCK CREDENTIAL"

0 comments on commit 45528cf

Please sign in to comment.