Skip to content

Commit

Permalink
models and auth manager
Browse files Browse the repository at this point in the history
  • Loading branch information
mike0sv committed Feb 12, 2024
1 parent b5fc777 commit fade686
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 21 deletions.
53 changes: 36 additions & 17 deletions src/evidently/ui/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,19 +302,31 @@ class Permission(Enum):
PROJECT_SNAPSHOT_DELETE = "project_snapshot_delete"


class EntityType(Enum):
Project = "project"
Team = "team"
Org = "org"


class AuthManager(EvidentlyBaseModel):
allow_default_user: bool = True

@abstractmethod
def get_available_project_ids(self, user_id: UserID) -> Optional[Set[ProjectID]]:
raise NotImplementedError

@abstractmethod
def check_team_permission(self, user_id: UserID, team_id: TeamID, permission: Permission) -> bool:
raise NotImplementedError
# @abstractmethod
# def check_team1_permission(self, user_id: UserID, team_id: TeamID, permission: Permission) -> bool:
# raise NotImplementedError
#
# @abstractmethod
# def check_projec1t_permission(self, user_id: UserID, project_id: ProjectID, permission: Permission) -> bool:
# raise NotImplementedError

@abstractmethod
def check_project_permission(self, user_id: UserID, project_id: ProjectID, permission: Permission) -> bool:
def check_entity_permission(
self, user_id: UserID, entity_id: uuid.UUID, entity_type: EntityType, permission: Permission
) -> bool:
raise NotImplementedError

@abstractmethod
Expand Down Expand Up @@ -370,7 +382,7 @@ def _add_user_to_team(self, team_id: TeamID, user_id: UserID):
raise NotImplementedError

def add_user_to_team(self, manager: UserID, team_id: TeamID, user_id: UserID):
if not self.check_team_permission(manager, team_id, Permission.TEAM_USER_ADD):
if not self.check_entity_permission(manager, team_id, EntityType.Team, Permission.TEAM_USER_ADD):
raise NotEnoughPermissions()
self._add_user_to_team(team_id, user_id)

Expand All @@ -379,8 +391,11 @@ def _remove_user_from_team(self, team_id: TeamID, user_id: UserID):
raise NotImplementedError

