diff --git a/fiftyone/core/clips.py b/fiftyone/core/clips.py index b05cdb4c25..1990be274b 100644 --- a/fiftyone/core/clips.py +++ b/fiftyone/core/clips.py @@ -9,6 +9,7 @@ from copy import deepcopy from bson import ObjectId +from pymongo import UpdateOne, UpdateMany import eta.core.utils as etau @@ -47,10 +48,13 @@ def _sample_id(self): return ObjectId(self._doc.sample_id) def _save(self, deferred=False): - sample_ops, frame_ops = super()._save(deferred=deferred) + sample_ops, frame_ops = super()._save(deferred=True) if not deferred: - self._view._sync_source_sample(self) + self._view._save_sample( + self, sample_ops=sample_ops, frame_ops=frame_ops + ) + return None, [] return sample_ops, frame_ops @@ -334,14 +338,73 @@ def reload(self): super().reload() - def _sync_source_sample(self, sample): - if not self._classification_field: + def _check_for_field_edits(self, ops, fields): + updated_fields = set() + + for op in ops: + if isinstance(op, (UpdateOne, UpdateMany)): + updated_fields.update(op._doc.get("$set", {}).keys()) + updated_fields.update(op._doc.get("$unset", {}).keys()) + + for field in list(updated_fields): + chunks = field.split(".") + for i in range(1, len(chunks)): + updated_fields.add(".".join(chunks[:i])) + + return bool(updated_fields & set(fields)) + + def _bulk_write( + self, + ops, + ids=None, + sample_ids=None, + frames=False, + ordered=False, + progress=False, + ): + self._clips_dataset._bulk_write( + ops, + ids=ids, + sample_ids=sample_ids, + frames=frames, + ordered=ordered, + progress=progress, + ) + + # Clips views directly use their source collection's frames, so there's + # no need to sync + if frames: return - # Sync label + support to underlying TemporalDetection + field = self._classification_field + if field is not None and self._check_for_field_edits(ops, [field]): + self._sync_source(fields=[field], ids=ids) + self._source_collection._dataset._reload_docs(ids=ids) + + def _save_sample(self, sample, sample_ops=None, frame_ops=None): + if sample_ops: + foo.bulk_write(sample_ops, self._clips_dataset._sample_collection) + + if frame_ops: + foo.bulk_write(frame_ops, self._clips_dataset._frame_collection) + self._sync_source_sample( + sample, sample_ops=sample_ops, frame_ops=frame_ops + ) + + def _sync_source_sample(self, sample, sample_ops=None, frame_ops=None): field = self._classification_field + if not field: + return + + if sample_ops is not None and not self._check_for_field_edits( + sample_ops, [field] + ): + return + + # Sync label + support to underlying TemporalDetection + classification = sample[field] if classification is not None: doc = classification.to_dict() @@ -353,10 +416,9 @@ def _sync_source_sample(self, sample): self._source_collection._set_labels(field, [sample.sample_id], [doc]) def _sync_source(self, fields=None, ids=None, update=True, delete=False): - if not self._classification_field: - return - field = self._classification_field + if not field: + return if fields is not None and field not in fields: return diff --git a/fiftyone/core/collections.py b/fiftyone/core/collections.py index efe43ab7cb..fb3d12d0a9 100644 --- a/fiftyone/core/collections.py +++ b/fiftyone/core/collections.py @@ -113,14 +113,9 @@ def __init__( self.batch_size = batch_size self._dataset = sample_collection._dataset - self._sample_coll = sample_collection._dataset._sample_collection - self._frame_coll = sample_collection._dataset._frame_collection - self._is_generated = sample_collection._is_generated - self._sample_ops = [] self._frame_ops = [] self._batch_ids = [] - self._reload_parents = [] self._batching_strategy = batching_strategy self._curr_batch_size = None @@ -154,7 +149,6 @@ def save(self, sample): ) sample_ops, frame_ops = sample._save(deferred=True) - updated = sample_ops or frame_ops if sample_ops: self._sample_ops.extend(sample_ops) @@ -162,12 +156,9 @@ def save(self, sample): if frame_ops: self._frame_ops.extend(frame_ops) - if updated and self._is_generated: + if sample_ops or frame_ops: self._batch_ids.append(sample.id) - if updated and isinstance(sample, fosa.SampleView): - self._reload_parents.append(sample) - if self._batching_strategy == "static": self._curr_batch_size += 1 if self._curr_batch_size >= self.batch_size: @@ -194,22 +185,23 @@ def save(self, sample): def _save_batch(self): if self._sample_ops: - foo.bulk_write(self._sample_ops, self._sample_coll, ordered=False) + self.sample_collection._bulk_write( + self._sample_ops, + ids=self._batch_ids, + ordered=False, + ) self._sample_ops.clear() if self._frame_ops: - foo.bulk_write(self._frame_ops, self._frame_coll, ordered=False) + self.sample_collection._bulk_write( + self._frame_ops, + sample_ids=self._batch_ids, + frames=True, + ordered=False, + ) self._frame_ops.clear() - if self._batch_ids and self._is_generated: - self.sample_collection._sync_source(ids=self._batch_ids) - self._batch_ids.clear() - - if self._reload_parents: - for sample in self._reload_parents: - sample._reload_parents() - - self._reload_parents.clear() + self._batch_ids.clear() class SampleCollection(object): @@ -1904,10 +1896,15 @@ def untag_samples(self, tags): view = self.match_tags(tags) view._edit_sample_tags(update) - def _edit_sample_tags(self, update): + def _edit_sample_tags(self, update, ids=None): if self._is_read_only_field("tags"): raise ValueError("Cannot edit read-only field 'tags'") + if ids is None: + _ids = self.values("_id") + else: + _ids = [ObjectId(_id) for _id in ids] + update["$set"] = {"last_modified_at": datetime.utcnow()} ids = [] @@ -1915,13 +1912,15 @@ def _edit_sample_tags(self, update): batch_size = fou.recommend_batch_size_for_value( ObjectId(), max_size=100000 ) - for _ids in fou.iter_batches(self.values("_id"), batch_size): - ids.extend(_ids) - ops.append(UpdateMany({"_id": {"$in": _ids}}, update)) + for _batch_ids in fou.iter_batches(_ids, batch_size): + ids.extend(_batch_ids) + ops.append(UpdateMany({"_id": {"$in": _batch_ids}}, update)) if ops: self._dataset._bulk_write(ops, ids=ids) + return ids + def count_sample_tags(self): """Counts the occurrences of sample tags in this collection. @@ -2054,8 +2053,8 @@ def _edit_label_tags( if ids is None or label_ids is None: if is_frame_field: ids, label_ids = self.values(["frames._id", id_path]) - ids = itertools.chain.from_iterable(ids) - label_ids = itertools.chain.from_iterable(label_ids) + ids = list(itertools.chain.from_iterable(ids)) + label_ids = list(itertools.chain.from_iterable(label_ids)) else: ids, label_ids = self.values(["_id", id_path]) @@ -3217,6 +3216,24 @@ def _set_labels(self, field_name, sample_ids, label_docs, progress=False): def _delete_labels(self, ids, fields=None): self._dataset.delete_labels(ids=ids, fields=fields) + def _bulk_write( + self, + ops, + ids=None, + sample_ids=None, + frames=False, + ordered=False, + progress=False, + ): + self._dataset._bulk_write( + ops, + ids=ids, + sample_ids=sample_ids, + frames=frames, + ordered=ordered, + progress=progress, + ) + def compute_metadata( self, overwrite=False, diff --git a/fiftyone/core/dataset.py b/fiftyone/core/dataset.py index 7833070fdb..e7ab8b6313 100644 --- a/fiftyone/core/dataset.py +++ b/fiftyone/core/dataset.py @@ -3465,7 +3465,13 @@ def _make_dict( return d def _bulk_write( - self, ops, ids=None, frames=False, ordered=False, progress=False + self, + ops, + ids=None, + sample_ids=None, + frames=False, + ordered=False, + progress=False, ): if frames: coll = self._frame_collection @@ -3475,7 +3481,11 @@ def _bulk_write( foo.bulk_write(ops, coll, ordered=ordered, progress=progress) if frames: - fofr.Frame._reload_docs(self._frame_collection_name, frame_ids=ids) + fofr.Frame._reload_docs( + self._frame_collection_name, + sample_ids=sample_ids, + frame_ids=ids, + ) else: fos.Sample._reload_docs( self._sample_collection_name, sample_ids=ids @@ -5076,7 +5086,7 @@ def _clear_frames(self, view=None, sample_ids=None, frame_ids=None): self._frame_collection_name, sample_ids=sample_ids ) - def _keep_frames(self, view=None, frame_ids=None): + def _keep_frames(self, view=None): sample_collection = view if view is not None else self if not sample_collection._contains_videos(any_slice=True): return @@ -8001,6 +8011,7 @@ def reload(self): """Reloads the dataset and any in-memory samples from the database.""" self._reload(hard=True) self._reload_docs(hard=True) + self._reload_docs(frames=True, hard=True) def clear_cache(self): """Clears the dataset's in-memory cache. @@ -8042,11 +8053,18 @@ def _reload(self, hard=False): self._update_last_loaded_at() - def _reload_docs(self, hard=False): - fos.Sample._reload_docs(self._sample_collection_name, hard=hard) + def _reload_docs(self, ids=None, frames=False, hard=False): + if frames: + if not self._has_frame_fields(): + return - if self._has_frame_fields(): - fofr.Frame._reload_docs(self._frame_collection_name, hard=hard) + fofr.Frame._reload_docs( + self._frame_collection_name, frame_ids=ids, hard=hard + ) + else: + fos.Sample._reload_docs( + self._sample_collection_name, sample_ids=ids, hard=hard + ) def _serialize(self): return self._doc.to_dict(extended=True) diff --git a/fiftyone/core/materialize.py b/fiftyone/core/materialize.py index 71306cf9ba..5af477c3f0 100644 --- a/fiftyone/core/materialize.py +++ b/fiftyone/core/materialize.py @@ -11,6 +11,7 @@ import eta.core.utils as etau +import fiftyone.core.media as fom import fiftyone.core.sample as fos import fiftyone.core.odm as foo import fiftyone.core.utils as fou @@ -35,10 +36,13 @@ class MaterializedSampleView(fos.SampleView): """ def _save(self, deferred=False): - sample_ops, frame_ops = super()._save(deferred=deferred) + sample_ops, frame_ops = super()._save(deferred=True) if not deferred: - self._view._sync_source_sample(self) + self._view._save_sample( + self, sample_ops=sample_ops, frame_ops=frame_ops + ) + return None, [] return sample_ops, frame_ops @@ -151,22 +155,20 @@ def _set_name(self, name): def _set_media_type(self, media_type): self.__media_type = media_type - def _tag_labels(self, tags, label_field, ids=None, label_ids=None): - ids, label_ids = super()._tag_labels( - tags, label_field, ids=ids, label_ids=label_ids - ) + def _edit_sample_tags(self, update, ids=None): + ids = super()._edit_sample_tags(update, ids=ids) - self._source_collection._tag_labels( - tags, label_field, ids=ids, label_ids=label_ids - ) + self._source_collection._edit_sample_tags(update, ids=ids) - def _untag_labels(self, tags, label_field, ids=None, label_ids=None): - ids, label_ids = super()._untag_labels( - tags, label_field, ids=ids, label_ids=label_ids + def _edit_label_tags( + self, update_fcn, label_field, ids=None, label_ids=None + ): + ids, label_ids = super()._edit_label_tags( + update_fcn, label_field, ids=ids, label_ids=label_ids ) - self._source_collection._untag_labels( - tags, label_field, ids=ids, label_ids=label_ids + self._source_collection._edit_label_tags( + update_fcn, label_field, ids=ids, label_ids=label_ids ) def set_values(self, field_name, *args, **kwargs): @@ -244,6 +246,20 @@ def keep_fields(self): super().keep_fields() + def keep_frames(self): + """For each sample in the view, deletes all frames that are **not** in + the view from the underlying dataset. + + .. note:: + + This method is not a :class:`fiftyone.core.stages.ViewStage`; + it immediately writes the requested changes to the underlying + dataset. + """ + self._sync_source_keep_frames() + + super().keep_frames() + def reload(self): """Reloads the view. @@ -275,15 +291,83 @@ def _delete_labels(self, ids, fields=None): self._source_collection._delete_labels(ids, fields=fields) - def _sync_source_sample(self, sample): + def _bulk_write( + self, + ops, + ids=None, + sample_ids=None, + frames=False, + ordered=False, + progress=False, + ): + self._materialized_dataset._bulk_write( + ops, + ids=ids, + sample_ids=sample_ids, + frames=frames, + ordered=ordered, + progress=progress, + ) + self._sync_source_schema() - dst_dataset = self._source_collection._root_dataset + self._source_collection._bulk_write( + ops, + ids=ids, + sample_ids=sample_ids, + frames=frames, + ordered=ordered, + progress=progress, + ) + + def _save_sample(self, sample, sample_ops=None, frame_ops=None): + if sample_ops: + foo.bulk_write( + sample_ops, self._materialized_dataset._sample_collection + ) - match = {"_id": sample._id} - updates = sample.to_mongo_dict() + if frame_ops: + foo.bulk_write( + frame_ops, self._materialized_dataset._frame_collection + ) - dst_dataset._sample_collection.update_one(match, {"$set": updates}) + self._sync_source_sample( + sample, sample_ops=sample_ops, frame_ops=frame_ops + ) + + def _sync_source_sample(self, sample, sample_ops=None, frame_ops=None): + self._sync_source_schema() + + if sample_ops is None and frame_ops is None: + dst_dataset = self._source_collection._root_dataset + + match = {"_id": sample._id} + updates = sample.to_mongo_dict() + dst_dataset._sample_collection.update_one(match, {"$set": updates}) + + if sample.media_type == fom.VIDEO: + src_coll = self._materialized_dataset._frame_collection + dst_coll_name = dst_dataset._frame_collection_name + pipeline = [ + {"$match": {"_sample_id": sample._id}}, + { + "$merge": { + "into": dst_coll_name, + "whenMatched": "replace", + } + }, + ] + foo.aggregate(src_coll, pipeline) + else: + if sample_ops: + self._source_collection._bulk_write( + sample_ops, ids=[sample.id] + ) + + if frame_ops: + self._source_collection._bulk_write( + frame_ops, sample_ids=[sample.id], frames=True + ) def _sync_source(self, fields=None, ids=None, update=True, delete=False): has_frame_fields = self._has_frame_fields() @@ -365,8 +449,9 @@ def _sync_source(self, fields=None, ids=None, update=True, delete=False): ) if delete: - sample_ids = self._materialized_dataset.exclude(self).values("id") - dst_dataset._clear(sample_ids=sample_ids) + # It's okay to pass a materialized view to `dst_dataset` because + # they share sample IDs + dst_dataset._keep(view=self) def _sync_source_field_schema(self, path): field = self.get_field(path) @@ -489,6 +574,12 @@ def _sync_source_keep_fields(self): if del_fields: self._source_collection.exclude_fields(del_fields).keep_fields() + def _sync_source_keep_frames(self): + # It's okay to pass a materialized view to `dst_dataset` because they + # share sample IDs and frame numbers + dst_dataset = self._source_collection._dataset + dst_dataset._keep_frames(view=self) + def materialize_view(sample_collection, name=None, persistent=False): """Creates a dataset that contains a materialized copy of the given diff --git a/fiftyone/core/patches.py b/fiftyone/core/patches.py index 197bae101e..9c23d2a93d 100644 --- a/fiftyone/core/patches.py +++ b/fiftyone/core/patches.py @@ -9,6 +9,7 @@ from copy import deepcopy from bson import ObjectId +from pymongo import UpdateOne, UpdateMany import eta.core.utils as etau @@ -37,10 +38,13 @@ def _frame_id(self): return ObjectId(self._doc.frame_id) def _save(self, deferred=False): - sample_ops, frame_ops = super()._save(deferred=deferred) + sample_ops, frame_ops = super()._save(deferred=True) if not deferred: - self._view._sync_source_sample(self) + self._view._save_sample( + self, sample_ops=sample_ops, frame_ops=frame_ops + ) + return None, [] return sample_ops, frame_ops @@ -346,7 +350,59 @@ def reload(self): super().reload() - def _sync_source_sample(self, sample): + def _check_for_field_edits(self, ops, fields): + updated_fields = set() + + for op in ops: + if isinstance(op, (UpdateOne, UpdateMany)): + updated_fields.update(op._doc.get("$set", {}).keys()) + updated_fields.update(op._doc.get("$unset", {}).keys()) + + for field in list(updated_fields): + chunks = field.split(".") + for i in range(1, len(chunks)): + updated_fields.add(".".join(chunks[:i])) + + return bool(updated_fields & set(fields)) + + def _bulk_write( + self, + ops, + ids=None, + sample_ids=None, + frames=False, + ordered=False, + progress=False, + ): + self._patches_dataset._bulk_write( + ops, + ids=ids, + sample_ids=sample_ids, + frames=frames, + ordered=ordered, + progress=progress, + ) + + if self._check_for_field_edits(ops, self._label_fields): + self._sync_source(fields=self._label_fields, ids=ids) + self._source_collection._dataset._reload_docs(ids=ids) + + def _save_sample(self, sample, sample_ops=None, frame_ops=None): + if sample_ops: + foo.bulk_write( + sample_ops, self._patches_dataset._sample_collection + ) + + self._sync_source_sample( + sample, sample_ops=sample_ops, frame_ops=frame_ops + ) + + def _sync_source_sample(self, sample, sample_ops=None, frame_ops=None): + if sample_ops is not None and not self._check_for_field_edits( + sample_ops, self._label_fields + ): + return + for field in self._label_fields: self._sync_source_sample_field(sample, field) diff --git a/fiftyone/core/video.py b/fiftyone/core/video.py index 89a3f8b460..2b26a02564 100644 --- a/fiftyone/core/video.py +++ b/fiftyone/core/video.py @@ -11,7 +11,7 @@ import os from bson import ObjectId -from pymongo import UpdateOne +from pymongo import UpdateOne, UpdateMany import eta.core.utils as etau @@ -54,10 +54,13 @@ def _sample_id(self): return ObjectId(self._doc.sample_id) def _save(self, deferred=False): - sample_ops, frame_ops = super()._save(deferred=deferred) + sample_ops, frame_ops = super()._save(deferred=True) if not deferred: - self._view._sync_source_sample(self) + self._view._save_sample( + self, sample_ops=sample_ops, frame_ops=frame_ops + ) + return None, [] return sample_ops, frame_ops @@ -327,29 +330,93 @@ def _delete_labels(self, ids, fields=None): self._source_collection._delete_labels(ids, fields=frame_fields) - def _sync_source_sample(self, sample): - self._sync_source_schema() - - dst_dataset = self._source_collection._root_dataset + def _prune_sample_only_field_updates(self, ops): sample_only_fields = self._get_sample_only_fields( include_private=True, use_db_fields=True ) - updates = { - k: v - for k, v in sample.to_mongo_dict().items() - if k not in sample_only_fields - } + for op in ops: + if isinstance(op, (UpdateOne, UpdateMany)): + sets = op._doc.get("$set", None) + if sets: + for f in sample_only_fields: + sets.pop(f, None) - if not updates: - return + unsets = op._doc.get("$unset", None) + if unsets: + for f in sample_only_fields: + unsets.pop(f, None) + + def _bulk_write( + self, + ops, + ids=None, + sample_ids=None, + frames=False, + ordered=False, + progress=False, + ): + self._frames_dataset._bulk_write( + ops, + ids=ids, + sample_ids=sample_ids, + frames=frames, + ordered=ordered, + progress=progress, + ) - match = { - "_sample_id": sample._sample_id, - "frame_number": sample.frame_number, - } + self._sync_source_schema() + self._prune_sample_only_field_updates(ops) + + self._source_collection._bulk_write( + ops, + ids=ids, + sample_ids=sample_ids, + frames=True, + ordered=ordered, + progress=progress, + ) + + def _save_sample(self, sample, sample_ops=None, frame_ops=None): + if sample_ops: + foo.bulk_write(sample_ops, self._frames_dataset._sample_collection) + + self._sync_source_sample( + sample, sample_ops=sample_ops, frame_ops=frame_ops + ) - dst_dataset._frame_collection.update_one(match, {"$set": updates}) + def _sync_source_sample(self, sample, sample_ops=None, frame_ops=None): + self._sync_source_schema() + + if sample_ops is None: + dst_dataset = self._source_collection._root_dataset + sample_only_fields = self._get_sample_only_fields( + include_private=True, use_db_fields=True + ) + + updates = { + k: v + for k, v in sample.to_mongo_dict().items() + if k not in sample_only_fields + } + + if not updates: + return + + match = { + "_sample_id": sample._sample_id, + "frame_number": sample.frame_number, + } + + dst_dataset._frame_collection.update_one(match, {"$set": updates}) + else: + self._prune_sample_only_field_updates(sample_ops) + + self._source_collection._bulk_write( + sample_ops, + sample_ids=[sample.id], + frames=True, + ) def _sync_source(self, fields=None, ids=None, update=True, delete=False): dst_dataset = self._source_collection._root_dataset diff --git a/fiftyone/core/view.py b/fiftyone/core/view.py index 9022990b6f..c969b86eec 100644 --- a/fiftyone/core/view.py +++ b/fiftyone/core/view.py @@ -1324,8 +1324,8 @@ def keep_fields(self): self._dataset._keep_fields(view=self) def keep_frames(self): - """For each sample in the view, deletes all frames labels that are - **not** in the view from the underlying dataset. + """For each sample in the view, deletes all frames that are **not** in + the view from the underlying dataset. .. note:: diff --git a/tests/unittests/materialize_tests.py b/tests/unittests/materialize_tests.py new file mode 100644 index 0000000000..f3b28903bb --- /dev/null +++ b/tests/unittests/materialize_tests.py @@ -0,0 +1,378 @@ +""" +FiftyOne materialized view-related unit tests. + +| Copyright 2017-2024, Voxel51, Inc. +| `voxel51.com `_ +| +""" +from copy import deepcopy + +from bson import ObjectId +import unittest + +import fiftyone as fo +from fiftyone import ViewField as F + +from decorators import drop_datasets + + +class MaterializeTests(unittest.TestCase): + @drop_datasets + def test_materialize(self): + dataset = fo.Dataset() + + sample1 = fo.Sample( + filepath="video1.mp4", + tags=["test"], + weather="sunny", + ) + sample1.frames[1] = fo.Frame() + sample1.frames[2] = fo.Frame( + ground_truth=fo.Detections( + detections=[ + fo.Detection(label="cat"), + fo.Detection(label="dog"), + ] + ), + ) + sample1.frames[3] = fo.Frame() + + sample2 = fo.Sample( + filepath="video2.mp4", + tags=["test"], + weather="cloudy", + ) + sample2.frames[1] = fo.Frame( + ground_truth=fo.Detections( + detections=[ + fo.Detection(label="dog"), + fo.Detection(label="rabbit"), + ] + ), + ) + sample2.frames[3] = fo.Frame() + sample2.frames[5] = fo.Frame() + + sample3 = fo.Sample( + filepath="video3.mp4", + tags=["test"], + weather="rainy", + ) + + dataset.add_samples([sample1, sample2, sample3]) + + view = ( + dataset.limit(2) + .match_frames(F("frame_number") <= 2, omit_empty=False) + .materialize() + ) + + self.assertSetEqual( + set(view.get_field_schema().keys()), + { + "id", + "filepath", + "metadata", + "tags", + "created_at", + "last_modified_at", + "weather", + }, + ) + + self.assertSetEqual( + set(view.get_frame_field_schema().keys()), + { + "id", + "frame_number", + "created_at", + "last_modified_at", + "ground_truth", + }, + ) + + self.assertEqual( + view.get_field("metadata").document_type, + fo.VideoMetadata, + ) + + self.assertSetEqual( + set(view.select_fields().get_field_schema().keys()), + { + "id", + "filepath", + "metadata", + "tags", + "created_at", + "last_modified_at", + }, + ) + + self.assertSetEqual( + set(view.select_fields().get_frame_field_schema().keys()), + { + "id", + "frame_number", + "created_at", + "last_modified_at", + }, + ) + + with self.assertRaises(ValueError): + view.exclude_fields("tags") # can't exclude default field + + with self.assertRaises(ValueError): + view.exclude_fields( + "frames.frame_number" + ) # can't exclude default field + + index_info = view.get_index_information() + indexes = view.list_indexes() + default_indexes = { + "id", + "filepath", + "created_at", + "last_modified_at", + "frames.id", + "frames._sample_id_1_frame_number_1", + "frames.created_at", + "frames.last_modified_at", + } + + self.assertSetEqual(set(index_info.keys()), default_indexes) + self.assertSetEqual(set(indexes), default_indexes) + + with self.assertRaises(ValueError): + view.drop_index("id") # can't drop default index + + with self.assertRaises(ValueError): + view.drop_index("filepath") # can't drop default index + + with self.assertRaises(ValueError): + view.drop_index("frames.created_at") # can't drop default index + + self.assertEqual(len(view), 2) + self.assertEqual(view.count("frames"), 3) + + sample = view.first() + self.assertIsInstance(sample.id, str) + self.assertIsInstance(sample._id, ObjectId) + + for _id in view.values("id"): + self.assertIsInstance(_id, str) + + for oid in view.values("_id"): + self.assertIsInstance(oid, ObjectId) + + for _id in view.values("frames.id", unwind=True): + self.assertIsInstance(_id, str) + + for oid in view.values("frames._id", unwind=True): + self.assertIsInstance(oid, ObjectId) + + self.assertDictEqual(dataset.count_sample_tags(), {"test": 3}) + self.assertDictEqual(view.count_sample_tags(), {"test": 2}) + + view.tag_samples("foo") + + self.assertEqual(view.count_sample_tags()["foo"], 2) + self.assertEqual(dataset.count_sample_tags()["foo"], 2) + + view.untag_samples("foo") + + self.assertNotIn("foo", view.count_sample_tags()) + self.assertNotIn("foo", dataset.count_sample_tags()) + + view.tag_labels("test") + + self.assertDictEqual(view.count_label_tags(), {"test": 4}) + self.assertDictEqual(dataset.count_label_tags(), {"test": 4}) + + view.select_labels(tags="test").untag_labels("test") + + self.assertDictEqual(view.count_label_tags(), {}) + self.assertDictEqual(dataset.count_label_tags(), {}) + + view2 = view.limit(1).set_field( + "frames.ground_truth.detections.label", F("label").upper() + ) + + self.assertDictEqual( + view.count_values("frames.ground_truth.detections.label"), + {"cat": 1, "dog": 2, "rabbit": 1}, + ) + self.assertDictEqual( + view2.count_values("frames.ground_truth.detections.label"), + {"CAT": 1, "DOG": 1}, + ) + self.assertDictEqual( + dataset.count_values("frames.ground_truth.detections.label"), + {"cat": 1, "dog": 2, "rabbit": 1}, + ) + + values = { + _id: v + for _id, v in zip( + *view2.values( + [ + "frames.ground_truth.detections.id", + "frames.ground_truth.detections.label", + ], + unwind=True, + ) + ) + } + view.set_label_values( + "frames.ground_truth.detections.also_label", values + ) + + self.assertEqual( + view.count("frames.ground_truth.detections.also_label"), 2 + ) + self.assertEqual( + dataset.count("frames.ground_truth.detections.also_label"), 2 + ) + self.assertDictEqual( + view.count_values("frames.ground_truth.detections.also_label"), + dataset.count_values("frames.ground_truth.detections.also_label"), + ) + + view2.save() + + self.assertEqual(len(view), 2) + self.assertEqual(dataset.values(F("frames").length()), [3, 3, 0]) + self.assertDictEqual( + view.count_values("frames.ground_truth.detections.label"), + {"CAT": 1, "DOG": 1, "dog": 1, "rabbit": 1}, + ) + self.assertDictEqual( + dataset.count_values("frames.ground_truth.detections.label"), + {"CAT": 1, "DOG": 1, "dog": 1, "rabbit": 1}, + ) + + view2.keep() + view2.keep_frames() + view.reload() + + self.assertEqual(len(view), 1) + self.assertEqual(dataset.values(F("frames").length()), [2]) + self.assertDictEqual( + view.count_values("frames.ground_truth.detections.label"), + {"CAT": 1, "DOG": 1}, + ) + self.assertDictEqual( + dataset.count_values("frames.ground_truth.detections.label"), + {"CAT": 1, "DOG": 1}, + ) + + sample = view.exclude_fields("weather").first() + + sample["foo"] = "bar" + sample.save() + + self.assertIn("foo", view.get_field_schema()) + self.assertIn("foo", dataset.get_field_schema()) + self.assertIn("weather", view.get_field_schema()) + self.assertIn("weather", dataset.get_field_schema()) + self.assertEqual(view.count_values("foo")["bar"], 1) + self.assertEqual(dataset.count_values("foo")["bar"], 1) + self.assertDictEqual(view.count_values("weather"), {"sunny": 1}) + self.assertDictEqual(dataset.count_values("weather"), {"sunny": 1}) + + sample = view.exclude_fields("frames.ground_truth").first() + frame = sample.frames.first() + + frame["spam"] = "eggs" + sample.save() + + self.assertIn("spam", view.get_frame_field_schema()) + self.assertIn("spam", dataset.get_frame_field_schema()) + self.assertIn("ground_truth", view.get_frame_field_schema()) + self.assertIn("ground_truth", dataset.get_frame_field_schema()) + self.assertEqual(view.count_values("frames.spam")["eggs"], 1) + self.assertEqual(dataset.count_values("frames.spam")["eggs"], 1) + self.assertDictEqual( + view.count_values("frames.ground_truth.detections.label"), + {"CAT": 1, "DOG": 1}, + ) + self.assertDictEqual( + dataset.count_values("frames.ground_truth.detections.label"), + {"CAT": 1, "DOG": 1}, + ) + + dataset.untag_samples("test") + view.reload() + + self.assertEqual(dataset.count_sample_tags(), {}) + self.assertEqual(view.count_sample_tags(), {}) + + view.select_fields().keep_fields() + + self.assertNotIn("weather", view.get_field_schema()) + self.assertNotIn("weather", dataset.get_field_schema()) + self.assertNotIn("ground_truth", view.get_frame_field_schema()) + self.assertNotIn("ground_truth", dataset.get_frame_field_schema()) + + sample_view = view.first() + with self.assertRaises(KeyError): + sample_view["weather"] + + frame_view = sample_view.frames.first() + with self.assertRaises(KeyError): + frame_view["ground_truth"] + + # Test saving a materialized view + + self.assertIsNone(view.name) + + view_name = "test" + dataset.save_view(view_name, view) + self.assertEqual(view.name, view_name) + self.assertTrue(view.is_saved) + + also_view = dataset.load_saved_view(view_name) + self.assertEqual(view, also_view) + self.assertEqual(also_view.name, view_name) + self.assertTrue(also_view.is_saved) + + still_view = deepcopy(view) + self.assertEqual(still_view.name, view_name) + self.assertTrue(still_view.is_saved) + self.assertEqual(still_view, view) + + @drop_datasets + def test_materialize_save_context(self): + dataset = fo.Dataset() + + sample1 = fo.Sample(filepath="video1.mp4") + sample1.frames[1] = fo.Frame(filepath="frame11.jpg") + sample1.frames[2] = fo.Frame(filepath="frame12.jpg") + sample1.frames[3] = fo.Frame(filepath="frame13.jpg") + + sample2 = fo.Sample(filepath="video2.mp4") + + sample3 = fo.Sample(filepath="video3.mp4") + sample3.frames[1] = fo.Frame(filepath="frame31.jpg") + + dataset.add_samples([sample1, sample2, sample3]) + + view = ( + dataset.limit(2) + .match_frames(F("frame_number") != 2, omit_empty=False) + .materialize() + ) + + for sample in view.iter_samples(autosave=True): + sample["foo"] = "bar" + for frame in sample.frames.values(): + frame["foo"] = "bar" + + self.assertEqual(view.count("foo"), 2) + self.assertEqual(dataset.count("foo"), 2) + self.assertEqual(view.count("frames.foo"), 2) + self.assertEqual(dataset.count("frames.foo"), 2) + + +if __name__ == "__main__": + fo.config.show_progress_bars = False + unittest.main(verbosity=2)