Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support different number of input channels to YOLOX backbone #1239

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions exps/default/yolox_nano.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(self):
super(Exp, self).__init__()
self.depth = 0.33
self.width = 0.25
self.backbone_in_channels = 3
self.input_size = (416, 416)
self.random_size = (10, 20)
self.mosaic_scale = (0.5, 1.5)
Expand All @@ -34,8 +35,12 @@ def init_yolo(M):
in_channels = [256, 512, 1024]
# NANO model use depthwise = True, which is main difference.
backbone = YOLOPAFPN(
self.depth, self.width, in_channels=in_channels,
act=self.act, depthwise=True,
self.depth,
self.width,
backbone_in_channels=self.backbone_in_channels,
in_channels=in_channels,
act=self.act,
depthwise=True,
)
head = YOLOXHead(
self.num_classes, self.width, in_channels=in_channels,
Expand Down
14 changes: 12 additions & 2 deletions yolox/exp/yolox_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ def __init__(self):
super().__init__()

# ---------------- model config ---------------- #
# number of input channels, e.g. 3 for RGB input
self.backbone_in_channels = 3
# detect classes number of model
self.num_classes = 80
# factor of model depth
Expand Down Expand Up @@ -118,8 +120,16 @@ def init_yolo(M):

if getattr(self, "model", None) is None:
in_channels = [256, 512, 1024]
backbone = YOLOPAFPN(self.depth, self.width, in_channels=in_channels, act=self.act)
head = YOLOXHead(self.num_classes, self.width, in_channels=in_channels, act=self.act)
backbone = YOLOPAFPN(
self.depth,
self.width,
backbone_in_channels=self.backbone_in_channels,
in_channels=in_channels,
act=self.act,
)
head = YOLOXHead(
self.num_classes, self.width, in_channels=in_channels, act=self.act
)
self.model = YOLOX(backbone, head)

self.model.apply(init_yolo)
Expand Down
61 changes: 44 additions & 17 deletions yolox/models/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@


def create_yolox_model(
name: str, pretrained: bool = True, num_classes: int = 80, device=None
name: str,
pretrained: bool = True,
backbone_in_channels: int = 3,
num_classes: int = 80,
device=None,
) -> nn.Module:
"""creates and loads a YOLOX model

Expand All @@ -48,11 +52,20 @@ def create_yolox_model(
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device = torch.device(device)

assert name in _CKPT_FULL_PATH, f"user should use one of value in {_CKPT_FULL_PATH.keys()}"
assert (
name in _CKPT_FULL_PATH
), f"user should use one of value in {_CKPT_FULL_PATH.keys()}"
exp: Exp = get_exp(exp_name=name)
exp.backbone_in_channels = backbone_in_channels
exp.num_classes = num_classes
yolox_model = exp.get_model()
if pretrained and num_classes == 80:
if pretrained:
assert (
backbone_in_channels == 3
), f"There are no pretrained weights for the model whose number of input channels are {backbone_in_channels}"
assert (
num_classes == 80
), f"There are no pretrained weights for the model whose number of output classes are {num_classes}"
weights_url = _CKPT_FULL_PATH[name]
ckpt = load_state_dict_from_url(weights_url, map_location="cpu")
if "model" in ckpt:
Expand All @@ -63,29 +76,43 @@ def create_yolox_model(
return yolox_model


def yolox_nano(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-nano", pretrained, num_classes, device)
def yolox_nano(pretrained=True, backbone_in_channels=3, num_classes=80, device=None):
return create_yolox_model(
"yolox-nano", pretrained, backbone_in_channels, num_classes, device
)


def yolox_tiny(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-tiny", pretrained, num_classes, device)
def yolox_tiny(pretrained=True, backbone_in_channels=3, num_classes=80, device=None):
return create_yolox_model(
"yolox-tiny", pretrained, backbone_in_channels, num_classes, device
)


def yolox_s(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-s", pretrained, num_classes, device)
def yolox_s(pretrained=True, backbone_in_channels=3, num_classes=80, device=None):
return create_yolox_model(
"yolox-s", pretrained, backbone_in_channels, num_classes, device
)


def yolox_m(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-m", pretrained, num_classes, device)
def yolox_m(pretrained=True, backbone_in_channels=3, num_classes=80, device=None):
return create_yolox_model(
"yolox-m", pretrained, backbone_in_channels, num_classes, device
)


def yolox_l(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-l", pretrained, num_classes, device)
def yolox_l(pretrained=True, backbone_in_channels=3, num_classes=80, device=None):
return create_yolox_model(
"yolox-l", pretrained, backbone_in_channels, num_classes, device
)


def yolox_x(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-x", pretrained, num_classes, device)
def yolox_x(pretrained=True, backbone_in_channels=3, num_classes=80, device=None):
return create_yolox_model(
"yolox-x", pretrained, backbone_in_channels, num_classes, device
)


def yolov3(pretrained=True, num_classes=80, device=None):
return create_yolox_model("yolox-tiny", pretrained, num_classes, device)
def yolov3(pretrained=True, backbone_in_channels=3, num_classes=80, device=None):
return create_yolox_model(
"yolox-tiny", pretrained, backbone_in_channels, num_classes, device
)
3 changes: 2 additions & 1 deletion yolox/models/darknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(
self,
dep_mul,
wid_mul,
in_channels=3,
out_features=("dark3", "dark4", "dark5"),
depthwise=False,
act="silu",
Expand All @@ -112,7 +113,7 @@ def __init__(
base_depth = max(round(dep_mul * 3), 1) # 3

# stem
self.stem = Focus(3, base_channels, ksize=3, act=act)
self.stem = Focus(in_channels, base_channels, ksize=3, act=act)

# dark2
self.dark2 = nn.Sequential(
Expand Down
3 changes: 2 additions & 1 deletion yolox/models/yolo_fpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ class YOLOFPN(nn.Module):
def __init__(
self,
depth=53,
backbone_in_channels=3,
in_features=["dark3", "dark4", "dark5"],
):
super().__init__()

self.backbone = Darknet(depth)
self.backbone = Darknet(depth, in_channels=backbone_in_channels)
self.in_features = in_features

# out 1
Expand Down
5 changes: 4 additions & 1 deletion yolox/models/yolo_pafpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@ def __init__(
self,
depth=1.0,
width=1.0,
backbone_in_channels=3,
in_features=("dark3", "dark4", "dark5"),
in_channels=[256, 512, 1024],
depthwise=False,
act="silu",
):
super().__init__()
self.backbone = CSPDarknet(depth, width, depthwise=depthwise, act=act)
self.backbone = CSPDarknet(
depth, width, in_channels=backbone_in_channels, depthwise=depthwise, act=act
)
self.in_features = in_features
self.in_channels = in_channels
Conv = DWConv if depthwise else BaseConv
Expand Down