博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
【机器学习PAI实战】—— 玩转人工智能之利用GAN自动生成二次元头像
阅读量:6004 次
发布时间:2019-06-20

本文共 4173 字,大约阅读时间需要 13 分钟。

前言

深度学习作为人工智能的重要手段,迎来了爆发,在NLP、CV、物联网、无人机等多个领域都发挥了非常重要的作用。最近几年,各种深度学习算法层出不穷, Generative Adverarial Network(GAN)自2014年提出以来,引起广泛关注,身为深度学习三巨头之一的Yan Lecun对GAN的评价颇高,认为GAN是近年来在深度学习上最大的突破,是近十年来机器学习上最有意思的工作。围绕GAN的论文数量也迅速增多,各种版本的GAN出现,主要在CV领域带来了一些贡献,如下图所示。

【机器学习PAI实战】—— 玩转人工智能之利用GAN自动生成二次元头像

我们可以利用GAN生成一些我们需要的图像或者文本,比如二次元头像。

GAN简介

GAN主要的应用是自动生成一些东西,包括图像和文本等,比如随机给一个向量作为输入,通过GAN的Generator生成一张图片,或者生成一串语句。Conditional GAN的应用更多一些,比如数据集是一段文字和图像的数据对,通过训练,GAN可以通过给定一段文字生成对应的图像。

GAN主要可以分为Generator(生成器)和Discriminator(判别器)两个部分,其中Generator其实就是一个神经网络,输入一个向量,可以输出一张图像(即一个高维的向量表示),如下图示。

【机器学习PAI实战】—— 玩转人工智能之利用GAN自动生成二次元头像

【机器学习PAI实战】—— 玩转人工智能之利用GAN自动生成二次元头像
Discriminator也是一个神经网络,输入为一张图像,输出为一个数值,输出的数值用于判断输入的图像是否是真的,数值越大,说明图像是真的,数值越小,说明图像为假的,如下图示。
【机器学习PAI实战】—— 玩转人工智能之利用GAN自动生成二次元头像
【机器学习PAI实战】—— 玩转人工智能之利用GAN自动生成二次元头像
Generator负责生成图像,Discriminator负责对Generator生成的图像和真实图像去进行对比,区别出真假,Generator需要不断优化来欺骗Discriminator,以假乱真;而Discriminator也不断优化,来提高识别能力,能够识别出Generator的把戏。二者的这种关系可以形象地通过下图展示。

【机器学习PAI实战】—— 玩转人工智能之利用GAN自动生成二次元头像

Generator和Discriminator连接起来,形成一个比较大的深层网络,即为GAN网络。

场景描述

深度学习的各种算法在PAI上可以通过PAI-DSW进行实现,在PAI-DSW上进行训练数据,利用GAN自动生成二次元头像。

数据准备

首先需要准备真实的二次元头像作为数据集,这里从网上找到一些共享的资源,存储在了钉钉钉盘中,钉盘地址 ,提取密码: c2pz,数据集如下图示,约5万多张:

【机器学习PAI实战】—— 玩转人工智能之利用GAN自动生成二次元头像

算法实践

利用PAI-DSW进行GAN算法实践,首先需要安装准备好环境。

首先进入到Notebook建模,创建新实例,之后打开实例,进入Terminal,在Terminal下用户可以像在自己本地一样安装相应的依赖包,进行操作。

准备好环境之后,我们可以通过如下图示方法,将基于Tensorflow的DCGAN代码和数据集上传上去。

【机器学习PAI实战】—— 玩转人工智能之利用GAN自动生成二次元头像

用于训练的DCGAN代码地址:。

数据集和代码上传成功,如下图示。

其中,data目录下的faces即为数据集,该文件夹下为对应的5万多张真实二次元头像。DCGAN-tensorflow为整个代码路径,其中最主要的两个代码文件是main.py和model.py,其中最主要的核心代码如下。

def main(_):

pp.pprint(flags.FLAGS.__flags)

if FLAGS.input_width is None:

FLAGS.input_width = FLAGS.input_height
if FLAGS.output_width is None:
FLAGS.output_width = FLAGS.output_height

if not os.path.exists(FLAGS.checkpoint_dir):

os.makedirs(FLAGS.checkpoint_dir)
if not os.path.exists(FLAGS.sample_dir):
os.makedirs(FLAGS.sample_dir)

#gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)

run_config = tf.ConfigProto()
run_config.gpu_options.allow_growth=True

with tf.Session(config=run_config) as sess:

