APP下载

基于注意力机制的肺结节分类研究

2022-01-28洪敏杰刘星辰贾俊铖

计算机应用与软件 2022年1期
关键词:卷积结节神经网络

匡 健 洪敏杰 刘星辰 贾俊铖

(苏州大学计算机科学与技术学院 江苏 苏州 215000)

0 引 言

在已知的致死疾病中,肺部疾病占据了很大比重。其中又以肺癌最为突出,每年因肺癌死亡的人数约130万[1]。在肺癌早期,肺部病变更多地以结节的形式展现出来,如果能在该阶段及时发现并治疗,将极大提高病患的生还率。计算机辅助诊断技术用以清晰展示患者的病灶,在辅助医生确定病情上大有裨益。这其中最常见的属电子计算机断层扫描(Computed Tomography,CT)。医生根据CT图像,结合个人经验做出相应诊断。随着CT图像的普及,影像科医生需要开展繁重的阅片工作,因此在这一过程中也存在着漏诊[2]、错诊的情况。为进一步减少医生阅片的负担,肺结节分类技术应运而生。对于大量的CT图片,通过肺结节分类技术,即可初步对其进行良恶性诊断,在降低医生工作量的同时,也大大提高了诊断率。

利用机器学习的分类算法近年来受到人们关注,其通过对所给的实验数据进行学习,从而获得一个趋近于客观事实的结果。在进行学习之前,需要将图片数据进行处理。一般地,将数据分成训练集和测试集两部分。训练集用以学得模型,包括已经标记好类别的图片,而测试集是为了验证模型性能,包括那些未标记的CT图片。机器学习任务通过对训练集中的CT图片输入到模型进行训练,获得稳定的分类器,然后在训练好的机器学习模型上,为未标记的图片确定一个类别。因此,近年来在提升分类算法准确度上的研究也相继展开。在此过程中,研究人员提出了一种借鉴人类视觉的注意力机制[3]。注意力机制通过在训练过程中仅关注感兴趣区域而抑制相关区域,从而有效提升模型的性能。

本文在卷积神经网络的基础上提出一种基于注意力机制的3D深度卷积神经网络用于肺结节分类。实验结果表明,基于注意力机制的肺结节分类方法有效提升了分类的准确性。

1 相关工作

1.1 图像分类

传统的图像分类先将若干图像组成训练集,每个图像都分别被标记为不同类别中的一类。再使用该训练集训练一个分类器来学习每个类别的特征。最后通过预测一组新图像的类标签来评估分类器的性能。随着神经网络的兴起和数据集的丰富,深度卷积神经网络更多地被用在图像分类中。现存的许多图像分类算法大多通过在ImageNet数据集进行训练以证明其有效性。肺结节分类较之传统的图像分类算法大体相似却略有区别。

1.2 肺结节分类算法

肺结节分类算法一般由医生先给出肺部结节所在的位置,再将图片输入神经网络以训练。目前的肺结节分类主要集中在良恶性分类工作中,主要利用了神经网络、支持向量机(Support Vector Machine,SVM)等相关机器学习方法提取图像特征。胡强等[4]提出了一种基于遗传算法和BP神经网络的分类算法,对分类器进行了优化,实现了孤立性肺结节的良恶性分类。杨帆等[5]将卷积神经网络[6](Convolutional Neural Networks,CNN)引入了筛查存在肺结节的CT图像诊断,提出了一种基于CNN的分类算法。以上肺结节分类方法是基于2D图像层面的。Zhu等[7]则在此基础上使用双通路网络(Dual Path Network,DPN)进行了肺结节的检测与分类,从而实现了完整的肺结节诊断系统。

1.3 注意力机制