def remove_user_from_team(self, manager: UserID, team_id: TeamID, user_id: UserID):
if not self.check_team_permission(
manager, team_id, Permission.TEAM_USER_REMOVE if manager != user_id else Permission.TEAM_USER_REMOVE_SELF
if not self.check_entity_permission(
manager,
team_id,
EntityType.Team,
Permission.TEAM_USER_REMOVE if manager != user_id else Permission.TEAM_USER_REMOVE_SELF,
):
raise NotEnoughPermissions()
self._remove_user_from_team(team_id, user_id)
Expand All @@ -399,7 +414,7 @@ def _delete_team(self, team_id: TeamID):
raise NotImplementedError

def delete_team(self, user_id: UserID, team_id: TeamID):
if not self.check_team_permission(user_id, team_id, Permission.TEAM_DELETE):
if not self.check_entity_permission(user_id, team_id, EntityType.Team, Permission.TEAM_DELETE):
raise NotEnoughPermissions()
self._delete_team(team_id)

Expand All @@ -408,7 +423,7 @@ def _list_team_users(self, team_id: TeamID) -> List[User]:
raise NotImplementedError

def list_team_users(self, user_id: UserID, team_id: TeamID) -> List[User]:
if not self.check_team_permission(user_id, team_id, Permission.TEAM_READ):
if not self.check_entity_permission(user_id, team_id, EntityType.Team, Permission.TEAM_READ):
raise TeamNotFound()
return self._list_team_users(team_id)

Expand Down Expand Up @@ -449,13 +464,13 @@ def add_project(

def update_project(self, user_id: Optional[UserID], project: Project):
user = self.auth.get_or_default_user(user_id)
if not self.auth.check_project_permission(user.id, project.id, Permission.PROJECT_WRITE):
if not self.auth.check_entity_permission(user.id, project.id, EntityType.Project, Permission.PROJECT_WRITE):
raise ProjectNotFound()
return self.metadata.update_project(project)

def get_project(self, user_id: Optional[UserID], project_id: ProjectID) -> Optional[Project]:
user = self.auth.get_or_default_user(user_id)
if not self.auth.check_project_permission(user.id, project_id, Permission.PROJECT_READ):
if not self.auth.check_entity_permission(user.id, project_id, EntityType.Project, Permission.PROJECT_READ):
raise ProjectNotFound()
project = self.metadata.get_project(project_id)
if project is None:
Expand All @@ -464,7 +479,7 @@ def get_project(self, user_id: Optional[UserID], project_id: ProjectID) -> Optio

def delete_project(self, user_id: Optional[UserID], project_id: UUID):
user = self.auth.get_or_default_user(user_id)
if not self.auth.check_project_permission(user.id, project_id, Permission.PROJECT_DELETE):
if not self.auth.check_entity_permission(user.id, project_id, EntityType.Project, Permission.PROJECT_DELETE):
raise ProjectNotFound()
return self.metadata.delete_project(project_id)

Expand All @@ -475,15 +490,19 @@ def list_projects(self, user_id: Optional[UserID]) -> List[Project]:

def add_snapshot(self, user_id: Optional[UserID], project_id: UUID, snapshot: Snapshot):
user = self.auth.get_or_default_user(user_id)
if not self.auth.check_project_permission(user.id, project_id, Permission.PROJECT_SNAPSHOT_ADD):
if not self.auth.check_entity_permission(
user.id, project_id, EntityType.Project, Permission.PROJECT_SNAPSHOT_ADD
):
raise ProjectNotFound() # todo: better exception
blob_id = self.blob.put_snapshot(project_id, snapshot)
self.metadata.add_snapshot(project_id, snapshot, blob_id)
self.data.extract_points(project_id, snapshot)

def delete_snapshot(self, user_id: Optional[UserID], project_id: UUID, snapshot_id: UUID):
user = self.auth.get_or_default_user(user_id)
if not self.auth.check_project_permission(user.id, project_id, Permission.PROJECT_SNAPSHOT_DELETE):
if not self.auth.check_entity_permission(
user.id, project_id, EntityType.Project, Permission.PROJECT_SNAPSHOT_DELETE
):
raise ProjectNotFound() # todo: better exception
# todo
# self.data.remove_points(project_id, snapshot_id)
Expand All @@ -498,7 +517,7 @@ def search_project(self, user_id: Optional[UserID], project_name: str) -> List[P
def list_snapshots(
self, user_id: UserID, project_id: ProjectID, include_reports: bool = True, include_test_suites: bool = True
) -> List[SnapshotMetadata]:
if not self.auth.check_project_permission(user_id, project_id, Permission.PROJECT_READ):
if not self.auth.check_entity_permission(user_id, project_id, EntityType.Project, Permission.PROJECT_READ):
raise NotEnoughPermissions()
snapshots = self.metadata.list_snapshots(project_id, include_reports, include_test_suites)
for s in snapshots:
Expand All @@ -514,13 +533,13 @@ def load_snapshot(self, user_id: UserID, project_id: UUID, snapshot: Union[UUID,
def get_snapshot_metadata(
self, user_id: UserID, project_id: ProjectID, snapshot_id: SnapshotID
) -> SnapshotMetadata:
if not self.auth.check_project_permission(user_id, project_id, Permission.PROJECT_READ):
if not self.auth.check_entity_permission(user_id, project_id, EntityType.Project, Permission.PROJECT_READ):
raise NotEnoughPermissions()
meta = self.metadata.get_snapshot_metadata(project_id, snapshot_id)
meta.project.bind(self, user_id)
return meta

def reload_snapshots(self, user_id: UserID, project_id: UUID):
if not self.auth.check_project_permission(user_id, project_id, Permission.PROJECT_READ):
if not self.auth.check_entity_permission(user_id, project_id, EntityType.Project, Permission.PROJECT_READ):
raise NotEnoughPermissions()
self.metadata.reload_snapshots(project_id)
9 changes: 5 additions & 4 deletions src/evidently/ui/storage/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import uuid
from typing import Callable
from typing import ClassVar
from typing import List
Expand All @@ -9,6 +10,7 @@
from fastapi.security import APIKeyHeader

from evidently.ui.base import AuthManager
from evidently.ui.base import EntityType
from evidently.ui.base import Permission
from evidently.ui.base import Team
from evidently.ui.base import User
Expand Down Expand Up @@ -42,10 +44,9 @@ class NoopAuthManager(AuthManager):
def get_available_project_ids(self, user_id: UserID) -> Optional[Set[ProjectID]]:
return None

def check_team_permission(self, user_id: UserID, team_id: TeamID, permission: Permission) -> bool:
return True

def check_project_permission(self, user_id: UserID, project_id: ProjectID, permission: Permission) -> bool:
def check_entity_permission(
self, user_id: UserID, entity_id: uuid.UUID, entity_type: EntityType, permission: Permission
) -> bool:
return True

def create_user(self, user_id: UserID, name: Optional[str]) -> User:
Expand Down

0 comments on commit fade686

Please sign in to comment.