Inf-Net论文精读笔记

论文名称:Inf-Net: Automatic COVID-19 Lung Infection Segmentation From CT Images

论文链接:https://arxiv.org/pdf/2004.14133

github链接:https://github.com/DengPingFan/Inf-Net

文章主要成果简述

  • 作者从临床医生根据CT图像检测肺部感染区域的方法—先粗略的确定感染区域位置,再根据局部外观提取其精细的轮廓,中得到灵感,从而提出了一个用来实现新冠肺炎肺部CT图像感染区域分割的深度学习网络(Inf-Net)。该网络主要是通过显式的**边缘注意力模块(EA)从低层特征中提取感染区域边缘信息,然后通过并行部分解码器(PPD)聚合高层特征生成全局映射图提取出部分区域信息,最后通过一组级联隐式的反向注意力模块(RA)**引导加强边缘和区域信息之间的联系。后文会详细介绍EA,PPD以及RA具体的作用和实现细节
  • 由于新冠肺炎的CT图像感染区域标记数据集数量极为稀少有限且难以进行标记,因此作者采用一种基于随机采样策略的**半监督学习(SSL)**算法,使用了大量无标记的新冠肺炎CT图像对训练样本进行扩充后,对网络进行有效的训练,极大的提升了模型的精确性和有效性。
  • 最后,作者也设计大量的实验证明网络中的每一部分的有效性,同时也设计了对照实验充分验证Inf-Net网络在新冠肺炎CT图像进行区域分割的准确性和有效性。

背景介绍

2020年以来新冠肺炎在全球蔓延,截至本文2021年10月31日,根据WHO最新数据显示全球累计已有 245,373,039人确诊新冠肺炎,4,979,421死于新冠肺炎。一直以来RT-PCR(可以简单理解为核酸检测)都被认为是检测新冠肺炎最有效的检测手段,但也存在较高的误检漏检率。作为核酸检测的有效补充,肺部CT影像也是可以用于对新冠肺炎的早期筛查的。但是由于新冠肺炎感染区域的高度变化且感染前后的CT图像差异对比度极小,感染区域的划分与诊断更多依靠的是医护人员主观上丰富的临床经验,所以肺部CT图像更多用于新冠肺炎患者后期的辅助治疗。除此之外,正式因为上述原因,新冠肺炎CT图像感染区域的标记是十分耗费时间且难以进行的,所以现存的标记数据集及其稀少有限。

核心架构解读

整体网络架构概述

image-20211031145834721

​ 主干网络可选,文中主要采用Res2Net。首先CT图像被送入两层卷积层来提取得到高分辨率低层次特征f2,同时这里还加了一个EA模块用来显示提升低层特征边界信息的表达。然后f2被送入三层卷积层来提取高层特征,与此同时这三层卷积层提取的特征被并行的送入PPD进而生成粗略定位感染区域的全局映射图Sg,再然后在Sg和f2共同指导下通过级联的RA模块逐步聚合边界和区域信息的联系,最后通过Sigmoid生成最终的肺部感染区域预测图Sp。

边缘注意力模块 EA–辅助监督从低层特征学习提取准确的边缘信息

文中的EA模块具体来讲就是将富含边缘信息的低层特征f2作为输入通过一个卷积核的卷积层得到边缘映射图Se,然后用真实标注的数据图GT生成的边缘图Ge来指导Se的生成。简单说就是用Ge和Se的BCE(二值交叉熵损失)加入到最终的损失函数中指导模型的训练方向从而让模型准确有效的提取到边缘信息

并行部分解码器PPD–用来生成粗略定位无结构信息的感染区域全局映射图

image-20211031145936990

关于PPD作者在文中并没有具体介绍实现与应用细节,只是说这种方式受nnUnet启发(低层特征耗费了大量的计算资源但是对性能的贡献少)以并行的方式聚合高层特征,生成粗糙的全局映射图作为RA模块的指导,所以这里我就结合文中提供的PPD图片以及Inf-Net代码简要叙述PPD具体的操作流程。f5先是进行了一次上采样和3x3的卷积(图片应该是标反了)直接和f4进行了一次点乘,然后和f5本身进行一次上采样和卷积的结果进行结合得到临时T1。然后f5通过两次上采样和一次3x3的卷积同f4进行了一次上采样和卷积的结果以及f3进行点乘得到T2。上面得到的T1进行卷积,上采样,卷积后和T2进行结合,最后进行两次3x3卷积以及一次1x1卷积输出粗略定位感染区域的全局映射图

