import tensorflow as tf
layers = tf.contrib.layers
tfgan = tf.contrib.gan
from utils import img_grid_summary
import numpy as np
import logging
[docs]def generator_fn(gan_type):
"""Return generator model_fn for gan_type
Args:
gan_type (str): Either "COND" or "UNCOND"
Returns:
generator_fn: model_fn for GAN generator
"""
def generator(feature_dict, is_training=True):
"""Brain generator function for unconditional or conditional GAN
Generates 2D images with shape [128,128]!
Args:
feature_dict (dict): Input features, contains only "noise"
for UNCOND, and "noise" + "labels" for COND.
is_training (bool): If True, batch norm uses batch statistics.
If False, batch norm uses the exponential moving average
collected from population statistics.
Returns:
Tensor: A tensor with shape [batch_size, 128, 128, 1]
"""
with tf.contrib.framework.arg_scope(
[layers.fully_connected, layers.conv2d_transpose],
activation_fn=tf.nn.relu,
normalizer_fn=layers.batch_norm,
), \
tf.contrib.framework.arg_scope(
[layers.conv2d_transpose],
kernel_size=[4,4],
), \
tf.contrib.framework.arg_scope(
[layers.batch_norm],
is_training=is_training,
):
layer = layers.fully_connected(feature_dict["noise"], 1024)
if gan_type == "COND":
"""TODO: Implement a conditional brain gan"""
pass
layer = layers.fully_connected(layer, 4 * 4 * 512)
layer = tf.reshape(layer, [-1, 4, 4, 512])
layer = layers.conv2d_transpose(layer, 512, stride=2) # 8x8
layer = layers.conv2d_transpose(layer, 512, stride=1) # 8x8
layer = layers.conv2d_transpose(layer, 256, stride=2) # 16x16
layer = layers.conv2d_transpose(layer, 256, stride=1) # 16x16
layer = layers.conv2d_transpose(layer, 128, stride=2) # 32x32
layer = layers.conv2d_transpose(layer, 128, stride=1) # 32x32
layer = layers.conv2d_transpose(layer, 64, stride=2) # 64x64
layer = layers.conv2d_transpose(layer, 64, stride=1) # 64x64
layer = layers.conv2d_transpose(layer, 32, stride=2) # 128x128
layer = layers.conv2d(layer, 1, [4, 4], stride=1,
activation_fn=tf.tanh, normalizer_fn=None) # 128x128
img_grid_summary("fake", layer)
return layer
return generator
[docs]def critic_fn(gan_type):
"""Return critic model_fn for gan_type
Args:
gan_type (str): Either "COND" or "UNCOND"
Returns:
critic_fn: model_fn for GAN critic
"""
def critic(images, feature_dict):
"""Brain critic function for unconditional or conditional GAN
Expects 2D images with shape [128, 128]!
Args:
images (tensor): Real or generated Brain images.
feature_dict (dict): Input features, contains only "noise"
for UNCOND, and "noise" + "labels" for COND.
Returns:
float: Critic ouput used to estimate Wasserstein distance.
"""
with tf.contrib.framework.arg_scope(
[layers.conv2d],
stride=1,
kernel_size=[3, 3]
), \
tf.contrib.framework.arg_scope(
[layers.max_pool2d],
stride=2,
kernel_size=[2, 2]
):
layer = layers.conv2d(images, 16)
if gan_type == "COND":
"""TODO: Implement a conditional brain gan"""
pass
layer = layers.max_pool2d(layer)
layer = layers.conv2d(layer, 32)
layer = layers.max_pool2d(layer)
layer = layers.conv2d(layer, 64)
layer = layers.conv2d(layer, 64)
layer = layers.max_pool2d(layer)
layer = layers.conv2d(layer, 128)
layer = layers.conv2d(layer, 128)
layer = layers.max_pool2d(layer)
layer = layers.conv2d(layer, 256)
layer = layers.conv2d(layer, 256)
layer = layers.conv2d(layer, 256)
layer = layers.conv2d(layer, 1, kernel_size=[1, 1],
activation_fn=tf.identity)
return tf.reduce_mean(layer)
return critic