import os
import tensorflow as tf
import numpy as np
import pandas as pd
from itertools import product
tfgan = tf.contrib.gan
[docs]def bytes_feature(value):
"""Convert value to TF bytes feature
Used during serialization of features
Args:
value: instance of bytes, bytes list, or np.array
Returns:
tf.train.Feature
"""
if isinstance(value, (list, tuple)):
if not all([isinstance(x, bytes) for x in value]):
raise ValueError("bytes_feature expects list of bytes")
elif isinstance(value, bytes):
value = [value]
elif isinstance(value, np.ndarray):
value = [value.flatten().tostring()]
else:
raise ValueError("bytes_feature got bad value")
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
[docs]def int64_feature(value):
"""Convert value to TF int64 feature
Used during serialization of features
Args:
value: instance of int, int list, or np.int
Returns:
tf.train.Feature
"""
if isinstance(value, np.ndarray):
if "int" not in value.dtype.name:
raise ValueError("int64_feature requires dtype int")
value = [int(x) for x in value.flatten()]
elif isinstance(value, (list, tuple)):
if not all([isinstance(x, int) for x in value]):
raise ValueError("int64_feature expects all elements as int")
elif isinstance(value, int):
value = [value]
elif isinstance(value, (np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64)):
value = [int(value)]
else:
raise ValueError("int64_feature got bad value")
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
[docs]def float32_feature(value):
"""Convert value to TF float32 feature
Used during serialization of features
Args:
value: instance of float, float list, or np.flot
Returns:
tf.train.Feature
"""
if isinstance(value, np.ndarray):
value = [float(x) for x in value.flatten()]
elif isinstance(value, (list, tuple)):
if not all([isinstance(x, float) for x in value]):
raise ValueError("float32_feature expects all elements as float")
elif isinstance(value, float):
value = [value]
elif isinstance(value, (np.float16, np.float32, np.float64)):
value = [float(value)]
else:
raise ValueError("float32_feature got bad value")
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
[docs]class PrintOnce(object):
"""Utility to print a message only once
Args:
message: Formattable string
"""
def __init__(self, message):
super(PrintOnce, self).__init__()
self.message = message
self._printed = False
[docs] def print(self, *args, **kwargs):
if not self._printed:
print(self.message.format(*args, **kwargs))
self._printed = True
[docs]def mkdir_and_join(record_dir, record_pattern):
"""
Args:
record_dir (str): Path to record directory
record_pattern (str): Record file pattern
Returns:
(str): Full record file pattern, including directoy
"""
if not os.path.exists(record_dir):
os.mkdir(record_dir)
return os.path.join(record_dir, record_pattern)
[docs]def to_int_size(split_to_size, n_samples):
if all([isinstance(size, int) for size in split_to_size.values()]):
return split_to_size
else:
return {key: int(val * n_samples) for key, val in split_to_size.items()}
[docs]def flip_idx(series):
"""Return index of first flip from "1" to "0"
It is assumed, that the series starts with a block
of "1", then the index of the last element of that
block is returned, otherwise -1.
Args:
series: list, np.array, or pd.Series
Examples:
>>> flip_idx([1,1,1,0,0])
2
>>> flip_idx([1,1,0,1,1])
1
>>> flip_idx([1,1,1,1,1]) # No flip!
-1
>>> flip_idx([0,0,1,1,1]) # No flip "1" to "0"!
-1
Returns:
int: Index of the last "1" before flipping to "0",
-1 if no flip from "1" to "0".
"""
if isinstance(series, pd.Series):
series = series.values
elif isinstance(series, list):
series = np.array(series)
if ((series == 1).all() or
(series == 0).all() or
series[0] == 0 or
len(series) == 1):
return -1
return (series[:-1] - series[1:]).argmax()
[docs]def is_one_after(series, idx):
"""Check if all elements are "1" after idx.
Args:
series: list, np.array, or pd.Series
idx (int): Last index before check (not included)
Returns:
bool: True, if all elements after idx are "1".
Examples:
>>> is_one_after([0,0,1,1,1], 1)
True
>>> is_one_after([0,0,1,1,0], 1)
False
"""
if isinstance(series, list):
series = np.array(series)
if series[idx+1:].sum() != len(series) - (idx + 1):
return False
else:
return True
[docs]def hasflip(df, sort_key="age", from_key="mci", to_key="ad", within=None):
"""Detect binary flip between two Series
Check if from_key flips from 1 to 0 at some point,
and if to_key flips from 0 to 1 at the same time.
Args:
df: Dataframe
Returns:
bool: True, if flip occurs.
Examples:
>>> w = pd.DataFrame({"age": [1,2,3,4], "mci": [1,1,0,0], "ad": [0,0,1,1]})
>>> hasflip(w)
True
>>> x = pd.DataFrame({"age": [1,2,3,4], "mci": [1,1,1,1], "ad": [0,0,0,0]})
>>> hasflip(x)
False
>>> y = pd.DataFrame({"age": [1,2,3,4], "mci": [1,1,0,1], "ad": [0,0,1,0]})
>>> hasflip(y)
False
>>> z = pd.DataFrame({"age": [1,2,3,4], "mci": [1,1,1,0], "ad": [0,0,0,1]})
>>> hasflip(z, within=2)
False
"""
df = df.sort_values(by=[sort_key])
fidx = flip_idx(df[from_key])
if (
(within is not None and df[sort_key].iloc[fidx+1] - df[sort_key].iloc[0] > within) or
fidx == -1 or
not is_one_after(df[to_key], fidx)
):
return False
else:
return True
[docs]def spans(df, key, mode, span):
"""Check if a numeric column spans a certain range.
Args:
df (Dataframe): Dataframe containing samples
key (str): Numeric column in df
mode (str): one of "at_least", "at_most", "more_than",
"less_than".
span (int, float): Numeric range that is tested.
Returns:
bool: True, if numeric span of key complies with mode.
Examples:
>>> x = pd.DataFrame({"age": [1,2,3,4]})
>>> spans(x, "age", "at_least", 2)
True
>>> spans(x, "age", "at_most", 2)
False
>>> spans(x, "age", "more_than", 2)
True
>>> spans(x, "age", "less_than", 2)
False
"""
MODES = ["at_least", "at_most", "more_than", "less_than"]
if mode not in MODES:
raise ValueError("mode not in {}".format(MODES))
df = df.sort_values(by=[key])
if mode == "at_least":
return df[key].max() >= df[key].min() + span
elif mode == "at_most":
return df[key].max() <= df[key].min() + span
elif mode == "more_than":
return df[key].max() > df[key].min() + span
elif mode == "less_than":
return df[key].max() < df[key].min() + span
[docs]def filter_groups(df, class_to_filter, group_key=None):
"""Separate dataframe into classes defined by filter
Add a class column to df, indicating class defined
by filter functions.
Args:
df: Dataframe containing samples
class_to_filter (dict): Mapping of class names to
filter functions. Filter functions take a Dataframe
as input and return bool, which indicates if sample
belongs to class.
group_key (str): Group df by group_key, if not None.
Returns:
Dataframe: Input df extended by a column "class".
Examples:
>>> df = pd.DataFrame({"subject": [1,1,2,2,3,3], "gender": [1,1,0,0,0,1]})
>>> class_to_filter = {"male": lambda x: x["gender"].all(),
... "female": lambda x: not x["gender"].any(),
... "trans": lambda x: not x["gender"].all() and x["gender"].any()}
>>> filter_groups(df, class_to_filter, "subject")
subject gender class
0 1 1 male
1 1 1 male
2 2 0 female
3 2 0 female
4 3 0 trans
5 3 1 trans
"""
if group_key is not None:
df = df.groupby(group_key)
cls_list = []
pd.options.mode.chained_assignment = None
for cls_key, cls_filter in class_to_filter.items():
cls_df = df.filter(cls_filter)
cls_df["class"] = cls_key
cls_list.append(cls_df)
return pd.concat(cls_list)
[docs]def build_pairs(df, cond=lambda x,y: True, no_duplicate=[]):
"""Cartesian Product of dataframe
Produce all combinations of df x df where
cond is true.
Args:
df: Dataframe
cond: function with two arguments, returning bool
no_duplicate (str): Keys which should not be duplicated.
Returns:
DataFrame: Containing all pairs of df x df, where
cond was true.
Examples:
>>> df = pd.DataFrame({"age": [1,2,3], "gender": ["m", "f", "m"]})
>>> cond = lambda x,y: y["age"].values > x["age"].values
>>> build_pairs(df, cond)
age_0 gender_0 age_1 gender_1
0 1 m 2 f
1 1 m 3 m
2 2 f 3 m
"""
pairs = []
for (_, img1), (_, img2) in product(df.iterrows(), repeat=2):
img1 = img1.to_frame().T
img2 = img2.to_frame().T
img1["dummy"] = ""
img2["dummy"] = ""
if cond(img1, img2):
sample = img1.merge(img2, on="dummy", how="outer", suffixes=("_0", "_1"))
pairs.append(sample)
if pairs != []:
pairs = pd.concat(pairs, ignore_index=True)
pairs = pairs.drop(columns=["dummy"])
if no_duplicate != []:
pairs = pairs.rename(index=str,
columns={key + "_0": key for key in no_duplicate})
pairs = pairs.drop(columns=[key + "_1" for key in no_duplicate])
return pairs
else:
return None
[docs]def aggregate(df, class_to_filter, group_key=None, builder=None):
"""Aggregate class samples from dataframe
Group rows of df by group_key, and use class filters to
annotate each row. If builder is provided, group annotated
rows again by class and group_key, and apply builder.
Args:
df: Dataframe
class_to_filter (dict): Mapping of class names to
filter functions. Filter functions take a Dataframe
as input and return bool, which indicates if sample
belongs to class.
group_key (str): Which columns to use for grouping.
builder: Function that accepts Dataframe, and returns
a processed Dataframe.
Returns:
Dataframe: Aggregated Dataframe, where each row
represents one sample. The column "class" indicates
the class of the sample.
"""
class_df = filter_groups(df, class_to_filter, group_key)
if builder is not None:
class_df = class_df.groupby(["class", group_key],
group_keys=False)
class_df = class_df.apply(builder)
if (group_key in class_df.index.names and
group_key in class_df.columns):
class_df.reset_index(group_key,
drop=True, inplace=True)
return class_df
[docs]def ispowerof2(x):
"""Check if x is a power of 2
Examples:
>>> ispowerof2(64)
True
>>> ispowerof2(63)
False
"""
return (x & (x - 1)) == 0
[docs]def approx_square(x):
"""Calculate even grid height and width
Assumes that x is a power of 2!
Args:
x (int): Assumed to be a power of 2
Returns:
list: [h, w] such that x = 2**w * 2**h, and w,h as close to sqrt(x) as
possible
Examples:
>>> approx_square(32)
[4, 8]
>>> approx_square(64)
[8, 8]
>>> approx_square(128)
[8, 16]
"""
if not ispowerof2(x):
raise ValueError("Expected power of 2, got {}".format(x))
root = np.log2(x) / 2.0
h = (2**np.floor(root)).astype(np.int32)
w = (2**np.ceil(root)).astype(np.int32)
return [h,w]
[docs]def grid_size_from(tensor, axis=0):
"""Calculate appoximate grid size from axis
Assumes that length of axis is a power of 2!
Args:
tensor: Typically a batch of images
axis (int): Which axis to use for approximate grid size. Default: 0
Returns:
list: Approximate grid size [h, w]
"""
x = tensor.shape[axis].value
return approx_square(x)
[docs]def img_grid_summary(name, tensor):
"""Create a image grid summary
Args:
name (str): Name for image grid summary
tensor: Image tensor with shape [batch_size, img_height, img_width,
channels].
Returns: None
"""
if tensor.shape[0].value < 8:
grid_size = [1, tensor.shape[0].value]
else:
grid_size = grid_size_from(tensor)
img_h = tensor.shape[1].value
img_w = tensor.shape[2].value
num_channels = tensor.shape[3].value
grid = tfgan.eval.image_grid(tensor, grid_size, [img_h, img_w],
num_channels)
tf.summary.image(name, grid, max_outputs=1)
[docs]def slice_from(axis, position, n_dims=3):
"""Convert axis and position into slice
Args:
axis (int): Slicing axis
position (int): Position of slice along axis
Returns:
tuple: Slice tuple for indexing into an array
Examples:
>>> slice_from(0, 2)
(2, slice(None, None, None), slice(None, None, None))
>>> slice_from(0, 2, n_dims=2)
(2, slice(None, None, None))
"""
slice_list = [slice(None, None, None) for _ in range(n_dims)]
slice_list[axis] = position
return tuple(slice_list)
[docs]def scale(x, newmin=-1, newmax=1):
"""Scale all entries in x between newmin and newmax
Args:
x (array): Array to be scaled
Returns:
array: Scaled array with same shape as x
Examples:
>>> scale([-2, 0, 2])
array([-1., 0., 1.])
>>> scale([0, 1, 4])
array([-1. , -0.5, 1. ])
"""
if isinstance(x, (list,np.ndarray)):
oldmax = np.max(x)
oldmin = np.min(x)
else:
oldmax = tf.reduce_max(x)
oldmin = tf.reduce_min(x)
return newmin + (x - oldmin) / (oldmax - oldmin) * (newmax - newmin)