Source code for datasets.mnist.serialize

import os
import numpy as np
import tensorflow as tf

from sacred import Experiment
from ingredient_wrapper import Ingredient
from .get_raw import mnist_raw_ingred, get_raw
from datasets.serialize import serializer_ingred, serialize_npz
from handlers import  Int64Handler, BytesHandler


mnist_serializer_ingred = Ingredient("mnist_serializer",
                                     ingredients=[mnist_raw_ingred,
                                                  serializer_ingred])


@mnist_serializer_ingred.config
def serializer_updates():
    record_dir = "datasets/mnist/records"
    record_pattern = "mnist_{split}_{idx}.tfrecord"

    split_to_size = {"train": 60000, "eval": 0, "test": 10000}
    samples_per_record = 60000
    compression = "GZIP"

@mnist_serializer_ingred.config
def config(npz_file):

    img_shape = [28, 28]
    img_dtype = "uint8"

    keys_to_descriptions = {"X": "Digit Image",
                            "y": "Digit Label, from 0 to 9."}

    # TODO: Turn classes into serialized description and parse later
    keys_to_handlers = {"X": BytesHandler(),
                        "y": Int64Handler()}


[docs]@mnist_serializer_ingred.command def serialize(npz_file, xy_to_key): if not os.path.exists(npz_file): get_raw() serialize_npz(npz_file=npz_file, xy_to_key=xy_to_key)
if __name__ == '__main__': ex = Experiment("Serialize_MNIST", ingredients=[mnist_serializer_ingred]) @ex.main def main(): serialize() ex.run_commandline()