AI编程助手
AI免费问答

【论文复现赛第六期-语义分割】CCNet

P粉084495128   2025-07-31 13:39   837浏览 原创
本文复现了CCNet语义分割模型,其核心为Criss-Cross Attention模块,通过循环操作让像素建立联系以获取丰富语义。使用PaddleSeg复现,采用ResNet101骨干网,在Cityscapes验证集上mIoU达80.95%,已合入PaddleSeg,还提供了训练、验证等流程及复现经验。

【论文复现赛第六期-语义分割】ccnet - php中文网

【论文复现赛第六期-语义分割】CCNet: Criss-Cross Attention for Semantic Segmentation

paper:CCNet: Criss-Cross Attention for Semantic Segmentation
github:https://github.com/speedinghzl/CCNet/tree/pure-python
复现地址: https://github.com/justld/CCNet_paddle

上下文信息在语义分割任务中非常重要,CCNet提出了criss-cross attention模块,同时引入循环操作,使得图片中每个像素都可以和其他像素建立联系,从而使得每个像素都可以获得丰富的语义信息。本项目使用PaddleSeg复现CCNet,在Cityscapes验证集上miou为80.95%,该算法已被PaddleSeg合入。

模型预测结果如下(图片来自cityscapes val):【论文复现赛第六期-语义分割】CCNet - php中文网        

一、Criss-Cross Attention Module

Criss-Cross Attention Module是本文的核心,该模块使得不同位置的像素建立联系,从而丰富语义信息。特征图经过该模块每个像素即可得到其横向和纵向所有像素的语义信息,故只需要2个Criss-Cross Attention Module,每个像素即可与其他所有像素建立联系,从而其丰富语义特征。 【论文复现赛第六期-语义分割】CCNet - php中文网 假设输入为X:[N, C, H, W],以纵向为例说明计算过程:

①通过1x1卷积,得到 Q_h:[N, Cr, H, W],K_h:[N, Cr, H, W], V_h:[N, C, H, W] (Q_w\K_w\V_w同理);

②维度变换,reshape得到 Q_h:[N * W,H,Cr],K_h:[N * W,Cr,H], V_h: [N * W,C,H] ;

③Q_h和K_h矩阵乘法,得到energy_h:[N * W, H, H];(源码中Enegy_H计算时加上了个维度为[N*W, H, H]的对角-inf矩阵,但是energy_w计算时没加,有点没搞懂。。)

④类似上面的流程,得到energy_h:[N * W, H, H]和energy_w:[N * H, W, W],reshape后维度变换得到energy_h:[N, H, W, H]和energy_w:[N, H, W, W],拼接得到energy:[N, H, W, H + W];

⑤在energy最后一个维度使用softmax,得到attention系数;

⑥将attention系数拆分为attn_h:[N, H, W, H]和attn_w:[N, H, W, W],维度变换后与V_h和V_w分别相乘得到输出out_h和out_w;

⑦将out_h+out_w,并乘上一个系数γ(可学习参数),再加上residual connection,得到最终输出。

其pytorch源码如下:

def INF(B,H,W):
     return -torch.diag(torch.tensor(float("inf")).cuda().repeat(H),0).unsqueeze(0).repeat(B*W,1,1) 
 
class CrissCrossAttention(nn.Module):
    """ Criss-Cross Attention Module"""
    def __init__(self, in_dim):
        super(CrissCrossAttention,self).__init__()
        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.softmax = Softmax(dim=3)
        self.INF = INF
        self.gamma = nn.Parameter(torch.zeros(1)) 
 
    def forward(self, x):
        m_batchsize, _, height, width = x.size()
        proj_query = self.query_conv(x)
        proj_query_H = proj_query.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1)
        proj_query_W = proj_query.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1)
        proj_key = self.key_conv(x)
        proj_key_H = proj_key.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
        proj_key_W = proj_key.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
        proj_value = self.value_conv(x)
        proj_value_H = proj_value.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
        proj_value_W = proj_value.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
        energy_H = (torch.bmm(proj_query_H, proj_key_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3)
        energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize,height,width,width)
        concate = self.softmax(torch.cat([energy_H, energy_W], 3))
 
        att_H = concate[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height)        #print(concate)
        #print(att_H) 
        att_W = concate[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width)
        out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1)
        out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3)        #print(out_H.size(),out_W.size())
        return self.gamma*(out_H + out_W) + x

       

二、Recurrent Criss-Cross Attention (RCCA)

如下图左,单个Criss-Cross Attention Module可以使得某像素与其横向和纵向其他像素建立联系,当2个Criss-Cross Attention Module串行时,即可与其他所有像素建立联系。【论文复现赛第六期-语义分割】CCNet - php中文网        

