import os
import numpy as np
import tensorflow as tf
from sacred import Experiment
from ingredient_wrapper import Ingredient
from .get_raw import mnist_raw_ingred, get_raw
from datasets.serialize import serializer_ingred, serialize_npz
from handlers import Int64Handler, BytesHandler
mnist_serializer_ingred = Ingredient("mnist_serializer",
ingredients=[mnist_raw_ingred,
serializer_ingred])
@mnist_serializer_ingred.config
def serializer_updates():
record_dir = "datasets/mnist/records"
record_pattern = "mnist_{split}_{idx}.tfrecord"
split_to_size = {"train": 60000, "eval": 0, "test": 10000}
samples_per_record = 60000
compression = "GZIP"
@mnist_serializer_ingred.config
def config(npz_file):
img_shape = [28, 28]
img_dtype = "uint8"
keys_to_descriptions = {"X": "Digit Image",
"y": "Digit Label, from 0 to 9."}
# TODO: Turn classes into serialized description and parse later
keys_to_handlers = {"X": BytesHandler(),
"y": Int64Handler()}
[docs]@mnist_serializer_ingred.command
def serialize(npz_file,
xy_to_key):
if not os.path.exists(npz_file):
get_raw()
serialize_npz(npz_file=npz_file,
xy_to_key=xy_to_key)
if __name__ == '__main__':
ex = Experiment("Serialize_MNIST", ingredients=[mnist_serializer_ingred])
@ex.main
def main():
serialize()
ex.run_commandline()