Source code for models.mnist

import tensorflow as tf

layers = tf.contrib.layers
tfgan = tf.contrib.gan

from utils import img_grid_summary
import numpy as np

import logging

[docs]def generator_fn(gan_type): """Return generator model_fn for gan_type Args: gan_type (str): Either "COND" or "UNCOND" Returns: generator_fn: model_fn for GAN generator """ def generator(feature_dict, is_training=True): """MNIST generator function for unconditional or conditional GAN Args: feature_dict (dict): Input features, contains only "noise" for UNCOND, and "noise" + "labels" for COND. is_training (bool): If True, batch norm uses batch statistics. If False, batch norm uses the exponential moving average collected from population statistics. Returns: Tensor: A tensor with shape [batch_size, 28, 28, 1] """ with tf.contrib.framework.arg_scope( [layers.fully_connected, layers.conv2d_transpose], activation_fn=tf.nn.relu, normalizer_fn=layers.batch_norm),\ tf.contrib.framework.arg_scope( [layers.batch_norm], is_training=is_training ): layer = layers.fully_connected(feature_dict["noise"], 1024) if gan_type == "COND": layer = tfgan.features.condition_tensor_from_onehot( layer, tf.one_hot(feature_dict["labels"], 10) ) layer = layers.fully_connected(layer, 7*7*256) layer = tf.reshape(layer, [-1, 7, 7, 256]) layer = layers.conv2d_transpose(layer, 64, [4,4], stride=2) # 14x14 layer = layers.conv2d_transpose(layer, 32, [4,4], stride=2) # 28x28 layer = layers.conv2d(layer, 1, [4,4], stride=1, activation_fn=tf.tanh, normalizer_fn=None) img_grid_summary("fake", layer) return layer return generator
[docs]def critic_fn(gan_type): """Return critic model_fn for gan_type Args: gan_type (str): Either "COND" or "UNCOND" Returns: critic_fn: model_fn for GAN critic """ def critic(images, feature_dict): """MNIST critic function for unconditional or conditional GAN Args: images (tensor): Real or generated MNIST images. feature_dict (dict): Input features, contains only "noise" for UNCOND, and "noise" + "labels" for COND. Returns: float: Critic ouput used to estimate Wasserstein distance. """ with tf.contrib.framework.arg_scope( [layers.fully_connected, layers.conv2d], activation_fn=tf.nn.relu ): layer = layers.conv2d(images, 64, [4,4], stride=2) # 14x14 if gan_type == "COND": layer = tfgan.features.condition_tensor_from_onehot(layer, tf.one_hot(feature_dict["labels"], 10) ) layer = layers.conv2d(layer, 128, [4,4], stride=2) # 7x7 layer = layers.conv2d(layer, 1, [4,4], stride=2) # 7x7 return tf.reduce_mean(layer) return critic