Source code for datasets.mnist.input_fn

import os

import tensorflow as tf

from ingredient_wrapper import Ingredient
from .serialize import mnist_serializer_ingred, serialize
from parsers import ReshapeParser, IdentityParser, DecodeRawParser
from input_fn import input_fn_ingred, input_fn, any_record_exists


mnist_input_fn_ingred = Ingredient("mnist_feeder",
	                               ingredients=[input_fn_ingred,
	                                            mnist_serializer_ingred])

@mnist_input_fn_ingred.config
def config(img_shape, img_dtype):

    keys_to_parsers = {
        "X": ReshapeParser([-1]+img_shape+[1], DecodeRawParser(img_dtype)),
        "y": IdentityParser()
    }

    feature_keys = ["X"]
    label_keys = ["y"]


[docs]@mnist_input_fn_ingred.capture def mnist_input_fn(): if not any_record_exists(): serialize() return input_fn()