APP下载

基于可调节判别器的领域适应

2022-03-05赵小强蒋红梅

兰州理工大学学报 2022年1期
关键词:源域特征提取损失

赵小强, 蒋红梅

(1. 兰州理工大学 电气工程与信息工程学院, 甘肃 兰州 730050; 2. 兰州理工大学 甘肃省工业过程先进控制重点实验室, 甘肃 兰州 730050; 3. 兰州理工大学 国家级电气与控制工程实验教学中心, 甘肃 兰州 730050)

基于有监督的深度卷积神经网络已经在很多应用领域取得了显著的成就[1],然而在有监督学习中,一个泛化能力强的模型往往需要大量的标记数据.数据的收集和标注过程费时费力、成本高,这使得基于有监督的深度卷积神经网络算法难以训练出一个泛化能力较强的模型.此外,在真实场景中,由于环境的变化,训练数据集和真实数据集之间存在分布差异,导致有监督的模型泛化能力较差[2].针对上述问题,领域适应(domain adaptation,DA)[3]应运而生,利用其他领域的标记样本在领域间建立桥梁,同时对无标记的目标域进行适配,从而提高目标域的预测质量.领域适应在计算机视觉和自然语言等领域有着较为广泛的应用,如使用电商网站上已经分类好的图片数据对手机拍摄图片进行分类[4],使用已有情感标记的电影评论数据对快餐评论数据进行情感预测[5],这些问题都属于领域适应研究的范畴.领域适应消除了训练数据和测试数据必须服从独立同分布的限制,解决了标准监督学习所面临的训练数据不足、训练数据与测试数据存在分布偏移的问题.

近年来,DA算法取得了快速发展.早期的典型算法是基于最大均值差异(maximum mean difference,MMD),其核心思想是将不同域的样本映射在同一特征空间,通过最小化不同域的特征表示的MMD,减小域间差异.基于MMD的DA算法主要包含:深度域混淆(deep domain confusion,DDC)[6]在两个权值共享的CNN的特征层之间加入一个适应层,通过最小化源域及目标域特征之间的MMD减小域间差异,但MMD与核函数直接关联,如果选择了表现较差的核函数则MMD表现也会较差,从而导致最终目标域的预测质量较差;深度适应网络(deep adaptation network,DAN)[7]在DDC的基础上,为了避免单个MMD中核函数的选择比较困难,以及获取更多的有用信息,采用一个多核的MMD适配多个全连接层;联合适配网络(joint adaptation network,JAN)[8]则用一个联合的MMD适配多个全连接层.

相比较早期的浅层领域适应,应用深层神经网络的领域适应算法在共享特征提取和分类精度上表现更为优越.近年来,随着生成对抗网络(generative adversarial network,GAN)[9]不断发展,基于对抗学习(adversarial learning,AL)[10]的算法得到广泛研究.领域对抗神经网络(domain adversarial neural network,DANN)[11]是首次将对抗学习引入领域适应的方法,其通过特征提取器与判别器之间的对抗来获得域不变特征,但未考虑特征的类别信息,导致目标域的预测精确度较低;受条件生成对抗网络(conditional generative adversarial nets,CGAN)[12]的启发,对抗判别领域适应(adversarial discriminative domain adaptation,ADDA)[13]提出了一种基于对抗学习方法的统一框架,并根据是否使用生成器、使用何种损失函数和是否跨域共享权重等角度对现有方法进行了总结,取得了较好的结果.

然而,在上述的领域适应算法中,对目标样本的适应性较差,且通用的熵最小化函数使易转移样本的梯度大、难转移样本的梯度小,从而导致目标域的预测精确度较低.针对上述问题,本文提出基于可调节判别器的领域适应(A-DADA)算法.首先,设计了可调节判别器,在目标域中,利用两个可调节判别器(2K维)分类概率的差值作为对抗训练的衡量指标,旨在减少已对齐的目标样本对抗训练的次数,增加未对齐目标样本的对抗训练次数;同时,将平方熵损失作为最小熵损失函数,降低了易转移样本的梯度幅度,提高了难转移样本的训练效率.然后,使用随机梯度下降法(stochastic gradient descent,SGD)以可调节的学习率策略对A-DADA网络进行训练,得到该网络的具有良好迁移性的模型.最后,对目标域中的测试集样本进行测试,经多个对比实验验证,本文的算法具有更好的性能.

