import os
import tensorflow as tf
from sacred import Experiment
from shutil import rmtree
from .serialize import brain3D_serializer_ingred
[docs]class TestBrainSerializer(tf.test.TestCase):
[docs] @classmethod
def setUpClass(self):
self.ex = Experiment("test_brain3D_serializer",
ingredients=[brain3D_serializer_ingred])
self.config_updates = {
"brain3D_serializer": {
"raw_dir": "datasets/brain/test_data",
"csv_file": "datasets/brain/test_data/meta_data.csv",
"record_dir": "datasets/brain/test_records",
"record_pattern": "brain_{split}_{idx}.tfrecord",
"samples_per_record": 2,
"split_to_size": {
"train": 0.6,
"eval": 0.2,
"test": 0.2
},
"compression": "GZIP"
}
}
self.run_object = self.ex.run("brain3D_serializer.serialize",
config_updates=self.config_updates)
[docs] def record_path_iterator(self):
for split, idx in [("train", "0"), ("train", "1"), ("train", "2"),
("eval", "0"), ("test", "0")]:
pattern = (self.config_updates["brain3D_serializer"]
["record_pattern"].format(split=split, idx=idx))
record_path = os.path.join(
self.config_updates["brain3D_serializer"]["record_dir"],
pattern)
yield record_path
[docs] @classmethod
def tearDownClass(self):
test_output_path = (self.config_updates["brain3D_serializer"]
["record_dir"])
if os.path.exists(test_output_path):
rmtree(test_output_path)
[docs] def test_config_update(self):
for key, val in self.config_updates["brain3D_serializer"].items():
self.assertEqual(
self.run_object.config["brain3D_serializer"][key], val)
[docs] def test_dataset_dir_exists(self):
self.assertTrue(
os.path.exists(
self.config_updates["brain3D_serializer"]["record_dir"]))
[docs] def test_records_exist(self):
for record_path in self.record_path_iterator():
self.assertTrue(os.path.exists(record_path))
[docs] def test_records_size(self):
for record_path in self.record_path_iterator():
self.assertEqual(os.path.getsize(record_path) // 10**6, 3)
if __name__ == '__main__':
tf.test.main()