三、网络结构

CCNet网络结构如下图所示,CNN表示特征提取器(backbone),Reduction减少特征图的通道数以减少后续计算量,Criss-Cross Attention用来建立不同位置像素间的联系从而丰富其语义信息,R表示Criss-Cross Attention Module的循环次数,注意多个Criss-Cross Attention Module共享参数。

【论文复现赛第六期-语义分割】CCNet - php中文网        

四、实验结果

在cityscapes验证集上,CCNet表现如下(每个配置训练3次,数据来自官方repo):

R cityscapes val miou link
1 77.31 & 77.91 & 76.89 77.91
2 79.74 & 79.22 & 78.40 79.74
2+OHEM 78.67 & 80.00 & 79.83 80.00

五、复现结果

本次复现的目标是CCNet-resnet101 R=2+OHEM在cityscapes验证集 mIOU= 80.0%,复现的miou为80.95%。详情见下表:

Model Backbone Resolution Training Iters mIoU mIoU (flip) mIoU (ms+flip) Links
CCNet ResNet101_OS8 769x769 60000 80.95% 81.23% 81.32% model|log|vdl

六、快速体验

运行以下cell,快速体验CCNet。

In [ ]
# step 1: unzip data%cd ~/PaddleSeg/
!mkdir data
!tar -xf ~/data/data64550/cityscapes.tar -C data/
%cd ~/
   
In [ ]
# step 2: 训练%cd ~/PaddleSeg
!python train.py --config configs/ccnet/ccnet_resnet101_os8_cityscapes_769x769_60k.yml \
    --do_eval --use_vdl --log_iter 100 --save_interval 4000 --save_dir output
   
In [ ]
# step 3: val%cd ~/PaddleSeg/
!python val.py \
       --config configs/ccnet/ccnet_resnet101_os8_cityscapes_769x769_60k.yml \
       --model_path output/best_model/model.pdparams       # --model_path /home/aistudio/converted_ddrnet23_imagenet.pdparams
   
In [ ]
# step 4: val flip%cd ~/PaddleSeg/
!python val.py \
       --config configs/ccnet/ccnet_resnet101_os8_cityscapes_769x769_60k.yml \
       --model_path output/best_model/model.pdparams \
       --aug_eval \
       --flip_horizontal
   
In [ ]
# step 5: val ms flip %cd ~/PaddleSeg/
!python val.py \
       --config configs/ccnet/ccnet_resnet101_os8_cityscapes_769x769_60k.yml \
       --model_path output/best_model/model.pdparams \
       --aug_eval \
       --scales 0.75 1.0 1.25 \
       --flip_horizontal
   
In [ ]
# step 6: 预测, 预测结果在~/PaddleSeg/output/result文件夹内%cd ~/PaddleSeg/
!python predict.py \
       --config configs/ccnet/ccnet_resnet101_os8_cityscapes_769x769_60k.yml \
       --model_path output/best_model/model.pdparams \
       --image_path data/cityscapes/leftImg8bit/val/frankfurt/frankfurt_000000_000294_leftImg8bit.png \
       --save_dir output/result
   
In [ ]
# step 7: export%cd ~/PaddleSeg
!python export.py \
       --config configs/ccnet/ccnet_resnet101_os8_cityscapes_769x769_60k.yml \
       --model_path output/best_model/model.pdparams \
       --save_dir output
   
In [ ]
# test tipc 1: prepare data%cd ~/PaddleSeg/
!bash test_tipc/prepare.sh ./test_tipc/configs/ccnet/train_infer_python.txt 'lite_train_lite_infer'
   
In [ ]
# test tipc 2: pip install%cd ~/PaddleSeg/test_tipc/
!pip install -r requirements.txt
   
In [ ]
# test tipc 3: 安装auto_log%cd ~/# !git clone https://github.com/LDOUBLEV/AutoLog            # 可以跳过git clone%cd AutoLog/
!pip3 install -r requirements.txt
!python3 setup.py bdist_wheel
!pip3 install ./dist/auto_log-1.2.0-py3-none-any.whl
   
In [ ]
# test tipc 4: test train inference%cd ~/PaddleSeg/
!bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ccnet/train_infer_python.txt 'lite_train_lite_infer'
   

七、复现经验

1、使用paddleseg套件复现论文,可以赢在起跑线;
2、为了防止组网错误,可以把官方权重转换为paddlepaddle,加载测试,确保模型组网无误;
3、模型组网完成后,一定要先测试模型导出是否有问题,确保无误再训练,否则test tipc不通过,会浪费很多时间。

声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn核实处理。