【收藏】简单易用 TensorFlow 代码集,GAN通用框架、函数

2019 年 2 月 18 日 新智元




  新智元报道  

来源:GitHub    作者:Junho Kim

【新智元导读】今天为大家推荐一个实用的GitHub项目:TensorFlow-Cookbook。 这是一个易用的TensorFlow代码集,包含了对GAN有用的一些通用架构和函数。


今天为大家推荐一个实用的GitHub项目:TensorFlow-Cookbook


这是一个易用的TensorFlow代码集,作者是来自韩国的AI研究科学家Junho Kim,内容涵盖了谱归一化卷积、部分卷积、pixel shuffle、几种归一化函数、 tf-datasetAPI,等等。


作者表示,这个repo包含了对GAN有用的一些通用架构和函数。


项目正在进行中,作者将持续为其他领域添加有用的代码,目前正在添加的是 tf-Eager mode的代码。欢迎提交pull requests和issues。


Github地址 : 

https://github.com/taki0112/Tensorflow-Cookbook


如何使用

Import

  • ops.py

    • operations

    • from ops import *

  • utils.py

    • image processing

    • from utils import *

Network template

def network(x, is_training=True, reuse=False, scope="network"):    with tf.variable_scope(scope, reuse=reuse):        x = conv(...)                ...                return logit

使用DatasetAPI向网络插入数据

Image_Data_Class = ImageData(img_size, img_ch, augment_flag) trainA = trainA.map(Image_Data_Class.image_processing, num_parallel_calls=16) trainA = trainA.shuffle(buffer_size=10000).prefetch(buffer_size=batch_size).batch(batch_size).repeat() trainA_iterator = trainA.make_one_shot_iterator() data_A = trainA_iterator.get_next() logit = network(data_A)


  • 了解更多,请阅读:

    https://github.com/taki0112/Tensorflow-DatasetAPI

Option

  • padding='SAME'

    • pad = ceil[ (kernel - stride) / 2 ]

  • pad_type

    • 'zero' or 'reflect'

  • sn

    • use spectral_normalization or not

  • Ra

    • use relativistic gan or not

  • loss_func

    • gan

    • lsgan

    • hinge

    • wgan

    • wgan-gp

    • dragan

注意

  • 如果你不想共享变量,请以不同的方式设置所有作用域名称。


权重(Weight)

weight_init = tf.truncated_normal_initializer(mean=0.0, stddev=0.02) weight_regularizer = tf.contrib.layers.l2_regularizer(0.0001) weight_regularizer_fully = tf.contrib.layers.l2_regularizer(0.0001)

初始化(Initialization)

  • Xavier : tf.contrib.layers.xavier_initializer()

  • He : tf.contrib.layers.variance_scaling_initializer()

  • Normal : tf.random_normal_initializer(mean=0.0, stddev=0.02)

  • Truncated_normal : tf.truncated_normal_initializer(mean=0.0, stddev=0.02)

  • Orthogonal : tf.orthogonal_initializer(1.0) / # if relu = sqrt(2), the others = 1.0

正则化(Regularization)

  • l2_decay : tf.contrib.layers.l2_regularizer(0.0001)

  • orthogonal_regularizer : orthogonal_regularizer(0.0001) & orthogonal_regularizer_fully(0.0001)


卷积(Convolution)


basic conv

x = conv(x, channels=64, kernel=3, stride=2, pad=1, pad_type='reflect', use_bias=True, sn=True, scope='conv')


partial conv (NVIDIA Partial Convolution)

x = partial_conv(x, channels=64, kernel=3, stride=2, use_bias=True, padding='SAME', sn=True, scope='partial_conv')



dilated conv

x = dilate_conv(x, channels=64, kernel=3, rate=2, use_bias=True, padding='SAME', sn=True, scope='dilate_conv')


Deconvolution


basic deconv


x = deconv(x, channels=64, kernel=3, stride=2, padding='SAME', use_bias=True, sn=True, scope='deconv')


Fully-connected

x = fully_conneted(x, units=64, use_bias=True, sn=True, scope='fully_connected')

Pixel shuffle

