import os
import tensorflow as tf
import pandas as pd
import nibabel as nib
import numpy as np
import nilearn as nl
import scipy
import utils
[docs]class FeatureHandler(object):
"""Base class for converting data into tensorflow features
This class is used during serialization of data into tfrecords.
Moreover, hanlder properties are also used when parsing tfrecords.
Attributes:
delegate_to (str): Name of handler that manages the feature.
shape (list): List of int containing the shape of the feature.
"""
def __init__(self, delegate_to, shape):
super(FeatureHandler, self).__init__()
self.delegate_to = delegate_to
self.shape=shape
[docs] def handle(self, row, key):
"""Convert data dict into tf feature.
Args:
row (dict): Dictionary containing all features.
key (str): Name of feature to be handled.
Returns:
dict: key to tf feature.
"""
pass
@property
def dtype(self):
"""Returns the tf data type.
This property is used when parsing a tfrecord.
"""
pass
@property
def isfixedlen(self):
"""Returns if handler works on fixed length features.
This property is used when parsing a tfrecord.
"""
pass
[docs]class Int64Handler(FeatureHandler):
"""Handler for integer features"""
def __init__(self, delegate_to=None, shape=[]):
super(Int64Handler, self).__init__(delegate_to, shape=shape)
[docs] def handle(self, row, key):
return {key: utils.int64_feature(row[key])}
@property
def dtype(self):
return tf.int64
@property
def isfixedlen(self):
return True
[docs]class Float32Handler(FeatureHandler):
"""Handler for float features"""
def __init__(self, delegate_to=None, shape=[]):
super(Float32Handler, self).__init__(delegate_to, shape=shape)
[docs] def handle(self, row, key):
return {key: utils.float32_feature(row[key])}
@property
def dtype(self):
return tf.float32
@property
def isfixedlen(self):
return True
[docs]class BytesHandler(FeatureHandler):
"""Handler for byte features"""
def __init__(self, delegate_to=None, shape=[]):
super(BytesHandler, self).__init__(delegate_to, shape=shape)
[docs] def handle(self, row, key):
return {key: utils.bytes_feature(row[key])}
@property
def dtype(self):
return tf.string
@property
def isfixedlen(self):
return True
[docs]class NiftiHandler(FeatureHandler):
"""Handler for nifti images
Loads nifti images using an img_id, potentially resizes them to img_shape,
and returns a byte feature. Moreover, it adds img_shape as int feature.
Attributes:
img_folder (str): Path to folder containing the nifti images.
img_shape (list of int): Desired shape for the images.
img_dtype (str): Desired data type for the images.
slices (dict): Mapping of slice name to slice tuple
"""
def __init__(self, img_folder, img_shape, img_dtype, img_slice=None,
img_key="image"):
super(NiftiHandler, self).__init__(delegate_to=None, shape=[])
self.img_folder = img_folder
self.img_shape = img_shape
self.img_dtype = getattr(np, img_dtype)
self.img_key = img_key
self.img_slice = img_slice
[docs] def handle(self, row, key):
img_id = row[self.img_key]
img_path = os.path.join(self.img_folder, str(img_id) + ".nii.gz")
img = nib.load(img_path).get_data().astype(self.img_dtype)
if self.img_slice is not None:
img = img[utils.slice_from(*self.img_slice)]
if self.img_shape != list(img.shape):
print("Warning: Resizing image {} from {} to {}.".format(
img_id, img.shape, self.img_shape))
img = scipy.misc.imresize(img, self.img_shape)
img_feature = {}
img_feature[key] = utils.bytes_feature(img.tostring())
img_feature[key + "/shape"] = utils.int64_feature(self.img_shape)
return img_feature
@property
def dtype(self):
return tf.string
@property
def isfixedlen(self):
return True
[docs]def features_from_handlers(keys_to_handlers):
"""Construct Fixed/VarLen Features
Used for parsing of serialized data, converts
feature information (isfixedlen, shape, dtype)
into TF feature.
Args:
keys_to_handlers (dict): Mapping of keys to handlers
Returns:
keys_to_features (dict): Mapping of keys to TF features
"""
keys_to_features = {}
for key, handler in keys_to_handlers.items():
if handler.isfixedlen:
keys_to_features[key] = tf.FixedLenFeature(handler.shape,
handler.dtype)
else:
keys_to_features[key] = tf.VarLenFeature(handler.dtype)
return keys_to_features