APP下载

基于生成对抗网络的空中目标图像生成算法研究

2021-06-24祁生勇臧月进吕国云

空天防御 2021年2期
关键词:梯度损失距离

祁生勇,臧月进,吕国云,杜 明

(1.西北工业大学电子信息学院,陕西西安 710072;2.上海机电工程研究所,上海 201109)

0 引言

2014年,Goodfellow 等[1]提出无监督生成对抗网络(generative adversarial networks,GAN),得益于GAN 巧妙的构思,其在图像生成领域取得了巨大的成功。

2015年,Denton 等[2]提出了拉普拉斯金字塔生成对抗网络(Laplacian pyramid GAN,LAPGAN),该算法将拉普拉斯金字塔和条件生成对抗网络[3]相结合,提升了生成图像的质量。2015年,Radford[4]提出了DCGAN,该模型结合卷积神经网络对于图像特征的提取能力和GAN 对数据建模的能力,进一步提升了生成图像的质量。2017年,Arjovsky[5]等提出了WGAN(Wasserstein GAN),从理论层面分析了原始GAN 训练不稳定的问题。 WGAN 使用Wasserstein 距离衡量真实数据分布和生成数据分布的距离,能够在任何情况下为生成器提供梯度信息以更新参数。

目前,生成对抗网络在图像生成、图像转换、图像超分辨率等[6]领域取得了巨大的成功。空中目标种类繁多,并且各种机型姿态各异,公开的数据集较少,因此针对空中目标图像生成的难度较大。本文基于DCGAN架构,通过优化判别器损失函数,提高了模型训练稳定性,同时提高了生成图像的质量。

1 基于DCGAN模型的图像生成算法

DCGAN 首次将GAN 和卷积神经网络结合起来,同时设计了一些优化训练的技巧防止模式崩塌,一定程度上解决了原始GAN 训练不稳定的问题,在MNIST[7]和LSUN[8]数据集中取得了较好的结果。

1.1 DCGAN网络结构

DCGAN 训练过程如图1所示,首先随机向量输入生成器得到生成图像,判别器负责区分图像为真的概率。判别器输出概率越接近1 说明输入图像越真实,接近0 则说明与真实图像差距较大。生成器和判别器不断地对抗迭代优化,理论上,最终判别器无法区分输入的是真实还是生成的图像,对任何输入得到的输出都为0.5,这时模型就达到了最优,生成器完整捕获了数据的真实分布。

图1 DCGAN训练过程Fig.1 DCGAN training process

图2 展示了DCGAN 生成器网络结构,首先输入大小为100且服从均匀分布的随机向量z,接着将其映射为1 024 个4×4 大小的特征图,特征图通过4 个不同的步幅卷积步骤后得到大小为64×64 的彩色图像G(z)。

图2 DCGAN生成器网络结构Fig.2 DCGAN generator network structure

DCGAN 判别器网络结构如图3所示,判别器由4个卷积层和1 个全连接层构成,输入真实图像或生成图像,输出图像为真的概率,若为真实图像则概率接近1,若为生成图像则概率接近0。同生成器一样,除了输入层,其他所有层都进行批归一化[9]处理。

图3 DCGAN判别器网络结构Fig.3 DCGAN discriminator network structure

1.2 DCGAN模型的损失函数

DCGAN 的损失函数和原始GAN 的损失函数一样,都为交叉熵损失函数,如式(1)所示。

式中:z和x分别表示随机向量和真实图像;pdata(x)和pz(z)分别表示真实图像和随机向量的概率分布;G和D分别表示生成器网络和判别器网络;E表示数学期望;V(D,G)表示目标函数。判别器D的目标是区分真实图像和生成图像:对于式(1)右侧的第1 项,输入是真实图像,判别器D希望输出的概率接近1;对于式(1)右侧的第2 项,输入为生成图像G(z),判别器希望输出趋近于0,取反之后也是越大越好,这就是max(D)的含义。生成器训练时,D保持不变,为了“欺骗”判别器,希望D(G(z))接近1,这时生成的图像会更接近真实图像,整体越小代表生成效果越好,这就是min(G)的含义。

1.3 DCGAN模型存在的问题

空中目标种类繁多、姿态各异,图像特征复杂,DCGAN 使用交叉熵损失函数易导致模型梯度消失,陷入局部最优解,不能完整地捕捉数据真实的分布。图4 给出了生成器数据分布与真实数据分布示意图,蓝色实线表示生成器捕捉的数据分布,黑色虚线表示真实的数据分布:(a)表示训练刚开始,两者距离较远;(b)表示随着模型不断地训练,生成器学习到的数据分布向真实分布靠近;(c)表示理想状况下生成器完全学习到了真实分布。

图4 生成器数据分布与真实数据分布示意图Fig.4 Schematic diagram of generator data distribution and real data distribution