1 基于可调节判别器的领域适应

在DA中,通常包含两个数据集,分别为源域数据集和目标域数据集.设Xs为含有标签信息的源域数据集,对应的标签信息数据集为Ys,ys为Xs中的一个样本xs对应的标签,Xt为无标签信息的目标域数据集,xt为Xt中的一个样本.源域和目标域存在分布差异,本文的目的是设计一种模型来尽可能减小源域和目标域之间的差异[14-15],使得使用源域数据训练出的模型,能够有效地适用于目标域.

多数基于特征对齐的领域适应算法的基本结构通常由两部分组成,分别为特征提取器G及分类器F.首先,将样本xs或xt输入到特征提取器网络G,获得对应的特征图,将特征图输入到分类器网络F,输出K维的类概率分布向量p(y|x),其中K为类别数[16];然后,特征提取器与分类器之间进行对抗训练直到达到最优结果.还有些DA算法考虑了类边界和目标样本之间的关系,其基本结构由三部分组成,即在上述结构的基础上,又加入一个分类器,分别为F1、F2.首先,训练两个分类器F1和F2,使目标样本的特征差异最大化,从而有效地检测源域支持外的目标样本;其次,通过训练特征提取器G“欺骗”判别器F1和F2,使目标样本的差异最小化,鼓励在源域的支持范围内生成目标样本.但上述算法对目标域样本的迁移性考虑不足,同时使用一般的熵损失函数在易转移样本的梯度幅度大,在难转移样本的梯度幅度小,从而导致类别不平衡.

本文主要从以下两个方面解决上述问题,首先,提出可调节判别器,减少已对齐目标样本的对抗训练的次数,增加未对齐目标样本的对抗训练的次数;其次,利用平方熵损失函数,旨在降低易转移样本的梯度幅度,增加难转移样本的梯度幅度.

1.1 可调节判别器

本文中的两个判别器D1、D2的输出设定为2K维向量[17],第一个K维是源域的类分布,第二个K维是目标域的类分布,从而同时学习域和类变量的对齐.在目标域的对抗训练中,首先,目标域数据无标签信息,使用对应的伪标签y′t;其次,判别器为2K维,因此判别器的正确输出应该是[0,y′t],而特征提取器“欺骗”判别器将其分类到源域中,即[y′t,0];最终,把两个判别器的分类概率的距离ldt作为权重应用到对抗训练损失函数上,旨在减少已对齐的目标样本对抗训练的次数,增加未对齐目标样本分布,在单个判别器中实现域级别和类级别的对抗训练的次数,从而构成可调节判别器.因此,可调节判别器在目标域样本上的对抗损失为

(1)

其中:f(x)=F(G(x));fD1(x)=D1(G(x));fD2(x)=D2(G(x));G为特征提取器;F为分类器.

1.2 平方熵

为了进一步提高目标样本的适应性,对目标域样本在类预测器上的熵损失函数进行熵最小化,一般采用的熵损失函数为香浓熵损失函数,表示为

(2)

考虑到二分类的情况,其对应的梯度函数为

(3)

由式(3)可知,高概率类别的梯度比中概率类别的梯度大得多.然而,香农熵损失函数最小化是由目标样本的高概率类别主导,忽略了中低概率类别,因此,本文提出平方熵损失函数替代香农熵损失函数,其表示如下:

(4)

对应的梯度函数为

(5)

