APP下载

基于生成对抗网络的遥感影像场景分类

2022-07-26陈红顺陈文杰

微型电脑应用 2022年6期
关键词:源域分类器特征提取

陈红顺, 陈文杰

(北京师范大学珠海分校,信息技术学院, 广东,珠海 519087)

0 引言

近年来,卷积神经网络(CNN)在图像分类方面取得了很多的应用[1-3],这得益于带有训练标签的大型图像数据集。由于遥感图像特殊性,制作带有训练标签的大型遥感图像数据集成本高、过程复杂,因此目前公开的遥感图像数据集规模不大,且不同数据集之间存在遥感平台、传感器、拍摄角度、分辨率、拍摄时间的差异,造成同一类型的地物在不同的数据集中差异巨大,这给遥感图像场景分类造成了困难。近年来,迁移学习方法开始应用于遥感图像场景分类[4]。

生成对抗网络[5](GAN)由Goodfellow等人于2014年提出,自提出以来引起了许多研究者的兴趣。目前,许多学者提出了各种改进模型,如深度卷积生成对抗网络(Deep Convolutional GAN,DCGAN)[6]、条件生成对抗网络(Conditional GAN)[7]等,并广泛应用于图像超分辨率合成、图像风格转移、图像分割[8]等领域。

本文针对遥感图像场景分类中带有标签数据不足的问题,提出一种生成对抗网络的分类算法,并应用于遥感图像场景分类。

1 生成对抗网络

如图1所示,生成对抗网络的结构主要包含2个模型:生成器(Generator,记作G)和判别器(Discriminator,记作D)。生成器通过学习得到真实数据的分布,生成尽可能与真实相似的数据,以达到让判别器无法鉴别;判别器的目标则是尽可能的区分真实数据和生成器生成的假数据。生成器和判别器交替训练,互相对抗博弈,最终达到纳什均衡。此时,生成器能够生成与真实数据分布相似的数据,判别器无法识别数据的“真假”。

图1 生成对抗网络结构

修改原始GAN判别器的输出类别标签可以将GAN扩展成半监督分类器[9]。此时,生成器不再生成数据而是作为特征提取器,通过对抗学习方法拟合两个数据分布,并把结果分别传入判别器和分类网络,并调整标签决策来增加类间差异。

基于GAN的训练机制,通过交替训练拟合两组数据,生成器(也称为特征提取器)在源域和目标域中共享参数,把源域数据和目标域数据交替放进去训练,达到生成网络能够弱化2个域的差异,提取到2个域的共同特征。

2 基于标签改进生成对抗网络的分类算法

2.1 算法模型

在前人研究的基础上,本文在原有模型中添加标签,以解决域适应过程中减弱类间差异的问题。整个算法流程如图2所示。为测试算法效果,使用MNIST和MNIST-M数据集进行分类(见图3)。MNIST-M是彩色带背景的手写数据集,它是由BSDS500数据集中随机提取图片,然后对其随机位置剪裁成28*28的大小,减去黑白的手写数据集取绝对值得到的,其线条特征与MNIST有一定的相似。如表1所示,网络结构由两部分组成,一部分是2个卷积层和2个池化层组成的特征提取网络,另一部分由全连接层组成的类间分类器和领域判别器。

图2 算法流程

(左:MNIST数据集,右:MNIST-M数据集)图3 测试数据集

表1 基于标签改进生成对抗网络的网络结构

2.2 训练算法

在开始训练源域网络前,将MNIST数据集图像转为假彩色,作为3通道输入到网络,测试时仍转假彩色去测试。batch设置为64,epoch设置为30,使用Adam优化器,其学习率为0.001,衰减率为(0.9,0.99),以迭代器的方式提取数据来应对源域和目标域数据量不相等。数据量少的将会重复抽取,直到数据量大的数据集完成一次读取为止。一个Epoch里,域分类器和特征提取器的训练次数为1∶10,原因是域分类器训练效果明显,收敛速度快,而特征提取器以欺骗域分类器训练以求得到域不变特征,故训练时间要长一些。损失函数无监督训练采用经sigmoid处理后的BCE损失函数,验证集部分采用交叉熵损失函数。最后是权重衰减参数,衰减权重为

W=w+w(1-t/epochs)

(7)

其中,w为目标域的验证集相关权重,设置为0.1,t为当前的epoch次数,epochs为总的迭代次数。总损失函数为

LOSS=Wloss1+(1-W)loss2

(8)

算法训练过程如下。

(1)训练源域网络,把源域数据和类别放入特征提取器和分类器,使网络能够很好的区分源域数据的类别,然后把此网络的参数模型保存下来,称为源域网络。

(2)训练领域判别器,固定特征提取器的参数交替(或合并)地输入源域和目标域的数据,以此来判断输入的图片来自源域或目标域,使领域判别器的参数更新,能够更好地区分图片来自源域还是目标域。

(3)训练特征提取器,固定领域判别器的参数交替(或合并)地输入源域和目标域的数据,此时,域分类器将以错误的域类别作为训练,以此来训练特征提取器,使特征提取器能够提取到两个领域的共同特征,以此来欺骗域分类器。两种方式以一定的比例交替训练,达到拟合两组数据的分布。

