APP下载

自适应局部关系网络的小样本学习方法

2021-10-18魏胜楠张景异耿俊香王中洲

沈阳理工大学学报 2021年4期
关键词:集上度量准确率

魏胜楠,张景异,陈 亮,耿俊香,王中洲

(沈阳理工大学 自动化与电气工程学院,沈阳 110159)

近年来,深度学习在计算机视觉任务中的多个方面均取得了巨大的进步[1],同时也为日常生活带来智能便捷的体验[2]。但深度学习的成功有赖于大量的有监督数据,对于一些无法收集到大量数据的情况,如医学影像的获取、稀有动植物的采集等[3],深度学习并不具备足够的优势。基于此,小样本学习[4]应势而生。小样本学习是在只给出少量标注数据的情况下对其进行识别分类[5],即给定一组带有足够标记量的基类数据和一组从未见过的新颖类别,通过对基类的学习训练实现对新颖类别的分类。

现阶段小样本学习研究多通过元学习来完成。元学习从一组训练数据中提取可迁移的知识,一个好的初始条件[6]、一种优化方法[7]或一个度量空间[4,8-9]都可以成为迁移的对象,将这些知识推广并应用到新的测试任务中。元训练阶段的任务设置通常与元测试阶段保持一致,以减少训练和测试之间的差距,提高模型的泛化能力。

基于元学习的度量方法[3,10]由于其突出的便捷高效性已受到诸多关注,该方法一般通过episode[8]的训练方式对测试的环境进行模拟,从而学习一个嵌入空间(将数据转换成一个特征向量,相似样本之间的特征向量距离小,不相似样本之间的特征向量距离大),确保来自同一类别的样本聚集在一起,不同类的样本彼此分离,测试集的图像则通过学习到的嵌入空间实现精确分类。Vinyals O等[8]提出的匹配网络是该度量方法的典型代表,其用一种带有嵌入特征提取器的最近邻方法实现了小样本分类。Li W等[11]认为考虑图像-图像级别的比较存在困难,故引入了局部描述子的概念,提出深度最近邻网络,实现小样本分类,但基于最近邻方法几乎不可能通过少量示例学习到一个包含各种复杂表象的真正概念。Sung F等[9]提出通过学习一种可迁移的深度度量方式比较图像之间的关系,从而实现小样本学习,但由于其嵌入模块需要满足测试集中没有见过的新颖类别,导致网络泛化能力较弱。

本文提出一种自适应局部关系网络的小样本学习研究方法。该方法利用自注意力机制[12]使网络能够有针对性地提取到特定任务的信息,达到增强网络泛化性的目的;将任务相关的局部描述子采用add的连接方式[13]构成复数连接空间,输入到非线性度量函数中,提取局部特征相似性的同时有效减少参数;在关系模块中加入SENet结构,确保网络能够抑制无效信息,更加关注特征通道上的有效信息;最后通过改进的损失函数得到高效准确的分类模型。

1 自适应局部关系网络

1.1 问题定义

小样本学习目前广泛应用的元学习方法又被称为学会学习,该方法在训练阶段构造多个不同的元任务,通过这些元任务学习到一个具备良好泛化能力的网络,使其面对新的测试任务时无需改变现有模型便可进行分类。每个元任务都在训练集中随机采样C个类,每个类由K个样本生成,总共C×K个样本,将该数据作为模型的支持集,再从C个类剩余的样本中抽取一批数据作为模型的预测对象。期望模型能从C×K个数据中学会如何区分C个类别,该问题一般称为C-wayK-shot问题。

为验证改进算法的性能,本文在MiniImagenet和Omniglot两个数据集上分别对5-way 1-shot和5-way 5-shot两种情况进行实验验证。

1.2 局部描述子

局部描述子是计算机视觉中一个基本的研究问题,文献[14]在1999年提出尺度不变特征转换(Scale Invariant Feature Transform)的局部描述子概念。假设有一张大小为h×w的图像,图像经过卷积后,输出的特征图可表示为一个h×w×c的张量T,其中c是通道数。如将每个长度为c的特征向量看成一个局部描述子xi,则该张量可以表示为m个局部描述子,m=h×w,每个局部描述子对应图片中的一个区域。局部描述子结构如图1所示。

图1 局部描述子结构

1.3 自注意力机制

自注意力机制是注意力机制的变体,其减少了对外部信息的依赖,更加擅长捕获到数据或特征的内部相关性。利用自注意力机制构建一个转换器,使网络产生特定于任务的嵌入信息。转换器处理流程如图2所示。

图2 转换器处理流程

对于自注意力机制来说,其存储一种三元信息,即一组查询点Q、键K、值V。为计算接近度和返回值,该点会先线性映射到某个空间上,映射方法为

(1)

(2)

将计算得到的相似度作为权重,用于计算和任务相关的局部描述子信息,计算式为

(3)

(4)

引入局部描述子和自注意力机制后的嵌入网络整体结构如图3所示。

图3 嵌入网络结构图

由自注意力机制构成的转换器可实现网络的自适应,即满足式(5)。

