Skip to content

Commit

Permalink
Merge pull request #14 from explosion/fix/serialization-new
Browse files Browse the repository at this point in the history
Fix serialization of extension attributes
  • Loading branch information
ines authored Dec 7, 2024
2 parents 470c6db + e7631f9 commit 0eef96c
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 13 deletions.
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
spacy>=3.7.5
docling>=2.5.2
pandas # version range set by Docling
srsly # version range set by spaCy
# Dev requirements
pytest
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ python_requires = >=3.10
install_requires =
spacy>=3.7.5
docling>=2.5.2
pandas # version range set by Docling
srsly # version range set by spaCy

[bdist_wheel]
universal = true
Expand Down
20 changes: 8 additions & 12 deletions spacy_layout/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,30 @@
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Iterable, Iterator

import srsly
from docling.datamodel.base_models import DocumentStream
from docling.document_converter import DocumentConverter
from docling_core.types.doc.base import CoordOrigin
from docling_core.types.doc.labels import DocItemLabel
from spacy.tokens import Doc, Span, SpanGroup

from .types import Attrs, DocLayout, DoclingItem, PageLayout, SpanLayout
from .util import decode_df, decode_obj, encode_df, encode_obj, get_bounding_box

if TYPE_CHECKING:
from docling.datamodel.base_models import InputFormat
from docling.document_converter import ConversionResult, FormatOption
from docling_core.types.doc.base import BoundingBox
from pandas import DataFrame
from spacy.language import Language


TABLE_PLACEHOLDER = "TABLE"

# Register msgpack encoders and decoders for custom types
srsly.msgpack_encoders.register("spacy-layout.dataclass", func=encode_obj)
srsly.msgpack_decoders.register("spacy-layout.dataclass", func=decode_obj)
srsly.msgpack_encoders.register("spacy-layout.dataframe", func=encode_df)
srsly.msgpack_decoders.register("spacy-layout.dataframe", func=decode_df)


class spaCyLayout:
def __init__(
Expand Down Expand Up @@ -181,13 +187,3 @@ def get_tables(self, doc: Doc) -> list[Span]:
for span in doc.spans[self.attrs.span_group]
if span.label_ == DocItemLabel.TABLE
]


def get_bounding_box(
bbox: "BoundingBox", page_height: float
) -> tuple[float, float, float, float]:
is_bottom = bbox.coord_origin == CoordOrigin.BOTTOMLEFT
y = page_height - bbox.t if is_bottom else bbox.t
height = bbox.t - bbox.b if is_bottom else bbox.b - bbox.t
width = bbox.r - bbox.l
return (bbox.l, y, width, height)
13 changes: 13 additions & 0 deletions spacy_layout/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,22 @@ class PageLayout:
width: float
height: float

@classmethod
def from_dict(cls, data: dict) -> "PageLayout":
return cls(**data)


@dataclass
class DocLayout:
"""Document layout features added to Doc object"""

pages: list[PageLayout]

@classmethod
def from_dict(cls, data: dict) -> "DocLayout":
pages = [PageLayout.from_dict(page) for page in data.get("pages", [])]
return cls(pages=pages)


@dataclass
class SpanLayout:
Expand All @@ -46,3 +55,7 @@ class SpanLayout:
width: float
height: float
page_no: int

@classmethod
def from_dict(cls, data: dict) -> "SpanLayout":
return cls(**data)
54 changes: 54 additions & 0 deletions spacy_layout/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import dataclasses
from typing import TYPE_CHECKING, Callable

from docling_core.types.doc.base import CoordOrigin
from pandas import DataFrame

from .types import DocLayout, PageLayout, SpanLayout

if TYPE_CHECKING:
from docling_core.types.doc.base import BoundingBox

TYPE_ATTR = "__type__"
OBJ_TYPES = {"SpanLayout": SpanLayout, "DocLayout": DocLayout, "PageLayout": PageLayout}


def encode_obj(obj, chain: Callable | None = None):
"""Convert custom dataclass to dict for serialization."""
if isinstance(obj, tuple(OBJ_TYPES.values())):
result = dataclasses.asdict(obj)
result[TYPE_ATTR] = type(obj).__name__
return result
return obj if chain is None else chain(obj)


