Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RTDETR update #88

Open
wants to merge 50 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
859c67d
update
shijianjian Sep 4, 2024
6522333
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 4, 2024
095a268
update
shijianjian Sep 4, 2024
bdeee27
update
shijianjian Sep 4, 2024
d7da930
update
shijianjian Sep 4, 2024
33120ab
update
shijianjian Sep 4, 2024
a5ded63
update
shijianjian Sep 4, 2024
e8019f0
update
shijianjian Sep 4, 2024
85715a6
update
shijianjian Sep 5, 2024
05106d3
update
shijianjian Sep 5, 2024
2dcba2c
update
shijianjian Sep 5, 2024
83fa545
update
shijianjian Sep 6, 2024
023133a
update
shijianjian Sep 6, 2024
b69d535
update
shijianjian Sep 6, 2024
a9412fc
update
shijianjian Sep 6, 2024
5f2aae5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 6, 2024
928aec7
update
shijianjian Sep 6, 2024
3d7ac84
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 6, 2024
db87c53
update
shijianjian Sep 6, 2024
70ae085
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 6, 2024
82e3240
update
shijianjian Sep 6, 2024
22bb115
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 6, 2024
d639025
update
shijianjian Sep 6, 2024
a902739
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 6, 2024
292f410
doc update
shijianjian Sep 7, 2024
7e87160
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 7, 2024
db1cb53
post processor as in the original codew
edgarriba Sep 7, 2024
82f2062
fix typing
edgarriba Sep 8, 2024
fea1f92
update
shijianjian Sep 8, 2024
6f4b5f3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 8, 2024
f4a8b12
update
shijianjian Sep 8, 2024
9f6a0c2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 8, 2024
33e454c
update
shijianjian Sep 8, 2024
4b70885
Merge branch 'feat/rtdetr_update' of https://github.com/shijianjian/k…
shijianjian Sep 8, 2024
2f3f531
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 8, 2024
6790ed1
update
shijianjian Sep 8, 2024
ee3903e
Merge branch 'feat/rtdetr_update' of https://github.com/shijianjian/k…
shijianjian Sep 8, 2024
fb07bb9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 8, 2024
e84ab3a
update
shijianjian Sep 8, 2024
5f49f28
update
shijianjian Sep 8, 2024
eb4bd54
update
shijianjian Sep 8, 2024
41d94ff
update
shijianjian Sep 8, 2024
0dd67b5
update
shijianjian Sep 8, 2024
03edce3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 8, 2024
11c6f7f
update
shijianjian Sep 8, 2024
4fe3acd
Merge branch 'feat/rtdetr_update' of https://github.com/shijianjian/k…
shijianjian Sep 8, 2024
365de3c
update
shijianjian Sep 8, 2024
224392e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 8, 2024
0a2da30
update
shijianjian Sep 9, 2024
ab5fb53
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions docs/source/models/rt_detr.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,28 @@
Real-Time Detection Transformer (RT-DETR)
=========================================

.. code-block:: python

from kornia.io import load_image
from kornia.models.detector.rtdetr import RTDETRDetectorBuilder

input_img = load_image(img_path)[None] # Load image to BCHW

# NOTE: available models: 'rtdetr_r18vd', 'rtdetr_r34vd', 'rtdetr_r50vd_m', 'rtdetr_r50vd', 'rtdetr_r101vd'.
# NOTE: recommended image scales: [480, 512, 544, 576, 608, 640, 640, 640, 672, 704, 736, 768, 800]
detector = RTDETRDetectorBuilder.build("rtdetr_r18vd", image_size=640)

# get the output boxes
boxes = detector(input_img)

# draw the bounding boxes on the images directly.
output = detector.draw(input_img, output_type="pil")
output[0].save("Kornia-RTDETR-output.png")

# convert the whole model to ONNX directly
RTDETRDetectorBuilder.to_onnx("RTDETR-640.onnx", model_name="rtdetr_r18vd", image_size=640)


.. card::
:link: https://arxiv.org/abs/2304.08069

Expand Down
2 changes: 1 addition & 1 deletion kornia/color/yuv.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def yuv_to_rgb(image: Tensor) -> Tensor:
if not isinstance(image, Tensor):
raise TypeError(f"Input type is not a Tensor. Got {type(image)}")

if len(image.shape) < 3 or image.shape[-3] != 3:
if image.dim() < 3 or image.shape[-3] != 3:
raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")

