diff --git a/docs/source/models/rt_detr.rst b/docs/source/models/rt_detr.rst index cfeaa06c8a..ca67cb06b0 100644 --- a/docs/source/models/rt_detr.rst +++ b/docs/source/models/rt_detr.rst @@ -1,6 +1,28 @@ Real-Time Detection Transformer (RT-DETR) ========================================= +.. code-block:: python + + from kornia.io import load_image + from kornia.models.detector.rtdetr import RTDETRDetectorBuilder + + input_img = load_image(img_path)[None] # Load image to BCHW + + # NOTE: available models: 'rtdetr_r18vd', 'rtdetr_r34vd', 'rtdetr_r50vd_m', 'rtdetr_r50vd', 'rtdetr_r101vd'. + # NOTE: recommended image scales: [480, 512, 544, 576, 608, 640, 640, 640, 672, 704, 736, 768, 800] + detector = RTDETRDetectorBuilder.build("rtdetr_r18vd", image_size=640) + + # get the output boxes + boxes = detector(input_img) + + # draw the bounding boxes on the images directly. + output = detector.draw(input_img, output_type="pil") + output[0].save("Kornia-RTDETR-output.png") + + # convert the whole model to ONNX directly + RTDETRDetectorBuilder.to_onnx("RTDETR-640.onnx", model_name="rtdetr_r18vd", image_size=640) + + .. card:: :link: https://arxiv.org/abs/2304.08069 diff --git a/kornia/color/yuv.py b/kornia/color/yuv.py index e250e1ac36..1334be1089 100644 --- a/kornia/color/yuv.py +++ b/kornia/color/yuv.py @@ -122,7 +122,7 @@ def yuv_to_rgb(image: Tensor) -> Tensor: if not isinstance(image, Tensor): raise TypeError(f"Input type is not a Tensor. Got {type(image)}") - if len(image.shape) < 3 or image.shape[-3] != 3: + if image.dim() < 3 or image.shape[-3] != 3: raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") y: Tensor = image[..., 0, :, :] diff --git a/kornia/contrib/models/common.py b/kornia/contrib/models/common.py index 7ad86d54c6..b46049f93f 100644 --- a/kornia/contrib/models/common.py +++ b/kornia/contrib/models/common.py @@ -9,7 +9,16 @@ class ConvNormAct(nn.Sequential): def __init__( - self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, act: str = "relu", groups: int = 1 + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + act: str = "relu", + groups: int = 1, + conv_naming: str = "conv", + norm_naming: str = "norm", + act_naming: str = "act", ) -> None: super().__init__() if kernel_size % 2 == 0: @@ -23,9 +32,13 @@ def __init__( padding = 0 else: padding = (kernel_size - 1) // 2 - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, 1, groups, False) - self.norm = nn.BatchNorm2d(out_channels) - self.act = {"relu": nn.ReLU, "silu": nn.SiLU, "none": nn.Identity}[act](inplace=True) + conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, 1, groups, False) + norm = nn.BatchNorm2d(out_channels) + activation = {"relu": nn.ReLU, "silu": nn.SiLU, "none": nn.Identity}[act](inplace=True) + + self.__setattr__(conv_naming, conv) + self.__setattr__(norm_naming, norm) + self.__setattr__(act_naming, activation) # Lightly adapted from diff --git a/kornia/contrib/models/rt_detr/architecture/hybrid_encoder.py b/kornia/contrib/models/rt_detr/architecture/hybrid_encoder.py index 6e9573abb9..5e2a8cc98e 100644 --- a/kornia/contrib/models/rt_detr/architecture/hybrid_encoder.py +++ b/kornia/contrib/models/rt_detr/architecture/hybrid_encoder.py @@ -4,6 +4,7 @@ from __future__ import annotations +import copy from typing import Optional import torch @@ -80,15 +81,16 @@ class AIFI(Module): def __init__(self, embed_dim: int, num_heads: int, dim_feedforward: int, dropout: float = 0.0) -> None: super().__init__() self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout) # NOTE: batch_first = False - self.dropout1 = nn.Dropout(dropout) - self.norm1 = nn.LayerNorm(embed_dim) self.linear1 = nn.Linear(embed_dim, dim_feedforward) - self.act = nn.GELU() self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, embed_dim) + + self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) + self.norm1 = nn.LayerNorm(embed_dim) self.norm2 = nn.LayerNorm(embed_dim) + self.act = nn.GELU() def forward(self, x: Tensor) -> Tensor: # using post-norm @@ -149,6 +151,20 @@ def build_2d_sincos_pos_emb( return pos_emb.unsqueeze(1) # (H * W, 1, C) +class TransformerEncoder(nn.Module): + def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: + super().__init__() + self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_layers)]) + self.num_layers = num_layers + + def forward(self, src: Tensor) -> Tensor: # NOTE: Missing src_mask: Tensor = None, pos_embed: Tensor = None + output = src + for layer in self.layers: + output = layer(output) + + return output + + class CCFM(Module): def __init__(self, num_fmaps: int, hidden_dim: int, expansion: float = 1.0) -> None: super().__init__() @@ -192,12 +208,20 @@ def forward(self, fmaps: list[Tensor]) -> list[Tensor]: class HybridEncoder(Module): def __init__(self, in_channels: list[int], hidden_dim: int, dim_feedforward: int, expansion: float = 1.0) -> None: super().__init__() - self.input_proj = nn.ModuleList([ConvNormAct(in_ch, hidden_dim, 1, act="none") for in_ch in in_channels]) - self.aifi = AIFI(hidden_dim, 8, dim_feedforward) + self.input_proj = nn.ModuleList( + [ + ConvNormAct( # To align the naming strategy for the official weights + in_ch, hidden_dim, 1, act="none", conv_naming="0", norm_naming="1", act_naming="2" + ) + for in_ch in in_channels + ] + ) + encoder_layer = AIFI(hidden_dim, 8, dim_feedforward) + self.encoder = nn.Sequential(TransformerEncoder(encoder_layer, 1)) self.ccfm = CCFM(len(in_channels), hidden_dim, expansion) def forward(self, fmaps: list[Tensor]) -> list[Tensor]: projected_maps = [proj(fmap) for proj, fmap in zip(self.input_proj, fmaps)] - projected_maps[-1] = self.aifi(projected_maps[-1]) + projected_maps[-1] = self.encoder(projected_maps[-1]) new_fmaps = self.ccfm(projected_maps) return new_fmaps diff --git a/kornia/contrib/models/rt_detr/architecture/resnet_d.py b/kornia/contrib/models/rt_detr/architecture/resnet_d.py index 57fa171c82..4432b38ac3 100644 --- a/kornia/contrib/models/rt_detr/architecture/resnet_d.py +++ b/kornia/contrib/models/rt_detr/architecture/resnet_d.py @@ -6,6 +6,8 @@ from __future__ import annotations +from collections import OrderedDict + from torch import nn from kornia.contrib.models.common import ConvNormAct @@ -15,7 +17,9 @@ def _make_shortcut(in_channels: int, out_channels: int, stride: int) -> Module: return ( - nn.Sequential(nn.AvgPool2d(2, 2), ConvNormAct(in_channels, out_channels, 1, act="none")) + nn.Sequential( + OrderedDict([("pool", nn.AvgPool2d(2, 2)), ("conv", ConvNormAct(in_channels, out_channels, 1, act="none"))]) + ) if stride == 2 else ConvNormAct(in_channels, out_channels, 1, act="none") ) @@ -28,14 +32,18 @@ def __init__(self, in_channels: int, out_channels: int, stride: int, shortcut: b KORNIA_CHECK(stride in {1, 2}) super().__init__() self.convs = nn.Sequential( - ConvNormAct(in_channels, out_channels, 3, stride=stride), - ConvNormAct(out_channels, out_channels, 3, act="none"), + OrderedDict( + [ + ("branch2a", ConvNormAct(in_channels, out_channels, 3, stride=stride)), + ("branch2b", ConvNormAct(out_channels, out_channels, 3, act="none")), + ] + ) ) - self.shortcut = nn.Identity() if shortcut else _make_shortcut(in_channels, out_channels, stride) + self.short = nn.Identity() if shortcut else _make_shortcut(in_channels, out_channels, stride) self.relu = nn.ReLU(inplace=True) def forward(self, x: Tensor) -> Tensor: - return self.relu(self.convs(x) + self.shortcut(x)) + return self.relu(self.convs(x) + self.short(x)) class BottleneckD(Module): @@ -46,15 +54,25 @@ def __init__(self, in_channels: int, out_channels: int, stride: int, shortcut: b super().__init__() expanded_out_channels = out_channels * self.expansion self.convs = nn.Sequential( - ConvNormAct(in_channels, out_channels, 1), - ConvNormAct(out_channels, out_channels, 3, stride=stride), - ConvNormAct(out_channels, expanded_out_channels, 1, act="none"), + OrderedDict( + [ + ("branch2a", ConvNormAct(in_channels, out_channels, 1)), + ("branch2b", ConvNormAct(out_channels, out_channels, 3, stride=stride)), + ("branch2c", ConvNormAct(out_channels, expanded_out_channels, 1, act="none")), + ] + ) ) - self.shortcut = nn.Identity() if shortcut else _make_shortcut(in_channels, expanded_out_channels, stride) + self.short = nn.Identity() if shortcut else _make_shortcut(in_channels, expanded_out_channels, stride) self.relu = nn.ReLU(inplace=True) def forward(self, x: Tensor) -> Tensor: - return self.relu(self.convs(x) + self.shortcut(x)) + return self.relu(self.convs(x) + self.short(x)) + + +class Block(nn.Sequential): + def __init__(self, blocks: Module) -> None: + super().__init__() + self.blocks = blocks class ResNetD(Module): @@ -63,16 +81,22 @@ def __init__(self, n_blocks: list[int], block: type[BasicBlockD | BottleneckD]) super().__init__() in_channels = 64 self.conv1 = nn.Sequential( - ConvNormAct(3, in_channels // 2, 3, stride=2), - ConvNormAct(in_channels // 2, in_channels // 2, 3), - ConvNormAct(in_channels // 2, in_channels, 3), - nn.MaxPool2d(3, stride=2, padding=1), + OrderedDict( + [ + ("conv1_1", ConvNormAct(3, in_channels // 2, 3, stride=2)), + ("conv1_2", ConvNormAct(in_channels // 2, in_channels // 2, 3)), + ("conv1_3", ConvNormAct(in_channels // 2, in_channels, 3)), + ("pool", nn.MaxPool2d(3, stride=2, padding=1)), + ] + ) ) - self.res2, in_channels = self.make_stage(in_channels, 64, 1, n_blocks[0], block) - self.res3, in_channels = self.make_stage(in_channels, 128, 2, n_blocks[1], block) - self.res4, in_channels = self.make_stage(in_channels, 256, 2, n_blocks[2], block) - self.res5, in_channels = self.make_stage(in_channels, 512, 2, n_blocks[3], block) + res2, in_channels = self.make_stage(in_channels, 64, 1, n_blocks[0], block) + res3, in_channels = self.make_stage(in_channels, 128, 2, n_blocks[1], block) + res4, in_channels = self.make_stage(in_channels, 256, 2, n_blocks[2], block) + res5, in_channels = self.make_stage(in_channels, 512, 2, n_blocks[3], block) + + self.res_layers = nn.ModuleList([res2, res3, res4, res5]) self.out_channels = [ch * block.expansion for ch in [128, 256, 512]] @@ -80,18 +104,20 @@ def __init__(self, n_blocks: list[int], block: type[BasicBlockD | BottleneckD]) def make_stage( in_channels: int, out_channels: int, stride: int, n_blocks: int, block: type[BasicBlockD | BottleneckD] ) -> tuple[Module, int]: - stage = nn.Sequential( - block(in_channels, out_channels, stride, False), - *[block(out_channels * block.expansion, out_channels, 1, True) for _ in range(n_blocks - 1)], + stage = Block( + nn.Sequential( + block(in_channels, out_channels, stride, False), + *[block(out_channels * block.expansion, out_channels, 1, True) for _ in range(n_blocks - 1)], + ) ) return stage, out_channels * block.expansion def forward(self, x: Tensor) -> list[Tensor]: x = self.conv1(x) - res2 = self.res2(x) - res3 = self.res3(res2) - res4 = self.res4(res3) - res5 = self.res5(res4) + res2 = self.res_layers[0](x) + res3 = self.res_layers[1](res2) + res4 = self.res_layers[2](res3) + res5 = self.res_layers[3](res4) return [res3, res4, res5] @staticmethod diff --git a/kornia/contrib/models/rt_detr/architecture/rtdetr_head.py b/kornia/contrib/models/rt_detr/architecture/rtdetr_head.py index b75e56eca6..6345577e5a 100644 --- a/kornia/contrib/models/rt_detr/architecture/rtdetr_head.py +++ b/kornia/contrib/models/rt_detr/architecture/rtdetr_head.py @@ -3,6 +3,7 @@ from __future__ import annotations +import copy from typing import Optional import torch @@ -192,14 +193,10 @@ def forward( return out -class TransformerDecoder: - def __init__(self, hidden_dim: int, decoder_layers: nn.ModuleList, num_layers: int, eval_idx: int = -1) -> None: +class TransformerDecoder(Module): + def __init__(self, hidden_dim: int, decoder_layer: nn.Module, num_layers: int, eval_idx: int = -1) -> None: super().__init__() - self.layers = decoder_layers - # TODO: come back to this later - # self.layers = nn.ModuleList([ - # copy.deepcopy(decoder_layer) for _ in range(num_layers) - # ]) + self.layers = nn.ModuleList([copy.deepcopy(decoder_layer) for _ in range(num_layers)]) self.hidden_dim = hidden_dim self.num_layers = num_layers self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx @@ -272,13 +269,16 @@ def __init__( num_decoder_layers: int, num_heads: int = 8, num_decoder_points: int = 4, - # num_levels: int = 3, + num_levels: int = 3, dropout: float = 0.0, + num_denoising: int = 100, ) -> None: super().__init__() self.num_queries = num_queries # TODO: verify this is correct - self.num_levels = len(in_channels) + if len(in_channels) > num_levels: + raise ValueError(f"`num_levels` cannot be greater than {len(in_channels)}. Got {num_levels}.") + self.num_levels = num_levels # build the input projection layers self.input_proj = nn.ModuleList() @@ -288,25 +288,23 @@ def __init__( # https://github.com/lyuwenyu/RT-DETR/blob/main/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L403-L410 # NOTE: need to be integrated with the TransformerDecoderLayer - self.decoder_layers = nn.ModuleList( - [ - TransformerDecoderLayer( - embed_dim=hidden_dim, - num_heads=num_heads, - dropout=dropout, - num_levels=len(in_channels), - num_points=num_decoder_points, - ) - for _ in range(num_decoder_layers) - ] + decoder_layer = TransformerDecoderLayer( + embed_dim=hidden_dim, + num_heads=num_heads, + dropout=dropout, + num_levels=self.num_levels, + num_points=num_decoder_points, ) self.decoder = TransformerDecoder( - hidden_dim=hidden_dim, decoder_layers=self.decoder_layers, num_layers=num_decoder_layers + hidden_dim=hidden_dim, decoder_layer=decoder_layer, num_layers=num_decoder_layers ) # denoising part - self.denoising_class_embed = nn.Embedding(num_classes, hidden_dim) # not used in evaluation + if num_denoising > 0: + self.denoising_class_embed = nn.Embedding( + num_classes + 1, hidden_dim, padding_idx=num_classes + ) # not used in evaluation # decoder embedding self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, num_layers=2) @@ -334,7 +332,7 @@ def forward(self, feats: Tensor) -> tuple[Tensor, Tensor]: ) # decoder - out_bboxes, out_logits = self.decoder.forward( + out_bboxes, out_logits = self.decoder( target, init_ref_points_unact, memory, diff --git a/kornia/contrib/models/rt_detr/model.py b/kornia/contrib/models/rt_detr/model.py index 1f250c77aa..3e6e6226bb 100644 --- a/kornia/contrib/models/rt_detr/model.py +++ b/kornia/contrib/models/rt_detr/model.py @@ -2,10 +2,13 @@ from __future__ import annotations +import re from dataclasses import dataclass from enum import Enum from typing import Optional +import torch + from kornia.contrib.models.base import ModelBase from kornia.contrib.models.rt_detr.architecture.hgnetv2 import PPHGNetV2 from kornia.contrib.models.rt_detr.architecture.hybrid_encoder import HybridEncoder @@ -13,6 +16,14 @@ from kornia.contrib.models.rt_detr.architecture.rtdetr_head import RTDETRHead from kornia.core import Tensor +URLs = { + "rtdetr_r18vd": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r18vd_dec3_6x_coco_from_paddle.pth", + "rtdetr_r34vd": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r34vd_dec4_6x_coco_from_paddle.pth", + "rtdetr_r50vd_m": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r50vd_m_6x_coco_from_paddle.pth", + "rtdetr_r50vd": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r50vd_6x_coco_from_paddle.pth", + "rtdetr_r101vd": "https://github.com/lyuwenyu/storage/releases/download/v0.1/rtdetr_r101vd_6x_coco_from_paddle.pth", +} + class RTDETRModelType(Enum): """Enum class that maps RT-DETR model type.""" @@ -23,6 +34,7 @@ class RTDETRModelType(Enum): resnet101d = 3 hgnetv2_l = 4 hgnetv2_x = 5 + resnet50d_m = 6 @dataclass @@ -65,7 +77,7 @@ class RTDETRConfig: class RTDETR(ModelBase[RTDETRConfig]): """RT-DETR Object Detection model, as described in https://arxiv.org/abs/2304.08069.""" - def __init__(self, backbone: ResNetD | PPHGNetV2, neck: HybridEncoder, head: RTDETRHead): + def __init__(self, backbone: ResNetD | PPHGNetV2, encoder: HybridEncoder, decoder: RTDETRHead): """Construct RT-DETR Object Detection model. Args: @@ -75,8 +87,8 @@ def __init__(self, backbone: ResNetD | PPHGNetV2, neck: HybridEncoder, head: RTD """ super().__init__() self.backbone = backbone - self.neck = neck - self.head = head + self.encoder = encoder + self.decoder = decoder @staticmethod def from_config(config: RTDETRConfig) -> RTDETR: @@ -119,6 +131,13 @@ def from_config(config: RTDETRConfig) -> RTDETR: head_num_decoder_layers = config.head_num_decoder_layers or 6 neck_expansion = config.neck_expansion or 1.0 + elif model_type == RTDETRModelType.resnet50d_m: + backbone = ResNetD.from_config(50) + neck_hidden_dim = config.neck_hidden_dim or 256 + neck_dim_feedforward = config.neck_dim_feedforward or 1024 + head_num_decoder_layers = config.head_num_decoder_layers or 6 + neck_expansion = config.neck_expansion or 0.5 + elif model_type == RTDETRModelType.resnet101d: backbone = ResNetD.from_config(101) neck_hidden_dim = config.neck_hidden_dim or 384 @@ -156,6 +175,76 @@ def from_config(config: RTDETRConfig) -> RTDETR: model.load_checkpoint(config.checkpoint) return model + @staticmethod + def from_pretrained(model_name: str) -> RTDETR: + """Load model from pretrained weights. + + Args: + model_name: 'rtdetr_r18vd', 'rtdetr_r34vd', 'rtdetr_r50vd_m', 'rtdetr_r50vd', 'rtdetr_r101vd'. + """ + + if model_name not in URLs: + raise ValueError(f"No pretrained model for '{model_name}'. Please select from {list(URLs.keys())}.") + + state_dict = torch.hub.load_state_dict_from_url( + URLs[model_name], map_location="cuda:0" if torch.cuda.is_available() else "cpu" + ) + + def map_name(old_name: str) -> str: + new_name = old_name + + # Encoder renaming + new_name = re.sub("encoder.pan_blocks", "encoder.ccfm.pan_blocks", new_name) + new_name = re.sub("encoder.downsample_convs", "encoder.ccfm.downsample_convs", new_name) + new_name = re.sub("encoder.fpn_blocks", "encoder.ccfm.fpn_blocks", new_name) + new_name = re.sub("encoder.lateral_convs", "encoder.ccfm.lateral_convs", new_name) + + # Backbone renaming + new_name = re.sub(".branch2b.", ".convs.branch2b.", new_name) + new_name = re.sub(".branch2a.", ".convs.branch2a.", new_name) + new_name = re.sub(".branch2c.", ".convs.branch2c.", new_name) + + return new_name + + def _state_dict_proc(state_dict: dict[str, Tensor]) -> dict[str, Tensor]: + state_dict = state_dict["ema"]["module"] # type:ignore + new_state_dict = {} + + # Apply the regex-based mapping function to each key + for old_name in state_dict.keys(): + new_name = map_name(old_name) + new_state_dict[new_name] = state_dict[old_name] + + return new_state_dict + + model = RTDETR.from_name(model_name, num_classes=80) + + model.load_state_dict(_state_dict_proc(state_dict)) + return model + + @staticmethod + def from_name(model_name: str, num_classes: int = 80) -> RTDETR: + """Load model without pretrained weights. + + Args: + model_name: 'rtdetr_r18vd', 'rtdetr_r34vd', 'rtdetr_r50vd_m', 'rtdetr_r50vd', 'rtdetr_r101vd'. + """ + + if model_name == "rtdetr_r18vd": + model = RTDETR.from_config(RTDETRConfig(RTDETRModelType.resnet18d, num_classes)) + elif model_name == "rtdetr_r34vd": + model = RTDETR.from_config(RTDETRConfig(RTDETRModelType.resnet34d, num_classes)) + elif model_name == "rtdetr_r50vd_m": + model = RTDETR.from_config(RTDETRConfig(RTDETRModelType.resnet50d_m, num_classes)) + elif model_name == "rtdetr_r50vd": + model = RTDETR.from_config(RTDETRConfig(RTDETRModelType.resnet50d, num_classes)) + elif model_name == "rtdetr_r101vd": + model = RTDETR.from_config(RTDETRConfig(RTDETRModelType.resnet101d, num_classes)) + else: + raise ValueError + + return model + def forward(self, images: Tensor) -> tuple[Tensor, Tensor]: """Detect objects in an image. @@ -167,10 +256,8 @@ def forward(self, images: Tensor) -> tuple[Tensor, Tensor]: :math:`K` is the number of classes. - **boxes** - Tensor of shape :math:`(N, Q, 4)`, where :math:`Q` is the number of queries. """ - if self.training: - raise RuntimeError("Only evaluation mode is supported. Please call model.eval().") feats = self.backbone(images) - feats_buf = self.neck(feats) - logits, boxes = self.head(feats_buf) + feats_buf = self.encoder(feats) + logits, boxes = self.decoder(feats_buf) return logits, boxes diff --git a/kornia/contrib/models/rt_detr/post_processor.py b/kornia/contrib/models/rt_detr/post_processor.py index d5fce83870..95b35623df 100644 --- a/kornia/contrib/models/rt_detr/post_processor.py +++ b/kornia/contrib/models/rt_detr/post_processor.py @@ -1,18 +1,50 @@ +"""Post-processor for the RT-DETR model.""" + from __future__ import annotations -# TODO: +from typing import Optional + import torch from kornia.core import Module, Tensor, concatenate -from kornia.image.base import ImageSize +def mod(a: Tensor, b: int) -> Tensor: + """Compute the modulo operation for two numbers. + + This function calculates the remainder of the division of 'a' by 'b' + using the formula: a - (a // b) * b, which is equivalent to the modulo operation. + + Args: + a: The dividend. + b: The divisor. + + Returns: + The remainder of a divided by b. + + Example: + >>> mod(7, 3) + 1 + """ + return a - (a // b) * b + + +# TODO: deprecate the confidence threshold and add the num_top_queries as a parameter and num_classes as a parameter class DETRPostProcessor(Module): - def __init__(self, confidence_threshold: float) -> None: + def __init__( + self, + confidence_threshold: Optional[float] = None, + num_classes: int = 80, + num_top_queries: int = 300, + confidence_filtering: bool = True, + ) -> None: super().__init__() self.confidence_threshold = confidence_threshold + self.num_classes = num_classes + self.confidence_filtering = confidence_filtering + self.num_top_queries = num_top_queries - def forward(self, logits: Tensor, boxes: Tensor, original_sizes: list[ImageSize]) -> list[Tensor]: + def forward(self, logits: Tensor, boxes: Tensor, original_sizes: Tensor) -> Tensor: """Post-process outputs from DETR. Args: @@ -20,7 +52,8 @@ def forward(self, logits: Tensor, boxes: Tensor, original_sizes: list[ImageSize] queries, :math:`K` is the number of classes. boxes: tensor with shape :math:`(N, Q, 4)`, where :math:`N` is the batch size, :math:`Q` is the number of queries. - original_sizes: list of tuples, each tuple represent (img_height, img_width). + original_sizes: tensor with shape :math:`(N, 2)`, where :math:`N` is the batch size and each element + represents the image size of (img_height, img_width). Returns: Processed detections. For each image, the detections have shape (D, 6), where D is the number of detections @@ -37,24 +70,24 @@ def forward(self, logits: Tensor, boxes: Tensor, original_sizes: list[ImageSize] cxcy, wh = boxes[..., :2], boxes[..., 2:] boxes_xy = concatenate([cxcy - wh * 0.5, wh], -1) - sizes_wh = torch.empty(1, 1, 2, device=boxes.device, dtype=boxes.dtype) - sizes_wh[..., 0] = original_sizes[0].width - sizes_wh[..., 1] = original_sizes[0].height - sizes_wh = sizes_wh.repeat(1, 1, 2) + # Get dynamic size from the input tensor itself + sizes_wh = original_sizes[0].flip(0).unsqueeze(0).unsqueeze(0).repeat(1, 1, 2) boxes_xy = boxes_xy * sizes_wh scores = logits.sigmoid() # RT-DETR was trained with focal loss. thus sigmoid is used instead of softmax - # the original code is slightly different - # it allows 1 bounding box to have multiple classes (multi-label) - scores, labels = scores.max(-1) + # retrieve the boxes with the highest score for each class + # https://github.com/lyuwenyu/RT-DETR/blob/b6bf0200b249a6e35b44e0308b6058f55b99696b/rtdetrv2_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py#L55-L62 + scores, index = torch.topk(scores.flatten(1), self.num_top_queries, dim=-1) + labels = mod(index, self.num_classes) + index = index // self.num_classes + boxes = boxes_xy.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes_xy.shape[-1])) + + all_boxes = concatenate([labels[..., None], scores[..., None], boxes], -1) - detections: list[Tensor] = [] - for i in range(scores.shape[0]): - mask = scores[i] >= self.confidence_threshold - labels_i = labels[i, mask].unsqueeze(-1) - scores_i = scores[i, mask].unsqueeze(-1) - boxes_i = boxes_xy[i, mask] - detections.append(concatenate([labels_i, scores_i, boxes_i], -1)) + if not self.confidence_filtering or self.confidence_threshold == 0: + return all_boxes - return detections + return all_boxes[(all_boxes[:, :, 1] > self.confidence_threshold).unsqueeze(-1).expand_as(all_boxes)].view( + all_boxes.shape[0], -1, all_boxes.shape[-1] + ) diff --git a/kornia/contrib/object_detection.py b/kornia/contrib/object_detection.py index ad1849378f..4a4b965dc2 100644 --- a/kornia/contrib/object_detection.py +++ b/kornia/contrib/object_detection.py @@ -1,14 +1,21 @@ from __future__ import annotations +import datetime +import logging +import os from dataclasses import dataclass from enum import Enum -from typing import Optional +from typing import Optional, Union import torch from kornia.core import Module, Tensor, concatenate from kornia.core.check import KORNIA_CHECK_SHAPE -from kornia.image.base import ImageSize +from kornia.core.external import PILImage as Image +from kornia.core.external import numpy as np +from kornia.geometry.transform import resize +from kornia.io import write_image +from kornia.utils.draw import draw_rectangle __all__ = [ "BoundingBoxDataFormat", @@ -19,6 +26,8 @@ "ObjectDetectorResult", ] +logger = logging.getLogger(__name__) + class BoundingBoxDataFormat(Enum): """Enum class that maps bounding box data format.""" @@ -113,17 +122,22 @@ def __init__(self, size: tuple[int, int], interpolation_mode: str = "bilinear") self.size = size self.interpolation_mode = interpolation_mode - def forward(self, imgs: list[Tensor]) -> tuple[Tensor, list[ImageSize]]: + def forward(self, imgs: Union[Tensor, list[Tensor]]) -> tuple[Tensor, Tensor]: + """ + Returns: + resized_imgs: resized images in a batch. + original_sizes: the original image sizes of (height, width). + """ # TODO: support other input formats e.g. file path, numpy - resized_imgs, original_sizes = [], [] - for i in range(len(imgs)): + resized_imgs: list[Tensor] = [] + + iters = len(imgs) if isinstance(imgs, list) else imgs.shape[0] + original_sizes = imgs[0].new_zeros((iters, 2)) + for i in range(iters): img = imgs[i] - # NOTE: assume that image layout is CHW - original_sizes.append(ImageSize(height=img.shape[1], width=img.shape[2])) - resized_imgs.append( - # TODO: fix kornia resize to support onnx - torch.nn.functional.interpolate(img.unsqueeze(0), size=self.size, mode=self.interpolation_mode) - ) + original_sizes[i, 0] = img.shape[-2] # Height + original_sizes[i, 1] = img.shape[-1] # Width + resized_imgs.append(resize(img[None], size=self.size, interpolation=self.interpolation_mode)) return concatenate(resized_imgs), original_sizes @@ -145,11 +159,12 @@ def __init__(self, model: Module, pre_processor: Module, post_processor: Module) self.post_processor = post_processor.eval() @torch.inference_mode() - def forward(self, images: list[Tensor]) -> list[Tensor]: + def forward(self, images: Union[Tensor, list[Tensor]]) -> Tensor: """Detect objects in a given list of images. Args: - images: list of RGB images. Each image is a Tensor with shape :math:`(3, H, W)`. + images: If list of RGB images. Each image is a Tensor with shape :math:`(3, H, W)`. + If Tensor, a Tensor with shape :math:`(B, 3, H, W)`. Returns: list of detections found in each image. For item in a batch, shape is :math:`(D, 6)`, where :math:`D` is the @@ -160,6 +175,52 @@ def forward(self, images: list[Tensor]) -> list[Tensor]: detections = self.post_processor(logits, boxes, images_sizes) return detections + def draw( + self, images: Union[Tensor, list[Tensor]], detections: Optional[Tensor] = None, output_type: str = "torch" + ) -> Union[Tensor, list[Tensor], list[Image.Image]]: # type: ignore + """Very simple drawing. + + Needs to be more fancy later. + """ + if detections is None: + detections = self.forward(images) + output = [] + for image, detection in zip(images, detections): + out_img = image[None].clone() + for out in detection: + out_img = draw_rectangle( + out_img, + torch.Tensor([[[out[-4], out[-3], out[-4] + out[-2], out[-3] + out[-1]]]]), + ) + if output_type == "torch": + output.append(out_img[0]) + elif output_type == "pil": + output.append(Image.fromarray((out_img[0] * 255).permute(1, 2, 0).numpy().astype(np.uint8))) # type: ignore + else: + raise RuntimeError(f"Unsupported output type `{output_type}`.") + return output + + def save( + self, images: Union[Tensor, list[Tensor]], detections: Optional[Tensor] = None, directory: Optional[str] = None + ) -> None: + """Saves the output image(s) to a directory. + + Args: + name: Directory to save the images. + n_row: Number of images displayed in each row of the grid. + """ + if directory is None: + name = f"detection-{datetime.datetime.now(tz=datetime.timezone.utc).strftime('%Y%m%d%H%M%S')!s}" + directory = os.path.join("Kornia_outputs", name) + outputs = self.draw(images, detections) + os.makedirs(directory, exist_ok=True) + for i, out_image in enumerate(outputs): + write_image( + os.path.join(directory, f"{str(i).zfill(6)}.jpg"), + out_image.mul(255.0).byte(), + ) + logger.info(f"Outputs are saved in {directory}") + def compile( self, *, diff --git a/kornia/core/external.py b/kornia/core/external.py index 4efdbfe189..1e0160035d 100644 --- a/kornia/core/external.py +++ b/kornia/core/external.py @@ -1,7 +1,12 @@ import importlib +import logging +import subprocess +import sys from types import ModuleType from typing import List, Optional +logger = logging.getLogger(__name__) + class LazyLoader: """A class that implements lazy loading for Python modules. @@ -15,6 +20,8 @@ class LazyLoader: module: The actual module object, initialized to None and loaded upon first access. """ + auto_install: bool = False + def __init__(self, module_name: str) -> None: """Initializes the LazyLoader with the name of the module. @@ -24,6 +31,10 @@ def __init__(self, module_name: str) -> None: self.module_name = module_name self.module: Optional[ModuleType] = None + def _install_package(self, module_name: str) -> None: + logger.info(f"Installing `{module_name}` ...") + subprocess.run([sys.executable, "-m", "pip", "install", "-U", module_name], shell=False, check=False) # noqa: S603 + def _load(self) -> None: """Loads the module if it hasn't been loaded yet. @@ -34,10 +45,23 @@ def _load(self) -> None: try: self.module = importlib.import_module(self.module_name) except ImportError as e: - raise ImportError( - f"Optional dependency '{self.module_name}' is not installed. " - f"Please install it to use this functionality." - ) from e + if self.auto_install: + self._install_package(self.module_name) + else: + if_install = input( + f"Optional dependency '{self.module_name}' is not installed. " + "Do you wish to install the dependency? [Y]es, [N]o, [A]ll." + ) + if if_install.lower() == "y": + self._install_package(self.module_name) + elif if_install.lower() == "a": + self.auto_install = True + self._install_package(self.module_name) + else: + raise ImportError( + f"Optional dependency '{self.module_name}' is not installed. " + f"Please install it to use this functionality." + ) from e def __getattr__(self, item: str) -> object: """Loads the module (if not already loaded) and returns the requested attribute. diff --git a/kornia/geometry/transform/affwarp.py b/kornia/geometry/transform/affwarp.py index b0abb14fe1..c85c5ddf3d 100644 --- a/kornia/geometry/transform/affwarp.py +++ b/kornia/geometry/transform/affwarp.py @@ -1,3 +1,4 @@ +import warnings from typing import Optional, Tuple, Union import torch @@ -567,11 +568,15 @@ def resize( input_size = h, w = input.shape[-2:] if isinstance(size, int): + if torch.onnx.is_in_onnx_export(): + warnings.warn("Please pass the size with a tuple when exporting to ONNX to correct the tracing.") aspect_ratio = w / h size = _side_to_image_size(size, aspect_ratio, side) - if size == input_size: - return input + # Skip this dangerous if-else when converting to ONNX. + if not torch.onnx.is_in_onnx_export(): + if size == input_size: + return input factors = (h / size[0], w / size[1]) diff --git a/kornia/io/io.py b/kornia/io/io.py index 62d9ed45b9..1e8797d8b2 100644 --- a/kornia/io/io.py +++ b/kornia/io/io.py @@ -65,7 +65,9 @@ def _to_uint8(image: Tensor) -> Tensor: return image.mul(255.0).byte() -def load_image(path_file: str | Path, desired_type: ImageLoadType, device: Device = "cpu") -> Tensor: +def load_image( + path_file: str | Path, desired_type: ImageLoadType = ImageLoadType.RGB32, device: Device = "cpu" +) -> Tensor: """Read an image file and decode using the Kornia Rust backend. Args: diff --git a/kornia/models/__init__.py b/kornia/models/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/kornia/models/detector/__init__.py b/kornia/models/detector/__init__.py new file mode 100644 index 0000000000..55a62efc22 --- /dev/null +++ b/kornia/models/detector/__init__.py @@ -0,0 +1 @@ +from .rtdetr import * diff --git a/kornia/models/detector/rtdetr.py b/kornia/models/detector/rtdetr.py new file mode 100644 index 0000000000..4bc1c6bd2d --- /dev/null +++ b/kornia/models/detector/rtdetr.py @@ -0,0 +1,154 @@ +import warnings +from typing import Optional + +import torch +from torch import nn + +from kornia.contrib.models.rt_detr import DETRPostProcessor +from kornia.contrib.models.rt_detr.model import RTDETR, RTDETRConfig +from kornia.contrib.object_detection import ObjectDetector, ResizePreProcessor +from kornia.core import rand + +__all__ = ["RTDETRDetectorBuilder"] + + +class RTDETRDetectorBuilder: + """A builder class for constructing RT-DETR object detection models. + + This class provides static methods to: + - Build an object detection model from a model name or configuration. + - Export the model to ONNX format for inference. + """ + + @staticmethod + def build( + model_name: Optional[str] = None, + config: Optional[RTDETRConfig] = None, + pretrained: bool = True, + image_size: Optional[int] = 640, + confidence_threshold: float = 0.5, + confidence_filtering: Optional[bool] = None, + ) -> ObjectDetector: + """Builds and returns an RT-DETR object detector model. + + Either `model_name` or `config` must be provided. If neither is provided, + a default pretrained model (`rtdetr_r18vd`) will be built. + + Args: + model_name: + Name of the RT-DETR model to load. Can be one of the available pretrained models. + Including 'rtdetr_r18vd', 'rtdetr_r34vd', 'rtdetr_r50vd_m', 'rtdetr_r50vd', 'rtdetr_r101vd'. + config: + A custom configuration object for building the RT-DETR model. + pretrained: + Whether to load a pretrained version of the model (applies when `model_name` is provided). + image_size: + The size to which input images will be resized during preprocessing. + If None, no resizing will be performed before passing to the model. Recommended scales include + [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]. + confidence_threshold: + The confidence threshold used during post-processing to filter detections. + confidence_filtering: + If to perform filtering on resulting boxes. If None, the filtering will be blocked when exporting + to ONNX, while it would perform as per confidence_threshold when build the model. + + Returns: + ObjectDetector + An object detector instance initialized with the specified model, preprocessor, and post-processor. + """ + if model_name is not None and config is not None: + raise ValueError("Either `model_name` or `config` should be `None`.") + + if config is not None: + model = RTDETR.from_config(config) + elif model_name is not None: + if pretrained: + model = RTDETR.from_pretrained(model_name) + else: + model = RTDETR.from_name(model_name) + else: + warnings.warn("No `model_name` or `config` found. Will build pretrained `rtdetr_r18vd`.") + model = RTDETR.from_pretrained("rtdetr_r18vd") + + return ObjectDetector( + model, + ResizePreProcessor((image_size, image_size)) if image_size is not None else nn.Identity(), + DETRPostProcessor( + confidence_threshold, + num_classes=config.num_classes if config is not None else 80, + confidence_filtering=confidence_filtering or not torch.onnx.is_in_onnx_export, + ), + ) + + @staticmethod + def to_onnx( + model_name: Optional[str] = None, + onnx_name: Optional[str] = None, + config: Optional[RTDETRConfig] = None, + pretrained: bool = True, + image_size: Optional[int] = 640, + confidence_threshold: float = 0.5, + confidence_filtering: Optional[bool] = None, + ) -> tuple[str, ObjectDetector]: + """Exports an RT-DETR object detection model to ONNX format. + + Either `model_name` or `config` must be provided. If neither is provided, + a default pretrained model (`rtdetr_r18vd`) will be built. + + Args: + model_name: + Name of the RT-DETR model to load. Can be one of the available pretrained models. + config: + A custom configuration object for building the RT-DETR model. + pretrained: + Whether to load a pretrained version of the model (applies when `model_name` is provided). + image_size: + The size to which input images will be resized during preprocessing. + If None, image_size will be dynamic. Recommended scales include + [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]. + confidence_threshold: + The confidence threshold used during post-processing to filter detections. + confidence_filtering: + If to perform filtering on resulting boxes. If None, the filtering will be blocked when exporting + to ONNX, while it would perform as per confidence_threshold when build the model. + + Returns: + - The name of the ONNX model. + - The exported torch model. + """ + + detector = RTDETRDetectorBuilder.build( + model_name=model_name, + config=config, + pretrained=pretrained, + image_size=image_size, + confidence_threshold=confidence_threshold, + confidence_filtering=confidence_filtering, + ) + if onnx_name is None: + _model_name = model_name + if model_name is None and config is not None: + _model_name = "rtdetr-customized" + elif model_name is None and config is None: + _model_name = "rtdetr_r18vd" + onnx_name = f"Kornia-RTDETR-{_model_name}-{image_size}.onnx" + + if image_size is None: + val_image = rand(1, 3, 640, 640) + else: + val_image = rand(1, 3, image_size, image_size) + + dynamic_axes = {"input": {0: "batch_size", 2: "height", 3: "width"}, "output": {0: "batch_size"}} + torch.onnx.export( + detector, + val_image, + onnx_name, + export_params=True, + opset_version=17, + do_constant_folding=True, + input_names=["input"], + output_names=["output"], + dynamic_axes=dynamic_axes, + ) + + return onnx_name, detector diff --git a/kornia/utils/image.py b/kornia/utils/image.py index a3f6a76393..05293e757d 100644 --- a/kornia/utils/image.py +++ b/kornia/utils/image.py @@ -264,7 +264,7 @@ def _wrapper(input: Tensor, *args: Any, **kwargs: Any) -> Tensor: if not isinstance(input, Tensor): raise TypeError(f"Input input type is not a Tensor. Got {type(input)}") - if input.numel() == 0: + if input.shape.numel() == 0: raise ValueError("Invalid input tensor, it is empty.") input_shape = input.shape diff --git a/tests/contrib/test_object_detector.py b/tests/contrib/test_object_detector.py index dd5c07aacb..00eee98b81 100644 --- a/tests/contrib/test_object_detector.py +++ b/tests/contrib/test_object_detector.py @@ -17,7 +17,7 @@ def test_smoke(self, device, dtype): config = RTDETRConfig("resnet50d", 10, head_num_queries=10) model = RTDETR.from_config(config).to(device, dtype).eval() pre_processor = kornia.contrib.object_detection.ResizePreProcessor((32, 32)) - post_processor = DETRPostProcessor(confidence).to(device, dtype).eval() + post_processor = DETRPostProcessor(confidence, num_top_queries=3).to(device, dtype).eval() detector = kornia.contrib.ObjectDetector(model, pre_processor, post_processor) sizes = torch.randint(5, 10, (batch_size, 2)) * 32 @@ -39,8 +39,8 @@ def test_smoke(self, device, dtype): def test_onnx(self, device, dtype, tmp_path: Path, variant: str): config = RTDETRConfig(variant, 1) model = RTDETR.from_config(config).to(device=device, dtype=dtype).eval() - pre_processor = kornia.contrib.object_detection.ResizePreProcessor(640) - post_processor = DETRPostProcessor(0.3) + pre_processor = kornia.contrib.object_detection.ResizePreProcessor((640, 640)) + post_processor = DETRPostProcessor(0.3, num_top_queries=3) detector = kornia.contrib.ObjectDetector(model, pre_processor, post_processor) data = torch.rand(3, 400, 640, device=device, dtype=dtype) @@ -55,7 +55,7 @@ def test_onnx(self, device, dtype, tmp_path: Path, variant: str): input_names=["images"], output_names=["detections"], dynamic_axes=dynamic_axes, - opset_version=16, + opset_version=17, ) assert model_path.is_file() diff --git a/tests/core/test_lazyloader.py b/tests/core/test_lazyloader.py index 1025f97827..b832de99c8 100644 --- a/tests/core/test_lazyloader.py +++ b/tests/core/test_lazyloader.py @@ -1,3 +1,5 @@ +from io import StringIO + import pytest from kornia.core.external import LazyLoader @@ -19,7 +21,8 @@ def test_lazy_loader_loading_module(self): assert loader.sqrt(4) == 2.0 assert loader.module is not None # Should be loaded now - def test_lazy_loader_invalid_module(self): + def test_lazy_loader_invalid_module(self, monkeypatch): + monkeypatch.setattr("sys.stdin", StringIO("n")) # Test that LazyLoader raises an ImportError for an invalid module loader = LazyLoader("non_existent_module") with pytest.raises(ImportError) as excinfo: