Skip to content

Commit

Permalink
finish implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
brimoor committed Jan 3, 2025
1 parent 1ee2b61 commit 6051368
Show file tree
Hide file tree
Showing 8 changed files with 776 additions and 87 deletions.
78 changes: 70 additions & 8 deletions fiftyone/core/clips.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from copy import deepcopy

from bson import ObjectId
from pymongo import UpdateOne, UpdateMany

import eta.core.utils as etau

Expand Down Expand Up @@ -47,10 +48,13 @@ def _sample_id(self):
return ObjectId(self._doc.sample_id)

def _save(self, deferred=False):
sample_ops, frame_ops = super()._save(deferred=deferred)
sample_ops, frame_ops = super()._save(deferred=True)

if not deferred:
self._view._sync_source_sample(self)
self._view._save_sample(
self, sample_ops=sample_ops, frame_ops=frame_ops
)
return None, []

return sample_ops, frame_ops

Expand Down Expand Up @@ -334,14 +338,73 @@ def reload(self):

super().reload()

def _sync_source_sample(self, sample):
if not self._classification_field:
def _check_for_field_edits(self, ops, fields):
updated_fields = set()

for op in ops:
if isinstance(op, (UpdateOne, UpdateMany)):
updated_fields.update(op._doc.get("$set", {}).keys())
updated_fields.update(op._doc.get("$unset", {}).keys())

for field in list(updated_fields):
chunks = field.split(".")
for i in range(1, len(chunks)):
updated_fields.add(".".join(chunks[:i]))

return bool(updated_fields & set(fields))

def _bulk_write(
self,
ops,
ids=None,
sample_ids=None,
frames=False,
ordered=False,
progress=False,
):
self._clips_dataset._bulk_write(
ops,
ids=ids,
sample_ids=sample_ids,
frames=frames,
ordered=ordered,
progress=progress,
)

# Clips views directly use their source collection's frames, so there's
# no need to sync
if frames:
return

# Sync label + support to underlying TemporalDetection
field = self._classification_field
if field is not None and self._check_for_field_edits(ops, [field]):
self._sync_source(fields=[field], ids=ids)
self._source_collection._dataset._reload_docs(ids=ids)

def _save_sample(self, sample, sample_ops=None, frame_ops=None):
if sample_ops:
foo.bulk_write(sample_ops, self._clips_dataset._sample_collection)

if frame_ops:
foo.bulk_write(frame_ops, self._clips_dataset._frame_collection)

self._sync_source_sample(
sample, sample_ops=sample_ops, frame_ops=frame_ops
)

def _sync_source_sample(self, sample, sample_ops=None, frame_ops=None):
field = self._classification_field

if not field:
return

if sample_ops is not None and not self._check_for_field_edits(
sample_ops, [field]
):
return

# Sync label + support to underlying TemporalDetection

classification = sample[field]
if classification is not None:
doc = classification.to_dict()
Expand All @@ -353,10 +416,9 @@ def _sync_source_sample(self, sample):
self._source_collection._set_labels(field, [sample.sample_id], [doc])

def _sync_source(self, fields=None, ids=None, update=True, delete=False):
if not self._classification_field:
return

field = self._classification_field
if not field:
return

if fields is not None and field not in fields:
return
Expand Down
71 changes: 44 additions & 27 deletions fiftyone/core/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,9 @@ def __init__(
self.batch_size = batch_size

self._dataset = sample_collection._dataset
self._sample_coll = sample_collection._dataset._sample_collection
self._frame_coll = sample_collection._dataset._frame_collection
self._is_generated = sample_collection._is_generated

self._sample_ops = []
self._frame_ops = []
self._batch_ids = []
self._reload_parents = []

self._batching_strategy = batching_strategy
self._curr_batch_size = None
Expand Down Expand Up @@ -154,20 +149,16 @@ def save(self, sample):
)

sample_ops, frame_ops = sample._save(deferred=True)
updated = sample_ops or frame_ops

if sample_ops:
self._sample_ops.extend(sample_ops)

if frame_ops:
self._frame_ops.extend(frame_ops)

if updated and self._is_generated:
if sample_ops or frame_ops:
self._batch_ids.append(sample.id)

if updated and isinstance(sample, fosa.SampleView):
self._reload_parents.append(sample)

if self._batching_strategy == "static":
self._curr_batch_size += 1
if self._curr_batch_size >= self.batch_size:
Expand All @@ -194,22 +185,23 @@ def save(self, sample):

def _save_batch(self):
if self._sample_ops:
foo.bulk_write(self._sample_ops, self._sample_coll, ordered=False)
self.sample_collection._bulk_write(
self._sample_ops,
ids=self._batch_ids,
ordered=False,
)
self._sample_ops.clear()

if self._frame_ops:
foo.bulk_write(self._frame_ops, self._frame_coll, ordered=False)
self.sample_collection._bulk_write(
self._frame_ops,
sample_ids=self._batch_ids,
frames=True,
ordered=False,
)
self._frame_ops.clear()

