import os
import glob
import tensorflow as tf
from ingredient_wrapper import Ingredient
from parsers import parse_serial_batch
from datasets.serialize import serializer_ingred
input_fn_ingred = Ingredient("input_fn")
@input_fn_ingred.config
def config():
split = "train" # one of "train", "eval", test"
batch_size = 1
n_epochs = 100
random_seed = 1
buffer_n_batches = 1 # buffer size for shuffle_repeat, prefetch, and RecordReader
buffer_size = buffer_n_batches * batch_size
num_parallel = 1 # parallel threads for map, and RecordReader
feature_keys = [] # List of key strings to include in feature dict
label_keys = [] # List of key strings to include in label dict
keys_to_parsers = {} # Dict mapping feature/label keys to parsers
[docs]@input_fn_ingred.capture
def map_fn(feature_keys, label_keys, keys_to_parsers, keys_to_handlers):
"""Wrapper for batch parsing
Args:
feature_keys (list of str): Keys which should be included in feature_dict
label_keys (list of str): Keys which should be included as label_dict
keys_to_parsers (dict): Mapping of keys to parser instances
keys_to_handlers (dict): Mapping of keys to handler instances
Returns:
fn (function): Function which maps a serial batch to (features, labels)
"""
def fn(serial_batch):
"""Parser for serial batches
This function is used as argument for tf.data.Dataset.map()
Args:
serial_batch: Serialized batch as contained in tf.data.Dataset
after batching.
Returns:
tuple: of feature_dict and label_dict
"""
features = parse_serial_batch(serial_batch,
keys_to_parsers,
keys_to_handlers)
feature_dict = {key: features[key] for key in feature_keys}
labels = {key: features[key] for key in label_keys}
return (feature_dict, labels)
return fn
[docs]@input_fn_ingred.capture
def shuffle_repeat_prefetch(dataset, buffer_size, n_epochs, random_seed):
"""Wrapper for shuffle+repeat+prefetch
Shuffle, repeat, and prefetch a dataset
Args:
dataset (tf.data.Dataset): Dataset to be processed
buffer_size (int): Number of objects to be buffered, and prefetched
n_epochs (int): Number of dataset repetitions
random_seed: control randomness i.e. reproducibility
Returns:
dataset (tf.data.Dataset): Processed dataset
"""
dataset = dataset.apply(
tf.contrib.data.shuffle_and_repeat(
buffer_size=buffer_size,
count=n_epochs,
seed=random_seed))
dataset = dataset.prefetch(buffer_size)
return dataset
[docs]@input_fn_ingred.capture
def any_record_exists(split, record_dir, record_pattern):
"""Indicator if any record of given pattern exists
Args:
split (str): Split descriptor ("train", "eval", "test")
record_dir (str): Directory where records are contained
record_pattern (str): Pattern for record files
Returns:
bool: True, if at least one record with record_pattern exists in record_dir.
"""
record_pattern = os.path.join(record_dir,
record_pattern.format(split=split, idx="*"))
return len(glob.glob(record_pattern)) > 0
[docs]@input_fn_ingred.capture
def dataset_from_records(split, record_dir, record_pattern, random_seed):
"""Load all records as a file list
Args:
split (str): Split descriptor ("train", "eval", "test")
record_dir (str): Directory where records are contained
record_pattern (str): Pattern for record files
random_seed: control randomness i.e. reproducibility
Returns:
dataset (tf.data.Dataset): List of record_file strings
"""
record_pattern = os.path.join(record_dir,
record_pattern.format(split=split, idx="*"))
dataset = tf.data.Dataset.list_files(record_pattern, seed=random_seed)
return dataset