Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/14 create dynamic authentcation #15

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
40 changes: 31 additions & 9 deletions scripts/lighthouse.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import os
import sys
from pathlib import Path
from fastapi import FastAPI, Query
from pydantic import BaseModel
import uvicorn
from fastapi import FastAPI
from fastapi.security import OAuth2PasswordBearer
from typing import List, Optional
from pyrepositories import JsonTable, DataSource, Entity, IdTypes, FieldBase, FieldKeyTypes, FieldTypes

from pyrepositories import JsonTable, DataSource, Entity, IdTypes
import uvicorn

path_root = Path(__file__).parents[1]
sys.path.append(os.path.join(path_root, 'src'))

from crud import CRUDApi, Model, EntityFactory
from crud import CRUDApi, Model, EntityFactory, AuthConfig, UserEntity, USER_FIELDS


class Organizer(Model):
Expand All @@ -34,40 +35,60 @@ class Event(Model):
class EventEntity(Entity):
@property
def date(self):
self.get_field("date")
return self.get_field("date")

@date.setter
def date(self, value):
self.set_field_value("date", value)

@property
def organizer(self):
return self.get_field("organizer")

@organizer.setter
def organizer(self, value):
self.set_field_value("organizer", value)

@property
def status(self):
return self.get_field("status")

@status.setter
def status(self, value):
self.set_field_value("status", value)

@property
def max_attendees(self):
return self.get_field("max_attendees")

@max_attendees.setter
def max_attendees(self, value):
self.set_field_value("max_attendees", value)

@property
def joiners(self):
return self.get_field("joiners")

@joiners.setter
def joiners(self, value):
self.set_field_value("joiners", value)



app = FastAPI()


def get_dummy_users_db(datasource: DataSource):
tbl = JsonTable("users", os.path.join(path_root, "data"), USER_FIELDS)
tbl.clear()
users = [UserEntity("jonhdoe", "[email protected]", "John Doe"), UserEntity("janedoe", "[email protected]", "Jane Doe")]
datasource.add_table(tbl)

for user in users:
datasource.insert("users", user)


ds = DataSource(id_type=IdTypes.UUID)
authConfig = AuthConfig(get_dummy_users_db(ds), OAuth2PasswordBearer(tokenUrl="token"))

