量化百科

DCGAN的小尝试(1)

由ypyu创建,最终由ypyu 被浏览 22 用户

话说当今的深度学习网络框架世界,除了Caffe,还有很多不错的框架。这一次为了省事,我们直接找一个开源的应用进行分析和尝试。而这次的框架主角是keras,一个拥有简洁API的框架。而我们今天的主角来自深度学习界的大神Yann LeCun(我比较喜欢叫他颜乐存,哈哈……)在Quora上的对这个问题的回答:

What are some recent and potentially upcoming breakthroughs in deep learning?


There are many interesting recent development in deep learning, probably too many for me to describe them all here. But there are a few ideas that caught my attention enough for me to get personally involved in research projects.The most important one, in my opinion, is adversarial training (also called GAN for Generative Adversarial Networks).

总结成一句话:深度学习的未来,就是干(GAN)!

听上去就让人热血沸腾啊~

抱着极大的好奇心,我们开始对GAN的探险之旅。

DCGAN

如果从GAN的起点开始聊起,那么等我们聊到正题,估计好几集都过去了。所以让我们忘掉前面的种种解法,直接来到我们的深度学习部分:DCGAN。全称是Deep Convolution GAN。也就是用深度卷积网络进行对抗生成网络的建模。

对抗神经网络(GAN)有两个主角——

  • 一个是G(Generator),也就是生成模型;它的输入是一个随机生成的向量,长度不定,输出是一个具有一定大小的图像(N*N*3)和(N*N*1)。
  • 一个是D(Discriminator),也就是判别模型。在我们接下来介绍的模型中,它的输入维度和G的输出一样,输出是一个长度为1 的向量,数字的范围从0到1,表示图像像一个正常图片的程度。

G的输入和输出都比较好理解,D的输入也比较好理解,那么D的输出是什么含义呢?它表示了对给定输出是否像我们给定的标准的输入数据。这句话可能有点绕口,我们可以把判别模型理解成一个解决分类问题的模型,那么在这个问题中判别模型的结果就是区分一个输入属于下面两个类别中的哪个——“正常输入”和“非正常输入”。

举个更具体的例子。对于MNIST数据集来说,每一个手写的数字都可以认为是一个“正常输入”,而随便生成的一个不像手写数字的输入都可以认为是一个“非正常输入”。而我们的判别模型就是要判断这个问题,我们学习的目标也是学习出一个能够解决这个问题的模型。

那么,我们的生成模型的目标呢?就是我们能够从一个随机生成的向量生成一样“正常输入”的图像。听上去有点神奇吧,不过现实中这个效果是可以实现的。我们可以想象我们的输入空间是满足某种分布的一个空间,对于空间中的每一个点,我们都可以利用生成模型将其映射成为一个图像,现在我们限定了生成的图像必须是“正常输入”,那么输入和输出在某种程度上已经确定,我们就可以用监督学习的方式进行学习了。不过对于生成模型来说,我们的loss是生成图像的likelihood,这个和判别模型的loss不太一样。

好了,两个模型的输入输出已经说完了,下面还有两个问题需要解决:

  1. 判别模型的训练数据该如何准备?正例可以用现有数据,那负例呢?
  2. 生成模型的loss该如何计算?

其实要想解决这两个问题,我们需要把两个模型连起来。因为生成模型和输出和判别模型的输入在维度和含义上都是相同的,这样连起来我们就可以解决上面的两个问题。我们利用判别模型去判断生成模型的likelihood,而用生成模型产生的结果去做判别模型的负例,这样就把上面的两个问题解决了。

当然,关于把生成模型的输出作负例这件事,听上去还是有点奇怪的。生成模型的目标是生成“正常输入”,那么生成了“正常输入”还被当成负例,也是够冤的。不过这种矛盾的关系在机器学习中经常存在,就像优化目标中的loss项和正则项一样,这两个目标往往也是一对矛盾体。所以这种矛盾的存在并不奇怪,这也是这个模型被称为“对抗”的原因。

大家都喜欢用警察和小偷的关系来比喻生成模型和判别模型之间的“对抗”关系,我觉得可以用“魔高一尺,道高一丈”,“道高一丈,魔高十丈”来解释两个模型随着对抗不断强化的关系。判别模型在进化中能够捕捉不像“正常输入”的所有细节,而生成模型则会尽全力地模仿判别模型心中“正常输入”的形象。

好了,说了这么多,我们来看看针对“Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks” 这篇论文的keras版“实现”:GitHub - jacobgil/keras-dcgan: Keras implementation of Deep Convolutional Generative Adversarial Networks,说是“实现”是因为这个实现实际上和论文中期望的有点小不同。

这个代码使用的数据集是MNIST,经典的小数据集,手写数字。在我的实验结果中,生成模型生成的手写数字是这样的:

除了个别数字之外,大多数的数字生成得还是有模有样的嘛!

另外我们看一下两个模型在训练过程中的Loss:

其中蓝色是生成模型的loss,绿色是判别模型的loss,可以看出两个模型的Loss都存在一定程度的抖动,也可以算是对抗过程中的此消彼长吧。

最后套用乐存大师的话做结尾:

It seems like a rather technical issue, but I really think it opens the door to an entire world of possibilities.

既然乐存老师都这么说了,我们真得好好看看这个模型了。下一会我们来看看前面提到的论文和上面的提到的实现。

\

标签

深度学习API