Source code for datasets.mnist.test_input_fn

import os
import tensorflow as tf

from sacred import Experiment
from shutil import rmtree
from .input_fn import mnist_input_fn_ingred, mnist_input_fn


[docs]class TestMnistInputFn(tf.test.TestCase): """docstring for SerializeTest"""
[docs] @classmethod def setUpClass(self): self.ex = Experiment("test_mnist_input_fn", ingredients=[mnist_input_fn_ingred]) self.config_updates = { "model_dir": "datasets/mnist/test_model_dir", "mnist_feeder": { "record_dir": "datasets/mnist/test_records", "record_pattern": "mnist_{split}_{idx}.tfrecord", "samples_per_record": 60000, "batch_size": 512, "n_epochs": 1 } } @self.ex.command def train(model_dir): def model_fn(features, labels): X = tf.to_float(features["X"]) X = tf.layers.flatten(X) onehot_labels = tf.one_hot(labels["y"], 10) logits = tf.layers.dense(X, units=10) loss = tf.losses.softmax_cross_entropy( onehot_labels, logits) train_op = tf.train.GradientDescentOptimizer(0.001).minimize( loss, global_step=tf.train.get_global_step()) return tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=loss, train_op=train_op) estimator = tf.estimator.Estimator(model_fn, model_dir) estimator.train(mnist_input_fn()) self.run_object = self.ex.run("train", config_updates=self.config_updates)
[docs] @classmethod def tearDownClass(self): model_dir = self.config_updates["model_dir"] if os.path.exists(model_dir): rmtree(model_dir) record_dir = self.config_updates["mnist_feeder"]["record_dir"] if os.path.exists(record_dir): rmtree(record_dir)
if __name__ == '__main__': tf.test.main()