-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
update baseline model of semantic segmentation
- Loading branch information
Showing
71 changed files
with
2,059 additions
and
0 deletions.
There are no files selected for viewing
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Semantic Segmentation with Distribution Alignment | ||
|
||
## Setup | ||
1. Install mmcv-full | ||
2. Clone mmsegmentation via git and install it | ||
2. Download the checkpoints of pre-trained model from the mmsegmentation | ||
3. Prepare the dataset of ADE20K | ||
|
||
## Training | ||
Please edit the `exps/setup_env.sh` to use your own path | ||
|
||
1. FCN model | ||
``` | ||
source exps/setup_env.sh | ||
bash exps/ade20k_fcn_disalign/disalign_fcn_r50-d8_512x512_160k_ade20k.sh | ||
``` | ||
|
||
## Model Zoo | ||
|
||
- Baseline Results | ||
|
||
```bash | ||
# ResNet-50 Backbone | ||
bash exps/fcn_r50-d8_512x512_160k_ade20k.sh | ||
# ResNet-101 Backbone | ||
bash exps/fcn_r101-d8_512x512_160k_ade20k.sh | ||
# ResNeSt-101 Backbone | ||
bash exps/fcn_s101-d8_512x512_160k_ade20k.sh | ||
``` | ||
|
||
| Method |AugTest| mIoU | mAcc | mHeadIoU | mBodyIoU | mTailIoU | mHeadAcc | mBodyAcc | mTailAcc |Log| | ||
|---------------|-------|------|-------|----------|----------|----------|----------|----------|----------|---| | ||
|FCN-R50-D8-160K| False | 36.1 | 45.41 | 62.53 | 38.12 | 27.58 | 76.88 | 48.82 | 34.51 || | ||
|FCN-R50-D8-160K| True |38.08 | 46.27 | 64.64 | 39.95 | 29.62 | 78.64 | 49.3 | 35.41 || | ||
|FCN-R101-D8-160K| False | 39.91 | 49.62 | 65.28 | 41.96 | 31.65 | 79.14 | 52.58 | 39.58 || | ||
|FCN-R101-D8-160K| True | 41.4 | 50.21 | 66.97 | 43.32 | 33.17 | 80.61 | 52.88 | 40.15 || | ||
|FCN-S101-D8-160K| False | 45.62 | 57.76 | 66.6 | 47.54 | 38.63 | 78.77 | 62.14 | 48.94 || | ||
|FCN-S101-D8-160K| True | 46.16 | 57.34 | 67.56 | 47.99 | 39.12 | 79.37 | 61.73 | 48.24 || | ||
|
||
- DisAlign Results | ||
```bash | ||
# TODO: Updated | ||
``` |
32 changes: 32 additions & 0 deletions
32
semantic_seg/configs/ade20k_fcn_disalign/disalign_fcn_r50-d8_512x512_160k_ade20k.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import os | ||
from mmcv import Config | ||
|
||
mmseg_home = os.environ["MMSEG_HOME"] | ||
|
||
cfg = Config.fromfile(os.path.join( | ||
mmseg_home, | ||
'configs/fcn/fcn_r50-d8_512x512_160k_ade20k.py' | ||
) | ||
) | ||
# runtime settings | ||
cfg.runner = dict(type='IterBasedRunner', max_iters=4000) | ||
cfg.checkpoint_config = dict(by_epoch=False, interval=800) | ||
cfg.evaluation = dict(interval=800, metric='mIoU', pre_eval=True) | ||
|
||
# model | ||
cfg.model.decode_head.type="DisAlignFCNHead" | ||
cfg.model.decode_head.loss_decode=dict( | ||
type="GRWCrossEntropyLoss", | ||
use_sigmoid=False, | ||
loss_weight=1.0, | ||
class_weight="./data/ade/ADEChallengeData2016/objectInfo150.txt", | ||
exp_scale=0.2 | ||
) | ||
|
||
cfg.model.auxiliary_head.loss_decode.loss_weight=0.0 | ||
|
||
# dataset | ||
cfg.data.val.type="ADE20KLTDataset" | ||
cfg.data.test.type="ADE20KLTDataset" | ||
cfg.data.train.type="ADE20KLTDataset" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
from mmseg.datasets.builder import DATASETS | ||
from mmseg.datasets import ADE20KDataset | ||
import numpy as np | ||
from mmseg.core import mean_iou | ||
from mmcv.utils import print_log | ||
from functools import reduce | ||
from prettytable import PrettyTable | ||
from mmseg.core import eval_metrics, intersect_and_union, pre_eval_to_metrics | ||
import mmcv | ||
from collections import OrderedDict | ||
|
||
|
||
@DATASETS.register_module() | ||
class ADE20KLTDataset(ADE20KDataset): | ||
def evaluate(self, | ||
results, | ||
metric='mIoU', | ||
logger=None, | ||
gt_seg_maps=None, | ||
**kwargs): | ||
"""Evaluate the dataset. | ||
Args: | ||
results (list[tuple[torch.Tensor]] | list[str]): per image pre_eval | ||
results or predict segmentation map for computing evaluation | ||
metric. | ||
metric (str | list[str]): Metrics to be evaluated. 'mIoU', | ||
'mDice' and 'mFscore' are supported. | ||
logger (logging.Logger | None | str): Logger used for printing | ||
related information during evaluation. Default: None. | ||
gt_seg_maps (generator[ndarray]): Custom gt seg maps as input, | ||
used in ConcatDataset | ||
Returns: | ||
dict[str, float]: Default metrics. | ||
""" | ||
if isinstance(metric, str): | ||
metric = [metric] | ||
allowed_metrics = ['mIoU', 'mDice', 'mFscore'] | ||
if not set(metric).issubset(set(allowed_metrics)): | ||
raise KeyError('metric {} is not supported'.format(metric)) | ||
|
||
eval_results = {} | ||
# test a list of files | ||
if mmcv.is_list_of(results, np.ndarray) or mmcv.is_list_of( | ||
results, str): | ||
if gt_seg_maps is None: | ||
gt_seg_maps = self.get_gt_seg_maps() | ||
num_classes = len(self.CLASSES) | ||
ret_metrics = eval_metrics( | ||
results, | ||
gt_seg_maps, | ||
num_classes, | ||
self.ignore_index, | ||
metric, | ||
label_map=self.label_map, | ||
reduce_zero_label=self.reduce_zero_label) | ||
# test a list of pre_eval_results | ||
else: | ||
ret_metrics = pre_eval_to_metrics(results, metric) | ||
|
||
# Because dataset.CLASSES is required for per-eval. | ||
if self.CLASSES is None: | ||
class_names = tuple(range(num_classes)) | ||
else: | ||
class_names = self.CLASSES | ||
|
||
# summary table | ||
ret_metrics_summary = OrderedDict({ | ||
ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2) | ||
for ret_metric, ret_metric_value in ret_metrics.items() | ||
}) | ||
|
||
# each class table | ||
ret_metrics.pop('aAcc', None) | ||
ret_metrics_class = OrderedDict({ | ||
ret_metric: np.round(ret_metric_value * 100, 2) | ||
for ret_metric, ret_metric_value in ret_metrics.items() | ||
}) | ||
ret_metrics_class.update({'Class': class_names}) | ||
ret_metrics_class.move_to_end('Class', last=False) | ||
|
||
# for logger | ||
class_table_data = PrettyTable() | ||
for key, val in ret_metrics_class.items(): | ||
class_table_data.add_column(key, val) | ||
|
||
summary_table_data = PrettyTable() | ||
|
||
|
||
#>>>>> Print Long-tail Metrics | ||
iou = ret_metrics_class['IoU'] | ||
acc = ret_metrics_class["Acc"] | ||
|
||
ret_metrics_summary.update( | ||
{ | ||
"HeadIoU": np.round(np.nanmean(iou[:20]) , 2), | ||
"BodyIoU": np.round(np.nanmean(iou[20:75]) , 2), | ||
"TailIoU": np.round(np.nanmean(iou[75:]) , 2), | ||
} | ||
) | ||
|
||
ret_metrics_summary.update( | ||
{ | ||
"HeadAcc": np.round(np.nanmean(acc[:20]), 2), | ||
"BodyAcc": np.round(np.nanmean(acc[20:75]), 2), | ||
"TailAcc": np.round(np.nanmean(acc[75:]), 2), | ||
} | ||
) | ||
#>>>>> Print Long-tail Metrics | ||
|
||
for key, val in ret_metrics_summary.items(): | ||
if key == 'aAcc': | ||
summary_table_data.add_column(key, [val]) | ||
else: | ||
summary_table_data.add_column('m' + key, [val]) | ||
|
||
print_log('per class results:', logger) | ||
print_log('\n' + class_table_data.get_string(), logger=logger) | ||
print_log('Summary:', logger) | ||
print_log('\n' + summary_table_data.get_string(), logger=logger) | ||
|
||
# each metric dict | ||
for key, value in ret_metrics_summary.items(): | ||
if key == 'aAcc': | ||
eval_results[key] = value / 100.0 | ||
else: | ||
eval_results['m' + key] = value / 100.0 | ||
|
||
ret_metrics_class.pop('Class', None) | ||
for key, value in ret_metrics_class.items(): | ||
eval_results.update({ | ||
key + '.' + str(name): value[idx] / 100.0 | ||
for idx, name in enumerate(class_names) | ||
}) | ||
|
||
return eval_results |
101 changes: 101 additions & 0 deletions
101
semantic_seg/disalign/models/decode_heads/fcn_head_disalign.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import torch | ||
import torch.nn as nn | ||
from mmcv.cnn import ConvModule | ||
|
||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead | ||
from mmseg.models.builder import HEADS | ||
|
||
|
||
@HEADS.register_module() | ||
class DisAlignFCNHead(BaseDecodeHead): | ||
"""Fully Convolution Networks for Semantic Segmentation. | ||
This head is implemented of `FCNNet <https://arxiv.org/abs/1411.4038>`_. | ||
Args: | ||
num_convs (int): Number of convs in the head. Default: 2. | ||
kernel_size (int): The kernel size for convs in the head. Default: 3. | ||
concat_input (bool): Whether concat the input and output of convs | ||
before classification layer. | ||
dilation (int): The dilation rate for convs in the head. Default: 1. | ||
""" | ||
|
||
def __init__(self, | ||
num_convs=2, | ||
kernel_size=3, | ||
concat_input=True, | ||
dilation=1, | ||
**kwargs): | ||
assert num_convs >= 0 and dilation > 0 and isinstance(dilation, int) | ||
self.num_convs = num_convs | ||
self.concat_input = concat_input | ||
self.kernel_size = kernel_size | ||
super(DisAlignFCNHead, self).__init__(**kwargs) | ||
if num_convs == 0: | ||
assert self.in_channels == self.channels | ||
|
||
conv_padding = (kernel_size // 2) * dilation | ||
convs = [] | ||
convs.append( | ||
ConvModule( | ||
self.in_channels, | ||
self.channels, | ||
kernel_size=kernel_size, | ||
padding=conv_padding, | ||
dilation=dilation, | ||
conv_cfg=self.conv_cfg, | ||
norm_cfg=self.norm_cfg, | ||
act_cfg=self.act_cfg)) | ||
for i in range(num_convs - 1): | ||
convs.append( | ||
ConvModule( | ||
self.channels, | ||
self.channels, | ||
kernel_size=kernel_size, | ||
padding=conv_padding, | ||
dilation=dilation, | ||
conv_cfg=self.conv_cfg, | ||
norm_cfg=self.norm_cfg, | ||
act_cfg=self.act_cfg)) | ||
if num_convs == 0: | ||
self.convs = nn.Identity() | ||
else: | ||
self.convs = nn.Sequential(*convs) | ||
if self.concat_input: | ||
self.conv_cat = ConvModule( | ||
self.in_channels + self.channels, | ||
self.channels, | ||
kernel_size=kernel_size, | ||
padding=kernel_size // 2, | ||
conv_cfg=self.conv_cfg, | ||
norm_cfg=self.norm_cfg, | ||
act_cfg=self.act_cfg) | ||
|
||
# Magnitude and Margin of DisAlign | ||
self.logit_scale = nn.Parameter(torch.ones(1,self.num_classes, 1, 1)) | ||
self.logit_bias = nn.Parameter(torch.zeros(1,self.num_classes, 1, 1)) | ||
# Confidence function | ||
self.confidence_layer = ConvModule( | ||
self.channels, | ||
1, | ||
kernel_size=1, | ||
conv_cfg=self.conv_cfg, | ||
norm_cfg=self.norm_cfg, | ||
act_cfg=self.act_cfg | ||
) | ||
|
||
def forward(self, inputs): | ||
"""Forward function.""" | ||
x = self._transform_inputs(inputs) | ||
output = self.convs(x) | ||
if self.concat_input: | ||
output = self.conv_cat(torch.cat([x, output], dim=1)) | ||
|
||
confidence = self.confidence_layer(output).sigmoid() | ||
output = self.cls_seg(output) | ||
|
||
# only adjust the foreground classification scores | ||
scores_tmp = confidence * (output * self.logit_scale + self.logit_bias) | ||
output = scores_tmp + (1 - confidence) * output | ||
|
||
return output |
Oops, something went wrong.