Skip to content

Commit

Permalink
Added support for batch size and multi processing
Browse files Browse the repository at this point in the history
  • Loading branch information
minhtuevo committed Jan 11, 2025
1 parent 5801576 commit 21c0fb9
Showing 1 changed file with 87 additions and 32 deletions.
119 changes: 87 additions & 32 deletions fiftyone/utils/eval/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
| `voxel51.com <https://voxel51.com/>`_
|
"""
import contextlib
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from copy import deepcopy
import inspect
import itertools
import logging
import multiprocessing as mp

import numpy as np

Expand Down Expand Up @@ -41,6 +44,9 @@ def evaluate_detections(
classwise=True,
dynamic=True,
progress=None,
batch_size=None,
num_workers=None,
executor_type="process",
**kwargs,
):
"""Evaluates the predicted detections in the given samples with respect to
Expand Down Expand Up @@ -135,6 +141,12 @@ def evaluate_detections(
progress (None): whether to render a progress bar (True/False), use the
default value ``fiftyone.config.show_progress_bars`` (None), or a
progress callback function to invoke instead
batch_size (None): the batch size at which to process samples. By
default, all samples are processed in a single (1) batch
num_workers (None): number of parallel workers. Defaults to CPU count - 1
executor_type ('process'): type of parallel executor to use:
- 'process': ProcessPoolExecutor for CPU-bound tasks
- 'thread': ThreadPoolExecutor for I/O-bound tasks
**kwargs: optional keyword arguments for the constructor of the
:class:`DetectionEvaluationConfig` being used
Expand Down Expand Up @@ -176,45 +188,67 @@ def evaluate_detections(
processing_frames = samples._is_frame_field(pred_field)
save = eval_key is not None

if save:
tp_field = "%s_tp" % eval_key
fp_field = "%s_fp" % eval_key
fn_field = "%s_fn" % eval_key

if config.requires_additional_fields:
_samples = samples
else:
_samples = samples.select_fields([gt_field, pred_field])

# Determine number of workers
if num_workers is None:
num_workers = max(1, mp.cpu_count() - 1)

logger.info(
f"Evaluating detections in parallel with {num_workers} workers..."
)

matches = []
logger.info("Evaluating detections...")
for sample in _samples.iter_samples(progress=progress, autosave=save):
if processing_frames:
docs = sample.frames.values()
else:
docs = [sample]

sample_tp = 0
sample_fp = 0
sample_fn = 0
for doc in docs:
doc_matches = eval_method.evaluate(doc, eval_key=eval_key)
matches.extend(doc_matches)
tp, fp, fn = _tally_matches(doc_matches)
sample_tp += tp
sample_fp += fp
sample_fn += fn

if processing_frames and save:
doc[tp_field] = tp
doc[fp_field] = fp
doc[fn_field] = fn

if save:
sample[tp_field] = sample_tp
sample[fp_field] = sample_fp
sample[fn_field] = sample_fn
with contextlib.ExitStack() as stack:
if use_masks:
stack.enter_context(
_samples.download_context(
media_fields=[gt_field, pred_field], progress=progress
)
)

# Create a pool of workers
executor_cls = (
ProcessPoolExecutor
if executor_type == "process"
else ThreadPoolExecutor
)
with executor_cls(max_workers=num_workers) as executor:
futures = []

# Submit batches of samples to the worker pool
for batch in _samples.iter_samples(
progress=progress, batch_size=batch_size, autosave=save
):
if processing_frames:
docs = batch.frames.values()
else:
docs = [batch]

future = executor.submit(
_process_sample,
docs,
eval_method,
eval_key,
processing_frames,
)
futures.append((future, batch))

# Collect results
for future, sample in futures:
batch_matches, batch_stats = future.result()
matches.extend(batch_matches)

if save:
sample_tp, sample_fp, sample_fn = batch_stats
sample[f"{eval_key}_tp"] = sample_tp
sample[f"{eval_key}_fp"] = sample_fp
sample[f"{eval_key}_fn"] = sample_fn

# Generate and save final results
results = eval_method.generate_results(
samples,
matches,
Expand All @@ -228,6 +262,27 @@ def evaluate_detections(
return results


def _process_sample(docs, eval_method, eval_key, processing_frames):
"""Process a single sample or its frames."""
matches = []
sample_tp = sample_fp = sample_fn = 0

for doc in docs:
doc_matches = eval_method.evaluate(doc, eval_key=eval_key)
matches.extend(doc_matches)
tp, fp, fn = _tally_matches(doc_matches)
sample_tp += tp
sample_fp += fp
sample_fn += fn

if processing_frames and eval_key is not None:
doc[f"{eval_key}_tp"] = tp
doc[f"{eval_key}_fp"] = fp
doc[f"{eval_key}_fn"] = fn

return matches, (sample_tp, sample_fp, sample_fn)


class DetectionEvaluationConfig(foe.EvaluationMethodConfig):
"""Base class for configuring :class:`DetectionEvaluation` instances.
Expand Down

0 comments on commit 21c0fb9

Please sign in to comment.