def decode_obj(obj, chain: Callable | None = None):
"""Load custom dataclass from serialized dict."""
if isinstance(obj, dict) and obj.get(TYPE_ATTR) in OBJ_TYPES:
obj_type = obj.pop(TYPE_ATTR)
return OBJ_TYPES[obj_type].from_dict(obj)
return obj if chain is None else chain(obj)


def encode_df(obj, chain: Callable | None = None):
"""Convert pandas.DataFrame for serialization."""
if isinstance(obj, DataFrame):
return {"data": obj.to_dict(), TYPE_ATTR: "DataFrame"}
return obj if chain is None else chain(obj)


def decode_df(obj, chain: Callable | None = None):
"""Load pandas.DataFrame from serialized data."""
if isinstance(obj, dict) and obj.get(TYPE_ATTR) == "DataFrame":
return DataFrame(obj["data"])
return obj if chain is None else chain(obj)


def get_bounding_box(
bbox: "BoundingBox", page_height: float
) -> tuple[float, float, float, float]:
is_bottom = bbox.coord_origin == CoordOrigin.BOTTOMLEFT
y = page_height - bbox.t if is_bottom else bbox.t
height = bbox.t - bbox.b if is_bottom else bbox.b - bbox.t
width = bbox.r - bbox.l
return (bbox.l, y, width, height)
45 changes: 44 additions & 1 deletion tests/test_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@

import pytest
import spacy
import srsly
from docling_core.types.doc.base import BoundingBox, CoordOrigin
from docling_core.types.doc.labels import DocItemLabel
from pandas import DataFrame
from pandas.testing import assert_frame_equal
from spacy.tokens import DocBin

from spacy_layout import spaCyLayout
from spacy_layout.layout import TABLE_PLACEHOLDER, get_bounding_box
from spacy_layout.types import DocLayout, SpanLayout
from spacy_layout.types import DocLayout, PageLayout, SpanLayout

PDF_STARCRAFT = Path(__file__).parent / "data" / "starcraft.pdf"
PDF_SIMPLE = Path(__file__).parent / "data" / "simple.pdf"
Expand Down Expand Up @@ -118,3 +122,42 @@ def test_bounding_box(box, page_height, expected):
top, bottom, left, right, origin = box
bbox = BoundingBox(t=top, b=bottom, l=left, r=right, coord_origin=origin)
assert get_bounding_box(bbox, page_height) == expected


def test_serialize_objects():
span_layout = SpanLayout(x=10, y=20, width=30, height=40, page_no=1)
doc_layout = DocLayout(pages=[PageLayout(page_no=1, width=500, height=600)])
bytes_data = srsly.msgpack_dumps({"span": span_layout, "doc": doc_layout})
data = srsly.msgpack_loads(bytes_data)
assert isinstance(data, dict)
assert data["span"] == span_layout
assert data["doc"] == doc_layout
df = DataFrame(data={"col1": [1, 2], "col2": [3, 4]})
bytes_data = srsly.msgpack_dumps({"df": df})
data = srsly.msgpack_loads(bytes_data)
assert isinstance(data, dict)
assert_frame_equal(df, data["df"])


@pytest.mark.parametrize("path", [PDF_SIMPLE, PDF_TABLE])
def test_serialize_roundtrip(path, nlp):
layout = spaCyLayout(nlp)
doc = layout(path)
doc_bin = DocBin(store_user_data=True)
doc_bin.add(doc)
bytes_data = doc_bin.to_bytes()
new_doc_bin = DocBin().from_bytes(bytes_data)
new_doc = list(new_doc_bin.get_docs(nlp.vocab))[0]
layout_spans = new_doc.spans[layout.attrs.span_group]
assert len(layout_spans) == len(doc.spans[layout.attrs.span_group])
assert all(
isinstance(span._.get(layout.attrs.span_layout), SpanLayout)
for span in layout_spans
)
assert isinstance(new_doc._.get(layout.attrs.doc_layout), DocLayout)
tables = doc._.get(layout.attrs.doc_tables)
new_tables = new_doc._.get(layout.attrs.doc_tables)
for before, after in zip(tables, new_tables):
table_before = before._.get(layout.attrs.span_data)
table_after = after._.get(layout.attrs.span_data)
assert_frame_equal(table_before, table_after)

0 comments on commit 0eef96c

Please sign in to comment.