注意力机制多用于自然语言处理、机器翻译、计算机视觉等方面,能有效提升模型的性能。在分类中也卓有成效。传统的神经网络通常是在经过一系列的卷积、池化、激活函数,以及线性变换等操作后生成特征图,其质量极大程度取决于模型的优劣,因此近年来许多研究人员致力于探索性能更优的卷积神经网络模型。在此过程中,如ResNet[8]、DenseNet[9]等网络被相继提出,但是对于有效生成特征图问题仍然有很大的改进空间。

由此引出的注意力机制实质上是借鉴了人类视觉认知中对于视野中的特定区域的关注行为。在计算机视觉任务中,注意力机制的核心即是给图像加入权重,避免不相关因素对于最终结果的影响。

对于现有的模型,我们可以引入注意力以突出任务中关键的部分。目前,可以应用在计算机视觉任务中的注意力机制主要分为强注意力(Hard Attention)软注意力(Soft Attention)。强注意力即通过数据标注的方式,在模型训练过程中显式地告诉模型感兴趣区域的位置。强注意力更加关注点,也就是图像中的每个点都有可能延伸出注意力,同时强注意力是一个随机的预测过程,更强调动态变化,其训练过程往往是通过增强学习(Reinforcement Learning)来完成的。软注意力即更关注区域或者通道的注意力机制。软注意力是确定性的注意力,学习完成后直接可以通过网络生成。通过软注意力可使网络模型通过反向传播这一过程主动地学习任务所需要关注的区域。在近年来的计算机视觉任务研究中,以这两种注意力机制最为常见。Oktay等[10]提出了基于注意力机制的U-Net网络模型,这使得肺部分割的精确率得到了有效的提升。Hu等[11]在ImageNet2018中首次提出了采用基于图像通道的注意力机制的方法。

2 基于注意力机制的卷积神经网络设计

本文在借鉴Zhu等[7]的基础上提出了一种基于注意力机制的神经网络,将3D卷积神经网络与注意力机制结合,随后采用梯度提升树算法[12](Gradient Boosting Machine,GBM),实现了完整的肺结节分类网络的构建。

2.1 3D卷积神经网络(3D CNN)设计

在传统的机器学习任务中,通常将预处理之后的图片输入到网络中进行训练,以期得到较为满意的结果,但是在本实验中所得到的病患肺部数据实则是每隔一定层后所得到的一组图片数据,此时若仍使用传统2D数据进行训练,势必会存在肺部结节不够清晰与全面,以至于提取的特征不明显、最终训练效果差等问题。为解决这些问题,可将原始的肺部CT图像处理成3D CT图像作为网络输入。然而由于受到GPU显存的限制,无法将一个病患的完整3D图像输入网络,因此对原始3D图像进行裁剪,根据医生提供的标注信息,最终得到了包含了肺结节的17×17×17大小的三维数组。

本文中所使用的网络结构主体是ResNet网络,主要由残差网络块、池化层、全连接层和注意力机制模块等组成,具体结构如图1所示。

图1 3D神经网络结构

对于输入网络的CT图片采用常用的批处理方案,将每个批次大小(Batch_size)设置为16,由于输入的是灰度图,初始通道数为1。于是每个批次送入网络的是一个16×1×17×17×17的数组,在卷积、归一化、通过激活函数和最大池化层后,在通道数增加至64的同时,也减小了图片的尺寸。将改变后的图片数据输入残差网络卷积块中,每个残差块[13]包含若干卷积操作以提取特征并减小尺寸。在经历过一系列残差块后,再次进行平均池化操作,使得图片尺寸降为1×1×1,最终通过拼接和线性变换得到了16×2的张量。

2.2 注意力模块应用

注意力机制可用于为特征赋予一个权重来提升效率,因此在本文中提出一种可训练的注意力机制模块,并将其整合到上述的卷积神经网络中。注意力因子(即权重)对不相关的特征具有抑制作用,而突出图像分类任务中感兴趣的区域,具有对特征映射重新采样的作用。本文对现有的网络所得到的特征相继使用通道注意力机制运算以及空间注意力机制运算,注意力机制模块在ResNet网络中的使用如图2所示。

