浅谈GAN网络

浅谈GAN网络GAN网络现在是研究的热门领域,CV中几乎所有的任务开始使用了GAN来提升性能,跟着研究潮流,这二天看了看LanGoodFellow2014年关于GAN网络的开山之作《GenerativeAdversarialNets》,因为看得是第一篇有关GAN的论文,所以想想还是有必要进行一个细致的总结。GAN网络实际上包含了2个网络,一个是生成网络(Generator),另一个是判别网络(Dis…

大家好,欢迎来到IT知识分享网。

GAN网络现在是研究的热门领域,CV中几乎所有的任务开始使用了GAN来提升性能,跟着研究潮流,这二天看了看Lan GoodFellow 2014年关于GAN网络的开山之作《Generative Adversarial Nets》,因为看得是第一篇有关GAN的论文,所以想想还是有必要进行一个细致的总结。

GAN网络实际上包含了2个网络,一个是生成网络(Generator),另一个是判别网络(Discriminator)。但是个人不成熟的理解就是GAN的核心目的其实还是在于它的生成器G,而至于为什么存在判别器D,主要是为了引入对抗训练,通过对抗训练的方式让生成器能够生成高质量的图片。在GAN之前也有很多比较经典的生成网络,比如自编码器AE,去噪自编码器DAE,变分自编码器VAE等等,都可以用来生成图片。那么GAN与它们有什么不同?主要区别还是在于生成网络的学习方式不同,AE,DAE只是单纯的从单个样例中学习,目标是重构出输入样例,说白了网络的输出就是尽可能重构出网络的输入,这样其实并不能说数据生成,充其量也就是数据复原罢了,于是更高级的VAE(变分自编码器)就出现了,VAE的思想不是去学习每个样例的特征,而是去学习一个数据集的分布,如果学习到了这个分布,那么我任意取出分布中的一个点,将这个点通过解码器部分,就能得到一个符合原数据分布,而且与原数据集中的数据不同的新数据。在这个过程中,往往假设数据集符合正太分布,编码器中有2个并行的全连接层,一个代表数据集的均值,另一个代表数据集的方差,因为前面是假设数据集符合一个正太分布,那么怎么保证这个假设成立?那么就用KL散度去度量学习到的数据集的正太分布和标准正太分布之间的差异性,去优化这个差异性,使得学习到的数据集的分布越来越趋近与正太分布。我们假设的数据集是一个正太分布,而所有的正太分布都可以由标准正太分布转化得到,N(\mu ,\sigma ) = \sigma *N(0,1)+\mu),KL散度计算2个分布见的差异性:KL(p||q)\sum p(x_{i})*[log(p(x_{i})-log(q(x_{i}))].在VAE中数据集的分布是假设得到的,而在GAN中没有显式的定义一个数据集的分布,GAN中数据集分布是由生成器G自动学习到的一个复杂的分布。

前面说到GAN的终极目的是学习一个高质量的生成器G,GAN通过引入判别器D来实现得到高质量的生成器G。G在训练过程中的目的是生成尽可能逼真的图片去让判别器判断不了这张图片到底是真实图片还是生成的虚假照片,D在训练过程中的目的就是尽可能取辨别真假图片,所以G是希望是D的犯错率最大化,而D则是希望自己犯错率最小化,二者互为对抗,在竞争中共同进步。理论上这种关系可以达到一个平衡点,即所谓的纳什均衡,也就是说G生成的图片D判别它为真实数据的概率是0.5,也即现在判别器已经无法区分生成器所生成的图片的真假,那么生成器的目的也就达到了,以假乱真了。Paper 中GAN的优化目标函数如下:

浅谈GAN网络

 

这是一个最大最小优化问题,这个表达式真是绝了,因为同时涉及到最大最小问题,很难一起优化,于是作者就采用的交替训练策略,首先训练判别器网络D,优化判别器的参数,然后在训练生成器网络G。训练步骤如下:

浅谈GAN网络

作者在论文中说明了,训练D网络的step=K,是一个超参数,在实验中作者取K=1,理论上当G网络的更新速度足够慢的时候,可以保存D网络训练的更充分,反过来又使的G网络的性能更好。但是上面训练生成器网络时也有一个问题,作者在原文中也指明了,在训练早期,G的性能较差,G生成的图片和原始数据差异较大,D网络很容易就判别它是虚假的图片,也即D(G(zi))很小,log(1-D(G(Zi)))趋近于0,导致梯度消失问题,会引起G网络训练不充分,于是作者在优化G网络时,对G网络的损失函数做了相应的调整变成minimize -log(D(G(Zi))).作者在文中也通过理论证明了,全局最优解的存在性,此时D网络的LOSS=1.38,D对于G生成的数据的准确率应该是50%,见下图:

浅谈GAN网络

使用tensorflow简单搭建一个ministGAN

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Jul 27 15:06:58 2018

@author: wsw
"""

# create mnist fc GAN

import tensorflow as tf
import numpy as np
import os
import time
import matplotlib.pyplot as plt
slim = tf.contrib.slim

tf.reset_default_graph()

def load_data():
    dataPath = '../mnist/train.npy'
    trainData = np.load(dataPath)
    return trainData



def get_next_batch(Datas,batchsize=128):
    image = tf.train.slice_input_producer([Datas],
                                          num_epochs=200,
                                          shuffle=True,
                                          )
    # form a batch date
    image_batch = tf.train.batch([image],
                                 batch_size=batchsize,
                                 capacity=1000,
                                 num_threads=4)
    return image_batch


def generator(inputs,is_training=True):
    
    with slim.arg_scope([slim.fully_connected],
                        activation_fn=tf.nn.leaky_relu,
                        ):
        net = slim.fully_connected(inputs,num_outputs=256,scope='fc1')
        net = slim.fully_connected(net,num_outputs=784,
                                   activation_fn=tf.nn.tanh,
                                   scope='fc2')
        return net

    
def discriminator(inputs):
    
    with slim.arg_scope([slim.fully_connected],
                        activation_fn=tf.nn.leaky_relu,
                        ):
        net = slim.fully_connected(inputs,num_outputs=512,scope='fc1')
        net = slim.fully_connected(net,num_outputs=1,
                                   activation_fn=tf.nn.sigmoid,
                                   scope='fc3')

        return net


def train_GAN():
    # load data
    train_Data = load_data()
    train_Data = np.float32(train_Data/255.0)
    # get batch data
    batchsize = 128
    image_batch = get_next_batch(train_Data,batchsize=batchsize)
    # build gan model
    # random noise inputs
    noise_inputs = tf.random_normal(shape=[128,100])
    
    with tf.variable_scope('Generator',reuse=tf.AUTO_REUSE):
        gen_imgs = generator(noise_inputs)
        
    with tf.variable_scope('Discriminator',reuse=tf.AUTO_REUSE):
        # compute d_loss
        # true data score
        truth_score = discriminator(image_batch)
        # fake data score
        fake_score = discriminator(gen_imgs)
    
    with tf.name_scope('compute_accuracy'):
        predict_fake_label = tf.where(fake_score>0.5,
                                      tf.ones([batchsize,1],
                                               dtype=tf.uint8),
                                      tf.zeros([batchsize,1],
                                               dtype=tf.uint8))
                                      
        gt_fake_label = tf.zeros(shape=[batchsize,1],dtype=tf.uint8)
        accuracy = tf.reduce_mean(tf.cast(tf.equal(predict_fake_label,gt_fake_label),
                                          dtype=tf.float32))
        
    with tf.name_scope('D_loss'):
        d_loss = -tf.reduce_mean(tf.log(truth_score)+tf.log(1-fake_score))
    
    with tf.name_scope('G_loss'):
        g_loss = -tf.reduce_mean(tf.log(fake_score))
    
    
    
    with tf.name_scope('optimizer'):
        global_step = tf.train.create_global_step()
        lr = tf.train.exponential_decay(learning_rate=1e-4,
                                        global_step=global_step,
                                        decay_rate=0.9,
                                        decay_steps=470)
        optimizer = tf.train.AdamOptimizer(lr)
        # get generator variable list
        gen_vars = slim.get_variables(scope='Generator')
        # get discriminator variable list
        disc_vars = slim.get_variables(scope='Discriminator')
        print('Generator Trainable Variables',gen_vars)
        print('Discriminator Trainable Variables',disc_vars)
        train_G = optimizer.minimize(g_loss,global_step,var_list=gen_vars)
        train_D = optimizer.minimize(d_loss,global_step,var_list=disc_vars)
        
    
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        # must needed 
        tf.local_variables_initializer().run()
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess,coord)
        try:
            epoch = 1
            while not coord.should_stop():
                
                start = time.time()
                # train Discriminator
                d_loss_value,_,accu = sess.run([d_loss,train_D,accuracy])
                # train Generator
                g_loss_value,_ = sess.run([g_loss,train_G])
                end = time.time()
                step = global_step.eval()//2
                fmt = 'Epoch:{:02d}-Step:{:05d}-Gloss:{:.3f}-Dloss:{:.3f}-D_accu:{:.5f}-Elapsed:{:.3f}(Sec)'\
                .format(epoch,step,g_loss_value,d_loss_value,accu,end-start)
                if step%100==0:
                    print(fmt)
                if step%470==0:
                    epoch += 1
                    valid_imgs_out = sess.run(gen_imgs)
                    show_result(valid_imgs_out,epoch)
                    
        except tf.errors.OutOfRangeError: 
            coord.request_stop()
            print('train finished!!!')
        coord.join(threads)


        
def show_result(valid_imgs,epoch):
    imgDir = './images2'
    if not os.path.exists(imgDir):
        os.mkdir(imgDir)
    for i in range(100):
        img = valid_imgs[i]
        img = img.reshape(28,28)
        plt.subplot(10,10,i+1)
        plt.imshow(img,cmap='gray')
        plt.axis('off')
    plt.savefig(os.path.join(imgDir,'%d.png'%epoch))
    
    
if __name__ == '__main__':
    train_GAN()

训练时有几个细节需要注意:

  1. 训练G网络时候需要固定D网络参数不变,为了方便操作,就用slim.get_variables()分别获取G和D网络对应的参数,然后分别进行优化训练
  2. G网络中引入BN,以及最后一层激活函数选择tanh,其他层激活函数选择leaky_relu效果比较好

     3. 以上代码我自己跑得时候,迭代200个epoch貌似效果也不是很好,如果大家发现问题在哪?请在下方留言,不胜感激

200epoch 后生成的图片效果(囧,感觉很糟糕)

浅谈GAN网络

 

非常模糊,只能看到每个数字的大概形状。

免责声明:本站所有文章内容,图片,视频等均是来源于用户投稿和互联网及文摘转载整编而成,不代表本站观点,不承担相关法律责任。其著作权各归其原作者或其出版社所有。如发现本站有涉嫌抄袭侵权/违法违规的内容,侵犯到您的权益,请在线联系站长,一经查实,本站将立刻删除。 本文来自网络,若有侵权,请联系删除,如若转载,请注明出处:https://yundeesoft.com/16237.html

(0)

相关推荐

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注微信