SOLO代码解读[head]

目前在做实例分割相关的课题,这次介绍基于mmdetection的SOLO代码,主要是head部分的代码解读。

detectors/solo.py

这里里面定义了整体的结构,可见直接继承SingleStageInsDetector类,从config文件可以看出,SOLO的backbone为resne,neck为FPN,head为solo head。所以重点是在solo head。

1
2
3
4
5
6
7
8
9
10
11
12
@DETECTORS.register_module
class SOLO(SingleStageInsDetector):

    def __init__(self,
                 backbone,
                 neck,
                 bbox_head,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None):
        super(SOLO, self).__init__(backbone, neck, bbox_head, train_cfg,
                                   test_cfg, pretrained)

anchor_heads/solo_head.py

这里就是SOLO的重点介绍部分,包括head的定义,以及loss的计算。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
    def _init_layers(self):
        norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
        self.ins_convs = nn.ModuleList()
        self.cate_convs = nn.ModuleList()
        # 在SOLO中self.stacked_convs=7,category和mas branch都是先经过7个conv block最后输出
        for i in range(self.stacked_convs):
            chn = self.in_channels + 2 if i == 0 else self.seg_feat_channels
            self.ins_convs.append(
                ConvModule(
                    chn,
                    self.seg_feat_channels,
                    3,
                    stride=1,
                    padding=1,
                    norm_cfg=norm_cfg,
                    bias=norm_cfg is None))

            chn = self.in_channels if i == 0 else self.seg_feat_channels
            self.cate_convs.append(
                ConvModule(
                    chn,
                    self.seg_feat_channels,
                    3,
                    stride=1,
                    padding=1,
                    norm_cfg=norm_cfg,
                    bias=norm_cfg is None))
        # 因为mask branch的输出通道数由grid决定,所有这里是个list,
        # 即不同的level对应不同的输出conv block
        self.solo_ins_list = nn.ModuleList()
        for seg_num_grid in self.seg_num_grids:
            self.solo_ins_list.append(
                nn.Conv2d(
                    self.seg_feat_channels, seg_num_grid**2, 1))
		# category branch的输出conv block的通道都是类别数-1
        self.solo_cate = nn.Conv2d(
            self.seg_feat_channels, self.cate_out_channels, 3, padding=1)

定义完网络结构后,下面开始前向计算

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
    def forward(self, feats, eval=False):
        # self.split_feats将5个level的feature map尺寸重新处理一下
        new_feats = self.split_feats(feats)
        # 取出每个level的尺寸
        featmap_sizes = [featmap.size()[-2:] for featmap in new_feats]
        upsampled_size = (featmap_sizes[0][0] * 2, featmap_sizes[0][1] * 2)
        # 这是前向计算的主体函数,self.forward_single每次处理一个level
        ins_pred, cate_pred = multi_apply(self.forward_single, new_feats, 
                                          list(range(len(self.seg_num_grids))),
                                          eval=eval, upsampled_size=upsampled_size)
        return ins_pred, cate_pred
    
    # 这里为什么将第一个和最后一个feature map的尺寸进行变换
    def split_feats(self, feats):
        return (F.interpolate(feats[0], scale_factor=0.5, mode='bilinear'), 
                feats[1], 
                feats[2], 
                feats[3], 
                F.interpolate(feats[4], size=feats[3].shape[-2:], mode='bilinear'))