由式(5)可知,熵损失函数的梯度与对应的类别概率为线性关系.与香农熵最小化方法相比,虽然高概率类别仍然有较大的梯度,但它的主导作用已经减弱,使得中概率类别具有与高概率类别相差不大的训练梯度,因此,平方熵损失函数对不同的类别具有更均衡的梯度.

1.3 损失函数

基于可调节判别器领域适应的损失函数主要包含三部分:第一部分为类预测器的基于源域数据的分类损失函数lCE(f(x),y)和基于目标域数据的平方熵损失函数lte(F),第二部分为两个可调节判别器对应的分类损失ldsc1、ldsc2、ldtc1、ldtc2和对抗损失ldsa1、ldsa2、ldta1、ldta2,第三部分为两个可调节判别器输出同一域的类概率差值的绝对值之和ld.基于可调节判别器领域适应的损失函数如下式所示(可调节判别器的损失函数以D1为例,D2与其具有相同的形式):

(6)

其中:lCE(f(x),y)=-〈y,logf(x)〉,为交叉熵损失函数.

1.4 算法流程

本文提出的基于可调节判别器的领域适应的结构(图1)由四部分组成,分别为特征提取器G、两个可调节判别器D1、D2和一个类预测器F.其中,类预测器的输出为K维向量,可调节判别器的输出为2K维向量.对于源域样本xs和目标域样本xt,首先,使用共享的特征提取器G来提取样本特征,分别得到源域样本的特征G(xs)和目标域样本的特征G(xt);其次,源域样本的特征分别输入到两个可调节判别器D1、D2及类预测器F中,而目标域样本的特征不需要输入到类预测器中,分别得到源域样本的类别预测概率和目标域样本的类别预测概率;然后,对于目标域,将两个判别器输出的类别预测概率的差值作为权重应用在判别器的对抗损失上,得到关于目标域的对抗损失;最后,将平方熵损失函数作为熵最小化损失函数,以提高类别的平衡性.

图1 A-DADA structure diagram

A-DADA算法流程如下所示:

Input:源域数据集Ds=(Xs,Ys),目标域数据集Dt=Xt,训练次数分别为K1、K2,Batch Size的大小为n.

Step1:采用ImageNet[21]预训练模型的参数初始化网络层参数;

Step2: forkin 1:K1do

Step2.2:xsn通过网络G得到G(xsn),再分别通过F、D1、D2网络计算得到F(G(xsn))、D1(G(xsn))和D2(G(xsn));

Step2.3:根据式(6)中对应的损失函数lsc、ldsc1和ldsc2的计算公式,训练分类器和可调节判别器对源域样本进行正确分类,目标函数为

Step3: forkin 1:K2do

Step3.2:xsn、xtn通过网络G得到G(xsn)和G(xtn),再分别通过F、D1、D2网络计算得到F(G(xsn))、F(G(xtn))、D1(G(xsn))、D1(G(xtn))和D2(G(xtn));

Step3.3:根据式(6)中对应的损失函数lsc、lte、ldsc2、ldtc2、ldsc1、ldtc1和ld的计算公式,训练类预测器和判别器,目标函数为

λdsc1ldsc1+λdtc1ldtc1-λdld

Step3.4:根据式(6)中对应的损失函数ldsa1、ldta1、ldsa2、ldta2和ld的计算公式,训练特征提取器,目标函数如下:

Step4:end for

2 实验及分析

2.1 实验设置与数据

本文实验使用的数据集Office-31[19]是一个基于图片领域适应的应用较为广泛的数据集,一共包含4 652张图片,分为31个类别,这些图片源于3个不同的领域,分别为Amazon(A)、Webcam(W)和DSLR(D).其中,Amazon为电商网站Amazon.com的商品展示图片;Webcam为图像处理软件Webcam处理后的图片;DSLR为数码单反相机拍摄的图片.实验中,将这3个领域的数据集设置6种迁移任务,即A→W、D→W、A→D、W→D、D→A和W→A.