if self._batch_ids and self._is_generated:
self.sample_collection._sync_source(ids=self._batch_ids)
self._batch_ids.clear()

if self._reload_parents:
for sample in self._reload_parents:
sample._reload_parents()

self._reload_parents.clear()
self._batch_ids.clear()


class SampleCollection(object):
Expand Down Expand Up @@ -1904,24 +1896,31 @@ def untag_samples(self, tags):
view = self.match_tags(tags)
view._edit_sample_tags(update)

def _edit_sample_tags(self, update):
def _edit_sample_tags(self, update, ids=None):
if self._is_read_only_field("tags"):
raise ValueError("Cannot edit read-only field 'tags'")

if ids is None:
_ids = self.values("_id")
else:
_ids = [ObjectId(_id) for _id in ids]

update["$set"] = {"last_modified_at": datetime.utcnow()}

ids = []
ops = []
batch_size = fou.recommend_batch_size_for_value(
ObjectId(), max_size=100000
)
for _ids in fou.iter_batches(self.values("_id"), batch_size):
ids.extend(_ids)
ops.append(UpdateMany({"_id": {"$in": _ids}}, update))
for _batch_ids in fou.iter_batches(_ids, batch_size):
ids.extend(_batch_ids)
ops.append(UpdateMany({"_id": {"$in": _batch_ids}}, update))

if ops:
self._dataset._bulk_write(ops, ids=ids)

return ids

def count_sample_tags(self):
"""Counts the occurrences of sample tags in this collection.
Expand Down Expand Up @@ -2054,8 +2053,8 @@ def _edit_label_tags(
if ids is None or label_ids is None:
if is_frame_field:
ids, label_ids = self.values(["frames._id", id_path])
ids = itertools.chain.from_iterable(ids)
label_ids = itertools.chain.from_iterable(label_ids)
ids = list(itertools.chain.from_iterable(ids))
label_ids = list(itertools.chain.from_iterable(label_ids))
else:
ids, label_ids = self.values(["_id", id_path])

Expand Down Expand Up @@ -3217,6 +3216,24 @@ def _set_labels(self, field_name, sample_ids, label_docs, progress=False):
def _delete_labels(self, ids, fields=None):
self._dataset.delete_labels(ids=ids, fields=fields)

def _bulk_write(
self,
ops,
ids=None,
sample_ids=None,
frames=False,
ordered=False,
progress=False,
):
self._dataset._bulk_write(
ops,
ids=ids,
sample_ids=sample_ids,
frames=frames,
ordered=ordered,
progress=progress,
)

def compute_metadata(
self,
overwrite=False,
Expand Down
32 changes: 25 additions & 7 deletions fiftyone/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3465,7 +3465,13 @@ def _make_dict(
return d

def _bulk_write(
self, ops, ids=None, frames=False, ordered=False, progress=False
self,
ops,
ids=None,
sample_ids=None,
frames=False,
ordered=False,
progress=False,
):
if frames:
coll = self._frame_collection
Expand All @@ -3475,7 +3481,11 @@ def _bulk_write(
foo.bulk_write(ops, coll, ordered=ordered, progress=progress)

if frames:
fofr.Frame._reload_docs(self._frame_collection_name, frame_ids=ids)
fofr.Frame._reload_docs(
self._frame_collection_name,
sample_ids=sample_ids,
frame_ids=ids,
)
else:
fos.Sample._reload_docs(
self._sample_collection_name, sample_ids=ids
Expand Down Expand Up @@ -5076,7 +5086,7 @@ def _clear_frames(self, view=None, sample_ids=None, frame_ids=None):
self._frame_collection_name, sample_ids=sample_ids
)

def _keep_frames(self, view=None, frame_ids=None):
def _keep_frames(self, view=None):
sample_collection = view if view is not None else self
if not sample_collection._contains_videos(any_slice=True):
return
Expand Down Expand Up @@ -8001,6 +8011,7 @@ def reload(self):
"""Reloads the dataset and any in-memory samples from the database."""
self._reload(hard=True)
self._reload_docs(hard=True)
self._reload_docs(frames=True, hard=True)

def clear_cache(self):
"""Clears the dataset's in-memory cache.
Expand Down Expand Up @@ -8042,11 +8053,18 @@ def _reload(self, hard=False):

self._update_last_loaded_at()

def _reload_docs(self, hard=False):
fos.Sample._reload_docs(self._sample_collection_name, hard=hard)
def _reload_docs(self, ids=None, frames=False, hard=False):
if frames:
if not self._has_frame_fields():
return

if self._has_frame_fields():
fofr.Frame._reload_docs(self._frame_collection_name, hard=hard)
fofr.Frame._reload_docs(
self._frame_collection_name, frame_ids=ids, hard=hard
)
else:
fos.Sample._reload_docs(
self._sample_collection_name, sample_ids=ids, hard=hard
)

def _serialize(self):
return self._doc.to_dict(extended=True)
Expand Down
Loading

0 comments on commit 6051368

Please sign in to comment.