下面的self.forward_single就是前向计算的主体部分,不复杂,注意的是在mask branch需要将x、y坐标信息拼接到对应level的feature map上。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    def forward_single(self, x, idx, eval=False, upsampled_size=None):
        '''
        :param x: fpn每个level的feature map [N,C,H,W]
        :param idx:  [0,1,2,3,4]中的一个,用来指示当前的level级别
        :param eval: False
        :param upsampled_size: 最大feature map/C1 的h,w
        :return:
        '''
        # 老规矩,因为这里有两个branch
        ins_feat = x
        cate_feat = x

        # 这里先处理mask分支,1. 将x,y坐标拼接在feature map上,通道数+2,
        # 这里是将坐标信息concat到feature map上
        x_range = torch.linspace(-1, 1, ins_feat.shape[-1], device=ins_feat.device) # w --> x
        y_range = torch.linspace(-1, 1, ins_feat.shape[-2], device=ins_feat.device) # h --> y
        # 对x_range, y_range 进行扩充
        y, x = torch.meshgrid(y_range, x_range)
        # 将两个坐标扩成4维
        y = y.expand([ins_feat.shape[0], 1, -1, -1])
        x = x.expand([ins_feat.shape[0], 1, -1, -1])
        # 将坐标cancat到feature map的通道上
        coord_feat = torch.cat([x, y], 1)
        ins_feat = torch.cat([ins_feat, coord_feat], 1)
        # 将处理好的新的fearure map送进ins_convs
        for i, ins_layer in enumerate(self.ins_convs):
            ins_feat = ins_layer(ins_feat)
        # 这里将feature map上采样到2H*2W,应该是为了提高mask分割的精度
        ins_feat = F.interpolate(ins_feat, scale_factor=2, mode='bilinear')
        # 这里获得了mask分支的结果
        ins_pred = self.solo_ins_list[idx](ins_feat)
        
        # 这里开始处理category分支
        for i, cate_layer in enumerate(self.cate_convs):
            # 如果是第一个conv,则需要进行采样,因为category分支的尺寸是h=w=S
            if i == self.cate_down_pos:
                seg_num_grid = self.seg_num_grids[idx]
                cate_feat = F.interpolate(cate_feat, size=seg_num_grid, mode='bilinear')
            cate_feat = cate_layer(cate_feat)
        # 这里获得了category分支的结果
        cate_pred = self.solo_cate(cate_feat)

        # 如果使测试模式,
        # 将mask分支的结果取sigmoid,并且上采样到原始C1的尺寸
        # category分支的结果进行points_nms,这个待会再看
        if eval:
            ins_pred = F.interpolate(ins_pred.sigmoid(), size=upsampled_size, mode='bilinear')
            cate_pred = points_nms(cate_pred.sigmoid(), kernel=2).permute(0, 2, 3, 1)
        return ins_pred, cate_pred

此时模型的前向计算就结束了,框架很简洁,接下来的loss计算相对繁琐一些。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
    def loss(self,
             ins_preds,
             cate_preds,
             gt_bbox_list,
             gt_label_list,
             gt_mask_list,
             img_metas,
             cfg,
             gt_bboxes_ignore=None):

        # featmap_sizes里面包含了五个level的mask branch输出尺寸
        featmap_sizes = [featmap.size()[-2:] for featmap in
                         ins_preds]
		# 接下来又遇到multi_apply函数,说明self.solo_target_single每次处理batch中的一张图片
        # 其实从config文件也可以看出来,SOLO也用到了bbox信息
        # 接下来我们直接跳到self.solo_target_single函数
        ins_label_list, cate_label_list, ins_ind_label_list = multi_apply(
            self.solo_target_single,
            gt_bbox_list,
            gt_label_list,
            gt_mask_list,
            featmap_sizes=featmap_sizes)