x = conv_pixel_shuffle_down(x, scale_factor=2, use_bias=True, sn=True, scope='pixel_shuffle_down') x = conv_pixel_shuffle_up(x, scale_factor=2, use_bias=True, sn=True, scope='pixel_shuffle_up')


  • down ===> [height, width] -> [height // scale_factor, width // scale_factor]

  • up ===> [height, width] -> [height * scale_factor, width * scale_factor]


Block


residual block


x = resblock(x, channels=64, is_training=is_training, use_bias=True, sn=True, scope='residual_block') x = resblock_down(x, channels=64, is_training=is_training, use_bias=True, sn=True, scope='residual_block_down') x = resblock_up(x, channels=64, is_training=is_training, use_bias=True, sn=True, scope='residual_block_up')


  • down ===> [height, width] -> [height // 2, width // 2]

  • up ===> [height, width] -> [height * 2, width * 2]



attention block


x = self_attention(x, channels=64, use_bias=True, sn=True, scope='self_attention') x = self_attention_with_pooling(x, channels=64, use_bias=True, sn=True, scope='self_attention_version_2') x = squeeze_excitation(x, channels=64, ratio=16, use_bias=True, sn=True, scope='squeeze_excitation') x = convolution_block_attention(x, channels=64, ratio=16, use_bias=True, sn=True, scope='convolution_block_attention')




Normalization


Normalization


x = batch_norm(x, is_training=is_training, scope='batch_norm') x = instance_norm(x, scope='instance_norm') x = layer_norm(x, scope='layer_norm') x = group_norm(x, groups=32, scope='group_norm') x = pixel_norm(x) x = batch_instance_norm(x, scope='batch_instance_norm') x = condition_batch_norm(x, z, is_training=is_training, scope='condition_batch_norm'): x = adaptive_instance_norm(x, gamma, beta):


  • 如何使用 condition_batch_norm,请参考:

https://github.com/taki0112/BigGAN-Tensorflow

  • 如何使用 adaptive_instance_norm,请参考:

  https://github.com/taki0112/MUNIT-Tensorflow



Activation

x = relu(x) x = lrelu(x, alpha=0.2) x = tanh(x) x = sigmoid(x) x = swish(x)

Pooling & Resize

x = up_sample(x, scale_factor=2) x = max_pooling(x, pool_size=2) x = avg_pooling(x, pool_size=2) x = global_max_pooling(x) x = global_avg_pooling(x) x = flatten(x) x = hw_flatten(x)


Loss


classification loss


loss, accuracy = classification_loss(logit, label)

pixel loss

loss = L1_loss(x, y) loss = L2_loss(x, y) loss = huber_loss(x, y) loss = histogram_loss(x, y)

  • histogram_loss 表示图像像素值在颜色分布上的差异。

gan loss

d_loss = discriminator_loss(Ra=True, loss_func='wgan-gp', real=real_logit, fake=fake_logit) g_loss = generator_loss(Ra=True, loss_func='wgan_gp', real=real_logit, fake=fake_logit)


  • 如何使用 gradient_penalty,请参考:

https://github.com/taki0112/BigGAN-Tensorflow/blob/master/BigGAN_512.py#L180

kl-divergence (z ~ N(0, 1))

loss = kl_loss(mean, logvar)

Author

Junho Kim


Github地址 : 

https://github.com/taki0112/Tensorflow-Cookbook



【加入社群】


新智元AI技术+产业社群招募中,欢迎对AI技术+产业落地感兴趣的同学,加小助手微信号:aiera2015_2   入群;通过审核后我们将邀请进群,加入社群后务必修改群备注(姓名 - 公司 - 职位;专业群审核较严,敬请谅解)。


登录查看更多
4

相关内容

《深度学习》圣经花书的数学推导、原理与Python代码实现
《强化学习—使用 Open AI、TensorFlow和Keras实现》174页pdf
专知会员服务
136+阅读 · 2020年3月1日
一网打尽!100+深度学习模型TensorFlow与Pytorch代码实现集合
【书籍】深度学习框架:PyTorch入门与实践(附代码)
专知会员服务
163+阅读 · 2019年10月28日
TensorFlow 2.0 学习资源汇总
专知会员服务
66+阅读 · 2019年10月9日
机器学习相关资源(框架、库、软件)大列表
专知会员服务
39+阅读 · 2019年10月9日
Github 项目推荐 | PyTorch 实现的 GAN 文本生成框架
AI研习社
35+阅读 · 2019年6月10日
Github项目推荐 | Pytorch TVM 扩展
AI研习社
11+阅读 · 2019年5月5日
Github项目推荐 | GAN评估指标的Tensorflow简单实现
AI研习社
16+阅读 · 2019年4月19日
实战 | 源码入门之Faster RCNN
计算机视觉life
19+阅读 · 2019年4月16日
超强干货!TensorFlow易用代码大集合...
机器学习算法与Python学习
6+阅读 · 2019年2月20日
请快点粘贴复制,这是一份好用的TensorFlow代码集
TF Boys必看!一文搞懂TensorFlow 2.0新架构!
引力空间站
18+阅读 · 2019年1月16日
手把手教你由TensorFlow上手PyTorch(附代码)
数据派THU
5+阅读 · 2017年10月1日
Seeing What a GAN Cannot Generate
Arxiv
8+阅读 · 2019年10月24日
Arxiv
8+阅读 · 2018年5月21日
Arxiv
5+阅读 · 2018年5月1日
Arxiv
4+阅读 · 2018年3月30日
VIP会员
相关VIP内容
《深度学习》圣经花书的数学推导、原理与Python代码实现
《强化学习—使用 Open AI、TensorFlow和Keras实现》174页pdf
专知会员服务
136+阅读 · 2020年3月1日
一网打尽!100+深度学习模型TensorFlow与Pytorch代码实现集合
【书籍】深度学习框架:PyTorch入门与实践(附代码)
专知会员服务
163+阅读 · 2019年10月28日
TensorFlow 2.0 学习资源汇总
专知会员服务
66+阅读 · 2019年10月9日
机器学习相关资源(框架、库、软件)大列表
专知会员服务
39+阅读 · 2019年10月9日
相关资讯
Github 项目推荐 | PyTorch 实现的 GAN 文本生成框架
AI研习社
35+阅读 · 2019年6月10日
Github项目推荐 | Pytorch TVM 扩展
AI研习社
11+阅读 · 2019年5月5日
Github项目推荐 | GAN评估指标的Tensorflow简单实现
AI研习社
16+阅读 · 2019年4月19日
实战 | 源码入门之Faster RCNN
计算机视觉life
19+阅读 · 2019年4月16日
超强干货!TensorFlow易用代码大集合...
机器学习算法与Python学习
6+阅读 · 2019年2月20日
请快点粘贴复制,这是一份好用的TensorFlow代码集
TF Boys必看!一文搞懂TensorFlow 2.0新架构!
引力空间站
18+阅读 · 2019年1月16日
手把手教你由TensorFlow上手PyTorch(附代码)
数据派THU
5+阅读 · 2017年10月1日
Top
微信扫码咨询专知VIP会员