y: Tensor = image[..., 0, :, :]
Expand Down
21 changes: 17 additions & 4 deletions kornia/contrib/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,16 @@

class ConvNormAct(nn.Sequential):
def __init__(
self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, act: str = "relu", groups: int = 1
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
act: str = "relu",
groups: int = 1,
conv_naming: str = "conv",
norm_naming: str = "norm",
act_naming: str = "act",
) -> None:
super().__init__()
if kernel_size % 2 == 0:
Expand All @@ -23,9 +32,13 @@ def __init__(
padding = 0
else:
padding = (kernel_size - 1) // 2
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, 1, groups, False)
self.norm = nn.BatchNorm2d(out_channels)
self.act = {"relu": nn.ReLU, "silu": nn.SiLU, "none": nn.Identity}[act](inplace=True)
conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, 1, groups, False)
norm = nn.BatchNorm2d(out_channels)
activation = {"relu": nn.ReLU, "silu": nn.SiLU, "none": nn.Identity}[act](inplace=True)

self.__setattr__(conv_naming, conv)
self.__setattr__(norm_naming, norm)
self.__setattr__(act_naming, activation)


# Lightly adapted from
Expand Down
36 changes: 30 additions & 6 deletions kornia/contrib/models/rt_detr/architecture/hybrid_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from __future__ import annotations

import copy
from typing import Optional

import torch
Expand Down Expand Up @@ -80,15 +81,16 @@ class AIFI(Module):
def __init__(self, embed_dim: int, num_heads: int, dim_feedforward: int, dropout: float = 0.0) -> None:
super().__init__()
self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout) # NOTE: batch_first = False
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(embed_dim)

self.linear1 = nn.Linear(embed_dim, dim_feedforward)
self.act = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, embed_dim)

self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.act = nn.GELU()

def forward(self, x: Tensor) -> Tensor:
# using post-norm
Expand Down Expand Up @@ -149,6 +151,20 @@ def build_2d_sincos_pos_emb(
return pos_emb.unsqueeze(1) # (H * W, 1, C)


class TransformerEncoder(nn.Module):
def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
super().__init__()
self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_layers)])
self.num_layers = num_layers

def forward(self, src: Tensor) -> Tensor: # NOTE: Missing src_mask: Tensor = None, pos_embed: Tensor = None
output = src
for layer in self.layers:
output = layer(output)

return output


class CCFM(Module):
def __init__(self, num_fmaps: int, hidden_dim: int, expansion: float = 1.0) -> None:
super().__init__()
Expand Down Expand Up @@ -192,12 +208,20 @@ def forward(self, fmaps: list[Tensor]) -> list[Tensor]:
class HybridEncoder(Module):
def __init__(self, in_channels: list[int], hidden_dim: int, dim_feedforward: int, expansion: float = 1.0) -> None:
super().__init__()
self.input_proj = nn.ModuleList([ConvNormAct(in_ch, hidden_dim, 1, act="none") for in_ch in in_channels])
self.aifi = AIFI(hidden_dim, 8, dim_feedforward)
self.input_proj = nn.ModuleList(
[
ConvNormAct( # To align the naming strategy for the official weights
in_ch, hidden_dim, 1, act="none", conv_naming="0", norm_naming="1", act_naming="2"
)
for in_ch in in_channels
]
)
encoder_layer = AIFI(hidden_dim, 8, dim_feedforward)
self.encoder = nn.Sequential(TransformerEncoder(encoder_layer, 1))
self.ccfm = CCFM(len(in_channels), hidden_dim, expansion)

def forward(self, fmaps: list[Tensor]) -> list[Tensor]:
projected_maps = [proj(fmap) for proj, fmap in zip(self.input_proj, fmaps)]
projected_maps[-1] = self.aifi(projected_maps[-1])
projected_maps[-1] = self.encoder(projected_maps[-1])
new_fmaps = self.ccfm(projected_maps)
return new_fmaps
76 changes: 51 additions & 25 deletions kornia/contrib/models/rt_detr/architecture/resnet_d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from __future__ import annotations

from collections import OrderedDict

from torch import nn

from kornia.contrib.models.common import ConvNormAct
Expand All @@ -15,7 +17,9 @@

