Skip to content

Commit

Permalink
ENG-528: remove credits enforcer (#2195)
Browse files Browse the repository at this point in the history
* ENG-528: remove credits enforcer
  • Loading branch information
parikls authored Dec 20, 2024
1 parent f869d8d commit aa36a70
Show file tree
Hide file tree
Showing 10 changed files with 137 additions and 237 deletions.
9 changes: 0 additions & 9 deletions platform_api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

from platform_api.orchestrator.job_policy_enforcer import (
CreditsLimitEnforcer,
CreditsNotificationsEnforcer,
JobPolicyEnforcePoller,
RetentionPolicyEnforcer,
RuntimeLimitEnforcer,
Expand Down Expand Up @@ -509,14 +508,6 @@ async def _init_app(app: aiohttp.web.Application) -> AsyncIterator[None]:
enforcers=[
RuntimeLimitEnforcer(jobs_service),
CreditsLimitEnforcer(jobs_service, admin_client),
CreditsNotificationsEnforcer(
jobs_service=jobs_service,
admin_client=admin_client,
notifications_client=notifications_client,
notification_threshold=(
config.job_policy_enforcer.credit_notification_threshold
),
),
StopOnClusterRemoveEnforcer(
jobs_service=jobs_service,
auth_client=auth_client,
Expand Down
77 changes: 0 additions & 77 deletions platform_api/orchestrator/job_policy_enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,12 @@
from collections import defaultdict
from collections.abc import Callable, Iterable, Mapping
from datetime import timedelta
from decimal import Decimal
from typing import Any, Optional, TypeVar

from aiohttp import ClientResponseError
from neuro_admin_client import AdminClient, ClusterUser, OrgCluster
from neuro_auth_client import AuthClient
from neuro_logging import new_trace, trace
from neuro_notifications_client import (
Client as NotificationsClient,
CreditsWillRunOutSoon,
)

from platform_api.cluster import ClusterConfigRegistry
from platform_api.config import JobPolicyEnforcerConfig
Expand All @@ -35,78 +30,6 @@ async def enforce(self) -> None:
pass


class CreditsNotificationsEnforcer(JobPolicyEnforcer):
def __init__(
self,
jobs_service: JobsService,
admin_client: AdminClient,
notifications_client: NotificationsClient,
notification_threshold: Decimal,
):
self._jobs_service = jobs_service
self._admin_client = admin_client
self._notifications_client = notifications_client
self._threshold = notification_threshold
self._sent: dict[tuple[str, str], Optional[Decimal]] = defaultdict(lambda: None)

async def _notify_user_if_needed(
self,
username: str,
cluster_name: str,
org_name: Optional[str],
credits: Optional[Decimal],
) -> None:
notification_key = (username, cluster_name)
if credits is None or credits >= self._threshold:
return
# Note: this check is also performed in notifications service
# using redis storage, so it's OK to use in memory dict here:
# this is just an optimization to avoid spamming it
# with duplicate notifications
if self._sent[notification_key] == credits:
return
# TODO patch notifications to support org_name
await self._notifications_client.notify(
CreditsWillRunOutSoon(
user_id=username, cluster_name=cluster_name, credits=credits
)
)
self._sent[notification_key] = credits

@trace
async def enforce(self) -> None:
user_to_clusters: dict[str, set[tuple[str, Optional[str]]]] = defaultdict(set)
job_filter = JobFilter(
statuses={JobStatus(item) for item in JobStatus.active_values()}
)
async with self._jobs_service.iter_all_jobs(job_filter) as running_jobs:
async for job in running_jobs:
user_to_clusters[job.owner].add((job.cluster_name, job.org_name))
await run_and_log_exceptions(
self._enforce_for_user(username, clusters_with_org)
for username, clusters_with_org in user_to_clusters.items()
)

async def _enforce_for_user(
self, username: str, clusters_and_orgs: set[tuple[str, Optional[str]]]
) -> None:
base_name = username.split("/", 1)[0] # SA inherit balance from main user
_, cluster_users = await self._admin_client.get_user_with_clusters(base_name)
cluster_to_user = {
(cluster_user.cluster_name, cluster_user.org_name): cluster_user
for cluster_user in cluster_users
}
for cluster_name, org_name in clusters_and_orgs:
cluster_user = cluster_to_user.get((cluster_name, org_name))
if cluster_user:
await self._notify_user_if_needed(
username=username,
cluster_name=cluster_name,
org_name=org_name,
credits=cluster_user.balance.credits,
)


class RuntimeLimitEnforcer(JobPolicyEnforcer):
def __init__(self, service: JobsService):
self._service = service
Expand Down
57 changes: 32 additions & 25 deletions platform_api/orchestrator/jobs_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
AdminClient,
ClusterUser,
GetUserResponse,
Org,
OrgCluster,
OrgUser,
ProjectUser,
Expand Down Expand Up @@ -166,14 +167,12 @@ async def _raise_for_orgs_running_jobs_quota(self, org_cluster: OrgCluster) -> N
if running_count >= org_cluster.quota.total_running_jobs:
raise RunningJobsQuotaExceededError.create_for_org(org_cluster.org_name)

async def _raise_for_no_credits(
self, cluster_entry: Union[ClusterUser, OrgCluster]
) -> None:
if cluster_entry.balance.is_non_positive:
if isinstance(cluster_entry, ClusterUser):
raise NoCreditsError.create_for_user(cluster_entry.user_name)
async def _raise_for_no_credits(self, org_entry: Union[OrgUser, Org]) -> None:
if org_entry.balance.is_non_positive:
if isinstance(org_entry, OrgUser):
raise NoCreditsError.create_for_user(org_entry.user_name)
else:
raise NoCreditsError.create_for_org(cluster_entry.org_name)
raise NoCreditsError.create_for_org(org_entry.name)

async def _make_pass_config_token(
self, username: str, cluster_name: str, job_id: str
Expand Down Expand Up @@ -243,37 +242,45 @@ async def create_job(
base_name = get_base_owner(
user.name
) # SA has access to same clusters as a user
cluster_user = await self._admin_client.get_cluster_user(
user_name=base_name,
cluster_name=cluster_name,
org_name=org_name,
)

if job_name is not None and maybe_job_id(job_name):
raise JobsServiceException(
"Failed to create job: job name cannot start with 'job-' prefix."
)

# check quotas for both a user and a cluster
cluster_user = await self._admin_client.get_cluster_user(
user_name=base_name,
cluster_name=cluster_name,
org_name=org_name,
)
if not wait_for_jobs_quota:
await self._raise_for_running_jobs_quota(cluster_user)
try:
await self._raise_for_no_credits(cluster_user)
except NoCreditsError:
await self._notifications_client.notify(
JobCannotStartNoCredits(
user_id=user.name,
cluster_name=cluster_name,
)
)
raise

if org_name:
# check that OrgCluster itself has enough credits and quota:
org_cluster = await self._admin_client.get_org_cluster(
cluster_name, org_name
)
if not wait_for_jobs_quota:
await self._raise_for_orgs_running_jobs_quota(org_cluster)
# TODO: add notification about org cluster credits exhausted
await self._raise_for_no_credits(org_cluster)
org_user = await self._admin_client.get_org_user(
org_name=org_name,
user_name=base_name,
)

try:
await self._raise_for_no_credits(org_user)
except NoCreditsError:
await self._notifications_client.notify(
JobCannotStartNoCredits(
user_id=user.name,
cluster_name=cluster_name,
)
)
raise

org = await self._admin_client.get_org(org_name)
await self._raise_for_no_credits(org)

if pass_config:
job_request = await self._setup_pass_config(
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ install_requires =
alembic==1.9.4
psycopg2-binary==2.9.7
typing-extensions==4.9.0
neuro-admin-client==24.11.0
neuro-admin-client==24.12.4
yarl==1.12.1

[options.entry_points]
Expand Down
15 changes: 12 additions & 3 deletions tests/integration/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ def test_cluster_name() -> str:
return "test-cluster"


@pytest.fixture
def test_org_name() -> str:
return "test-org"


class UserFactory(Protocol):
async def __call__(
self,
Expand All @@ -183,6 +188,7 @@ async def regular_user_factory(
token_factory: Callable[[str], str],
admin_token: str,
test_cluster_name: str,
test_org_name: str,
admin_client_factory: Callable[[str], Awaitable[AdminClient]],
) -> UserFactory:
async def _factory(
Expand All @@ -201,6 +207,7 @@ async def _factory(
await admin_client.create_user(name=name, email=f"{name}@email.com")
user_token = token_factory(name)
user_admin_client = await admin_client_factory(user_token)
admin_admin_client = await admin_client_factory(admin_token)
for entry in clusters:
org_name: str | None = None
if len(entry) == 3:
Expand Down Expand Up @@ -228,6 +235,7 @@ async def _factory(
org_name=org_name,
user_name=name,
role=org_user_role,
balance=balance,
)
except ClientResponseError:
pass
Expand All @@ -236,8 +244,10 @@ async def _factory(
cluster_name=cluster,
org_name=org_name,
)
await admin_client.update_org_cluster_balance(
cluster_name=cluster,
except ClientResponseError:
pass
try:
await admin_admin_client.update_org_balance(
org_name=org_name,
credits=Decimal("100"),
)
Expand All @@ -249,7 +259,6 @@ async def _factory(
org_name=org_name,
role=cluster_user_role,
user_name=name,
balance=balance,
quota=quota,
)
except ClientResponseError:
Expand Down
15 changes: 12 additions & 3 deletions tests/integration/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3609,9 +3609,17 @@ async def test_create_job_has_credits(
regular_user_factory: UserFactory,
jobs_client_factory: Callable[[_User], JobsClient],
test_cluster_name: str,
test_org_name: str,
) -> None:
user = await regular_user_factory(
clusters=[(test_cluster_name, Balance(credits=Decimal("100")), Quota())]
clusters=[
(
test_cluster_name,
test_org_name,
Balance(credits=Decimal("100")),
Quota(),
)
]
)
url = api.jobs_base_url
job_request = job_request_factory()
Expand All @@ -3634,9 +3642,10 @@ async def test_create_job_no_credits(
regular_user_factory: UserFactory,
credits: Decimal,
cluster_name: str,
test_org_name: str,
) -> None:
user = await regular_user_factory(
clusters=[(cluster_name, Balance(credits=credits), Quota())]
clusters=[(cluster_name, test_org_name, Balance(credits=credits), Quota())]
)
url = api.jobs_base_url
job_request = job_request_factory()
Expand Down Expand Up @@ -3973,7 +3982,7 @@ async def test_get_all_jobs_filter_by_org(

org_user = await regular_user_factory(
clusters=[
("test-cluster", Balance(), Quota()),
("test-cluster", "org", Balance(), Quota()),
("test-cluster", "org1", Balance(), Quota()),
("test-cluster", "org2", Balance(), Quota()),
],
Expand Down
Loading

0 comments on commit aa36a70

Please sign in to comment.