Skip to content

Commit

Permalink
Update benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Jan 10, 2025
1 parent 7174903 commit 05b6ed7
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 28 deletions.
19 changes: 13 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ layout_predictions = layout_predictor([image])

## Table Recognition

This command will write out a json file with the detected table cells and row/column ids, along with row/column bounding boxes. If you want to get a formatted markdown table, check out the [tabled](https://www.github.com/VikParuchuri/tabled) repo.
This command will write out a json file with the detected table cells and row/column ids, along with row/column bounding boxes. If you want to get a formatted markdown or HTML table, check out the [marker](https://www.github.com/VikParuchuri/marker) repo. You can use the `TableConverter` to detect and extract tables in images and PDFs.

```shell
surya_table DATA_PATH
Expand All @@ -254,12 +254,19 @@ The `results.json` file will contain a json dictionary where the keys are the in
- `rows` - detected table rows
- `bbox` - the bounding box of the table row
- `row_id` - the id of the row
- `is_header` - if it is a header row.
- `cols` - detected table columns
- `bbox` - the bounding box of the table column
- `col_id`- the id of the column
- `is_header` - if it is a header column
- `cells` - detected table cells
- `bbox` - the axis-aligned rectangle for the text line in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner.
- `text` - if text could be pulled out of the pdf, the text of this cell.
- `row_id` - the id of the row the cell belongs to.
- `col_id` - the id of the column the cell belongs to.
- `colspan` - the number of columns spanned by the cell.
- `rowspan` - the number of rows spanned by the cell.
- `is_header` - whether it is a header cell.
- `page` - the page number in the file
- `table_idx` - the index of the table on the page (sorted in vertical order)
- `image_bbox` - the bbox for the image in (x1, y1, x2, y2) format. (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner. All line bboxes will be contained within this bbox.
Expand Down Expand Up @@ -395,12 +402,12 @@ The accuracy is computed by finding if each pair of layout boxes is in the corre

## Table Recognition

| Model | Row Intersection | Col Intersection | Time Per Image |
|-------------------|------------------|------------------|------------------|
| Surya | 0.97 | 0.93 | 0.03 |
| Table transformer | 0.72 | 0.84 | 0.02 |
| Model | Row Intersection | Col Intersection | Time Per Image |
|-------------------|--------------------|--------------------|------------------|
| Surya | 1 | 0.98625 | 0.30202 |
| Table transformer | 0.84 | 0.86857 | 0.08082 |

Higher is better for intersection, which the percentage of the actual row/column overlapped by the predictions.
Higher is better for intersection, which the percentage of the actual row/column overlapped by the predictions. This benchmark is mostly a sanity check - there is a more rigorous one in [marker](https://www.github.com/VikParuchuri/marker)

**Methodology**

Expand Down
38 changes: 18 additions & 20 deletions benchmark/table_recognition.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import argparse

import click
from PIL import ImageDraw
import collections
import json
Expand All @@ -16,21 +18,19 @@
import datasets


def main():
parser = argparse.ArgumentParser(description="Benchmark surya table recognition model.")
parser.add_argument("--results_dir", type=str, help="Path to JSON file with benchmark results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
parser.add_argument("--max", type=int, help="Maximum number of images to run benchmark on.", default=None)
parser.add_argument("--tatr", action="store_true", help="Run table transformer.", default=False)
parser.add_argument("--debug", action="store_true", help="Enable debug mode.", default=False)
args = parser.parse_args()

@click.command(help="Benchmark table rec dataset")
@click.option("--results_dir", type=str, help="Path to JSON file with benchmark results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
@click.option("--max_rows", type=int, help="Maximum number of images to run benchmark on.", default=512)
@click.option("--tatr", is_flag=True, help="Run table transformer.", default=False)
@click.option("--debug", is_flag=True, help="Enable debug mode.", default=False)
def main(results_dir: str, max_rows: int, tatr: bool, debug: bool):
table_rec_predictor = TableRecPredictor()

pathname = "table_rec_bench"
# These have already been shuffled randomly, so sampling from the start is fine
split = "train"
if args.max is not None:
split = f"train[:{args.max}]"
if max_rows is not None:
split = f"train[:{max_rows}]"
dataset = datasets.load_dataset(settings.TABLE_REC_BENCH_DATASET_NAME, split=split)
images = list(dataset["image"])
images = convert_if_not_rgb(images)
Expand All @@ -44,7 +44,7 @@ def main():
surya_time = time.time() - start

folder_name = os.path.basename(pathname).split(".")[0]
result_path = os.path.join(args.results_dir, folder_name)
result_path = os.path.join(results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)

page_metrics = collections.OrderedDict()
Expand All @@ -54,8 +54,8 @@ def main():
row = dataset[idx]
pred_row_boxes = [p.bbox for p in pred.rows]
pred_col_bboxes = [p.bbox for p in pred.cols]
actual_row_bboxes = row["rows"]
actual_col_bboxes = row["cols"]
actual_row_bboxes = [r["bbox"] for r in row["rows"]]
actual_col_bboxes = [c["bbox"] for c in row["columns"]]
row_score = penalized_iou_score(pred_row_boxes, actual_row_bboxes)
col_score = penalized_iou_score(pred_col_bboxes, actual_col_bboxes)
page_results = {
Expand All @@ -70,16 +70,14 @@ def main():

page_metrics[idx] = page_results

if args.debug:
if debug:
# Save debug images
draw_img = image.copy()
draw = ImageDraw.Draw(draw_img)
draw_bboxes_on_image(pred_row_boxes, draw_img, [f"Row {i}" for i in range(len(pred_row_boxes))])
draw_bboxes_on_image(pred_col_bboxes, draw_img, [f"Col {i}" for i in range(len(pred_col_bboxes))], color="blue")
draw_img.save(os.path.join(result_path, f"{idx}_bbox.png"))

actual_draw_image = image.copy()
draw = ImageDraw.Draw(actual_draw_image)
draw_bboxes_on_image(actual_row_bboxes, actual_draw_image, [f"Row {i}" for i in range(len(actual_row_bboxes))])
draw_bboxes_on_image(actual_col_bboxes, actual_draw_image, [f"Col {i}" for i in range(len(actual_col_bboxes))], color="blue")
actual_draw_image.save(os.path.join(result_path, f"{idx}_actual.png"))
Expand All @@ -95,7 +93,7 @@ def main():
"page_metrics": page_metrics
}}

if args.tatr:
if tatr:
tatr_model = load_tatr()
start = time.time()
tatr_predictions = batch_inference_tatr(tatr_model, images, 1)
Expand All @@ -108,8 +106,8 @@ def main():
row = dataset[idx]
pred_row_boxes = [p["bbox"] for p in pred["rows"]]
pred_col_bboxes = [p["bbox"] for p in pred["cols"]]
actual_row_bboxes = row["rows"]
actual_col_bboxes = row["cols"]
actual_row_bboxes = [r["bbox"] for r in row["rows"]]
actual_col_bboxes = [c["bbox"] for c in row["columns"]]
row_score = penalized_iou_score(pred_row_boxes, actual_row_bboxes)
col_score = penalized_iou_score(pred_col_bboxes, actual_col_bboxes)
page_results = {
Expand Down Expand Up @@ -143,7 +141,7 @@ def main():
f"{surya_time / len(images):.5f}"],
]

if args.tatr:
if tatr:
table.append(["Table transformer", f"{out_data['tatr']['mean_row_iou']:.2f}", f"{out_data['tatr']['mean_col_iou']:.5f}",
f"{tatr_time / len(images):.5f}"])

Expand Down
2 changes: 1 addition & 1 deletion surya/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def TORCH_DEVICE_MODEL(self) -> str:
TABLE_REC_IMAGE_SIZE: Dict = {"height": 768, "width": 768}
TABLE_REC_MAX_BOXES: int = 150
TABLE_REC_BATCH_SIZE: Optional[int] = None
TABLE_REC_BENCH_DATASET_NAME: str = "vikp/fintabnet_bench"
TABLE_REC_BENCH_DATASET_NAME: str = "datalab-to/fintabnet_bench"
COMPILE_TABLE_REC: bool = False

# OCR Error Detection
Expand Down
2 changes: 1 addition & 1 deletion surya/table_rec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def batch_table_recognition(
row_inputs = self.processor(images=None, query_items=row_query_items, columns=columns, convert_images=False)
row_input_ids = row_inputs["input_ids"].to(self.model.device)
cell_predictions = []
for j in tqdm(range(0, len(row_input_ids), batch_size), desc="Recognizing table cells"):
for j in range(0, len(row_input_ids), batch_size):
cell_batch_hidden_states = row_encoder_hidden_states[j:j + batch_size]
cell_batch_input_ids = row_input_ids[j:j + batch_size]
cell_batch_size = len(cell_batch_input_ids)
Expand Down

0 comments on commit 05b6ed7

Please sign in to comment.