Source code for input_fn

import os
import glob
import tensorflow as tf

from ingredient_wrapper import Ingredient
from parsers import parse_serial_batch
from datasets.serialize import serializer_ingred


input_fn_ingred = Ingredient("input_fn")

@input_fn_ingred.config
def config():
    split = "train" # one of "train", "eval", test"
    batch_size = 1
    n_epochs = 100
    random_seed = 1

    buffer_n_batches = 1 # buffer size for shuffle_repeat, prefetch, and RecordReader
    buffer_size = buffer_n_batches * batch_size

    num_parallel = 1 # parallel threads for map, and RecordReader

    feature_keys = [] # List of key strings to include in feature dict
    label_keys = [] # List of key strings to include in label dict
    keys_to_parsers = {} # Dict mapping feature/label keys to parsers


[docs]@input_fn_ingred.capture def map_fn(feature_keys, label_keys, keys_to_parsers, keys_to_handlers): """Wrapper for batch parsing Args: feature_keys (list of str): Keys which should be included in feature_dict label_keys (list of str): Keys which should be included as label_dict keys_to_parsers (dict): Mapping of keys to parser instances keys_to_handlers (dict): Mapping of keys to handler instances Returns: fn (function): Function which maps a serial batch to (features, labels) """ def fn(serial_batch): """Parser for serial batches This function is used as argument for tf.data.Dataset.map() Args: serial_batch: Serialized batch as contained in tf.data.Dataset after batching. Returns: tuple: of feature_dict and label_dict """ features = parse_serial_batch(serial_batch, keys_to_parsers, keys_to_handlers) feature_dict = {key: features[key] for key in feature_keys} labels = {key: features[key] for key in label_keys} return (feature_dict, labels) return fn
[docs]@input_fn_ingred.capture def shuffle_repeat_prefetch(dataset, buffer_size, n_epochs, random_seed): """Wrapper for shuffle+repeat+prefetch Shuffle, repeat, and prefetch a dataset Args: dataset (tf.data.Dataset): Dataset to be processed buffer_size (int): Number of objects to be buffered, and prefetched n_epochs (int): Number of dataset repetitions random_seed: control randomness i.e. reproducibility Returns: dataset (tf.data.Dataset): Processed dataset """ dataset = dataset.apply( tf.contrib.data.shuffle_and_repeat( buffer_size=buffer_size, count=n_epochs, seed=random_seed)) dataset = dataset.prefetch(buffer_size) return dataset
[docs]@input_fn_ingred.capture def any_record_exists(split, record_dir, record_pattern): """Indicator if any record of given pattern exists Args: split (str): Split descriptor ("train", "eval", "test") record_dir (str): Directory where records are contained record_pattern (str): Pattern for record files Returns: bool: True, if at least one record with record_pattern exists in record_dir. """ record_pattern = os.path.join(record_dir, record_pattern.format(split=split, idx="*")) return len(glob.glob(record_pattern)) > 0
[docs]@input_fn_ingred.capture def dataset_from_records(split, record_dir, record_pattern, random_seed): """Load all records as a file list Args: split (str): Split descriptor ("train", "eval", "test") record_dir (str): Directory where records are contained record_pattern (str): Pattern for record files random_seed: control randomness i.e. reproducibility Returns: dataset (tf.data.Dataset): List of record_file strings """ record_pattern = os.path.join(record_dir, record_pattern.format(split=split, idx="*")) dataset = tf.data.Dataset.list_files(record_pattern, seed=random_seed) return dataset
[docs]@serializer_ingred.capture @input_fn_ingred.capture def input_fn(split, batch_size, buffer_size, num_parallel, compression, random_seed): """Generic input function for use with tf.estimator.Estimator Returns a input_fn as consumed by e.g. Estimator.train(input_fn()) Args: split (str): Split descriptor ("train", "eval", "test") batch_size (int): Batch size buffer_size (int): Number of objects that are buffered and prefetched num_parallel (int): Number of parallel threads for map, and RecordReader compression (str): Which compression was used during serialization ("NONE", "GZIP", "ZLIB") random_seed (int): control randomness i.e. reproducibility Ingredient functions: map_fn(): from input_fn_ingred, defines parsing of features. Returns: input_fn (function): Input function as consumed by Estimator.train/evaluate/predict """ def fn(): """Input function for Estimator Performs generic processing (reading, batching, mapping) Returns: dataset (tf.data.Dataset) """ dataset = dataset_from_records(split, random_seed=random_seed) dataset = tf.data.TFRecordDataset(dataset, buffer_size=buffer_size, num_parallel_reads=num_parallel, compression_type="" if compression == "NONE" else compression) dataset = dataset.prefetch(buffer_size) dataset = dataset.batch(batch_size, drop_remainder=True) dataset = dataset.map(map_fn(), num_parallel) dataset = shuffle_repeat_prefetch(dataset, random_seed) return dataset return fn