fields = [
FieldBase("date", FieldTypes.STR, FieldKeyTypes.PRIMARY),
Expand All @@ -88,16 +109,17 @@ def joiners(self, value):
# filters = { "date": (str, ""), "organizer": (str, ""), "status": (str, ""), "event_type": (str, ""), }

ds.add_table(t)
api = CRUDApi(ds, app)
api = CRUDApi(ds, app, authConfig)

router = api.register_router("event", Event, filters=filters).get_base()

router = api.register_router("event" , Event, filters=filters).get_base()

@router.get("/test", tags=["event"])
def test():
return "test"

api.publish()

api.publish()


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions src/crud/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .app import CRUDApi, CRUDApiRouter
from .lib import Model
from .lib import Model, AuthConfig, UserEntity, USER_FIELDS
from .entities import EntityFactory

__all__ = ['CRUDApi', 'CRUDApiRouter', 'Model', 'EntityFactory']
__all__ = ['CRUDApi', 'CRUDApiRouter', 'Model', 'EntityFactory', 'AuthConfig', 'UserEntity', 'USER_FIELDS']
206 changes: 176 additions & 30 deletions src/crud/app.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from typing import List
from typing import Annotated, List
from enum import Enum
from fastapi import Depends, FastAPI
from pyrepositories import DataSource, Entity, FieldBase, FieldTypes
from fastapi import Depends, FastAPI, HTTPException, status
from pyrepositories import DataSource, Entity, EntityField
from pydantic import create_model
from fastapi.routing import APIRouter
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
Comment on lines +5 to +9
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider organizing imports to improve readability.

- from fastapi import Depends, FastAPI, HTTPException, status
+ from fastapi import Depends, FastAPI
+ from fastapi.exceptions import HTTPException
+ from fastapi import status

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
from fastapi import Depends, FastAPI, HTTPException, status
from pyrepositories import DataSource, Entity, EntityField
from pydantic import create_model
from fastapi.routing import APIRouter
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from fastapi import Depends, FastAPI
from fastapi.exceptions import HTTPException
from fastapi import status
from pyrepositories import DataSource, Entity, EntityField
from pydantic import create_model
from fastapi.routing import APIRouter
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm

from .entities import EntityFactory
from .lib import convert_field_to_filter, convert_dict_to_filter
from .lib import AuthConfig, decode_token, User

id_path = '/single/{id}'

Expand All @@ -16,7 +20,7 @@ def construct_path(base_path: str, path: str, is_plural: bool, use_prefix: bool)
return f'{base_path}{plural}{path}'
else:
return f'{path}'


def get_tags(name: str, use_name_as_tag: bool) -> List[str | Enum] | None:
return [name] if use_name_as_tag else []
Expand All @@ -38,17 +42,78 @@ def convert2int(value: str) -> bool:
return False


def setup_routes_with_auth(router, base_path: str, datatype: str, datasource: DataSource, model_type: type,
factory: EntityFactory, use_prefix: bool, auth_scheme: OAuth2PasswordBearer, filters: list[FieldBase] | None = None,
tags: List[str | Enum] | None = None):
table = datasource.get_table(datatype)
if not table:
raise ValueError(f'Table {datatype} not found in datasource')

@router.get(construct_path(f'{base_path}', '', True, use_prefix), tags=tags)
async def read_items(token: Annotated[str, Depends(auth_scheme)]):
return format_entities(datasource.get_all(datatype) or [])

if len(filters) > 0:
filter_dict = convert_field_to_filter(filters)

@router.get(construct_path(f'{base_path}', '/filter', True, use_prefix), tags=tags)
async def filter_items(token: Annotated[str, Depends(auth_scheme)],
params: create_model("Query", **filter_dict) = Depends()):
fields = params.dict()
processed_filters = convert_dict_to_filter(fields)
result = format_entities(datasource.get_by_filter(datatype, processed_filters) or [])
return result

@router.get(construct_path(base_path, id_path, False, use_prefix), tags=tags)
async def read_item(token: Annotated[str, Depends(auth_scheme)], id: int | str):
return datasource.get_by_id(datatype, id)

@router.post(construct_path(base_path, '', False, use_prefix), tags=tags)
async def create_item(token: Annotated[str, Depends(auth_scheme)], item: model_type):
return datasource.insert(datatype, factory.create_entity(table.field_structure, item.model_dump()))

@router.put(construct_path(base_path, id_path, False, use_prefix), tags=tags)
async def update_item(token: Annotated[str, Depends(auth_scheme)], id: int | str,
item: model_type):
entity = factory.create_entity(table.field_structure, item.model_dump())
if isinstance(id, str) and convert2int(id):
id = int(id)
return datasource.update(datatype, id, entity)

@router.delete(construct_path(base_path, id_path, False, use_prefix), tags=tags)
async def delete_item(token: Annotated[str, Depends(auth_scheme)], id: int | str):
if isinstance(id, str) and convert2int(id):
id = int(id)
return datasource.delete(datatype, id)

@router.delete(construct_path(base_path, '', True, use_prefix), tags=tags)
async def delete_all_items(token: Annotated[str, Depends(auth_scheme)]):
return datasource.clear(datatype)


class CRUDApiRouter:
def __init__(self, datasource: DataSource, name: str, model_type: type, factory: EntityFactory, use_prefix: bool = True, use_name_as_tag: bool = True, filters: list[FieldBase] = []):
def __init__(
self,
datasource: DataSource,
name: str,
model_type: type,
factory: EntityFactory,
use_prefix: bool = True,
use_name_as_tag: bool = True,
auth: AuthConfig | None = None,
filters: list[FieldBase] | None = None
):
self.__datasource = datasource
self.__is_included = False
self.__auth = auth
self.name = name
self.use_prefix = use_prefix
self.use_name_as_tag = use_name_as_tag
self.__filters = filters or []
datatype = name.lower()
tags = get_tags(name, use_name_as_tag)
table = self.__datasource.get_table(datatype)
if not table:
self.__table = self.__datasource.get_table(datatype)
if not self.__table:
raise ValueError(f'Table {datatype} not found in datasource')

base_path = f'/{datatype}'
Expand All @@ -57,12 +122,46 @@ def __init__(self, datasource: DataSource, name: str, model_type: type, factory:
prefix=get_prefix(datatype, use_prefix)
)

if auth:
self.__setup_routes_with_auth(base_path, tags, datatype, model_type, factory, use_prefix)
else:
self.__setup_routes(base_path, tags, datatype, model_type, factory, use_prefix)

def get_base(self):
return self.__router

def get_datasource(self):
return self.__datasource

@property
def is_included(self):
return self.__is_included

def include(self):
self.__is_included = True

def __setup_routes_with_auth(self, base_path: str, tags: List[str | Enum] | None, datatype: str, model_type: type,
factory: EntityFactory, use_prefix: bool):
if not self.__table:
raise ValueError(f'Table {datatype} not found in datasource')

if not self.__auth:
raise ValueError('Auth config is required for this route')

global_auth = self.__auth
setup_routes_with_auth(self.__router, base_path, datatype, self.__datasource, model_type, factory, use_prefix, global_auth.oauth2_scheme, self.__filters, tags)
Comment on lines +151 to +152
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optimize the use of setup_routes_with_auth by reducing redundancy.

Consider refactoring to avoid passing so many parameters explicitly if they can be derived from existing class properties or through dependency injection.


def __setup_routes(self, base_path: str, tags: List[str | Enum] | None, datatype: str, model_type: type,
factory: EntityFactory, use_prefix: bool):
if not self.__table:
raise ValueError(f'Table {datatype} not found in datasource')

@self.__router.get(construct_path(f'{base_path}', '', True, use_prefix), tags=tags)
async def read_items():
return format_entities(self.__datasource.get_all(datatype) or [])
if len(filters) > 0:
filter_dict = convert_field_to_filter(filters)

if len(self.__filters) > 0:
filter_dict = convert_field_to_filter(self.__filters)

@self.__router.get(construct_path(f'{base_path}', '/filter', True, use_prefix), tags=tags)
async def filter_items(params: create_model("Query", **filter_dict) = Depends()):
Expand All @@ -78,7 +177,7 @@ async def read_item(id: int | str):
@self.__router.post(construct_path(base_path, '', False, use_prefix), tags=tags)
async def create_item(item: model_type):
try:
entity = factory.create_entity(table.field_structure, item.model_dump())
entity = factory.create_entity(self.__table.field_structure, item.model_dump())
result = self.__datasource.insert(datatype, entity)
if result:
return {'success': True, 'created_entity': result.serialize()}
Expand All @@ -90,7 +189,7 @@ async def create_item(item: model_type):
@self.__router.put(construct_path(base_path, id_path, False, use_prefix), tags=tags)
async def update_item(item_id: int | str, item: model_type):
try:
entity = factory.create_entity(table.field_structure, item.model_dump())
entity = factory.create_entity(self.__table.field_structure, item.model_dump())
if isinstance(item_id, str) and convert2int(item_id):
item_id = int(item_id)
result = self.__datasource.update(datatype, item_id, entity)
Expand All @@ -116,25 +215,65 @@ async def delete_item(item_id: int | str):
async def delete_all_items():
return self.__datasource.clear(datatype)

def get_base(self):
return self.__router

def get_datasource(self):
return self.__datasource

@property
def is_included(self):
return self.__is_included

def include(self):
self.__is_included = True


class CRUDApi:
def __init__(self, datasource: DataSource, app: FastAPI):
def __init__(self, datasource: DataSource, app: FastAPI, auth: AuthConfig | None = None):
self.__datasource = datasource
self.__app = app # type: FastAPI
self.__routers = {} # type: dict[str, CRUDApiRouter]
self.__app = app # type: FastAPI
self.__routers = {} # type: dict[str, CRUDApiRouter]
self.__auth = auth

if self.__auth:
self.__setup_auth()

def __setup_auth(self):
if not self.__auth:
return None

global_auth = self.__auth

def get_user(username: str):
return global_auth.users_db.get_unique('username', username)

def fake_decode_token(token: str):
# This doesn't provide any security at all
# Check the next version
user = get_user(token)
return user
Comment on lines +235 to +242
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace fake_decode_token with a secure implementation.

The fake_decode_token function is marked as insecure. Replace it with a secure token decoding mechanism before deploying to production.

- def fake_decode_token(token: str):
-     # This doesn't provide any security at all
-     # Check the next version
-     user = get_user(token)
-     return user
+ def secure_decode_token(token: str):
+     # Implement secure token decoding logic here

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
def get_user(username: str):
return global_auth.users_db.get_unique('username', username)
def fake_decode_token(token: str):
# This doesn't provide any security at all
# Check the next version
user = get_user(token)
return user
def get_user(username: str):
return global_auth.users_db.get_unique('username', username)
def secure_decode_token(token: str):
# Implement secure token decoding logic here


async def get_current_user(token: Annotated[str, Depends(global_auth.oauth2_scheme)]):
user = fake_decode_token(token)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
return user

async def get_current_active_user(current_user: Annotated[User, Depends(get_current_user)]):
if current_user.disabled:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user

def fake_hash_password(password: str):
return "fakehashed" + password

@self.__app.post("/token", tags=["auth"])
async def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()]):
results = global_auth.users_db.get_unique('username', form_data.username)
if len(results) != 1:
raise HTTPException(status_code=400, detail="Incorrect username or password")
user = results[0]
hashed_password = fake_hash_password(form_data.password)
if not hashed_password == user.hashed_password:
raise HTTPException(status_code=400, detail="Incorrect username or password")

