Source code for datasets.mnist.test_input_fn
import os
import tensorflow as tf
from sacred import Experiment
from shutil import rmtree
from .input_fn import mnist_input_fn_ingred, mnist_input_fn
[docs]class TestMnistInputFn(tf.test.TestCase):
"""docstring for SerializeTest"""
[docs] @classmethod
def setUpClass(self):
self.ex = Experiment("test_mnist_input_fn", ingredients=[mnist_input_fn_ingred])
self.config_updates = {
"model_dir": "datasets/mnist/test_model_dir",
"mnist_feeder": {
"record_dir": "datasets/mnist/test_records",
"record_pattern": "mnist_{split}_{idx}.tfrecord",
"samples_per_record": 60000,
"batch_size": 512,
"n_epochs": 1
}
}
@self.ex.command
def train(model_dir):
def model_fn(features, labels):
X = tf.to_float(features["X"])
X = tf.layers.flatten(X)
onehot_labels = tf.one_hot(labels["y"], 10)
logits = tf.layers.dense(X, units=10)
loss = tf.losses.softmax_cross_entropy(
onehot_labels,
logits)
train_op = tf.train.GradientDescentOptimizer(0.001).minimize(
loss,
global_step=tf.train.get_global_step())
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.TRAIN,
loss=loss,
train_op=train_op)
estimator = tf.estimator.Estimator(model_fn, model_dir)
estimator.train(mnist_input_fn())
self.run_object = self.ex.run("train",
config_updates=self.config_updates)
[docs] @classmethod
def tearDownClass(self):
model_dir = self.config_updates["model_dir"]
if os.path.exists(model_dir):
rmtree(model_dir)
record_dir = self.config_updates["mnist_feeder"]["record_dir"]
if os.path.exists(record_dir):
rmtree(record_dir)
if __name__ == '__main__':
tf.test.main()