Source code for handlers

import os

import tensorflow as tf
import pandas as pd 
import nibabel as nib
import numpy as np
import nilearn as nl

import scipy

import utils


[docs]class FeatureHandler(object): """Base class for converting data into tensorflow features This class is used during serialization of data into tfrecords. Moreover, hanlder properties are also used when parsing tfrecords. Attributes: delegate_to (str): Name of handler that manages the feature. shape (list): List of int containing the shape of the feature. """ def __init__(self, delegate_to, shape): super(FeatureHandler, self).__init__() self.delegate_to = delegate_to self.shape=shape
[docs] def handle(self, row, key): """Convert data dict into tf feature. Args: row (dict): Dictionary containing all features. key (str): Name of feature to be handled. Returns: dict: key to tf feature. """ pass
@property def dtype(self): """Returns the tf data type. This property is used when parsing a tfrecord. """ pass @property def isfixedlen(self): """Returns if handler works on fixed length features. This property is used when parsing a tfrecord. """ pass
[docs]class Int64Handler(FeatureHandler): """Handler for integer features""" def __init__(self, delegate_to=None, shape=[]): super(Int64Handler, self).__init__(delegate_to, shape=shape)
[docs] def handle(self, row, key): return {key: utils.int64_feature(row[key])}
@property def dtype(self): return tf.int64 @property def isfixedlen(self): return True
[docs]class Float32Handler(FeatureHandler): """Handler for float features""" def __init__(self, delegate_to=None, shape=[]): super(Float32Handler, self).__init__(delegate_to, shape=shape)
[docs] def handle(self, row, key): return {key: utils.float32_feature(row[key])}
@property def dtype(self): return tf.float32 @property def isfixedlen(self): return True
[docs]class BytesHandler(FeatureHandler): """Handler for byte features""" def __init__(self, delegate_to=None, shape=[]): super(BytesHandler, self).__init__(delegate_to, shape=shape)
[docs] def handle(self, row, key): return {key: utils.bytes_feature(row[key])}
@property def dtype(self): return tf.string @property def isfixedlen(self): return True
[docs]class NiftiHandler(FeatureHandler): """Handler for nifti images Loads nifti images using an img_id, potentially resizes them to img_shape, and returns a byte feature. Moreover, it adds img_shape as int feature. Attributes: img_folder (str): Path to folder containing the nifti images. img_shape (list of int): Desired shape for the images. img_dtype (str): Desired data type for the images. slices (dict): Mapping of slice name to slice tuple """ def __init__(self, img_folder, img_shape, img_dtype, img_slice=None, img_key="image"): super(NiftiHandler, self).__init__(delegate_to=None, shape=[]) self.img_folder = img_folder self.img_shape = img_shape self.img_dtype = getattr(np, img_dtype) self.img_key = img_key self.img_slice = img_slice
[docs] def handle(self, row, key): img_id = row[self.img_key] img_path = os.path.join(self.img_folder, str(img_id) + ".nii.gz") img = nib.load(img_path).get_data().astype(self.img_dtype) if self.img_slice is not None: img = img[utils.slice_from(*self.img_slice)] if self.img_shape != list(img.shape): print("Warning: Resizing image {} from {} to {}.".format( img_id, img.shape, self.img_shape)) img = scipy.misc.imresize(img, self.img_shape) img_feature = {} img_feature[key] = utils.bytes_feature(img.tostring()) img_feature[key + "/shape"] = utils.int64_feature(self.img_shape) return img_feature
@property def dtype(self): return tf.string @property def isfixedlen(self): return True
[docs]def features_from_handlers(keys_to_handlers): """Construct Fixed/VarLen Features Used for parsing of serialized data, converts feature information (isfixedlen, shape, dtype) into TF feature. Args: keys_to_handlers (dict): Mapping of keys to handlers Returns: keys_to_features (dict): Mapping of keys to TF features """ keys_to_features = {} for key, handler in keys_to_handlers.items(): if handler.isfixedlen: keys_to_features[key] = tf.FixedLenFeature(handler.shape, handler.dtype) else: keys_to_features[key] = tf.VarLenFeature(handler.dtype) return keys_to_features