Source code for models.brain

import tensorflow as tf

layers = tf.contrib.layers
tfgan = tf.contrib.gan

from utils import img_grid_summary
import numpy as np

import logging

[docs]def generator_fn(gan_type): """Return generator model_fn for gan_type Args: gan_type (str): Either "COND" or "UNCOND" Returns: generator_fn: model_fn for GAN generator """ def generator(feature_dict, is_training=True): """Brain generator function for unconditional or conditional GAN Generates 2D images with shape [128,128]! Args: feature_dict (dict): Input features, contains only "noise" for UNCOND, and "noise" + "labels" for COND. is_training (bool): If True, batch norm uses batch statistics. If False, batch norm uses the exponential moving average collected from population statistics. Returns: Tensor: A tensor with shape [batch_size, 128, 128, 1] """ with tf.contrib.framework.arg_scope( [layers.fully_connected, layers.conv2d_transpose], activation_fn=tf.nn.relu, normalizer_fn=layers.batch_norm, ), \ tf.contrib.framework.arg_scope( [layers.conv2d_transpose], kernel_size=[4,4], ), \ tf.contrib.framework.arg_scope( [layers.batch_norm], is_training=is_training, ): layer = layers.fully_connected(feature_dict["noise"], 1024) if gan_type == "COND": """TODO: Implement a conditional brain gan""" pass layer = layers.fully_connected(layer, 4 * 4 * 512) layer = tf.reshape(layer, [-1, 4, 4, 512]) layer = layers.conv2d_transpose(layer, 512, stride=2) # 8x8 layer = layers.conv2d_transpose(layer, 512, stride=1) # 8x8 layer = layers.conv2d_transpose(layer, 256, stride=2) # 16x16 layer = layers.conv2d_transpose(layer, 256, stride=1) # 16x16 layer = layers.conv2d_transpose(layer, 128, stride=2) # 32x32 layer = layers.conv2d_transpose(layer, 128, stride=1) # 32x32 layer = layers.conv2d_transpose(layer, 64, stride=2) # 64x64 layer = layers.conv2d_transpose(layer, 64, stride=1) # 64x64 layer = layers.conv2d_transpose(layer, 32, stride=2) # 128x128 layer = layers.conv2d(layer, 1, [4, 4], stride=1, activation_fn=tf.tanh, normalizer_fn=None) # 128x128 img_grid_summary("fake", layer) return layer return generator
[docs]def critic_fn(gan_type): """Return critic model_fn for gan_type Args: gan_type (str): Either "COND" or "UNCOND" Returns: critic_fn: model_fn for GAN critic """ def critic(images, feature_dict): """Brain critic function for unconditional or conditional GAN Expects 2D images with shape [128, 128]! Args: images (tensor): Real or generated Brain images. feature_dict (dict): Input features, contains only "noise" for UNCOND, and "noise" + "labels" for COND. Returns: float: Critic ouput used to estimate Wasserstein distance. """ with tf.contrib.framework.arg_scope( [layers.conv2d], stride=1, kernel_size=[3, 3] ), \ tf.contrib.framework.arg_scope( [layers.max_pool2d], stride=2, kernel_size=[2, 2] ): layer = layers.conv2d(images, 16) if gan_type == "COND": """TODO: Implement a conditional brain gan""" pass layer = layers.max_pool2d(layer) layer = layers.conv2d(layer, 32) layer = layers.max_pool2d(layer) layer = layers.conv2d(layer, 64) layer = layers.conv2d(layer, 64) layer = layers.max_pool2d(layer) layer = layers.conv2d(layer, 128) layer = layers.conv2d(layer, 128) layer = layers.max_pool2d(layer) layer = layers.conv2d(layer, 256) layer = layers.conv2d(layer, 256) layer = layers.conv2d(layer, 256) layer = layers.conv2d(layer, 1, kernel_size=[1, 1], activation_fn=tf.identity) return tf.reduce_mean(layer) return critic