Source code for datasets.brain.input_fn

import os

import tensorflow as tf

from ingredient_wrapper import Ingredient
from .serialize import (brain2D_serializer_ingred, brain3D_serializer_ingred,
    serialize)
from parsers import ReshapeParser, IdentityParser, DecodeRawParser
from input_fn import input_fn_ingred, input_fn, any_record_exists


brain2D_input_fn_ingred = Ingredient("brain2D_feeder",
                                   ingredients=[input_fn_ingred,
                                                brain2D_serializer_ingred])

brain3D_input_fn_ingred = Ingredient("brain3D_feeder",
                                   ingredients=[input_fn_ingred,
                                                brain3D_serializer_ingred])

@brain2D_input_fn_ingred.config
def config2D(img_shape, img_dtype):

    keys_to_parsers = {
        "image": ReshapeParser([-1] + img_shape + [1],
            parent=DecodeRawParser(img_dtype)),
        "image/shape": IdentityParser(),
        "age": IdentityParser()
    }

    feature_keys = ["image"]
    label_keys = ["age"]


@brain3D_input_fn_ingred.config
def config3D(img_shape, img_dtype):

    keys_to_parsers = {
        "image": ReshapeParser([-1] + img_shape + [1],
            parent=DecodeRawParser(img_dtype)),
        "image/shape": IdentityParser(),
        "age": IdentityParser()
    }

    feature_keys = ["image"]
    label_keys = ["age"]


[docs]@brain2D_input_fn_ingred.capture @brain3D_input_fn_ingred.capture def brain_input_fn(): if not any_record_exists(): serialize() return input_fn()