if FLAGS.dataset == 'mnist':
dcgan = DCGAN(
sess,
input_width=FLAGS.input_width,
input_height=FLAGS.input_height,
output_width=FLAGS.output_width,
output_height=FLAGS.output_height,
batch_size=FLAGS.batch_size,
sample_num=FLAGS.batch_size,
y_dim=10,
z_dim=FLAGS.generate_test_images,
dataset_name=FLAGS.dataset,
input_fname_pattern=FLAGS.input_fname_pattern,
crop=FLAGS.crop,
checkpoint_dir=FLAGS.checkpoint_dir,
sample_dir=FLAGS.sample_dir,
data_dir=FLAGS.data_dir)
else:
dcgan = DCGAN(
sess,
input_width=FLAGS.input_width,
input_height=FLAGS.input_height,
output_width=FLAGS.output_width,
output_height=FLAGS.output_height,
batch_size=FLAGS.batch_size,
sample_num=FLAGS.batch_size,
z_dim=FLAGS.generate_test_images,
dataset_name=FLAGS.dataset,
input_fname_pattern=FLAGS.input_fname_pattern,
crop=FLAGS.crop,
checkpoint_dir=FLAGS.checkpoint_dir,
sample_dir=FLAGS.sample_dir,
data_dir=FLAGS.data_dir)

show_all_variables()if FLAGS.train:  dcgan.train(FLAGS)    else:      # Update D network      _, summary_str = self.sess.run([d_optim, self.d_sum],        feed_dict={ self.inputs: batch_images, self.z: batch_z })      self.writer.add_summary(summary_str, counter)      # Update G network      _, summary_str = self.sess.run([g_optim, self.g_sum],        feed_dict={ self.z: batch_z })      self.writer.add_summary(summary_str, counter)      # Run g_optim twice to make sure that d_loss does not go to zero (different from paper)      _, summary_str = self.sess.run([g_optim, self.g_sum],        feed_dict={ self.z: batch_z })      self.writer.add_summary(summary_str, counter)      errD_fake = self.d_loss_fake.eval({ self.z: batch_z })      errD_real = self.d_loss_real.eval({ self.inputs: batch_images })      errG = self.g_loss.eval({self.z: batch_z})

一切就绪之后,我们执行命令进行训练,调用命令如下:

​python main.py --input_height 96 --input_width 96 --output_height 48 --output_width 48 --dataset faces --crop --train --epoch 300 --input_fname_pattern "*.jpg"

其中,参数dateset指定数据集的目录,epoch指定循环迭代的次数,input_height、input_width用于指定输入文件的大小,输出文件的大小同样也需要参数设定,代码执行过程如下图示:​

我们来看下执行结果,分别看一下epoch为1,30,100的时候生成的二次元头像效果图。

epoch=1

epoch=30

epoch=100​

我们发现,随着不断迭代,生成的二次元头像也越来越逼真。

总结

通过上面的实践,我们领略到了GAN的魅力,GAN的变种有很多,除此之外我们还可以利用GAN做非常多的有意思的事情,比如通过文字生成图像,通过简单文字生成宣传海报等。PAI-DSW像是一个练武场,为我们准备好了深度学习所需要的环境和条件,让我们可以尽情享受大数据和深度学习的乐趣,除了GAN,像比较火热的Bert等模型,我们也都可以试一试。

转载于:https://blog.51cto.com/14031893/2369459

你可能感兴趣的文章
Vijos 1082 丛林探险
查看>>
PHP 杂谈《重构-改善既有代码的设计》之四 简化条件表达式
查看>>
Linux 内核编码风格
查看>>
[引]用c#产生1-100之间的不重复的随机数,并且可进行降序 升序排序
查看>>
SVN客户端使用教程
查看>>
Windows 8 应用开发权威指南 之 应用程序的数据存储(1)应用程序安装目录操作...
查看>>
节点指向c语言新建双循环链表/遍历
查看>>
MVC 模型绑定
查看>>
视频教程视频Java+PHP+.NET海量教程来了 500G教程
查看>>
字符搜索正则表达式语法详解
查看>>
条件数据库Android:sqllite使用
查看>>
回溯法---->哈密顿环
查看>>
Javascript 连连看
查看>>
智能手机中显示信号强度格数
查看>>
Wcf服务引用报错数据包含无法解析的引用:没有终结点在侦听可以接受消息的 这通常是由于不正确的地址或者 SOAP 操作导致的...
查看>>
“内心强大"
查看>>
jsp自定义标签技术(原理和代码实现以及平台搭建)
查看>>
分析:重定向和请求转发
查看>>
java向上转型和向下转型
查看>>
正则表达式应用实例
查看>>