return {"access_token": user.username, "token_type": "bearer"}
Comment on lines +262 to +272
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improve error handling in the login function.

Consider adding more specific error messages and handling potential exceptions that could arise during the login process. This will enhance the user experience and make debugging easier.

- if len(results) != 1:
+ if not results:
    raise HTTPException(status_code=400, detail="Incorrect username or password")
+ except Exception as e:
+    raise HTTPException(status_code=500, detail=str(e))

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
@self.__app.post("/token", tags=["auth"])
async def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()]):
results = global_auth.users_db.get_unique('username', form_data.username)
if len(results) != 1:
raise HTTPException(status_code=400, detail="Incorrect username or password")
user = results[0]
hashed_password = fake_hash_password(form_data.password)
if not hashed_password == user.hashed_password:
raise HTTPException(status_code=400, detail="Incorrect username or password")
return {"access_token": user.username, "token_type": "bearer"}
@self.__app.post("/token", tags=["auth"])
async def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()]):
results = global_auth.users_db.get_unique('username', form_data.username)
if not results:
raise HTTPException(status_code=400, detail="Incorrect username or password")
user = results[0]
hashed_password = fake_hash_password(form_data.password)
if not hashed_password == user.hashed_password:
raise HTTPException(status_code=400, detail="Incorrect username or password")
return {"access_token": user.username, "token_type": "bearer"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))


