Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ignore invalid keys instead of raising an error #68

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions kornia/augmentation/container/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,9 @@ class AugmentationSequential(TransformMatrixMinIn, ImageSequential):
... )
>>> out = aug_list(input, mask, bbox)

How to use a dictionary as input with AugmentationSequential? The dictionary should starts with
one of the datakey availables.
How to use a dictionary as input with AugmentationSequential? The dictionary keys that start with
one of the available datakeys will be augmented accordingly. Otherwise, the dictionary item is passed
without any augmentation.

>>> import kornia.augmentation as K
>>> img = torch.randn(1, 3, 256, 256)
Expand Down Expand Up @@ -291,7 +292,7 @@ def inverse( # type: ignore[override]
"""
original_keys = None
if len(args) == 1 and isinstance(args[0], dict):
original_keys, data_keys, args = self._preproc_dict_data(args[0])
original_keys, data_keys, args, invalid_data = self._preproc_dict_data(args[0])

# args here should already be `DataType`
# NOTE: how to right type to: unpacked args <-> tuple of args to unpack
Expand Down Expand Up @@ -324,7 +325,10 @@ def inverse( # type: ignore[override]
outputs = self._arguments_postproc(args, outputs, data_keys=self.transform_op.data_keys) # type: ignore

if isinstance(original_keys, tuple):
return {k: v for v, k in zip(outputs, original_keys)}
result = {k: v for v, k in zip(outputs, original_keys)}
if invalid_data:
result.update(invalid_data)
return result

if len(outputs) == 1 and isinstance(outputs, list):
return outputs[0]
Expand Down Expand Up @@ -414,7 +418,7 @@ def forward( # type: ignore[override]
# Unpack/handle dictionary args
original_keys = None
if len(args) == 1 and isinstance(args[0], dict):
original_keys, data_keys, args = self._preproc_dict_data(args[0])
original_keys, data_keys, args, invalid_data = self._preproc_dict_data(args[0])

self.transform_op.data_keys = self.transform_op.preproc_datakeys(data_keys)

Expand Down Expand Up @@ -455,7 +459,10 @@ def forward( # type: ignore[override]
self._params = params

if isinstance(original_keys, tuple):
return {k: v for v, k in zip(outputs, original_keys)}
result = {k: v for v, k in zip(outputs, original_keys)}
if invalid_data:
result.update(invalid_data)
return result

if len(outputs) == 1 and isinstance(outputs, list):
return outputs[0]
Expand All @@ -464,17 +471,19 @@ def forward( # type: ignore[override]

def _preproc_dict_data(
self, data: Dict[str, DataType]
) -> Tuple[Tuple[str, ...], List[DataKey], Tuple[DataType, ...]]:
) -> Tuple[Tuple[str, ...], List[DataKey], Tuple[DataType, ...], Optional[Dict[str, Any]]]:
if self.data_keys is not None:
raise ValueError("If you are using a dictionary as input, the data_keys should be None.")

keys = tuple(data.keys())
data_keys = self._read_datakeys_from_dict(keys)
data_keys, invalid_keys = self._read_datakeys_from_dict(keys)
invalid_data = {i: data.pop(i) for i in invalid_keys} if invalid_keys else None
keys = tuple(k for k in keys if k not in invalid_keys) if invalid_keys else keys
data_unpacked = tuple(data.values())

return keys, data_keys, data_unpacked
return keys, data_keys, data_unpacked, invalid_data

def _read_datakeys_from_dict(self, keys: Sequence[str]) -> List[DataKey]:
def _read_datakeys_from_dict(self, keys: Sequence[str]) -> Tuple[List[DataKey], Optional[List[str]]]:
def retrieve_key(key: str) -> DataKey:
"""Try to retrieve the datakey value by matching `<datakey>*`"""
# Alias cases, like INPUT, will not be get by the enum iterator.
Expand All @@ -492,7 +501,15 @@ def retrieve_key(key: str) -> DataKey:
f"Your input data dictionary keys should start with some of datakey values: {allowed_dk}. Got `{key}`"
)

return [DataKey.get(retrieve_key(k)) for k in keys]
valid_data_keys = []
invalid_keys = []
for k in keys:
try:
valid_data_keys.append(DataKey.get(retrieve_key(k)))
except ValueError:
invalid_keys.append(k)

return valid_data_keys, invalid_keys

def _preproc_mask(self, arg: MaskDataType) -> MaskDataType:
if isinstance(arg, list):
Expand Down
4 changes: 3 additions & 1 deletion tests/augmentation/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,20 +776,22 @@ def test_dict_as_input_forward_and_inverse(self, random_apply, bbox_key, device,
random_apply=random_apply,
)

data = {"input": inp, "mask": mask, bbox_key: bbox, "keypoints": keypoints}
data = {"input": inp, "mask": mask, bbox_key: bbox, "keypoints": keypoints, "id": 45}
out = aug(data)
assert out["input"].shape == inp.shape
assert out["mask"].shape == mask.shape
assert out[bbox_key].shape == bbox.shape
assert out["keypoints"].shape == keypoints.shape
assert set(out["mask"].unique().tolist()).issubset(set(mask.unique().tolist()))
assert out["id"] == 45

out_inv = aug.inverse(out)
assert out_inv["input"].shape == inp.shape
assert out_inv["mask"].shape == mask.shape
assert out_inv[bbox_key].shape == bbox.shape
assert out_inv["keypoints"].shape == keypoints.shape
assert set(out_inv["mask"].unique().tolist()).issubset(set(mask.unique().tolist()))
assert out_inv["id"] == 45

if random_apply is False:
reproducibility_test(data, aug)
Expand Down
Loading