{ψx;∀x∈χtrain∪xtest}=
T({φx;∀x∈χtrain∪xtest})

(5)

1.4 特征提取模块和关系模块

本文的特征提取模块采用和关系网络相同的设置。采用四个卷积块,每个卷积块都包含64个3×3的滤波器、一个批处理归一化、一个Relu非线性层。为使需要输出的特征信息能够在关系模块中进一步卷积,对前两个卷积块应用2×2的最大池化层,后两个卷积块则不采用。为更好地建立局部描述子特征通道之间的相互依赖关系,提升有用特征的同时抑制对当前任务贡献较小的特征,提出在关系模块中引入SENet结构。改进后的关系模块由三个卷积块和两个全连接层构成,每个卷积块包含64个3×3的滤波器、一个批归一化、一个非线性Relu层。对前两个卷积块应用2×2最大池化,最后一个卷积块应用全局平均池化。两个全连接层后接sigmoid函数和缩放函数,网络的整体架构如图4所示。

图4 自适应局部关系网络结构框图

1.5 局部描述子连接

支持集S和查询集Q的图像通过嵌入网络后,得到任务相关的局部描述子信息。每一张支持集都能够得到一个由m(m=h×w)个局部描述子构成的特征向量,即

D=[x1,x2,...,xm]∈Rc×m

(6)

式中xi为第i个局部描述子。选择add的连接方式,对支持集和查询集的局部描述子进行连接,该方法能够有效地把不同的特征信息进行加权,保留中心处最强的信息,同时有效减少参数数量,连接表示为

H(DFi,Dq)=DFi+jDq(i=1,2,…,k)

(7)

式中:j为虚数单位;DFi和Dq分别为支持集和查询集的局部描述子构成的特征向量;H为特征向量构成的复数连接空间,其维数是1×m。查询集图像与支持集中的每张图像分别连接,构成k个连接空间H1,H2,…,Hk。局部描述子连接方式如图5所示。

图5 局部描述子连接方式

1.6 距离度量函数与损失函数

本文采用卷积神经网络去拟合一种非线性的度量方式,从而确定距离度量函数,该方法更具灵活性,也更能捕获到特征信息之间的相似度[9]。得到复数空间后,将复数空间输入到关系模块中,提取两个局部描述子之间的局部不变性,对构成的k个复数连接空间分别给出一个分数,分数高的被判为属于同一类别。整体网络结构如图6所示。

图6 网络整体模型

将要解决的问题看成一种回归问题,基于此,损失函数可采用均方误差的形式,计算式为

(8)

式中

ri,j=gφ(C(fφ(xi),fφ(xj))),i=1,2,…,C

(9)

(10)

将模型预测出的相似性得分ri,j和标签y进行比较,并累加求和得到最终损失值。对于小样本来说,由于每个测试任务只有少量示例,极易出现过拟合现象,因此提出在原损失函数的基础上,添加L2正则化项,增强模型的鲁棒性,改进之后的分类器的损失函数为

(11)

式中γ为正则化惩罚系数,反向传播的求导公式为

(12)

未加入正则化项时梯度下降公式为

φ,φ=φ,φ-η∇LMSE(φ,φ)

(13)

加入L2正则化项后的梯度下降公式为

φ,φ=φ,φ-η(∇LMSE(φ,φ)+γ(φ,φ))=
(1-ηγ)(φ,φ)-η∇LMSE(φ,φ)

(14)

式中:φ、φ为整个网络的可学习参数;η为学习率。经过损失函数训练得到的网络最终会达到相同类别的样本分数接近1,不同类别的样本分数接近0的效果,正则化项的加入也会使网络的权值不断衰减,最终可解释变量数量减少,保证分类准确性的同时提高模型抗干扰能力,有效地避免过拟合现象。

2 实验结果与分析

2.1 数据集

MiniImageNet数据集由文献[8]于2016 年提出,是计算机视觉领域的一个重要基准数据集,其来源于图像网,从图像网中随机抽取100个类,每个类由600张图片构成;Omniglot数据集是Lake等提出的语言文字数据集[15],其 中包含来自50个不同字母的1623个字符(类),每个类包含不同的人绘制的20个样本。

2.2 实验设置

在Windows操作系统下,基于PyTorch深度学习框架进行实验,所有实验中使用初始率为10-3的Adam算法进行训练,每10000次迭代学习率减半,最大迭代次数为100000。MiniImageNet数据集中64个类别用于训练,16个类别用于验证,20个类别用于测试,所有图像的大小设置为84×84,将测试集上随机生成的600个批次的分类准确率的均值作为最终识别准确率;Omniglot数据集分为三个部分,其中1200类用于训练,123类用于验证,300类用于测试,所有图像的大小设置为28×28,同时通过对Omniglot数据集进行90°、180°和270°旋转来达到数据增强的目的[3],将测试集上随机生成的1000个批次的分类准确率的均值作为最终识别准确率。本文在5-way 1-shot、5-way 5-shot两种情况下进行实验。

2.3 实验结果

2.3.1 MiniImageNet数据集