(4)训练类间分类器,把验证集少量的目标域带标签数据输入生成的特征R作为监督学习,以此交叉熵作为损失函数的部分权重来训练整个网络。

2.3 实验结果与分析

将本文方法与WDGRL[10]、ADDA[11]方法进行比较,结果分别见图4和表2。可以看出,对假彩色处理后的MNIST已经有很好分类效果的源域网络直接用于目标域的彩色MNIST_M,有33%的准确率,说明源域和目标域之间确实有着一定的共同特征。ADDA通过域对抗方法,对齐2个域的分布,得到了良好的分类效果。WDGRL通过Wasserstein距离衡量2个样本的分布差异,参与了网络的更新标准。本文方法基于ADDA方法的基础上,目标域少量验证集以一定的减弱权重参与网络更新,相比较于前2种方法,准确率高一点,但训练时间要长约30%。

图4 训练过程中的目标域分类精度

表2 不同方法准确率对比

3 遥感影像场景分类实验

3.1 数据集

从AID[12]、NWPU-RESISC45[13]、UCMerced_LandUse[14]和WHU-RS19[15]数据集分别取出相应类别的数据来构建源域数据和目标域数据。AID共分为30个类别,每个类别有220~420张图像;NWPU-RESISC45数据集共分为45个类别,每个类别有700张图像;UCMerced_LandUse共分为21个类别,每个类别有100张图像;WHU-RS19共分为19个类别,每个类别有50张图像。源域数据来源为:从AID数据集中选取机场、桥梁、商业区、沙漠、工业区、湖泊、草地、公园、海港和地铁站共10个类别;从UCMerced_LandUse数据集中选取沙滩、林业区、河流和停车场共4个类别。目标域数据来源为:从NWPU-RESISC45数据集中选取机场、桥梁、商业区、沙漠、工业区、草地、停车场和地铁站共8个类别,从WHU-RS19数据集中选取沙滩、海港、湖泊、公园、林业区和河流共6个类别。

源域和目标域数据集数据的空间分辨率为0.2~30 m,从图5、图6可以看出,两者之间存在明显差异。由于不同数据集中的图片尺寸不一,需要对图片进行归一化处理,先把不同数据集的图片大小统一以最短边缩放成224,再把图片中心剪裁成224*224,对应VGG16网络的输入。训练过程中,分别从源域和目标域的各类别中取出10%的比例作为验证集。

图5 源域数据

图6 目标域数据

3.2 实验过程

各部分采用的网络结构如表3所示。生成器(特征提取器)采用VGG16网络[2],类间分类器和领域判别器均采用三层全连接层(fully connected layer),其中前两层全连接均使用激活函数ReLU,类间分类器在第1层引入了Droupout、第3层引入了平均池化。

表3 用于遥感网络场景分类的网络结构

由于网络初始化训练速度很慢,特别是生成器的训练,所以本文选取预训练的VGG16网络来初始化生成器的参数,加快初始训练速度。batch设置为64,epoch设置为30,使用Adam优化器,其学习率为0.002,衰减率为(0.9,0.99)。

一个epoch里,领域判别器和生成器的训练次数为1∶k,原因是域分类器训练效果明显,收敛速度快,而生成器以欺骗域分类器训练以求得到域不变特征训练时间长。本文在训练过程中,k的值设置为10。

3.3 实验结果与分析

本文的方法与DDC(Deep Domain Confusion)迁移学习方法[6]进行对比,其最终在目标域上的精度见表4,训练过程中的精度和loss变化分别见图7、图8。

图7 训练过程中目标域的分类准确率变化曲线

图8 训练过程中目标域的loss变化曲线

表4 精度评价

可以看出,仅利用训练好的源域网络对目标域进行分类,其精度达到47%,说明源域和目标域有着一定的相似特征。与DDC迁移学习方法相比,本文的方法分类精度略高。从loss变化趋势来看,DDC迁移学习方法波动明显,而本文方法的loss下降比较平稳。从准确率来看,两者的变化曲线相似,但是训练时间明显比DDC迁移学习方法要长。

4 总结

本文针对目标域带标签样本数据量少的问题,通过改进生成对抗网络模型,更好地利用标签信息增强目标域类间的区分度。在常用的遥感图像场景分类数据集上进行了实验,结果表明本文方法对目标域有较好的分类效果。但本文方法还存在一些问题,如在实际应用中往往不知道目标域的类别数量,无法明确地将目标域和源域类别一一对应,同时本论文方法也不适用于多领域迁移分类。

猜你喜欢

源域分类器特征提取
学贯中西(6):阐述ML分类器的工作流程
基于朴素Bayes组合的简易集成分类器①
基于参数字典的多源域自适应学习算法
基于动态分类器集成系统的卷烟感官质量预测方法
空间目标的ISAR成像及轮廓特征提取
基于Gazebo仿真环境的ORB特征提取与比对的研究
基于特征提取的绘本阅读机器人设计方案
一种自适应子融合集成多分类器方法
基于MED—MOMEDA的风电齿轮箱复合故障特征提取研究
从映射理论视角分析《麦田里的守望者》的成长主题