diff --git a/scripts/lighthouse.py b/scripts/lighthouse.py index c47f7a5..d90c31f 100644 --- a/scripts/lighthouse.py +++ b/scripts/lighthouse.py @@ -1,17 +1,18 @@ import os import sys from pathlib import Path -from fastapi import FastAPI, Query -from pydantic import BaseModel -import uvicorn +from fastapi import FastAPI +from fastapi.security import OAuth2PasswordBearer from typing import List, Optional from pyrepositories import JsonTable, DataSource, Entity, IdTypes, FieldBase, FieldKeyTypes, FieldTypes +from pyrepositories import JsonTable, DataSource, Entity, IdTypes +import uvicorn path_root = Path(__file__).parents[1] sys.path.append(os.path.join(path_root, 'src')) -from crud import CRUDApi, Model, EntityFactory +from crud import CRUDApi, Model, EntityFactory, AuthConfig, UserEntity, USER_FIELDS class Organizer(Model): @@ -34,40 +35,60 @@ class Event(Model): class EventEntity(Entity): @property def date(self): - self.get_field("date") + return self.get_field("date") + @date.setter def date(self, value): self.set_field_value("date", value) + @property def organizer(self): return self.get_field("organizer") + @organizer.setter def organizer(self, value): self.set_field_value("organizer", value) + @property def status(self): return self.get_field("status") + @status.setter def status(self, value): self.set_field_value("status", value) + @property def max_attendees(self): return self.get_field("max_attendees") + @max_attendees.setter def max_attendees(self, value): self.set_field_value("max_attendees", value) + @property def joiners(self): return self.get_field("joiners") + @joiners.setter def joiners(self, value): self.set_field_value("joiners", value) - app = FastAPI() + +def get_dummy_users_db(datasource: DataSource): + tbl = JsonTable("users", os.path.join(path_root, "data"), USER_FIELDS) + tbl.clear() + users = [UserEntity("jonhdoe", "john@doe.com", "John Doe"), UserEntity("janedoe", "jane@doe.com", "Jane Doe")] + datasource.add_table(tbl) + + for user in users: + datasource.insert("users", user) + + ds = DataSource(id_type=IdTypes.UUID) +authConfig = AuthConfig(get_dummy_users_db(ds), OAuth2PasswordBearer(tokenUrl="token")) fields = [ FieldBase("date", FieldTypes.STR, FieldKeyTypes.PRIMARY), @@ -88,16 +109,17 @@ def joiners(self, value): # filters = { "date": (str, ""), "organizer": (str, ""), "status": (str, ""), "event_type": (str, ""), } ds.add_table(t) -api = CRUDApi(ds, app) +api = CRUDApi(ds, app, authConfig) + +router = api.register_router("event", Event, filters=filters).get_base() -router = api.register_router("event" , Event, filters=filters).get_base() @router.get("/test", tags=["event"]) def test(): return "test" -api.publish() +api.publish() if __name__ == "__main__": diff --git a/src/crud/__init__.py b/src/crud/__init__.py index 359f9f4..fb432b2 100644 --- a/src/crud/__init__.py +++ b/src/crud/__init__.py @@ -1,5 +1,5 @@ from .app import CRUDApi, CRUDApiRouter -from .lib import Model +from .lib import Model, AuthConfig, UserEntity, USER_FIELDS from .entities import EntityFactory -__all__ = ['CRUDApi', 'CRUDApiRouter', 'Model', 'EntityFactory'] +__all__ = ['CRUDApi', 'CRUDApiRouter', 'Model', 'EntityFactory', 'AuthConfig', 'UserEntity', 'USER_FIELDS'] diff --git a/src/crud/app.py b/src/crud/app.py index 39eec33..35ef8d3 100644 --- a/src/crud/app.py +++ b/src/crud/app.py @@ -1,11 +1,15 @@ -from typing import List +from typing import Annotated, List from enum import Enum from fastapi import Depends, FastAPI from pyrepositories import DataSource, Entity, FieldBase, FieldTypes +from fastapi import Depends, FastAPI, HTTPException, status +from pyrepositories import DataSource, Entity, EntityField from pydantic import create_model from fastapi.routing import APIRouter +from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from .entities import EntityFactory from .lib import convert_field_to_filter, convert_dict_to_filter +from .lib import AuthConfig, decode_token, User id_path = '/single/{id}' @@ -16,7 +20,7 @@ def construct_path(base_path: str, path: str, is_plural: bool, use_prefix: bool) return f'{base_path}{plural}{path}' else: return f'{path}' - + def get_tags(name: str, use_name_as_tag: bool) -> List[str | Enum] | None: return [name] if use_name_as_tag else [] @@ -38,17 +42,78 @@ def convert2int(value: str) -> bool: return False +def setup_routes_with_auth(router, base_path: str, datatype: str, datasource: DataSource, model_type: type, + factory: EntityFactory, use_prefix: bool, auth_scheme: OAuth2PasswordBearer, filters: list[FieldBase] | None = None, + tags: List[str | Enum] | None = None): + table = datasource.get_table(datatype) + if not table: + raise ValueError(f'Table {datatype} not found in datasource') + + @router.get(construct_path(f'{base_path}', '', True, use_prefix), tags=tags) + async def read_items(token: Annotated[str, Depends(auth_scheme)]): + return format_entities(datasource.get_all(datatype) or []) + + if len(filters) > 0: + filter_dict = convert_field_to_filter(filters) + + @router.get(construct_path(f'{base_path}', '/filter', True, use_prefix), tags=tags) + async def filter_items(token: Annotated[str, Depends(auth_scheme)], + params: create_model("Query", **filter_dict) = Depends()): + fields = params.dict() + processed_filters = convert_dict_to_filter(fields) + result = format_entities(datasource.get_by_filter(datatype, processed_filters) or []) + return result + + @router.get(construct_path(base_path, id_path, False, use_prefix), tags=tags) + async def read_item(token: Annotated[str, Depends(auth_scheme)], id: int | str): + return datasource.get_by_id(datatype, id) + + @router.post(construct_path(base_path, '', False, use_prefix), tags=tags) + async def create_item(token: Annotated[str, Depends(auth_scheme)], item: model_type): + return datasource.insert(datatype, factory.create_entity(table.field_structure, item.model_dump())) + + @router.put(construct_path(base_path, id_path, False, use_prefix), tags=tags) + async def update_item(token: Annotated[str, Depends(auth_scheme)], id: int | str, + item: model_type): + entity = factory.create_entity(table.field_structure, item.model_dump()) + if isinstance(id, str) and convert2int(id): + id = int(id) + return datasource.update(datatype, id, entity) + + @router.delete(construct_path(base_path, id_path, False, use_prefix), tags=tags) + async def delete_item(token: Annotated[str, Depends(auth_scheme)], id: int | str): + if isinstance(id, str) and convert2int(id): + id = int(id) + return datasource.delete(datatype, id) + + @router.delete(construct_path(base_path, '', True, use_prefix), tags=tags) + async def delete_all_items(token: Annotated[str, Depends(auth_scheme)]): + return datasource.clear(datatype) + + class CRUDApiRouter: - def __init__(self, datasource: DataSource, name: str, model_type: type, factory: EntityFactory, use_prefix: bool = True, use_name_as_tag: bool = True, filters: list[FieldBase] = []): + def __init__( + self, + datasource: DataSource, + name: str, + model_type: type, + factory: EntityFactory, + use_prefix: bool = True, + use_name_as_tag: bool = True, + auth: AuthConfig | None = None, + filters: list[FieldBase] | None = None + ): self.__datasource = datasource self.__is_included = False + self.__auth = auth self.name = name self.use_prefix = use_prefix self.use_name_as_tag = use_name_as_tag + self.__filters = filters or [] datatype = name.lower() tags = get_tags(name, use_name_as_tag) - table = self.__datasource.get_table(datatype) - if not table: + self.__table = self.__datasource.get_table(datatype) + if not self.__table: raise ValueError(f'Table {datatype} not found in datasource') base_path = f'/{datatype}' @@ -57,12 +122,46 @@ def __init__(self, datasource: DataSource, name: str, model_type: type, factory: prefix=get_prefix(datatype, use_prefix) ) + if auth: + self.__setup_routes_with_auth(base_path, tags, datatype, model_type, factory, use_prefix) + else: + self.__setup_routes(base_path, tags, datatype, model_type, factory, use_prefix) + + def get_base(self): + return self.__router + + def get_datasource(self): + return self.__datasource + + @property + def is_included(self): + return self.__is_included + + def include(self): + self.__is_included = True + + def __setup_routes_with_auth(self, base_path: str, tags: List[str | Enum] | None, datatype: str, model_type: type, + factory: EntityFactory, use_prefix: bool): + if not self.__table: + raise ValueError(f'Table {datatype} not found in datasource') + + if not self.__auth: + raise ValueError('Auth config is required for this route') + + global_auth = self.__auth + setup_routes_with_auth(self.__router, base_path, datatype, self.__datasource, model_type, factory, use_prefix, global_auth.oauth2_scheme, self.__filters, tags) + + def __setup_routes(self, base_path: str, tags: List[str | Enum] | None, datatype: str, model_type: type, + factory: EntityFactory, use_prefix: bool): + if not self.__table: + raise ValueError(f'Table {datatype} not found in datasource') + @self.__router.get(construct_path(f'{base_path}', '', True, use_prefix), tags=tags) async def read_items(): return format_entities(self.__datasource.get_all(datatype) or []) - - if len(filters) > 0: - filter_dict = convert_field_to_filter(filters) + + if len(self.__filters) > 0: + filter_dict = convert_field_to_filter(self.__filters) @self.__router.get(construct_path(f'{base_path}', '/filter', True, use_prefix), tags=tags) async def filter_items(params: create_model("Query", **filter_dict) = Depends()): @@ -78,7 +177,7 @@ async def read_item(id: int | str): @self.__router.post(construct_path(base_path, '', False, use_prefix), tags=tags) async def create_item(item: model_type): try: - entity = factory.create_entity(table.field_structure, item.model_dump()) + entity = factory.create_entity(self.__table.field_structure, item.model_dump()) result = self.__datasource.insert(datatype, entity) if result: return {'success': True, 'created_entity': result.serialize()} @@ -90,7 +189,7 @@ async def create_item(item: model_type): @self.__router.put(construct_path(base_path, id_path, False, use_prefix), tags=tags) async def update_item(item_id: int | str, item: model_type): try: - entity = factory.create_entity(table.field_structure, item.model_dump()) + entity = factory.create_entity(self.__table.field_structure, item.model_dump()) if isinstance(item_id, str) and convert2int(item_id): item_id = int(item_id) result = self.__datasource.update(datatype, item_id, entity) @@ -116,25 +215,65 @@ async def delete_item(item_id: int | str): async def delete_all_items(): return self.__datasource.clear(datatype) - def get_base(self): - return self.__router - - def get_datasource(self): - return self.__datasource - - @property - def is_included(self): - return self.__is_included - - def include(self): - self.__is_included = True - class CRUDApi: - def __init__(self, datasource: DataSource, app: FastAPI): + def __init__(self, datasource: DataSource, app: FastAPI, auth: AuthConfig | None = None): self.__datasource = datasource - self.__app = app # type: FastAPI - self.__routers = {} # type: dict[str, CRUDApiRouter] + self.__app = app # type: FastAPI + self.__routers = {} # type: dict[str, CRUDApiRouter] + self.__auth = auth + + if self.__auth: + self.__setup_auth() + + def __setup_auth(self): + if not self.__auth: + return None + + global_auth = self.__auth + + def get_user(username: str): + return global_auth.users_db.get_unique('username', username) + + def fake_decode_token(token: str): + # This doesn't provide any security at all + # Check the next version + user = get_user(token) + return user + + async def get_current_user(token: Annotated[str, Depends(global_auth.oauth2_scheme)]): + user = fake_decode_token(token) + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + return user + + async def get_current_active_user(current_user: Annotated[User, Depends(get_current_user)]): + if current_user.disabled: + raise HTTPException(status_code=400, detail="Inactive user") + return current_user + + def fake_hash_password(password: str): + return "fakehashed" + password + + @self.__app.post("/token", tags=["auth"]) + async def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()]): + results = global_auth.users_db.get_unique('username', form_data.username) + if len(results) != 1: + raise HTTPException(status_code=400, detail="Incorrect username or password") + user = results[0] + hashed_password = fake_hash_password(form_data.password) + if not hashed_password == user.hashed_password: + raise HTTPException(status_code=400, detail="Incorrect username or password") + + return {"access_token": user.username, "token_type": "bearer"} + + @self.__app.get("/users/me", tags=["auth"]) + async def read_users_me(current_user: Annotated[User, Depends(get_current_active_user)]): + return current_user def get_app(self) -> FastAPI: return self.__app @@ -148,13 +287,20 @@ def get_router(self, datatype: str) -> CRUDApiRouter | None: def get_datasource(self) -> DataSource: return self.__datasource - def register_router(self, datatype: str, model_type: type, factory: EntityFactory = EntityFactory(), use_prefix: bool = True, filters: list[FieldBase] = []) -> CRUDApiRouter: - router = CRUDApiRouter(self.__datasource, datatype, model_type, factory, use_prefix, filters=filters) + def register_router(self, datatype: str, model_type: type, factory: EntityFactory = EntityFactory(), + use_prefix: bool = True, filters: list[FieldBase] = None) -> CRUDApiRouter: + if filters is None: + filters = [] + + router = CRUDApiRouter(self.__datasource, datatype, model_type, factory, use_prefix, auth=self.__auth, filters=filters) self.__routers[datatype] = router return router - def include_router(self, datatype: str, model_type: type, factory: EntityFactory = EntityFactory(), use_prefix: bool = True, filters: list[FieldBase] = []) -> CRUDApiRouter: - router = self.register_router(datatype, model_type, factory, use_prefix, filters) + def include_router(self, datatype: str, model_type: type, factory: EntityFactory = EntityFactory(), + use_prefix: bool = True, filters: list[FieldBase] = None) -> CRUDApiRouter: + if filters is None: + filters = [] + router = self.register_router(datatype, model_type, factory, use_prefix, filters=filters) self.__app.include_router(router.get_base()) router.include() return router diff --git a/src/crud/lib.py b/src/crud/lib.py index 328bedf..774720e 100644 --- a/src/crud/lib.py +++ b/src/crud/lib.py @@ -1,5 +1,6 @@ from pydantic import BaseModel -from pyrepositories import IdTypes, FieldBase, FieldTypes, Filter, FilterCondition, FilterCombination, FilterTypes +from pyrepositories import DataTable, Entity, FieldBase, FieldTypes, FieldKeyTypes, Filter, FilterCombination, FilterCondition, FilterTypes, EntityField +from fastapi.security import OAuth2PasswordBearer class Model(BaseModel): @@ -24,3 +25,90 @@ def convert_dict_to_filter(data: dict) -> Filter: conditions.append(FilterCondition(key, value, FilterTypes.CONTAINS)) return Filter(conditions, FilterCombination.AND) + + +class User(BaseModel): + username: str + email: str | None = None + full_name: str | None = None + disabled: bool | None = None + + +def decode_token(token: str): + return User( + username=token + "fakedecoded", email="john@doe.com", full_name="John Doe" + ) + + +class AuthConfig: + def __init__(self, users_db: DataTable, oauth2_scheme: OAuth2PasswordBearer): + self.users_db = users_db + self.oauth2_scheme = oauth2_scheme + + +class UserEntity(Entity): + def __init__(self, username: str, email: str, full_name: str): + base = [ + FieldBase("username", FieldTypes.STR, FieldKeyTypes.UNIQUE, username), + FieldBase("email", FieldTypes.STR, FieldKeyTypes.UNIQUE, email), + FieldBase("full_name", FieldTypes.STR, FieldKeyTypes.STANDARD, full_name), + FieldBase("disabled", FieldTypes.BOOL, FieldKeyTypes.STANDARD, False), + FieldBase("hashed_password", FieldTypes.STR, FieldKeyTypes.STANDARD, ""), + ] + entity_fields = [] + for field in base: + entity_fields.append(EntityField(field)) + super().__init__(entity_fields) + self.username = username + self.email = email + self.full_name = full_name + self.disabled = False + + @property + def username(self): + return self.get_field("username") + + @username.setter + def username(self, value): + self.set_field_value("username", value) + + @property + def email(self): + return self.get_field("email") + + @email.setter + def email(self, value): + self.set_field_value("email", value) + + @property + def full_name(self): + return self.get_field("full_name") + + @full_name.setter + def full_name(self, value): + self.set_field_value("full_name", value) + + @property + def disabled(self): + return self.get_field("disabled") + + @disabled.setter + def disabled(self, value): + self.set_field_value("disabled", value) + + @property + def hashed_password(self): + return self.get_field("hashed_password") + + @hashed_password.setter + def hashed_password(self, value): + self.set_field_value("hashed_password", value) + + +USER_FIELDS = [ + FieldBase("username", FieldTypes.STR, FieldKeyTypes.UNIQUE, ""), + FieldBase("email", FieldTypes.STR, FieldKeyTypes.UNIQUE, ""), + FieldBase("full_name", FieldTypes.STR, FieldKeyTypes.STANDARD, ""), + FieldBase("disabled", FieldTypes.BOOL, FieldKeyTypes.STANDARD, False), + FieldBase("hashed_password", FieldTypes.STR, FieldKeyTypes.STANDARD, ""), +] \ No newline at end of file