图2 注意力机制模块在ResNet残差块中的结构

在通道注意力机制运算中,采用挤压和激励网络[11](Squeeze and Excitation Networks,SENet)对输入特征映射的空间维度的方法,结合了自适应均值池化和自适应最大池化来增强特征映射的表达能力。较之于普通池化技术只能通过调整池化步长得到期望的池化结果,自适应池化因其固定大小输出的池化技术而更为有效。

在本文的每个ResNet块中加入通道注意力机制,首先对输入特征分别进行均值池化和最大池化操作(若设置池化输出为1,则形如64×17×17×17的输入在自适应池化后将会被挤压成64×1×1×1的输出),随后对两种池化后的输出分别使用一个共享参数的自编码结构的多隐层神经网络进行激励操作,得到两中间结果,将这两个中间结果相加,最后通过Sigmoid函数求得注意力因子。

类似于通道注意力机制,空间注意力机制即是对图像空间层面添加一个注意力因子。该因子的计算仍是先根据特征映射的最大池化和均值池化结果得到,此处的池化操作有别于前者,是在通道层面进行的池化操作,会将形如64×17×17×17的特征映射在通道参数设为1的情况下生成1×17×17×17的空间注意力因子。

对应到本文中,是将经过通道注意力后的特征映射先后进行最大池化和均值池化操作,随后将两结果在通道层面上进行拼接操作,将拼接后的结果通过卷积层实现降维,得到单通道的空间特征映射。最后仍然使用Sigmoid函数求得最终的空间注意力因子。

依次对原始卷积结果的特征映射进行通道注意力和空间注意力操作,将会得到新的特征映射。通过有监督的反向传播训练过程,有效抑制3D图像中不相关区域,突出感兴趣区域,从而使得分类网络的性能有进一步的提升。

2.3 梯度提升决策树算法(GBM)

梯度提升决策树算法[12]利用一种可迭代的决策树算法,将所有树的结论累加以求得最终答案。

提升树模型的实质是多个决策树的累加和,其数学模型如式(1)所示。

(1)

式中:T(x;Θm)表示决策树;Θm是决策树的参数;M表示决策树的个数。针对样本K={(x1,y1),(x2,y2),…,(xN,yN)},提升树模型的训练就是选择决策树的参数Θ={Θ1,Θ2,…,ΘM}以最小化损失函数∑L(yi,fM(xi))。

提升树模型亦可表示为一个迭代过程,如式(2)所示。

fm(x)=fm-1(x)+T(x;Θm)m=1,2,…,M

(2)

因此,提升树的训练亦可按照迭代的过程来完成,在m次迭代中,生成一个新的决策树T(x;Θm)。

综上,提升树算法的过程大致如下:初始化f0(x)=0,对每一个样本(xi,yi),计算其残差rm,i=yi-fm-1(xi),i=1,2,…,N;利用{(xi,rm,i)}i=1,2,…,N训练一个决策树,得到T(x;Θm),之后不断更新式(2),最终得到如式(1)所示的提升树。

然而提升树在一些情况下不便求出残差,梯度提升树便是用损失函数的负梯度方向值来近似拟合残差。在本文实验中,由3D神经网络得出的特征将会输入梯度提升决策树作进一步训练。由于肺结节分类实验在本质上属于二分类问题,因此GBM的损失函数如式(3)所示。

L(y,f(x))=log(1+exp(-yf(x)))

(3)

实验中将恶性结节标注为1、良性结节标注为0,计算通过3D神经网络生成的特征与目标标注之间的差异,最终训练出提升决策树,得到一个较为准确的分类器。

3 实 验

3.1 实验设置

本文采用的数据集是LUNA16数据集,该数据集来源于LIDC-IDRI数据集,其中有888幅已脱敏的病人肺部图像。其中肺结节是将直径在3 mm以上的样本筛选出来,再将相近的结节融合所得,每个结节的标注信息由LIDC-IDRI数据集提供,在三位以上专家共同标注与评估下,得到了较为准确的良恶性分类,标注数据如表1所示。

