import tensorflow as tf
layers = tf.contrib.layers
tfgan = tf.contrib.gan
from utils import img_grid_summary
import numpy as np
import logging
[docs]def generator_fn(gan_type):
"""Return generator model_fn for gan_type
Args:
gan_type (str): Either "COND" or "UNCOND"
Returns:
generator_fn: model_fn for GAN generator
"""
def generator(feature_dict, is_training=True):
"""MNIST generator function for unconditional or conditional GAN
Args:
feature_dict (dict): Input features, contains only "noise"
for UNCOND, and "noise" + "labels" for COND.
is_training (bool): If True, batch norm uses batch statistics.
If False, batch norm uses the exponential moving average
collected from population statistics.
Returns:
Tensor: A tensor with shape [batch_size, 28, 28, 1]
"""
with tf.contrib.framework.arg_scope(
[layers.fully_connected, layers.conv2d_transpose],
activation_fn=tf.nn.relu,
normalizer_fn=layers.batch_norm),\
tf.contrib.framework.arg_scope(
[layers.batch_norm],
is_training=is_training
):
layer = layers.fully_connected(feature_dict["noise"], 1024)
if gan_type == "COND":
layer = tfgan.features.condition_tensor_from_onehot(
layer,
tf.one_hot(feature_dict["labels"], 10)
)
layer = layers.fully_connected(layer, 7*7*256)
layer = tf.reshape(layer, [-1, 7, 7, 256])
layer = layers.conv2d_transpose(layer, 64, [4,4], stride=2) # 14x14
layer = layers.conv2d_transpose(layer, 32, [4,4], stride=2) # 28x28
layer = layers.conv2d(layer, 1, [4,4], stride=1,
activation_fn=tf.tanh,
normalizer_fn=None)
img_grid_summary("fake", layer)
return layer
return generator
[docs]def critic_fn(gan_type):
"""Return critic model_fn for gan_type
Args:
gan_type (str): Either "COND" or "UNCOND"
Returns:
critic_fn: model_fn for GAN critic
"""
def critic(images, feature_dict):
"""MNIST critic function for unconditional or conditional GAN
Args:
images (tensor): Real or generated MNIST images.
feature_dict (dict): Input features, contains only "noise"
for UNCOND, and "noise" + "labels" for COND.
Returns:
float: Critic ouput used to estimate Wasserstein distance.
"""
with tf.contrib.framework.arg_scope(
[layers.fully_connected, layers.conv2d],
activation_fn=tf.nn.relu
):
layer = layers.conv2d(images, 64, [4,4], stride=2) # 14x14
if gan_type == "COND":
layer = tfgan.features.condition_tensor_from_onehot(layer,
tf.one_hot(feature_dict["labels"], 10)
)
layer = layers.conv2d(layer, 128, [4,4], stride=2) # 7x7
layer = layers.conv2d(layer, 1, [4,4], stride=2) # 7x7
return tf.reduce_mean(layer)
return critic