From 86d0571f29fafc9e1d42daed0eba9cc8fe75bff9 Mon Sep 17 00:00:00 2001 From: Vik Paruchuri Date: Wed, 23 Oct 2024 09:17:55 -0400 Subject: [PATCH 1/3] Test parallel layout --- pyproject.toml | 2 +- surya/layout.py | 105 +++++++++++++++++++-------- surya/model/recognition/tokenizer.py | 7 +- 3 files changed, 82 insertions(+), 32 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a1b69e35..18a9523b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "surya-ocr" -version = "0.6.8" +version = "0.6.9" description = "OCR, layout, reading order, and table recognition in 90+ languages" authors = ["Vik Paruchuri "] readme = "README.md" diff --git a/surya/layout.py b/surya/layout.py index d488b970..b259b71a 100644 --- a/surya/layout.py +++ b/surya/layout.py @@ -1,3 +1,4 @@ +import multiprocessing import threading from collections import defaultdict from concurrent.futures import ProcessPoolExecutor @@ -195,37 +196,81 @@ def batch_layout_detection(images: List, model, processor, detection_results: Op results = [] max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images)) parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH - - if parallelize: - with ProcessPoolExecutor(max_workers=max_workers) as executor: + batch_queue = Queue() + processing_error = threading.Event() + + def inference_producer(): + try: + for batch in layout_generator: + batch_queue.put(batch) + if processing_error.is_set(): + break + except Exception as e: + processing_error.set() + print("Error in layout detection producer", e) + finally: + batch_queue.put(None) # Signal end of batches + + def postprocessing_consumer(executor): + if parallelize: img_idx = 0 - for preds, orig_sizes in layout_generator: - futures = [] - for pred, orig_size in zip(preds, orig_sizes): - future = executor.submit( + while not processing_error.is_set(): + batch = batch_queue.get() + if batch is None: + break + + try: + preds, orig_sizes = batch + img_idxs = [img_idx + i for i in range(len(preds))] + batch_results = list(executor.map( parallel_get_regions, - pred, - orig_size, - id2label, - detection_results[img_idx] if detection_results else None - ) - - futures.append(future) - img_idx += 1 - - for future in futures: - results.append(future.result()) - else: - img_idx = 0 - for preds, orig_sizes in layout_generator: - for pred, orig_size in zip(preds, orig_sizes): - results.append(parallel_get_regions( - pred, - orig_size, - id2label, - detection_results[img_idx] if detection_results else None - )) - - img_idx += 1 + preds, + orig_sizes, + [id2label] * len(preds), + [detection_results[idx] for idx in img_idxs] if detection_results else [None] * len(preds) + )) + + results.extend(batch_results) + img_idx += len(preds) + except Exception as e: + processing_error.set() + print("Error in layout postprocessing", e) + else: + img_idx = 0 + while not processing_error.is_set(): + batch = batch_queue.get() + if batch is None: + break + + try: + preds, orig_sizes = batch + img_idxs = [img_idx + i for i in range(len(preds))] + batch_results = list(map( + parallel_get_regions, + preds, + orig_sizes, + [id2label] * len(preds), + [detection_results[idx] for idx in img_idxs] if detection_results else [None] * len(preds) + )) + results.extend(batch_results) + img_idx += len(preds) + except Exception as e: + processing_error.set() + print("Error in layout postprocessing", e) + + # Start producer and consumer threads + producer = threading.Thread(target=inference_producer) + executor = ProcessPoolExecutor(max_workers=max_workers, mp_context=multiprocessing.get_context("spawn")) if parallelize else None + consumer = threading.Thread(target=postprocessing_consumer, args=(executor,)) + + producer.start() + consumer.start() + + # Wait for both threads to complete + producer.join() + consumer.join() + + if executor: + executor.shutdown() return results \ No newline at end of file diff --git a/surya/model/recognition/tokenizer.py b/surya/model/recognition/tokenizer.py index d57239a0..30018f5f 100644 --- a/surya/model/recognition/tokenizer.py +++ b/surya/model/recognition/tokenizer.py @@ -26,7 +26,12 @@ def utf16_numbers_to_text(numbers): byte_array.append(number & 0xFF) # Lower byte byte_array.append((number >> 8) & 0xFF) # Upper byte - text = byte_array.decode('utf-16le', errors="ignore") + try: + text = byte_array.decode('utf-16le', errors="ignore") + except Exception as e: + print(f"Error decoding utf16: {e}") + text = "" + return text From 026f9e35f347d04acc130dc5093404c15df556c4 Mon Sep 17 00:00:00 2001 From: Vik Paruchuri Date: Wed, 23 Oct 2024 09:35:28 -0400 Subject: [PATCH 2/3] Store futures versus immediately generating results --- surya/layout.py | 84 +++++++++++++++--------------------------- surya/util/parallel.py | 6 +++ 2 files changed, 35 insertions(+), 55 deletions(-) create mode 100644 surya/util/parallel.py diff --git a/surya/layout.py b/surya/layout.py index b259b71a..3ebb3080 100644 --- a/surya/layout.py +++ b/surya/layout.py @@ -1,3 +1,4 @@ +import contextlib import multiprocessing import threading from collections import defaultdict @@ -11,6 +12,7 @@ from surya.postprocessing.heatmap import keep_largest_boxes, get_and_clean_boxes, get_detected_boxes from surya.schema import LayoutResult, LayoutBox, TextDetectionResult from surya.settings import settings +from surya.util.parallel import FakeParallel def get_regions_from_detection_result(detection_result: TextDetectionResult, heatmaps: List[np.ndarray], orig_size, id2label, segment_assignment, vertical_line_width=20) -> List[LayoutBox]: @@ -193,7 +195,7 @@ def batch_layout_detection(images: List, model, processor, detection_results: Op layout_generator = batch_detection(images, model, processor, batch_size=batch_size) id2label = model.config.id2label - results = [] + postprocessing_futures = [] max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images)) parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH batch_queue = Queue() @@ -212,65 +214,37 @@ def inference_producer(): batch_queue.put(None) # Signal end of batches def postprocessing_consumer(executor): - if parallelize: - img_idx = 0 - while not processing_error.is_set(): - batch = batch_queue.get() - if batch is None: - break - - try: - preds, orig_sizes = batch - img_idxs = [img_idx + i for i in range(len(preds))] - batch_results = list(executor.map( - parallel_get_regions, - preds, - orig_sizes, - [id2label] * len(preds), - [detection_results[idx] for idx in img_idxs] if detection_results else [None] * len(preds) - )) - - results.extend(batch_results) - img_idx += len(preds) - except Exception as e: - processing_error.set() - print("Error in layout postprocessing", e) - else: - img_idx = 0 - while not processing_error.is_set(): - batch = batch_queue.get() - if batch is None: - break - - try: - preds, orig_sizes = batch - img_idxs = [img_idx + i for i in range(len(preds))] - batch_results = list(map( - parallel_get_regions, - preds, - orig_sizes, - [id2label] * len(preds), - [detection_results[idx] for idx in img_idxs] if detection_results else [None] * len(preds) - )) - results.extend(batch_results) - img_idx += len(preds) - except Exception as e: - processing_error.set() - print("Error in layout postprocessing", e) + img_idx = 0 + while not processing_error.is_set(): + batch = batch_queue.get() + if batch is None: + break + + try: + preds, orig_sizes = batch + for pred, orig_size in zip(preds, orig_sizes): + func = executor.submit if parallelize else FakeParallel + future = func(parallel_get_regions, pred, orig_size, id2label, detection_results[img_idx] if detection_results else None) + postprocessing_futures.append(future) + img_idx += 1 + except Exception as e: + processing_error.set() + print("Error in layout postprocessing", e) # Start producer and consumer threads producer = threading.Thread(target=inference_producer) - executor = ProcessPoolExecutor(max_workers=max_workers, mp_context=multiprocessing.get_context("spawn")) if parallelize else None - consumer = threading.Thread(target=postprocessing_consumer, args=(executor,)) - producer.start() - consumer.start() + with ProcessPoolExecutor( + max_workers=max_workers, + mp_context=multiprocessing.get_context("spawn") + ) if parallelize else contextlib.nullcontext() as executor: + consumer = threading.Thread(target=postprocessing_consumer, args=(executor,)) - # Wait for both threads to complete - producer.join() - consumer.join() + producer.start() + consumer.start() + producer.join() + consumer.join() - if executor: - executor.shutdown() + results = [future.result() for future in postprocessing_futures] return results \ No newline at end of file diff --git a/surya/util/parallel.py b/surya/util/parallel.py new file mode 100644 index 00000000..015c4253 --- /dev/null +++ b/surya/util/parallel.py @@ -0,0 +1,6 @@ +class FakeParallel(): + def __init__(self, func, *args): + self._result = func(*args) + + def result(self): + return self._result From ef8ec0d701e4ae5b86c0e26c5f698c9a64192582 Mon Sep 17 00:00:00 2001 From: Vik Paruchuri Date: Wed, 23 Oct 2024 09:55:13 -0400 Subject: [PATCH 3/3] Parallel detection --- surya/detection.py | 60 +++++++++++++++++++++++++++++++++++++--------- surya/layout.py | 7 +++--- 2 files changed, 52 insertions(+), 15 deletions(-) diff --git a/surya/detection.py b/surya/detection.py index bef9d069..b51a5e5c 100644 --- a/surya/detection.py +++ b/surya/detection.py @@ -1,3 +1,5 @@ +import contextlib +import multiprocessing import threading from queue import Queue from typing import List, Tuple, Generator @@ -16,6 +18,8 @@ from concurrent.futures import ProcessPoolExecutor import torch.nn.functional as F +from surya.util.parallel import FakeParallel + def get_batch_size(): batch_size = settings.DETECTOR_BATCH_SIZE @@ -127,18 +131,52 @@ def parallel_get_lines(preds, orig_sizes): def batch_text_detection(images: List, model, processor, batch_size=None) -> List[TextDetectionResult]: detection_generator = batch_detection(images, model, processor, batch_size=batch_size) - results = [] + postprocessing_futures = [] max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images)) parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH - - if parallelize: - with ProcessPoolExecutor(max_workers=max_workers) as executor: - for preds, orig_sizes in detection_generator: - batch_results = list(executor.map(parallel_get_lines, preds, orig_sizes)) - results.extend(batch_results) - else: - for preds, orig_sizes in detection_generator: - for pred, orig_size in zip(preds, orig_sizes): - results.append(parallel_get_lines(pred, orig_size)) + batch_queue = Queue() + processing_error = threading.Event() + + def inference_producer(): + try: + for batch in detection_generator: + batch_queue.put(batch) + if processing_error.is_set(): + break + except Exception as e: + processing_error.set() + print("Error with batch detection", e) + finally: + batch_queue.put(None) # Signal end of batches + + def postprocessing_consumer(executor): + while not processing_error.is_set(): + batch = batch_queue.get() + if batch is None: + break + + try: + preds, orig_sizes = batch + func = executor.submit if parallelize else FakeParallel + for pred, orig_size in zip(preds, orig_sizes): + postprocessing_futures.append(func(parallel_get_lines, pred, orig_size)) + except Exception as e: + processing_error.set() + print("Error with postprocessing", e) + + # Start producer and consumer threads + producer = threading.Thread(target=inference_producer, daemon=True) + producer.start() + + with ProcessPoolExecutor( + max_workers=max_workers, + mp_context=multiprocessing.get_context("spawn") + ) if parallelize else contextlib.nullcontext() as executor: + consumer = threading.Thread(target=postprocessing_consumer, args=(executor,), daemon=True) + consumer.start() + producer.join() + consumer.join() + + results = [future.result() for future in postprocessing_futures] return results \ No newline at end of file diff --git a/surya/layout.py b/surya/layout.py index 3ebb3080..eeb433dd 100644 --- a/surya/layout.py +++ b/surya/layout.py @@ -232,15 +232,14 @@ def postprocessing_consumer(executor): print("Error in layout postprocessing", e) # Start producer and consumer threads - producer = threading.Thread(target=inference_producer) + producer = threading.Thread(target=inference_producer, daemon=True) + producer.start() with ProcessPoolExecutor( max_workers=max_workers, mp_context=multiprocessing.get_context("spawn") ) if parallelize else contextlib.nullcontext() as executor: - consumer = threading.Thread(target=postprocessing_consumer, args=(executor,)) - - producer.start() + consumer = threading.Thread(target=postprocessing_consumer, args=(executor,), daemon=True) consumer.start() producer.join() consumer.join()