表1 良恶性分类标注表

本文的实验基于Intel i5-8400处理器,两张NVIDIA GeForce 1070Ti显卡(8 GB显存),32 GB内存,操作系统为Ubuntu16.04。实验中的网络模型采用Pytorch深度学习框架实现。在实验中,根据标注信息,将肺结节图片数据集整合成3D数据,最终得到了17×17×17的三维数组。由于受到GPU显存的限制,将训练参数的批次大小设置成16,随后采用随机梯度下降算法,在700个Epoch中反复训练。在训练过程中初始学习率为0.1,在第300到第500个Epoch中将学习率调整为0.01,第500个后的学习率调整为0.001,权重衰减值设置为0.000 5,动量值设置为0.9。

为防止数据过拟合以及增强模型的泛化能力,在数据预处理阶段对CT图像进行了随机变换操作,例如左右翻转、随机裁剪等变换。

3.2 实验分析

本文的实验评估采用了ROC(Receiver Operating Characteristic)曲线和准确率(Accuracy)作为评判模型的指标。

ROC曲线的横轴是假阳性率(False Positive Rate,FPR),纵轴为真阳性率(True Positive Rate,TPR)。在医学领域中,真阳性率一般又称为敏感度(Sensitivity),其计算公式如式(4)所示。

(4)

式中:TP表示真阳性;FN表示假阴性。

ROC曲线在正负样本的分布发生变化时形状基本保持不变,因此该评估指标能降低不同测试集带来的干扰,更加客观地衡量模型本身的性能,适用于分类问题。

AUC(Area under the Curve of ROC)常被用来评价一个二分类模型的训练效果,即ROC曲线下方的面积,表示预测的正例排在负例前面的概率。

肺结节分类模型的ROC曲线如图3所示,其中实线为使用ResNet网络和梯度提升树算法(GBM)的运行结果,虚线为在原有基础上添加Attention机制的运行结果。可以看出,添加Attention机制的网络,其AUC值比未添加Attention的网络的AUC大,因此可以证明添加了Attention机制的网络比原网络效果更好。

图3 模型的ROC曲线

对于给定的数据集,准确率是分类器正确分类的样本数与总样本数之比,也常作为分类问题的度量标准。如表2所示,在实验中分别记录了仅使用ResNet模型、结合ResNet与GBM后的模型准确率,再将其与添加了Attention机制后的模型进行对比。最终发现ResNet网络结合GBM的模型在添加Attention机制后相较于使用Multi-scale CNN[14]和Deep 3D DPN[7]的模型有较大的提升,由此可见Attention机制可以有效地提升肺结节分类的准确性。

表2 LUNA16 数据集上各模型准确率

4 结 语

本文提出一种基于注意力机制的肺结节分类方法,采用了3D卷积神经网络和梯度提升树算法,结合了空间和通道注意力机制,在LUNA16公开数据集上验证了本文方法的有效性。实验证明了使用注意力机制的方法准确率达到了91.30%,超过了仅使用的ResNet的方法和ResNet+GBM的方法。

在未来的工作中,我们将继续提升肺结节分类的准确性,并进一步探索肺结节类型的多分类问题。

猜你喜欢

卷积结节神经网络
基于全卷积神经网络的猪背膘厚快速准确测定
基于神经网络的船舶电力系统故障诊断方法
基于人工智能LSTM循环神经网络的学习成绩预测
基于图像处理与卷积神经网络的零件识别
MIV-PSO-BP神经网络用户热负荷预测
体检发现的结节,离癌症有多远?
查出肺结节,先别慌
了解这些,自己读懂甲状腺B超报告
基于深度卷积网络与空洞卷积融合的人群计数
甲状腺结节能 自己消失吗?