DCGAN 损失函数为交叉熵损失,生成器很容易收敛到局部最优,参数无法更新,最终生成器的数据分布与真实数据分布如图4(b)所示,这会导致生成图像有伪影、图像模糊等问题,因此需要对损失函数进行优化。

2 改进的DCGAN图像生成算法

2.1 DCGAN损失函数缺陷

基于第1 章的分析,DCGAN 的损失函数包含最大化判别器和最小化生成器。Goodfellow[1]等证明,当损失函数为交叉熵损失时,假设最优的判别器固定,生成器G更新的目标如式(2)所示。

式中:pdata(x)和分别代表真实数据的分布和生成器的数据分布;x和分别表示真实图像和生成图像;LG表示生成器损失函数。对式(2)进一步化简得到式(3)。

所以DCGAN 的生成器损失函数近似于最小化pdata(x)和之间的JS 散度(Jensen-Shannon divergence)。Arjovsky[5]指出,若pdata(x)和不重叠,则两个分布之间的JS 散度为常数2 lg2,因此生成器的梯度将变为0,而WGAN 中使用Wasserstein 距离去判断两者之间的距离,每次更新判别器权重时都强制映射到一个区间,保证了反向传播过程永远有梯度信息,两者之间的关系如图5所示。

图5 WGAN梯度示意图Fig.5 Schematic diagram of WGAN gradient

2.2 Wasserstein距离

针对原始GAN 存在的问题,WGAN 使用Wasserstein 距离反映两个分布之间的距离,如式(4)所示。

式中:γ表示pdata(x)和联合分布;Π表示所有联合分布的集合;表示两个数学分布的Wasserstein距离。

对于每个可能的联合分布γ,从中采样得到真实图像x和生成图像,计算该联合分布下的期望值,在所有可能的期望值中取下界,就得到了Wasserstein 距离。相较于JS 散度,即使两个分布不重叠,Wasserstein距离也能用来衡量两者之间的距离关系。

虽然Wasserstein 距离更准确地度量了生成图像和真实图像的分布距离,但是式(4)无法直接求解,因此进一步化简为式(5)。

式中,D∈1-Lipschitz 表示判别器D满足1-Lipschitz连续性条件,所以每次判别器权重参数会被强制截断到[-0.01,0.01]之间。

总体来说,原始GAN 中交叉熵损失函数不能很好地判别两个分布的距离。所以,WGAN 提出使用Wasserstein 距离衡量两个分布之间的距离,无论两者距离多远,都能提供有效的梯度信息以更新网络参数。

2.3 改进的损失函数

为了更清楚地说明判别器的作用,将式(1)中判别器的部分改写为式(6)。

式中,LD表示判别器损失。根据2.2节的分析,为了保证模型训练过程中梯度不为0,WGAN 在梯度更新时把判别器权重截断在[-0.01,0.01]之间,使其满足Lipschitz连续性条件。但是强制性截断处理会使得判别器丢失一部分图像信息,只能学习到一个简单的分布,这时判别器对复杂的飞机图像特征分辨能力较弱,因此本文对WGAN 判别器的权重参数使用梯度惩罚[10]替代强制性截断,同时满足Lipschitz 连续性条件,如式(7)所示。

式(7)中:y表示在x和的连线上随机插值采样得到的一个新样本;ppenalty表示新样本的分布;λ表示惩罚系数。进一步化简可以得到式(8)。

梯度惩罚过程如下:首先随机选择一个真实样本x∼pdata(x)和一个生成样本;然后在x和的连线上随机插值采样得到的一个新的样本y∼ppenalty。最后一项惩罚项表示:判别器D对采样得到的样本y求梯度,梯度大于1 的时候,惩罚项会使梯度信息接近1,这样就会将梯度限制在一定范围,pdata(x)和的距离也会越来越近,生成的样本越来越符合真实场景。

最终加入惩罚项的判别器损失函数,如式(9)所示。

3 实验结果与分析

3.1 实验概述

1)实验环境配置及数据集。实验环境基于酷睿i7 处理器与英伟达GTX1080Ti GPU 环境以及Pytorch 1.4.0深度学习框架。

构建空中机动目标数据集,该数据集包含6 800张各种类型的空中目标图片,网络模型的训练集与交叉验证集包含4 760 张图片,测试数据集包含2 040 张空中目标图片。

2)训练方式设计。由于网络参数较多,为了防止模型对于训练数据集过拟合,在模型训练时使用Dropout[11]技术,随机固定某些参数不更新。

生成器和判别器组成的DCGAN 网络体现的是一种相互对抗学习的过程,如果判别器训练效果足够好,生成器梯度会消失;判别器训练效果不好,生成器梯度又会不够准确。为了权衡两者的关系,在训练过程中,判别器参数更新多次,生成器参数更新一次。

