diff --git a/kornia/augmentation/container/augment.py b/kornia/augmentation/container/augment.py index 26ce74cdc9..914a1e29d4 100644 --- a/kornia/augmentation/container/augment.py +++ b/kornia/augmentation/container/augment.py @@ -175,8 +175,9 @@ class AugmentationSequential(TransformMatrixMinIn, ImageSequential): ... ) >>> out = aug_list(input, mask, bbox) - How to use a dictionary as input with AugmentationSequential? The dictionary should starts with - one of the datakey availables. + How to use a dictionary as input with AugmentationSequential? The dictionary keys that start with + one of the available datakeys will be augmented accordingly. Otherwise, the dictionary item is passed + without any augmentation. >>> import kornia.augmentation as K >>> img = torch.randn(1, 3, 256, 256) @@ -291,7 +292,7 @@ def inverse( # type: ignore[override] """ original_keys = None if len(args) == 1 and isinstance(args[0], dict): - original_keys, data_keys, args = self._preproc_dict_data(args[0]) + original_keys, data_keys, args, invalid_data = self._preproc_dict_data(args[0]) # args here should already be `DataType` # NOTE: how to right type to: unpacked args <-> tuple of args to unpack @@ -324,7 +325,10 @@ def inverse( # type: ignore[override] outputs = self._arguments_postproc(args, outputs, data_keys=self.transform_op.data_keys) # type: ignore if isinstance(original_keys, tuple): - return {k: v for v, k in zip(outputs, original_keys)} + result = {k: v for v, k in zip(outputs, original_keys)} + if invalid_data: + result.update(invalid_data) + return result if len(outputs) == 1 and isinstance(outputs, list): return outputs[0] @@ -414,7 +418,7 @@ def forward( # type: ignore[override] # Unpack/handle dictionary args original_keys = None if len(args) == 1 and isinstance(args[0], dict): - original_keys, data_keys, args = self._preproc_dict_data(args[0]) + original_keys, data_keys, args, invalid_data = self._preproc_dict_data(args[0]) self.transform_op.data_keys = self.transform_op.preproc_datakeys(data_keys) @@ -455,7 +459,10 @@ def forward( # type: ignore[override] self._params = params if isinstance(original_keys, tuple): - return {k: v for v, k in zip(outputs, original_keys)} + result = {k: v for v, k in zip(outputs, original_keys)} + if invalid_data: + result.update(invalid_data) + return result if len(outputs) == 1 and isinstance(outputs, list): return outputs[0] @@ -464,17 +471,19 @@ def forward( # type: ignore[override] def _preproc_dict_data( self, data: Dict[str, DataType] - ) -> Tuple[Tuple[str, ...], List[DataKey], Tuple[DataType, ...]]: + ) -> Tuple[Tuple[str, ...], List[DataKey], Tuple[DataType, ...], Optional[Dict[str, Any]]]: if self.data_keys is not None: raise ValueError("If you are using a dictionary as input, the data_keys should be None.") keys = tuple(data.keys()) - data_keys = self._read_datakeys_from_dict(keys) + data_keys, invalid_keys = self._read_datakeys_from_dict(keys) + invalid_data = {i: data.pop(i) for i in invalid_keys} if invalid_keys else None + keys = tuple(k for k in keys if k not in invalid_keys) if invalid_keys else keys data_unpacked = tuple(data.values()) - return keys, data_keys, data_unpacked + return keys, data_keys, data_unpacked, invalid_data - def _read_datakeys_from_dict(self, keys: Sequence[str]) -> List[DataKey]: + def _read_datakeys_from_dict(self, keys: Sequence[str]) -> Tuple[List[DataKey], Optional[List[str]]]: def retrieve_key(key: str) -> DataKey: """Try to retrieve the datakey value by matching `*`""" # Alias cases, like INPUT, will not be get by the enum iterator. @@ -492,7 +501,15 @@ def retrieve_key(key: str) -> DataKey: f"Your input data dictionary keys should start with some of datakey values: {allowed_dk}. Got `{key}`" ) - return [DataKey.get(retrieve_key(k)) for k in keys] + valid_data_keys = [] + invalid_keys = [] + for k in keys: + try: + valid_data_keys.append(DataKey.get(retrieve_key(k))) + except ValueError: + invalid_keys.append(k) + + return valid_data_keys, invalid_keys def _preproc_mask(self, arg: MaskDataType) -> MaskDataType: if isinstance(arg, list): diff --git a/tests/augmentation/test_container.py b/tests/augmentation/test_container.py index 3d58a621f3..f0dcff64b1 100644 --- a/tests/augmentation/test_container.py +++ b/tests/augmentation/test_container.py @@ -776,13 +776,14 @@ def test_dict_as_input_forward_and_inverse(self, random_apply, bbox_key, device, random_apply=random_apply, ) - data = {"input": inp, "mask": mask, bbox_key: bbox, "keypoints": keypoints} + data = {"input": inp, "mask": mask, bbox_key: bbox, "keypoints": keypoints, "id": 45} out = aug(data) assert out["input"].shape == inp.shape assert out["mask"].shape == mask.shape assert out[bbox_key].shape == bbox.shape assert out["keypoints"].shape == keypoints.shape assert set(out["mask"].unique().tolist()).issubset(set(mask.unique().tolist())) + assert out["id"] == 45 out_inv = aug.inverse(out) assert out_inv["input"].shape == inp.shape @@ -790,6 +791,7 @@ def test_dict_as_input_forward_and_inverse(self, random_apply, bbox_key, device, assert out_inv[bbox_key].shape == bbox.shape assert out_inv["keypoints"].shape == keypoints.shape assert set(out_inv["mask"].unique().tolist()).issubset(set(mask.unique().tolist())) + assert out_inv["id"] == 45 if random_apply is False: reproducibility_test(data, aug)