Source code for datasets.mnist.test_serialize

import os
import tensorflow as tf

from sacred import Experiment
from shutil import rmtree
from .serialize import mnist_serializer_ingred


[docs]class TestMNISTSerializer(tf.test.TestCase): """docstring for SerializeTest"""
[docs] @classmethod def setUpClass(self): self.ex = Experiment("test_mnist_serializer", ingredients=[mnist_serializer_ingred]) self.config_updates = { "mnist_serializer": { "record_dir": "datasets/mnist/test_records", "record_pattern": "mnist_{split}_{idx}.tfrecord", "samples_per_record": 10000, "split_to_size": { "train": 35000, "eval": 0, "test": 10000 }, "compression": "GZIP" } } self.run_object = self.ex.run("mnist_serializer.serialize", config_updates=self.config_updates)
[docs] def record_path_iterator(self): for split, idx in [("train", "0"), ("train", "1"), ("train", "2"), ("train", "3"), ("test", "0")]: pattern = (self.config_updates["mnist_serializer"] ["record_pattern"].format(split=split, idx=idx)) record_path = os.path.join( self.config_updates["mnist_serializer"]["record_dir"], pattern) yield record_path
[docs] @classmethod def tearDownClass(self): test_output_path = self.config_updates["mnist_serializer"]["record_dir"] if os.path.exists(test_output_path): rmtree(test_output_path)
[docs] def test_config_update(self): for key, val in self.config_updates["mnist_serializer"].items(): self.assertEqual(self.run_object.config["mnist_serializer"][key], val)
[docs] def test_dataset_dir_exists(self): self.assertTrue( os.path.exists( self.config_updates["mnist_serializer"]["record_dir"]))
[docs] def test_records_exist(self): for record_path in self.record_path_iterator(): self.assertTrue(os.path.exists(record_path))
[docs] def test_eval_record_does_not_exist(self): record_dir = self.config_updates["mnist_serializer"]["record_dir"] record_file = (self.config_updates["mnist_serializer"] ["record_pattern"].format(split="eval", idx=0)) record_path = os.path.join(record_dir, record_file) self.assertTrue(not os.path.exists(record_path))
[docs] def test_records_size(self): for record_path in self.record_path_iterator(): if "train_3" in record_path: self.assertEqual(os.path.getsize(record_path) // 10**5, 8) else: self.assertEqual(os.path.getsize(record_path) // 10**5, 17)
if __name__ == '__main__': tf.test.main()