import os
import tensorflow as tf
from sacred import Experiment
from shutil import rmtree
from .serialize import mnist_serializer_ingred
[docs]class TestMNISTSerializer(tf.test.TestCase):
"""docstring for SerializeTest"""
[docs] @classmethod
def setUpClass(self):
self.ex = Experiment("test_mnist_serializer",
ingredients=[mnist_serializer_ingred])
self.config_updates = {
"mnist_serializer": {
"record_dir": "datasets/mnist/test_records",
"record_pattern": "mnist_{split}_{idx}.tfrecord",
"samples_per_record": 10000,
"split_to_size": {
"train": 35000,
"eval": 0,
"test": 10000
},
"compression": "GZIP"
}
}
self.run_object = self.ex.run("mnist_serializer.serialize",
config_updates=self.config_updates)
[docs] def record_path_iterator(self):
for split, idx in [("train", "0"), ("train", "1"), ("train", "2"),
("train", "3"), ("test", "0")]:
pattern = (self.config_updates["mnist_serializer"]
["record_pattern"].format(split=split, idx=idx))
record_path = os.path.join(
self.config_updates["mnist_serializer"]["record_dir"],
pattern)
yield record_path
[docs] @classmethod
def tearDownClass(self):
test_output_path = self.config_updates["mnist_serializer"]["record_dir"]
if os.path.exists(test_output_path):
rmtree(test_output_path)
[docs] def test_config_update(self):
for key, val in self.config_updates["mnist_serializer"].items():
self.assertEqual(self.run_object.config["mnist_serializer"][key],
val)
[docs] def test_dataset_dir_exists(self):
self.assertTrue(
os.path.exists(
self.config_updates["mnist_serializer"]["record_dir"]))
[docs] def test_records_exist(self):
for record_path in self.record_path_iterator():
self.assertTrue(os.path.exists(record_path))
[docs] def test_eval_record_does_not_exist(self):
record_dir = self.config_updates["mnist_serializer"]["record_dir"]
record_file = (self.config_updates["mnist_serializer"]
["record_pattern"].format(split="eval", idx=0))
record_path = os.path.join(record_dir, record_file)
self.assertTrue(not os.path.exists(record_path))
[docs] def test_records_size(self):
for record_path in self.record_path_iterator():
if "train_3" in record_path:
self.assertEqual(os.path.getsize(record_path) // 10**5, 8)
else:
self.assertEqual(os.path.getsize(record_path) // 10**5, 17)
if __name__ == '__main__':
tf.test.main()