本文实验所用的深度学习框架为Pytorch[20],在搭载GPU为GTX1080Ti的服务器实验环境下使用Python3.6,在网络框架中,使用ResNet-50作为基础的特征提取器,其初始学习率为0.004,其中ResNet的初始参数为使用ImageNet[21]预训练的模型参数.类预测器和两个判别器由两个全连接层构成,其初始学习率为0.04,对于优化器的设置,采用动量为0.9的SGD来更新参数,同时采用与文献[22]相同的优化策略,学习率ηp由公式ηp=η0/(1+αp)β计算所得,其中p指模型训练完成程度,范围为0~1.0,并设置η0=0.01、α=10和β=0.75的优化器参数组合,其他参数的最优设置见表1.

表1 参数的最优设置

2.2 实验结果及分析

为了客观地比较算法的优劣,实验依次使用ResNet、DANN、ADDA、JAN算法和本文算法作对比实验,基于Office-31数据集的实验结果见表2.相较于其他算法,本文的A-DADA算法在多个迁移任务上具有更好的性能,与ResNet、DANN、ADDA、JAN算法相比,平均精确度分别提高了10.7%、4.6%、3.9%、2.5%;在Office-31数据集的6个迁移任务中得到提升,尤其在A→W和A→D两个任务上有较大的提升.但是在源域数据集较小的两个迁移任务D→A和W→A的精确度较低,说明本文的算法还存在一定局限性,其主要原因是对于源域数据集较小的领域,模型的适应性能被弱化.但由于本文算法较好地考虑了类别平衡及目标样本的适应性,使其在整体性能上优于其他对比算法.

表2 基于Office-31数据集的实验结果

本文对Office-31数据集上的任务A→W的训练曲线进行可视化,如图2所示,进一步分析算法的稳定性和收敛性,其中前10 000次迭代是无领域适应时的目标样本的预测平均精确度.由图可知,未加入领域适应时,目标样本的预测精确度较低且处于震荡状态,加入领域适应后,目标样本测试的平均精确度快速上升,并最终趋于稳定.

图2 目标样本的平均精确度Fig.2 Average accuracy of the target sample

为了进一步验证算法的有效性,使用本文算法在Office-31数据集上的特征可视化图片如图3所示,红点表示源域数据,蓝点表示目标域数据.从图3可以看出,在无领域适应时,目标域数据散乱地分布,也未观察到任何关于目标域间的分类信息及源域和目标域域间的适应信息,这说明源域数据与目标域数据之间存在较大差异;而在使用A-DADA算法对其进行领域适应分类后,源域数据和目标域数据的类间距离变小,具有相同类别的源域样本和目标域样本较好地拟合在一起,进一步验证了本文算法的有效性.

图3 T-SNE feature visualization

3 结论

为了提高基于对抗学习的领域适应(DA)对目标样本的适应性,本文提出了A-DADA算法.算法的网络结构主要由特征提取器G、两个可调节判别器D1、D2和一个类预测器F连接组成.该算法将源域和目标域数据输入到网络中,经过特征提取器G与可调节判别器间的对抗训练及判别器间的对抗训练,使该网络在目标域上具有更好的适应性.与ResNet-50、DANN、ADDA、JAN算法相比,本文算法在Office-31数据集上的平均精确度得到了提高,从而有效地提高了目标域的预测精确度.

在下一步研究中,将探求如何解决在源域数据集较小的两个领域中的迁移能力弱化的问题,从而使模型更具适应性.

猜你喜欢

源域特征提取损失
胖胖损失了多少元
基于参数字典的多源域自适应学习算法
两败俱伤
空间目标的ISAR成像及轮廓特征提取
基于Gazebo仿真环境的ORB特征提取与比对的研究
基于特征提取的绘本阅读机器人设计方案
基于Daubechies(dbN)的飞行器音频特征提取
从映射理论视角分析《麦田里的守望者》的成长主题
菜烧好了应该尽量马上吃
损失