3)实验参数设置。本文所有的训练过程中模型的学习率均为0.000 2;batch size 为64;梯度惩罚λ=10;生成器和判别器参数更新次数比例为1∶5,即判别器参数更新5次,生成器参数更新1次。

3.2 评价标准

为评价改进后算法的性能,从模型训练过程稳定性以及生成图像FID[12]和IS[13]得分两方面进行分析比较。改进前后生成器损失函数对比如图6所示。

图6 改进前后生成器损失函数对比Fig.6 Comparison of generator loss function before and after improvement

训练稳定性评估:为了评估改进DCGAN 训练过程的稳定性,本文采用相同的数据集对改进前后的损失函数进行比较,由图6可知,改进后模型的生成器损失函数整体波动更小,训练过程更稳定。

图像生成质量评估:FID 分析于2017年被提出,该方法首先把真实图像和生成图像输入训练好的分类模型中(如Inception Net-V3网络),去除了最后的池化层得到一个高维特征向量,通过计算生成图像和真实图像高维特征向量的距离,就可以得到FID 分数,FID 越小,表示生成的图像质量越好、多样性越好,数学表示如式(10)所示。

式中:μ为经验均值;Σ为经验协方差;Tr为矩阵的迹;x代表真实图像,g代表生成图像。

IS 是另一个通用的GAN 模型评价标准,IS 评价的思路也是使用一个训练好的网络对生成图像进行分类,如果分类的准确性越高,说明生成的图像越真实;同时生成每个种类图像的概率越平均,说明模型生成图像的多样性越高。综合得分越高表明生成器生成图像质量越好,数学表示如式(11)所示。

式中:表示生成图像;m表示标签信息;是这两个分布的KL 散度,对其求指数就得到了最终IS分数。

3.3 实验结果分析

本文使用空中机动目标数据集,基于Pytorch深度学习框架分别对原始的DCGAN和改进的DCGAN网络进行训练,测试了32×32 和64×64 两种分辨率的图像。

图6 为训练损失函数曲线,由图6 可知,改进后的生成器损失函数波动明显下降,训练过程更加稳定。图7(a)和图7(b)分别为改进前后32×32 分辨率的图像生成结果,可以看出,改进后的生成图像的边缘细节更清楚,图像噪点明显减少并且目标主体和背景区分明显。图8(a)表示真实图像,图8(b)和图8(c)分别表示改进前后64×64分辨率的图像生成结果。与图7相比,当分辨率增大时,更多的图像细节显示了出来,如:直升机机翼、战斗机尾翼以及起落架都更加清晰明显。从图8(b)可以看出改进前生成的飞机图像轮廓不明显,并且容易与背景混合,会出现很多不真实的纹理;从图8(c)可以看出改进后的飞机主体与背景更容易区分,当飞机颜色与背景相近时也不会出现两者混在一起的情况,生成图像更加接近真实场景。

图7 改进前后模型生成结果对比(32×32)Fig.7 Comparison of generated results(32×32)before and after improvement

图8 真实图像与改进前后模型生成结果对比(64×64)Fig.8 Comparison of real images and generated results(64×64)before and after improvement

综上所述,改进后模型生成的图像能够展示更多细节,飞机主体与背景差异显著增强。通过观察生成结果还可以看出,改进后模型生成的空中目标图像与真实图像更接近,虚假纹理减少,图像边缘细节更加丰富,可为空中目标检测识别任务提供更强的数据支持。表1给出了两种算法的FID和IS得分情况。

从表1 中可以看出,改进后的模型在32×32 分辨率下FID 和IS 得分分别提高了9.4%和7.6%;64×64分辨率下FID 和IS 得分分别提高了5.9%和4.8%。由此可得,改进后的模型生成图像的质量更高,图像细节更丰富,生成种类更加多样化,没有出现原始DCGAN模式崩溃的情况。

表1 两种算法FID和IS得分比较Tab.1 Comparison of FID and IS scores between two algorithms

4 结束语

本文提出一种改进的DCGAN 图像生成算法。模型训练过程中使用改进的Wasserstein 距离衡量生成数据分布和真实数据分布,优化了原始DCGAN 的判别器损失函数,能够在任何情况下为生成器提供梯度信息。实验结果表明,针对空中目标数据集,改进后的模型训练过程更加稳定,生成图像更清晰,并且能够根据需求生成不同分辨率的图像,可以有效扩充空中目标检测任务的数据样本。

猜你喜欢

梯度损失距离
洪涝造成孟加拉损失25.4万吨大米
两败俱伤
一个具梯度项的p-Laplace 方程弱解的存在性
内容、形式与表达——有梯度的语言教学策略研究
距离美
航磁梯度数据实测与计算对比研究
组合常见模型梯度设置问题
损失
爱的距离
距离有多远