def _make_shortcut(in_channels: int, out_channels: int, stride: int) -> Module:
return (
nn.Sequential(nn.AvgPool2d(2, 2), ConvNormAct(in_channels, out_channels, 1, act="none"))
nn.Sequential(
OrderedDict([("pool", nn.AvgPool2d(2, 2)), ("conv", ConvNormAct(in_channels, out_channels, 1, act="none"))])
)
if stride == 2
else ConvNormAct(in_channels, out_channels, 1, act="none")
)
Expand All @@ -28,14 +32,18 @@ def __init__(self, in_channels: int, out_channels: int, stride: int, shortcut: b
KORNIA_CHECK(stride in {1, 2})
super().__init__()
self.convs = nn.Sequential(
ConvNormAct(in_channels, out_channels, 3, stride=stride),
ConvNormAct(out_channels, out_channels, 3, act="none"),
OrderedDict(
[
("branch2a", ConvNormAct(in_channels, out_channels, 3, stride=stride)),
("branch2b", ConvNormAct(out_channels, out_channels, 3, act="none")),
]
)
)
self.shortcut = nn.Identity() if shortcut else _make_shortcut(in_channels, out_channels, stride)
self.short = nn.Identity() if shortcut else _make_shortcut(in_channels, out_channels, stride)
self.relu = nn.ReLU(inplace=True)

def forward(self, x: Tensor) -> Tensor:
return self.relu(self.convs(x) + self.shortcut(x))
return self.relu(self.convs(x) + self.short(x))


class BottleneckD(Module):
Expand All @@ -46,15 +54,25 @@ def __init__(self, in_channels: int, out_channels: int, stride: int, shortcut: b
super().__init__()
expanded_out_channels = out_channels * self.expansion
self.convs = nn.Sequential(
ConvNormAct(in_channels, out_channels, 1),
ConvNormAct(out_channels, out_channels, 3, stride=stride),
ConvNormAct(out_channels, expanded_out_channels, 1, act="none"),
OrderedDict(
[
("branch2a", ConvNormAct(in_channels, out_channels, 1)),
("branch2b", ConvNormAct(out_channels, out_channels, 3, stride=stride)),
("branch2c", ConvNormAct(out_channels, expanded_out_channels, 1, act="none")),
]
)
)
self.shortcut = nn.Identity() if shortcut else _make_shortcut(in_channels, expanded_out_channels, stride)
self.short = nn.Identity() if shortcut else _make_shortcut(in_channels, expanded_out_channels, stride)
self.relu = nn.ReLU(inplace=True)

def forward(self, x: Tensor) -> Tensor:
return self.relu(self.convs(x) + self.shortcut(x))
return self.relu(self.convs(x) + self.short(x))


class Block(nn.Sequential):
def __init__(self, blocks: Module) -> None:
super().__init__()
self.blocks = blocks


class ResNetD(Module):
Expand All @@ -63,35 +81,43 @@ def __init__(self, n_blocks: list[int], block: type[BasicBlockD | BottleneckD])
super().__init__()
in_channels = 64
self.conv1 = nn.Sequential(
ConvNormAct(3, in_channels // 2, 3, stride=2),
ConvNormAct(in_channels // 2, in_channels // 2, 3),
ConvNormAct(in_channels // 2, in_channels, 3),
nn.MaxPool2d(3, stride=2, padding=1),
OrderedDict(
[
("conv1_1", ConvNormAct(3, in_channels // 2, 3, stride=2)),
("conv1_2", ConvNormAct(in_channels // 2, in_channels // 2, 3)),
("conv1_3", ConvNormAct(in_channels // 2, in_channels, 3)),
("pool", nn.MaxPool2d(3, stride=2, padding=1)),
]
)
)

self.res2, in_channels = self.make_stage(in_channels, 64, 1, n_blocks[0], block)
self.res3, in_channels = self.make_stage(in_channels, 128, 2, n_blocks[1], block)
self.res4, in_channels = self.make_stage(in_channels, 256, 2, n_blocks[2], block)
self.res5, in_channels = self.make_stage(in_channels, 512, 2, n_blocks[3], block)
res2, in_channels = self.make_stage(in_channels, 64, 1, n_blocks[0], block)
res3, in_channels = self.make_stage(in_channels, 128, 2, n_blocks[1], block)
res4, in_channels = self.make_stage(in_channels, 256, 2, n_blocks[2], block)
res5, in_channels = self.make_stage(in_channels, 512, 2, n_blocks[3], block)

self.res_layers = nn.ModuleList([res2, res3, res4, res5])

self.out_channels = [ch * block.expansion for ch in [128, 256, 512]]

@staticmethod
def make_stage(
in_channels: int, out_channels: int, stride: int, n_blocks: int, block: type[BasicBlockD | BottleneckD]
) -> tuple[Module, int]:
stage = nn.Sequential(
block(in_channels, out_channels, stride, False),
*[block(out_channels * block.expansion, out_channels, 1, True) for _ in range(n_blocks - 1)],
stage = Block(
nn.Sequential(
block(in_channels, out_channels, stride, False),
*[block(out_channels * block.expansion, out_channels, 1, True) for _ in range(n_blocks - 1)],
)
)
return stage, out_channels * block.expansion

def forward(self, x: Tensor) -> list[Tensor]:
x = self.conv1(x)
res2 = self.res2(x)
res3 = self.res3(res2)
res4 = self.res4(res3)
res5 = self.res5(res4)
res2 = self.res_layers[0](x)
res3 = self.res_layers[1](res2)
res4 = self.res_layers[2](res3)
res5 = self.res_layers[3](res4)
return [res3, res4, res5]

@staticmethod
Expand Down
44 changes: 21 additions & 23 deletions kornia/contrib/models/rt_detr/architecture/rtdetr_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from __future__ import annotations

import copy
from typing import Optional

import torch
Expand Down Expand Up @@ -192,14 +193,10 @@ def forward(
return out


class TransformerDecoder:
def __init__(self, hidden_dim: int, decoder_layers: nn.ModuleList, num_layers: int, eval_idx: int = -1) -> None:
class TransformerDecoder(Module):
def __init__(self, hidden_dim: int, decoder_layer: nn.Module, num_layers: int, eval_idx: int = -1) -> None:
super().__init__()
self.layers = decoder_layers
# TODO: come back to this later
# self.layers = nn.ModuleList([
# copy.deepcopy(decoder_layer) for _ in range(num_layers)
# ])
self.layers = nn.ModuleList([copy.deepcopy(decoder_layer) for _ in range(num_layers)])
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx
Expand Down Expand Up @@ -272,13 +269,16 @@ def __init__(
num_decoder_layers: int,
num_heads: int = 8,
num_decoder_points: int = 4,
# num_levels: int = 3,
num_levels: int = 3,
dropout: float = 0.0,
num_denoising: int = 100,
) -> None:
super().__init__()
self.num_queries = num_queries
# TODO: verify this is correct
self.num_levels = len(in_channels)
if len(in_channels) > num_levels:
raise ValueError(f"`num_levels` cannot be greater than {len(in_channels)}. Got {num_levels}.")
self.num_levels = num_levels

# build the input projection layers
self.input_proj = nn.ModuleList()
Expand All @@ -288,25 +288,23 @@ def __init__(
# https://github.com/lyuwenyu/RT-DETR/blob/main/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L403-L410

# NOTE: need to be integrated with the TransformerDecoderLayer
self.decoder_layers = nn.ModuleList(
[
TransformerDecoderLayer(
embed_dim=hidden_dim,
num_heads=num_heads,
dropout=dropout,
num_levels=len(in_channels),
num_points=num_decoder_points,
)
for _ in range(num_decoder_layers)
]
decoder_layer = TransformerDecoderLayer(
embed_dim=hidden_dim,
num_heads=num_heads,
dropout=dropout,
num_levels=self.num_levels,
num_points=num_decoder_points,
)

self.decoder = TransformerDecoder(
hidden_dim=hidden_dim, decoder_layers=self.decoder_layers, num_layers=num_decoder_layers
hidden_dim=hidden_dim, decoder_layer=decoder_layer, num_layers=num_decoder_layers
)

# denoising part
self.denoising_class_embed = nn.Embedding(num_classes, hidden_dim) # not used in evaluation
if num_denoising > 0:
self.denoising_class_embed = nn.Embedding(
num_classes + 1, hidden_dim, padding_idx=num_classes
) # not used in evaluation

# decoder embedding
self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, num_layers=2)
Expand Down Expand Up @@ -334,7 +332,7 @@ def forward(self, feats: Tensor) -> tuple[Tensor, Tensor]:
)

# decoder
out_bboxes, out_logits = self.decoder.forward(
out_bboxes, out_logits = self.decoder(
target,
init_ref_points_unact,
memory,
Expand Down
Loading
Loading