From 21c0fb9cc700a05532c1766e5f4b9b1f3be45c4d Mon Sep 17 00:00:00 2001 From: minhtuevo Date: Fri, 10 Jan 2025 17:07:32 -0800 Subject: [PATCH] Added support for batch size and multi processing --- fiftyone/utils/eval/detection.py | 119 ++++++++++++++++++++++--------- 1 file changed, 87 insertions(+), 32 deletions(-) diff --git a/fiftyone/utils/eval/detection.py b/fiftyone/utils/eval/detection.py index fbab3e782e..5d3f5f53da 100644 --- a/fiftyone/utils/eval/detection.py +++ b/fiftyone/utils/eval/detection.py @@ -5,10 +5,13 @@ | `voxel51.com `_ | """ +import contextlib +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from copy import deepcopy import inspect import itertools import logging +import multiprocessing as mp import numpy as np @@ -41,6 +44,9 @@ def evaluate_detections( classwise=True, dynamic=True, progress=None, + batch_size=None, + num_workers=None, + executor_type="process", **kwargs, ): """Evaluates the predicted detections in the given samples with respect to @@ -135,6 +141,12 @@ def evaluate_detections( progress (None): whether to render a progress bar (True/False), use the default value ``fiftyone.config.show_progress_bars`` (None), or a progress callback function to invoke instead + batch_size (None): the batch size at which to process samples. By + default, all samples are processed in a single (1) batch + num_workers (None): number of parallel workers. Defaults to CPU count - 1 + executor_type ('process'): type of parallel executor to use: + - 'process': ProcessPoolExecutor for CPU-bound tasks + - 'thread': ThreadPoolExecutor for I/O-bound tasks **kwargs: optional keyword arguments for the constructor of the :class:`DetectionEvaluationConfig` being used @@ -176,45 +188,67 @@ def evaluate_detections( processing_frames = samples._is_frame_field(pred_field) save = eval_key is not None - if save: - tp_field = "%s_tp" % eval_key - fp_field = "%s_fp" % eval_key - fn_field = "%s_fn" % eval_key - if config.requires_additional_fields: _samples = samples else: _samples = samples.select_fields([gt_field, pred_field]) + # Determine number of workers + if num_workers is None: + num_workers = max(1, mp.cpu_count() - 1) + + logger.info( + f"Evaluating detections in parallel with {num_workers} workers..." + ) + matches = [] - logger.info("Evaluating detections...") - for sample in _samples.iter_samples(progress=progress, autosave=save): - if processing_frames: - docs = sample.frames.values() - else: - docs = [sample] - - sample_tp = 0 - sample_fp = 0 - sample_fn = 0 - for doc in docs: - doc_matches = eval_method.evaluate(doc, eval_key=eval_key) - matches.extend(doc_matches) - tp, fp, fn = _tally_matches(doc_matches) - sample_tp += tp - sample_fp += fp - sample_fn += fn - - if processing_frames and save: - doc[tp_field] = tp - doc[fp_field] = fp - doc[fn_field] = fn - - if save: - sample[tp_field] = sample_tp - sample[fp_field] = sample_fp - sample[fn_field] = sample_fn + with contextlib.ExitStack() as stack: + if use_masks: + stack.enter_context( + _samples.download_context( + media_fields=[gt_field, pred_field], progress=progress + ) + ) + # Create a pool of workers + executor_cls = ( + ProcessPoolExecutor + if executor_type == "process" + else ThreadPoolExecutor + ) + with executor_cls(max_workers=num_workers) as executor: + futures = [] + + # Submit batches of samples to the worker pool + for batch in _samples.iter_samples( + progress=progress, batch_size=batch_size, autosave=save + ): + if processing_frames: + docs = batch.frames.values() + else: + docs = [batch] + + future = executor.submit( + _process_sample, + docs, + eval_method, + eval_key, + processing_frames, + ) + futures.append((future, batch)) + + # Collect results + for future, sample in futures: + batch_matches, batch_stats = future.result() + matches.extend(batch_matches) + + if save: + sample_tp, sample_fp, sample_fn = batch_stats + sample[f"{eval_key}_tp"] = sample_tp + sample[f"{eval_key}_fp"] = sample_fp + sample[f"{eval_key}_fn"] = sample_fn + + # Generate and save final results results = eval_method.generate_results( samples, matches, @@ -228,6 +262,27 @@ def evaluate_detections( return results +def _process_sample(docs, eval_method, eval_key, processing_frames): + """Process a single sample or its frames.""" + matches = [] + sample_tp = sample_fp = sample_fn = 0 + + for doc in docs: + doc_matches = eval_method.evaluate(doc, eval_key=eval_key) + matches.extend(doc_matches) + tp, fp, fn = _tally_matches(doc_matches) + sample_tp += tp + sample_fp += fp + sample_fn += fn + + if processing_frames and eval_key is not None: + doc[f"{eval_key}_tp"] = tp + doc[f"{eval_key}_fp"] = fp + doc[f"{eval_key}_fn"] = fn + + return matches, (sample_tp, sample_fp, sample_fn) + + class DetectionEvaluationConfig(foe.EvaluationMethodConfig): """Base class for configuring :class:`DetectionEvaluation` instances.