import datetime
import tensorflow as tf
from sacred import Experiment
from sacred.stflow import LogFileWriter
from datasets.brain.input_fn import brain_input_fn, brain2D_input_fn_ingred
from models.brain import generator_fn, critic_fn
from utils import scale
tfgan = tf.contrib.gan
ex = Experiment("Brain GAN", ingredients=[brain2D_input_fn_ingred])
@ex.config
def config():
noise_dims = 64 # Length of latent noise vector
model_dir = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
save_summary_steps=100
save_checkpoints_steps=1000
log_step_count_steps=100
max_train_steps=2000
gen_lr = 10**-3 # Generator learning rate
crit_lr = 10**-4 # Critic learning rate
gan_type = "UNCOND" # GAN type, either COND or UNCOND
[docs]@ex.automain
@LogFileWriter(ex)
def main(model_dir, save_summary_steps, save_checkpoints_steps,
log_step_count_steps, gen_lr, crit_lr, max_train_steps, gan_type):
"""Run brain GAN Training
Example:
$ python -m exps.brain.gan with brain_feeder.batch_size=16
"""
config = tf.estimator.RunConfig(
model_dir=model_dir,
save_summary_steps=save_summary_steps,
save_checkpoints_steps=save_checkpoints_steps,
log_step_count_steps=log_step_count_steps
)
gan_estimator = tfgan.estimator.GANEstimator(
model_dir,
generator_fn=generator_fn(gan_type),
discriminator_fn=critic_fn(gan_type),
generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
generator_optimizer=tf.train.RMSPropOptimizer(gen_lr),
discriminator_optimizer=tf.train.RMSPropOptimizer(crit_lr),
config=config
)
gan_estimator.train(extended_input_fn(), max_steps=max_train_steps)