@self.__app.get("/users/me", tags=["auth"])
async def read_users_me(current_user: Annotated[User, Depends(get_current_active_user)]):
return current_user

def get_app(self) -> FastAPI:
return self.__app
Expand All @@ -148,13 +287,20 @@ def get_router(self, datatype: str) -> CRUDApiRouter | None:
def get_datasource(self) -> DataSource:
return self.__datasource

def register_router(self, datatype: str, model_type: type, factory: EntityFactory = EntityFactory(), use_prefix: bool = True, filters: list[FieldBase] = []) -> CRUDApiRouter:
router = CRUDApiRouter(self.__datasource, datatype, model_type, factory, use_prefix, filters=filters)
def register_router(self, datatype: str, model_type: type, factory: EntityFactory = EntityFactory(),
use_prefix: bool = True, filters: list[FieldBase] = None) -> CRUDApiRouter:
if filters is None:
filters = []

router = CRUDApiRouter(self.__datasource, datatype, model_type, factory, use_prefix, auth=self.__auth, filters=filters)
self.__routers[datatype] = router
return router

def include_router(self, datatype: str, model_type: type, factory: EntityFactory = EntityFactory(), use_prefix: bool = True, filters: list[FieldBase] = []) -> CRUDApiRouter:
router = self.register_router(datatype, model_type, factory, use_prefix, filters)
def include_router(self, datatype: str, model_type: type, factory: EntityFactory = EntityFactory(),
use_prefix: bool = True, filters: list[FieldBase] = None) -> CRUDApiRouter:
if filters is None:
filters = []
router = self.register_router(datatype, model_type, factory, use_prefix, filters=filters)
self.__app.include_router(router.get_base())
router.include()
return router
Expand Down
Loading