Source code for utils

import os

import tensorflow as tf
import numpy as np
import pandas as pd

from itertools import product

tfgan = tf.contrib.gan

[docs]def bytes_feature(value): """Convert value to TF bytes feature Used during serialization of features Args: value: instance of bytes, bytes list, or np.array Returns: tf.train.Feature """ if isinstance(value, (list, tuple)): if not all([isinstance(x, bytes) for x in value]): raise ValueError("bytes_feature expects list of bytes") elif isinstance(value, bytes): value = [value] elif isinstance(value, np.ndarray): value = [value.flatten().tostring()] else: raise ValueError("bytes_feature got bad value") return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
[docs]def int64_feature(value): """Convert value to TF int64 feature Used during serialization of features Args: value: instance of int, int list, or np.int Returns: tf.train.Feature """ if isinstance(value, np.ndarray): if "int" not in value.dtype.name: raise ValueError("int64_feature requires dtype int") value = [int(x) for x in value.flatten()] elif isinstance(value, (list, tuple)): if not all([isinstance(x, int) for x in value]): raise ValueError("int64_feature expects all elements as int") elif isinstance(value, int): value = [value] elif isinstance(value, (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64)): value = [int(value)] else: raise ValueError("int64_feature got bad value") return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
[docs]def float32_feature(value): """Convert value to TF float32 feature Used during serialization of features Args: value: instance of float, float list, or np.flot Returns: tf.train.Feature """ if isinstance(value, np.ndarray): value = [float(x) for x in value.flatten()] elif isinstance(value, (list, tuple)): if not all([isinstance(x, float) for x in value]): raise ValueError("float32_feature expects all elements as float") elif isinstance(value, float): value = [value] elif isinstance(value, (np.float16, np.float32, np.float64)): value = [float(value)] else: raise ValueError("float32_feature got bad value") return tf.train.Feature(float_list=tf.train.FloatList(value=value))
[docs]class PrintOnce(object): """Utility to print a message only once Args: message: Formattable string """ def __init__(self, message): super(PrintOnce, self).__init__() self.message = message self._printed = False
[docs] def print(self, *args, **kwargs): if not self._printed: print(self.message.format(*args, **kwargs)) self._printed = True
[docs]def mkdir_and_join(record_dir, record_pattern): """ Args: record_dir (str): Path to record directory record_pattern (str): Record file pattern Returns: (str): Full record file pattern, including directoy """ if not os.path.exists(record_dir): os.mkdir(record_dir) return os.path.join(record_dir, record_pattern)
[docs]def to_int_size(split_to_size, n_samples): if all([isinstance(size, int) for size in split_to_size.values()]): return split_to_size else: return {key: int(val * n_samples) for key, val in split_to_size.items()}
[docs]def flip_idx(series): """Return index of first flip from "1" to "0" It is assumed, that the series starts with a block of "1", then the index of the last element of that block is returned, otherwise -1. Args: series: list, np.array, or pd.Series Examples: >>> flip_idx([1,1,1,0,0]) 2 >>> flip_idx([1,1,0,1,1]) 1 >>> flip_idx([1,1,1,1,1]) # No flip! -1 >>> flip_idx([0,0,1,1,1]) # No flip "1" to "0"! -1 Returns: int: Index of the last "1" before flipping to "0", -1 if no flip from "1" to "0". """ if isinstance(series, pd.Series): series = series.values elif isinstance(series, list): series = np.array(series) if ((series == 1).all() or (series == 0).all() or series[0] == 0 or len(series) == 1): return -1 return (series[:-1] - series[1:]).argmax()
[docs]def is_one_after(series, idx): """Check if all elements are "1" after idx. Args: series: list, np.array, or pd.Series idx (int): Last index before check (not included) Returns: bool: True, if all elements after idx are "1". Examples: >>> is_one_after([0,0,1,1,1], 1) True >>> is_one_after([0,0,1,1,0], 1) False """ if isinstance(series, list): series = np.array(series) if series[idx+1:].sum() != len(series) - (idx + 1): return False else: return True
[docs]def hasflip(df, sort_key="age", from_key="mci", to_key="ad", within=None): """Detect binary flip between two Series Check if from_key flips from 1 to 0 at some point, and if to_key flips from 0 to 1 at the same time. Args: df: Dataframe Returns: bool: True, if flip occurs. Examples: >>> w = pd.DataFrame({"age": [1,2,3,4], "mci": [1,1,0,0], "ad": [0,0,1,1]}) >>> hasflip(w) True >>> x = pd.DataFrame({"age": [1,2,3,4], "mci": [1,1,1,1], "ad": [0,0,0,0]}) >>> hasflip(x) False >>> y = pd.DataFrame({"age": [1,2,3,4], "mci": [1,1,0,1], "ad": [0,0,1,0]}) >>> hasflip(y) False >>> z = pd.DataFrame({"age": [1,2,3,4], "mci": [1,1,1,0], "ad": [0,0,0,1]}) >>> hasflip(z, within=2) False """ df = df.sort_values(by=[sort_key]) fidx = flip_idx(df[from_key]) if ( (within is not None and df[sort_key].iloc[fidx+1] - df[sort_key].iloc[0] > within) or fidx == -1 or not is_one_after(df[to_key], fidx) ): return False else: return True
[docs]def spans(df, key, mode, span): """Check if a numeric column spans a certain range. Args: df (Dataframe): Dataframe containing samples key (str): Numeric column in df mode (str): one of "at_least", "at_most", "more_than", "less_than". span (int, float): Numeric range that is tested. Returns: bool: True, if numeric span of key complies with mode. Examples: >>> x = pd.DataFrame({"age": [1,2,3,4]}) >>> spans(x, "age", "at_least", 2) True >>> spans(x, "age", "at_most", 2) False >>> spans(x, "age", "more_than", 2) True >>> spans(x, "age", "less_than", 2) False """ MODES = ["at_least", "at_most", "more_than", "less_than"] if mode not in MODES: raise ValueError("mode not in {}".format(MODES)) df = df.sort_values(by=[key]) if mode == "at_least": return df[key].max() >= df[key].min() + span elif mode == "at_most": return df[key].max() <= df[key].min() + span elif mode == "more_than": return df[key].max() > df[key].min() + span elif mode == "less_than": return df[key].max() < df[key].min() + span
[docs]def filter_groups(df, class_to_filter, group_key=None): """Separate dataframe into classes defined by filter Add a class column to df, indicating class defined by filter functions. Args: df: Dataframe containing samples class_to_filter (dict): Mapping of class names to filter functions. Filter functions take a Dataframe as input and return bool, which indicates if sample belongs to class. group_key (str): Group df by group_key, if not None. Returns: Dataframe: Input df extended by a column "class". Examples: >>> df = pd.DataFrame({"subject": [1,1,2,2,3,3], "gender": [1,1,0,0,0,1]}) >>> class_to_filter = {"male": lambda x: x["gender"].all(), ... "female": lambda x: not x["gender"].any(), ... "trans": lambda x: not x["gender"].all() and x["gender"].any()} >>> filter_groups(df, class_to_filter, "subject") subject gender class 0 1 1 male 1 1 1 male 2 2 0 female 3 2 0 female 4 3 0 trans 5 3 1 trans """ if group_key is not None: df = df.groupby(group_key) cls_list = [] pd.options.mode.chained_assignment = None for cls_key, cls_filter in class_to_filter.items(): cls_df = df.filter(cls_filter) cls_df["class"] = cls_key cls_list.append(cls_df) return pd.concat(cls_list)
[docs]def build_pairs(df, cond=lambda x,y: True, no_duplicate=[]): """Cartesian Product of dataframe Produce all combinations of df x df where cond is true. Args: df: Dataframe cond: function with two arguments, returning bool no_duplicate (str): Keys which should not be duplicated. Returns: DataFrame: Containing all pairs of df x df, where cond was true. Examples: >>> df = pd.DataFrame({"age": [1,2,3], "gender": ["m", "f", "m"]}) >>> cond = lambda x,y: y["age"].values > x["age"].values >>> build_pairs(df, cond) age_0 gender_0 age_1 gender_1 0 1 m 2 f 1 1 m 3 m 2 2 f 3 m """ pairs = [] for (_, img1), (_, img2) in product(df.iterrows(), repeat=2): img1 = img1.to_frame().T img2 = img2.to_frame().T img1["dummy"] = "" img2["dummy"] = "" if cond(img1, img2): sample = img1.merge(img2, on="dummy", how="outer", suffixes=("_0", "_1")) pairs.append(sample) if pairs != []: pairs = pd.concat(pairs, ignore_index=True) pairs = pairs.drop(columns=["dummy"]) if no_duplicate != []: pairs = pairs.rename(index=str, columns={key + "_0": key for key in no_duplicate}) pairs = pairs.drop(columns=[key + "_1" for key in no_duplicate]) return pairs else: return None
[docs]def aggregate(df, class_to_filter, group_key=None, builder=None): """Aggregate class samples from dataframe Group rows of df by group_key, and use class filters to annotate each row. If builder is provided, group annotated rows again by class and group_key, and apply builder. Args: df: Dataframe class_to_filter (dict): Mapping of class names to filter functions. Filter functions take a Dataframe as input and return bool, which indicates if sample belongs to class. group_key (str): Which columns to use for grouping. builder: Function that accepts Dataframe, and returns a processed Dataframe. Returns: Dataframe: Aggregated Dataframe, where each row represents one sample. The column "class" indicates the class of the sample. """ class_df = filter_groups(df, class_to_filter, group_key) if builder is not None: class_df = class_df.groupby(["class", group_key], group_keys=False) class_df = class_df.apply(builder) if (group_key in class_df.index.names and group_key in class_df.columns): class_df.reset_index(group_key, drop=True, inplace=True) return class_df
[docs]def ispowerof2(x): """Check if x is a power of 2 Examples: >>> ispowerof2(64) True >>> ispowerof2(63) False """ return (x & (x - 1)) == 0
[docs]def approx_square(x): """Calculate even grid height and width Assumes that x is a power of 2! Args: x (int): Assumed to be a power of 2 Returns: list: [h, w] such that x = 2**w * 2**h, and w,h as close to sqrt(x) as possible Examples: >>> approx_square(32) [4, 8] >>> approx_square(64) [8, 8] >>> approx_square(128) [8, 16] """ if not ispowerof2(x): raise ValueError("Expected power of 2, got {}".format(x)) root = np.log2(x) / 2.0 h = (2**np.floor(root)).astype(np.int32) w = (2**np.ceil(root)).astype(np.int32) return [h,w]
[docs]def grid_size_from(tensor, axis=0): """Calculate appoximate grid size from axis Assumes that length of axis is a power of 2! Args: tensor: Typically a batch of images axis (int): Which axis to use for approximate grid size. Default: 0 Returns: list: Approximate grid size [h, w] """ x = tensor.shape[axis].value return approx_square(x)
[docs]def img_grid_summary(name, tensor): """Create a image grid summary Args: name (str): Name for image grid summary tensor: Image tensor with shape [batch_size, img_height, img_width, channels]. Returns: None """ if tensor.shape[0].value < 8: grid_size = [1, tensor.shape[0].value] else: grid_size = grid_size_from(tensor) img_h = tensor.shape[1].value img_w = tensor.shape[2].value num_channels = tensor.shape[3].value grid = tfgan.eval.image_grid(tensor, grid_size, [img_h, img_w], num_channels) tf.summary.image(name, grid, max_outputs=1)
[docs]def slice_from(axis, position, n_dims=3): """Convert axis and position into slice Args: axis (int): Slicing axis position (int): Position of slice along axis Returns: tuple: Slice tuple for indexing into an array Examples: >>> slice_from(0, 2) (2, slice(None, None, None), slice(None, None, None)) >>> slice_from(0, 2, n_dims=2) (2, slice(None, None, None)) """ slice_list = [slice(None, None, None) for _ in range(n_dims)] slice_list[axis] = position return tuple(slice_list)
[docs]def scale(x, newmin=-1, newmax=1): """Scale all entries in x between newmin and newmax Args: x (array): Array to be scaled Returns: array: Scaled array with same shape as x Examples: >>> scale([-2, 0, 2]) array([-1., 0., 1.]) >>> scale([0, 1, 4]) array([-1. , -0.5, 1. ]) """ if isinstance(x, (list,np.ndarray)): oldmax = np.max(x) oldmin = np.min(x) else: oldmax = tf.reduce_max(x) oldmin = tf.reduce_min(x) return newmin + (x - oldmin) / (oldmax - oldmin) * (newmax - newmin)