将本文改进算法在MiniImageNet数据集上的结果与小样本学习的一些经典网络算法进行比较,结果如表1所示。由表1可以看出本文的改进算法在两种情况下均效果良好。

表1 MiniImageNet数据集上的实验结果

在MiniImageNet数据集上的5-way 1-shot和5-way 5-shot情况下不同算法的迭代次数和分类准确率的关系如图7所示。

由图7可以看出,在训练的前20000次,模型的准确率随着训练次数的增加而增加,在训练进行到40000次时准确率基本趋于平稳。本文的改进算法在5-way 1-shot问题下达到52.34%的准确率,比RN提高1.9%,比DN4提高1.1%;在5-way 5-shot问题下达到73.08%的准确率,比RN提高7.76%,比DN4提高2.06%。

图7 两种情况下的分类准确率

2.3.2 Omniglot数据集

本文改进算法与其他经典的小样本学习方法在Omniglot数据集上的实验结果如表2所示。由表2可以看出,本文算法在5-way 1-shot 、5-way 5-shot情况下均获得较好的性能,MAML方法在5-shot情况下的分类准确率最高,其原因是模型进行了微调。

模型(算法)微调5-way分类准确率/%1-shot5-shotMATCHING NETSY97.998.7MANNN82.894.9MAMLY98.7±0.499.9±0.1PROTOTYPICAL NETSN98.899.7RELATION NETS(RN)N99.6±0.299.8±0.1OURSN99.699.7OURS(L2正则化)N99.7±0.2199.8±0.21

2.4 实验分析

2.4.1 自注意力机制实验分析

为进一步说明自注意力机制的引入对分类准确率的影响,在MiniImageNet数据集上对带有自注意力机制模块的网络和普通卷积神经网络两种情况进行对比实验,结果如表3所示。

表3 MiniImageNet数据集上自注意力机制方法对比

由表3可知,在1-shot情况下带有自注意力机制的网络分类准确率比普通卷积神经网络提高0.46%,在5-shot情况下提升4.74%。改进算法能有效地增强网络的泛化能力,满足测试集中的新样本,达到提高分类准确率的目的。

2.4.2 add连接方式实验分析

本文特征连接方式从简单的特征向量间的串联变成局部描述子之间的add连接,相比于DN4和RN在MiniImageNet数据集的5-way 1-shot 、5-way 5-shot情况下的分类准确率均有所提升。为进一步说明连接方式对分类准确率的影响,在MiniImageNet数据集上分别对不同方法的连接方式进行实验,结果如表4所示。

表4 MiniImageNet数据集上局部描述子连接方式对比

由表4可知,局部描述子的add连接方式的准确率最高,其原因是并行的连接方式能够有效地对特征信息进行加权,保留其中最重要的特征,滤掉相对无关的信息,从而提高分类准确率。

2.4.3 关系模块实验分析

为进一步说明加入SENet对信息进行筛选后分类准确率有所提升,在MiniImageNet数据集上对两种情况分别进行实验,结果如表5所示。

表5 MiniImageNet数据集上关系模块方法对比

由表5可知,加入SENet结构后网络的分类准确率在1-shot情况下提升0.76%,在5-shot情况下提升5.74%,原因是SENet结构使网络更加关注特征通道间的相关性,从而达到提高分类准确率的效果。

2.4.4 距离度量函数实验分析

本文的距离度量函数采用非线性度量,在MiniImageNet数据集上将其与现有人为设计的常见度量方式(如余弦距离或欧氏距离等)进行比较,实验结果如表6所示。

表6 MiniImageNet数据集上度量方法对比

由表6可见,非线性度量方式的分类效果最好,因欧式距离或余弦距离为设定好的固定的度量方式,很大程度上依赖于提取的特征信息,本文采用的利用网络自行获取度量距离的方法更为灵活,更能够捕获到特征之间的相似程度,故更大程度地提高分类准确率。

3 结论

应用一个自注意力机制使嵌入网络能够有针对性地提取与任务相关的特征信息,提升网络的泛化能力;引入局部描述子的思想,关注图像-类别的比较,局部描述子的连接由串联方式变为一种并行策略,将局部描述子连接后构成的复数连接空间应用普通卷积神经网络加SENet的结构去拟合一种非线性度量,有效地抑制通道中的无用信息,提高度量的准确率。通过在两个小样本学习常用的MiniImageNet数据集和Omniglot数据集上的实验,证明改进算法可高效提高分类准确率。

猜你喜欢

集上度量准确率
鲍文慧《度量空间之一》
GCD封闭集上的幂矩阵行列式间的整除性
乳腺超声检查诊断乳腺肿瘤的特异度及准确率分析
不同序列磁共振成像诊断脊柱损伤的临床准确率比较探讨
2015—2017 年宁夏各天气预报参考产品质量检验分析
颈椎病患者使用X线平片和CT影像诊断的临床准确率比照观察
基于互信息的多级特征选择算法
代数群上由模糊(拟)伪度量诱导的拓扑
突出知识本质 关注知识结构提升思维能力
度 量