-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Entities #405
base: mwp_v1
Are you sure you want to change the base?
Entities #405
Changes from 7 commits
3f2cd2d
60d90db
e6ab5f1
31b49c8
a4eb087
615f6f3
2340280
2b93d9a
02ac98a
a03a895
80da369
9fb795a
26e9791
0eb2052
7040ba1
91d2502
030d92a
7f7ab0a
6258a27
4344024
ca824be
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
from itertools import chain | ||
from typing import List | ||
|
||
from fastapi import APIRouter, Depends | ||
from sqlalchemy.orm import Session | ||
|
||
from api.dependencies import get_current_user, get_db_session | ||
from app.core.authorization.authz_user import AuthzUser | ||
from app.core.data.crud import Crud | ||
from app.core.data.crud.entity import crud_entity | ||
from app.core.data.crud.span_text import crud_span_text | ||
from app.core.data.dto.entity import ( | ||
EntityCreate, | ||
EntityMerge, | ||
EntityRead, | ||
EntityRelease, | ||
EntityUpdate, | ||
) | ||
|
||
router = APIRouter( | ||
prefix="/entity", dependencies=[Depends(get_current_user)], tags=["entity"] | ||
) | ||
|
||
|
||
@router.patch( | ||
"/{entity_id}", | ||
response_model=EntityRead, | ||
summary="Updates the Entity with the given ID.", | ||
) | ||
def update_by_id( | ||
*, | ||
db: Session = Depends(get_db_session), | ||
entity_id: int, | ||
entity: EntityUpdate, | ||
authz_user: AuthzUser = Depends(), | ||
) -> EntityRead: | ||
authz_user.assert_in_same_project_as(Crud.ENTITY, entity_id) | ||
entity.is_human = True | ||
db_obj = crud_entity.update(db=db, id=entity_id, update_dto=entity) | ||
return EntityRead.model_validate(db_obj) | ||
|
||
|
||
# add merge endpoint | ||
@router.put( | ||
"/merge", | ||
response_model=EntityRead, | ||
summary="Merges entities and/or span texts with given IDs.", | ||
) | ||
def merge_entities( | ||
*, | ||
db: Session = Depends(get_db_session), | ||
entity_merge: EntityMerge, | ||
authz_user: AuthzUser = Depends(), | ||
) -> EntityRead: | ||
authz_user.assert_in_same_project_as_many(Crud.ENTITY, entity_merge.entity_ids) | ||
all_span_texts = ( | ||
list( | ||
chain.from_iterable( | ||
[st.id for st in crud_entity.read(db=db, id=id).span_texts] | ||
for id in entity_merge.entity_ids | ||
) | ||
) | ||
+ entity_merge.spantext_ids | ||
) | ||
new_entity = EntityCreate( | ||
name=entity_merge.name, | ||
project_id=entity_merge.project_id, | ||
span_text_ids=all_span_texts, | ||
is_human=True, | ||
knowledge_base_id=entity_merge.knowledge_base_id, | ||
) | ||
db_obj = crud_entity.create(db=db, create_dto=new_entity, force=True) | ||
return EntityRead.model_validate(db_obj) | ||
|
||
|
||
# add resolve endpoint | ||
@router.put( | ||
"/release", | ||
response_model=List[EntityRead], | ||
summary="Releases entities and/or span texts with given IDs.", | ||
) | ||
def release_entities( | ||
*, | ||
db: Session = Depends(get_db_session), | ||
entity_resolve: EntityRelease, | ||
authz_user: AuthzUser = Depends(), | ||
) -> EntityRead: | ||
authz_user.assert_in_same_project_as_many(Crud.ENTITY, entity_resolve.entity_ids) | ||
all_span_texts = ( | ||
list( | ||
chain.from_iterable( | ||
[st.id for st in crud_entity.read(db=db, id=id).span_texts] | ||
for id in entity_resolve.entity_ids | ||
) | ||
) | ||
+ entity_resolve.spantext_ids | ||
) | ||
new_entities = [] | ||
for span_text_id in all_span_texts: | ||
span_text = crud_span_text.read(db=db, id=span_text_id) | ||
new_entity = EntityCreate( | ||
name=span_text.text, | ||
project_id=entity_resolve.project_id, | ||
span_text_ids=[span_text_id], | ||
) | ||
new_entities.append(new_entity) | ||
db_objs = crud_entity.create_multi(db=db, create_dtos=new_entities, force=True) | ||
return [EntityRead.model_validate(db_obj) for db_obj in db_objs] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from typing import List, Optional | ||
|
||
from fastapi.encoders import jsonable_encoder | ||
from sqlalchemy import select | ||
from sqlalchemy.orm import Session | ||
|
||
from app.core.data.crud.crud_base import CRUDBase | ||
from app.core.data.crud.span_text_entity_link import crud_span_text_entity_link | ||
from app.core.data.dto.entity import ( | ||
EntityCreate, | ||
EntityUpdate, | ||
) | ||
from app.core.data.dto.span_text_entity_link import ( | ||
SpanTextEntityLinkCreate, | ||
) | ||
from app.core.data.orm.entity import EntityORM | ||
from app.core.data.orm.span_text_entity_link import SpanTextEntityLinkORM | ||
|
||
|
||
class CRUDEntity(CRUDBase[EntityORM, EntityCreate, EntityUpdate]): | ||
def create( | ||
self, db: Session, *, create_dto: EntityCreate, force: bool = True | ||
) -> EntityORM: | ||
result = self.create_multi(db=db, create_dtos=[create_dto], force=force) | ||
return result[0] if len(result) > 0 else None | ||
|
||
def create_multi( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. create multi muss mit hilfe einer hash map umgesetzt werden. span_text_id -> entity dann alle entities erstellen. |
||
self, db: Session, *, create_dtos: List[EntityCreate], force: bool = True | ||
) -> List[EntityORM]: | ||
if len(create_dtos) == 0: | ||
return [] | ||
dto_objs_data = [ | ||
jsonable_encoder(dto, exclude={"span_text_ids"}) for dto in create_dtos | ||
] | ||
db_objs = [self.model(**data) for data in dto_objs_data] | ||
db.add_all(db_objs) | ||
db.flush() | ||
db.commit() | ||
|
||
links = [] | ||
for db_obj, create_dto in zip(db_objs, create_dtos): | ||
for span_text_id in create_dto.span_text_ids: | ||
links.append( | ||
SpanTextEntityLinkCreate( | ||
linked_entity_id=db_obj.id, linked_span_text_id=span_text_id | ||
) | ||
) | ||
crud_span_text_entity_link.create_multi(db=db, create_dtos=links, force=force) | ||
db.commit() | ||
self.remove_all_unused_entites(db=db) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. das muss weg |
||
return db_objs | ||
|
||
def read_by_project(self, db: Session, proj_id: int) -> List[EntityORM]: | ||
return db.query(self.model).filter(self.model.project_id == proj_id).all() | ||
|
||
def remove_multi(self, db: Session, *, ids: List[int]) -> List[EntityORM]: | ||
removed = db.query(EntityORM).filter(EntityORM.id.in_(ids)).all() | ||
db.query(EntityORM).filter(EntityORM.id.in_(ids)).delete( | ||
synchronize_session=False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is this? |
||
) | ||
db.commit() | ||
return removed | ||
|
||
def remove_all_unused_entites(self, db: Session) -> List[EntityORM]: | ||
subquery = select(SpanTextEntityLinkORM.linked_entity_id).distinct().subquery() | ||
query = ( | ||
db.query(EntityORM) | ||
.outerjoin(subquery, EntityORM.id == subquery.c.linked_entity_id) | ||
.filter(subquery.c.linked_entity_id.is_(None)) | ||
) | ||
to_remove = query.all() | ||
return self.remove_multi(db=db, ids=[e.id for e in to_remove]) | ||
|
||
|
||
crud_entity = CRUDEntity(EntityORM) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bitte crud_entity.release_... erstellen. und hier aufrufen.
Bei Entity Create fehlt is_human=false