self.solo_target_single作用是为category和mask branch分支分配gt label

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    def solo_target_single(self,
                               gt_bboxes_raw,
                               gt_labels_raw,
                               gt_masks_raw,
                               featmap_sizes=None):
        '''
        :param gt_bboxes_raw:  [objects,4]
        :param gt_labels_raw:  [objects,1]
        :param gt_masks_raw:   [objects,H,W]
        :param featmap_sizes: []
        :return:
        '''
        device = gt_labels_raw[0].device
        # 这里获得单张图片上所有的object的area [object]
        gt_areas = torch.sqrt((gt_bboxes_raw[:, 2] - gt_bboxes_raw[:, 0]) * (
                gt_bboxes_raw[:, 3] - gt_bboxes_raw[:, 1]))

        ins_label_list = []
        cate_label_list = []
        ins_ind_label_list = []
        # 这里每次遍历一个level
        for (lower_bound, upper_bound), stride, featmap_size, num_grid \
                in zip(self.scale_ranges, self.strides, featmap_sizes, self.seg_num_grids):
            # self.scale_ranges:((1, 96), (48, 192), (96, 384), (192, 768), (384, 2048))
            #     每个level的object的面积范围,这是缓解object重叠的一种方式,
            #     将不同尺度的object分配到不同的level取预测
            # self.strides :[8, 8, 16, 32, 32],对应与原图的采样比例 2**n
            #     这里与之前的split_feat也对应上了,第一个下采样到1/2,最后一个上采样到2倍
            # featmap_sizes : 每个level的尺寸
            # self.seg_num_grids : [40, 36, 24, 16, 12]

            # ins_label的尺寸与输出一致[S*S,H,W]
            ins_label = torch.zeros([num_grid ** 2, featmap_size[0], featmap_size[1]], dtype=torch.uint8, device=device)
            # cate_label的尺寸为[S,S]一个通道直接表示gt的label.
            cate_label = torch.zeros([num_grid, num_grid], dtype=torch.int64, device=device)
            # mask branch中S*S个通道的是否有匹配的gt
            ins_ind_label = torch.zeros([num_grid ** 2], dtype=torch.bool, device=device)

            # 这里返回的是area满足面积范围的索引,
            # bbox是为了将不同大小的object分配到不同level
            hit_indices = ((gt_areas >= lower_bound) & (gt_areas <= upper_bound)).nonzero().flatten()
            # 如果没有满足的object,直接跳掉下一个level
            if len(hit_indices) == 0:
                ins_label_list.append(ins_label)
                cate_label_list.append(cate_label)
                ins_ind_label_list.append(ins_ind_label)
                continue
            # 取出满足面积要求的object
            gt_bboxes = gt_bboxes_raw[hit_indices]
            gt_labels = gt_labels_raw[hit_indices]
            # 这里为什么要.cpu().numpy()?
            # 调试发现gt_masks使numpy,并不是tensor,为什么?
            gt_masks = gt_masks_raw[hit_indices.cpu().numpy(), ...]
            # 为什么要乘self.sigma,不太懂. 0.1*h,0.1*w
            half_ws = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * self.sigma
            half_hs = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1]) * self.sigma

            # 因为maks分支输出的h,w是原level的2倍
            output_stride = stride / 2

            # 遍历每个object,先为每个grid分配label,根据矩形框和mask的中心
            for seg_mask, gt_label, half_h, half_w in zip(gt_masks, gt_labels, half_hs, half_ws):
                #如果object的mask小于10个像素,则忽略
                if seg_mask.sum() < 10:
                   continue

                # mass center
                # featmap_sizes[0][0] * 4是原图的尺寸(貌似可能不等于,因为图像尺寸不是统一的)
                upsampled_size = (featmap_sizes[0][0] * 4, featmap_sizes[0][1] * 4)
                # 找出object的中心坐标
                center_h, center_w = ndimage.measurements.center_of_mass(seg_mask)
                # 将object的中心位置映射到grid上
                coord_w = int((center_w / upsampled_size[1]) // (1. / num_grid))
                coord_h = int((center_h / upsampled_size[0]) // (1. / num_grid))

                # left, top, right, down
                # 这里的四个值标定了一个矩形框,大小为0.1的原矩形框
                # 我这样理解,如果一个object的面积稍大,则一个object对应的grid是多个
                top_box = max(0, int(((center_h - half_h) / upsampled_size[0]) // (1. / num_grid)))
                down_box = min(num_grid - 1, int(((center_h + half_h) / upsampled_size[0]) // (1. / num_grid)))
                left_box = max(0, int(((center_w - half_w) / upsampled_size[1]) // (1. / num_grid)))
                right_box = min(num_grid - 1, int(((center_w + half_w) / upsampled_size[1]) // (1. / num_grid)))

                # 判断映射到grid里的中心是否在上述的矩形框里
                # 当top==down==left==right时说明当前object只占S*S中的一个grid
                top = max(top_box, coord_h-1)
                down = min(down_box, coord_h+1)
                left = max(coord_w-1, left_box)
                right = min(right_box, coord_w+1)
                # 将该位置的label记为gt_label
                cate_label[top:(down+1), left:(right+1)] = gt_label

                # 将seg_mask resize到当前level的mask branch的输出尺寸
                seg_mask = mmcv.imrescale(seg_mask, scale=1. / output_stride)
                seg_mask = torch.Tensor(seg_mask)
                # 为mask branch分配gt
                for i in range(top, down+1):
                    for j in range(left, right+1):
                        label = int(i * num_grid + j)
                        ins_label[label, :seg_mask.shape[0], :seg_mask.shape[1]] = seg_mask
                        ins_ind_label[label] = True
            ins_label_list.append(ins_label)
            cate_label_list.append(cate_label)
            ins_ind_label_list.append(ins_ind_label)
        return ins_label_list, cate_label_list, ins_ind_label_list
  • ins_label_list: 保存了整个batch的mask branch的gt label
  • cate_label_list: 整个batch的category branch的gt label
  • ins_ind_label_list: 整个batch标识S*S通道存在object的通道

我们再返回loss函数接着看

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
        # zip(*ins_label_list) [num_levels, N, S*S, 2H, 2W]
        #    返回值每一组为整个batch的同一个level的mask branch gt label.
        # zip(*ins_ind_label_list)
        #    同上
        # zip(ins_labels_level, ins_ind_labels_level)
        #    每个image的同level
        # 这里是将gt构造成pred的形式,便于后面的loss计算
        ins_labels = [torch.cat([ins_labels_level_img[ins_ind_labels_level_img, ...]
                                 for ins_labels_level_img, ins_ind_labels_level_img in
                                 zip(ins_labels_level, ins_ind_labels_level)], 0)
                      for ins_labels_level, ins_ind_labels_level in zip(zip(*ins_label_list), zip(*ins_ind_label_list))]

        ins_preds = [torch.cat([ins_preds_level_img[ins_ind_labels_level_img, ...]
                                for ins_preds_level_img, ins_ind_labels_level_img in
                                zip(ins_preds_level, ins_ind_labels_level)], 0)
                     for ins_preds_level, ins_ind_labels_level in zip(ins_preds, zip(*ins_ind_label_list))]

        # ins_ind_labels [S*S*batch_size] eg. 40*40*2=3200
        ins_ind_labels = [
            torch.cat([ins_ind_labels_level_img.flatten()
                       for ins_ind_labels_level_img in ins_ind_labels_level])
            for ins_ind_labels_level in zip(*ins_ind_label_list)
        ]
        # flatten_ins_ind_labels [7744]
        flatten_ins_ind_labels = torch.cat(ins_ind_labels)

        num_ins = flatten_ins_ind_labels.int().sum()

最后就是实际的loss计算

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
        # dice loss
        # mask部分使用dict loss
        loss_ins = []
        for input, target in zip(ins_preds, ins_labels):
            if input.size()[0] == 0:
                continue
            input = torch.sigmoid(input)
            loss_ins.append(dice_loss(input, target))
        loss_ins = torch.cat(loss_ins).mean()
        loss_ins = loss_ins * self.ins_loss_weight

        # cate
        # category部分使用FocalLoss
        cate_labels = [
            torch.cat([cate_labels_level_img.flatten()
                       for cate_labels_level_img in cate_labels_level])
            for cate_labels_level in zip(*cate_label_list)
        ]
        flatten_cate_labels = torch.cat(cate_labels)

        cate_preds = [
            cate_pred.permute(0, 2, 3, 1).reshape(-1, self.cate_out_channels)
            for cate_pred in cate_preds
        ]
        flatten_cate_preds = torch.cat(cate_preds)

        loss_cate = self.loss_cate(flatten_cate_preds, flatten_cate_labels, avg_factor=num_ins + 1)
        return dict(
            loss_ins=loss_ins,
            loss_cate=loss_cate)