import os
import numpy as np
import tensorflow as tf
import pandas as pd
from utils import PrintOnce, mkdir_and_join, to_int_size
from ingredient_wrapper import Ingredient
serializer_ingred = Ingredient("serializer")
@serializer_ingred.config
def config():
record_dir = "" # Directory where records are saved
record_pattern = "{split}_{idx}.tfrecord"
samples_per_record = None
split_to_size = {
"train": None,
"eval": None,
"test": None
}
compression = "NONE" # NONE, GZIP, ZLIB
[docs]@serializer_ingred.capture
def serialize_npz(npz_file,
split_to_size,
xy_to_key):
"""Create TFRecords from NPZ"""
data = np.load(npz_file)
def rows_gen(split):
try:
X=data[xy_to_key["x"].format(split=split)]
y=data[xy_to_key["y"].format(split=split)]
except KeyError:
print("{} split not in data.".format(split))
else:
n_samples = X.shape[0]
assert len(y) == n_samples
assert split_to_size[split] <= n_samples
for i in range(split_to_size[split]):
yield {"X": X[i], "y": y[i]}
write_records(rows_gen=rows_gen)
[docs]@serializer_ingred.capture
def write_records(record_dir,
record_pattern,
keys_to_handlers,
samples_per_record,
split_to_size,
rows_gen):
record_pattern = mkdir_and_join(record_dir, record_pattern)
size_info = PrintOnce("TFRecord has size {} MB")
for split in ["train", "eval", "test"]:
rgen = rows_gen(split)
n_records, n_samples_left = div_mod(split_to_size[split],
samples_per_record)
for idx in range(n_records):
writer, record_path = get_writer(record_pattern, split, idx)
write_samples(samples_per_record, rgen, keys_to_handlers, writer)
record_size = os.path.getsize(record_path)
size_info.print(record_size // 10**6)
if n_samples_left > 0:
writer, _ = get_writer(record_pattern, split, n_records)
write_samples(n_samples_left, rgen, keys_to_handlers, writer)
[docs]@serializer_ingred.capture
def get_writer(record_pattern, split, idx, compression):
record_path = record_pattern.format(split=split, idx=idx)
opts = tf.python_io.TFRecordOptions(
getattr(tf.python_io.TFRecordCompressionType, compression))
return tf.python_io.TFRecordWriter(record_path, opts), record_path
[docs]def write_samples(n_samples, rgen, keys_to_handlers, writer):
for i in range(n_samples):
write_row(next(rgen), keys_to_handlers, writer)
writer.close()
[docs]def features_from_row(row, keys_to_handlers):
features = {}
for key, handler in keys_to_handlers.items():
if handler.delegate_to is None:
for k, v in handler.handle(row, key).items():
features.update({k: v})
return tf.train.Features(feature=features)
[docs]def write_row(row, keys_to_handlers, writer):
features = features_from_row(row, keys_to_handlers)
example = tf.train.Example(features=features)
writer.write(example.SerializeToString())
[docs]def div_mod(x, y):
return x // y, x % y