反向注意力模块RA

image-20211031150038024

在代码中与文中的公式是一样的,但是和图片感觉可能略有不同

RA的输入有三个:本层提取到的高层次特征fi,前两层提取到的低层特征f2,以及反向回传的感染区域映射指导Si+1

首先对回传的指导映射图Si+1进行适当的采样处理,对其使用sigmoid函数后进行反转操作(0变1,1变0)然后再扩展为64通道得到RA模块的权值Ai,然后用Ai和高层特征fi做点乘再和经过下采样的低层特征做拼接后经三层卷积得到Ri,最后Ri同Si+1进行叠加得到Si

这里RA采用的是逐步擦除的策略,最终将不准确和粗糙的预测区域细化为准确完整的预测图

损失函数设计

不太会用markdown写数学公式,后续这里会进行适当的补充和修改

image-20211102133809074

用到损失函数主要有两种,一个是前面EA模块提到的二元交叉熵损失(BCELoss),另一个是在分割任务中常用的交并比损失(IoU)。用来监督边缘信息提取的Ledge就是直接使用了BCELoss。而在后续区域分割阶段的损失使用的BCE和IoU与常规的略有不同,主要是在困难像素样本点处增加了权值,Lseg由加权的BCE加上加权的IoU组成。最终总的损失函数就由EA模块中的BCE(Se,Ge),PPD模块中的加权BCE(Sg,Gs)+加权IoU(Sg,Gs),以及RA模块中每一部分上采样到和GT图相同大小后的BCE(Si,Gs)+IoU(Si,Gs)组成

半监督学习算法

image-20211102125037628

数据集的划分

总计有100张带标记的新冠肺炎CT图像感染区域数据以及1600张无标记的肺部CT图像,随机选择用50张带标记的数据作为训练集,剩余48张作为测试集(有两张不够明显,最后人为去除了),然后用1600张无标记的训练集做扩充

基于随机采样的半监督学习算法(具体实施策略)

1.首先用所有带标记的50张数据作为训练集对模型进行训练得到模型M

2.从1600张无标记的数据中随机抽出5张用M进行预测生成伪标记数据

3.将这五张带有伪标记的数据加入到训练集并从无标记训练集中删除

4.用新的训练集对模型进行微调

5.重复上述步骤直到所有1600张无标记数据都被加入到训练集当中

实验评价指标与结果

具体实验设计与结果文中描述的比较清晰易懂就暂时不在这里具体展开,后续会根据情况进行补充

实际运行代码后感悟

原论文中采用的超参数 epoch=100 batchsize=24 学习率为0.0001 每50个epoch降为原来的0.1倍

  • Inf-Net
    • 直接用50张有正确标注的CT图片作为训练集训练100轮得到 Inf-Net-100.pth,然后就可以用这个训练好的模型做预测,结果就是inf-net的结果
  • Semi-Inf-Net
    • 先将1600张无标注数据的CT图片随机打乱,五个一组的分成320组(split_1600.py)
    • 按照论文中半监督学习算法进行迭代学习直到混合数据集中包含50张真正标注的图片和1600张生成的伪标注数据(PseudoGenerator.py 初始时采用的Inf-Net-100.pth作为模型的初始化权重)
      • 初始时采用Inf-Net-100.pth 对第一组五张无标记CT图进行预测。
      • 然后对生成伪标记图做边缘提取。
      • 再将这三种图放入到混合数据图中
      • 利用混合数据集做10个epoch的微调生成新的Semi-Inf-Net-10.pth
      • 用新生成的Semi-Inf-Net-10.pth做替换重复上述过程直到1600张全部加入到混合数据集中
    • 先用伪标注的1600张数据做预训练得到 Inf-Net-Pesduo-100.pth,然后再用这个作为初始化权重用50张真正数据做100轮训练进行微调得到Semi-Inf-Net-100.pth
    • 最后进行预测

具体的代码运行流程