Source code for exps.brain.gan

import datetime
import tensorflow as tf

from sacred import Experiment
from sacred.stflow import LogFileWriter

from datasets.brain.input_fn import brain_input_fn, brain2D_input_fn_ingred
from models.brain import generator_fn, critic_fn
from utils import scale

tfgan = tf.contrib.gan


ex = Experiment("Brain GAN", ingredients=[brain2D_input_fn_ingred])


@ex.config
def config():
    noise_dims = 64 # Length of latent noise vector
    model_dir = datetime.datetime.now().strftime("%Y%m%d%H%M%S")

    save_summary_steps=100
    save_checkpoints_steps=1000
    log_step_count_steps=100
    max_train_steps=2000

    gen_lr = 10**-3 # Generator learning rate
    crit_lr = 10**-4 # Critic learning rate
    gan_type = "UNCOND" # GAN type, either COND or UNCOND


[docs]@ex.capture def extended_input_fn(noise_dims, brain2D_feeder, gan_type): """Make brain input fit for input to unconditional and conditional GAN The TFGAN module feeds the features to generator and discriminator, but the labels are passed only to the discriminator. The labels are supposed to be the real input. For the unconditional GAN, the features merely contain the latent noise, and they are not used by the discriminator. In the case of the conditional GAN, the features include the class label in addition to the latent noise. """ def add_noise_and_swap(feature_dict, label_dict): X = scale(feature_dict["image"]) feature_dict["noise"] = tf.random_normal( [brain2D_feeder["batch_size"], noise_dims]) feature_dict["age"] = label_dict["age"] return feature_dict, X def fn(): dataset = brain_input_fn()() dataset = dataset.map(add_noise_and_swap, num_parallel_calls=brain2D_feeder["num_parallel"]) return dataset return fn
[docs]@ex.automain @LogFileWriter(ex) def main(model_dir, save_summary_steps, save_checkpoints_steps, log_step_count_steps, gen_lr, crit_lr, max_train_steps, gan_type): """Run brain GAN Training Example: $ python -m exps.brain.gan with brain_feeder.batch_size=16 """ config = tf.estimator.RunConfig( model_dir=model_dir, save_summary_steps=save_summary_steps, save_checkpoints_steps=save_checkpoints_steps, log_step_count_steps=log_step_count_steps ) gan_estimator = tfgan.estimator.GANEstimator( model_dir, generator_fn=generator_fn(gan_type), discriminator_fn=critic_fn(gan_type), generator_loss_fn=tfgan.losses.wasserstein_generator_loss, discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss, generator_optimizer=tf.train.RMSPropOptimizer(gen_lr), discriminator_optimizer=tf.train.RMSPropOptimizer(crit_lr), config=config ) gan_estimator.train(extended_input_fn(), max_steps=max_train_steps)