import pandas as pd
from sacred import Experiment
from datasets.aggregate import meta_csv_aggregator_ingred
from utils import spans, hasflip, build_pairs, aggregate
brain_csv_aggregation = Experiment("brain_csv_aggregation",
ingredients=[meta_csv_aggregator_ingred])
@brain_csv_aggregation.config
def config():
meta_csv_file = "datasets/brain/test_data/meta_data_1200.csv"
D = 3 # conversion window
patient_id_key = "patient_id"
builder = None
protos = "HC_AD"
pair_span = 2
pair_slack = 0.25
pair_low = pair_span - pair_slack
pair_high = pair_span + pair_slack
# Only t0 < t1 pairs
if protos == "pairs":
proto_csv_file = "datasets/brain/test_data/pair_protos.csv"
def class_to_filter_fn(D):
return {
"HC": lambda x: x["HC"].all(),
"AD": lambda x: x["AD"].all(),
"MCI": lambda x: spans(x, "age", "at_most", D) and x["MCI"].all(),
}
class_to_filter = class_to_filter_fn(D)
def builder_fn(pair_low, pair_high, patient_id_key):
return lambda df: build_pairs(df,
cond=lambda x,y: (
pair_low < y["age"].values - x["age"].values < pair_high
),
no_duplicate=["class", patient_id_key])
builder = builder_fn(pair_low, pair_high, patient_id_key)
elif protos == "HC_AD":
proto_csv_file = "datasets/brain/test_data/HC_AD_protos.csv"
class_to_filter = {
"HC": lambda x: x["HC"].all(),
"AD": lambda x: x["AD"].all()
}
elif protos == "spMCI":
proto_csv_file = "datasets/brain/test_data/spMCI_protos.csv"
def class_to_filter_fn(D):
return {
"sMCI": lambda x: spans(x, "age", "more_than", D) and x["MCI"].all(),
"pMCI": lambda x: (
spans(x, "age", "more_than", D) and
hasflip(x, sort_key="age", from_key="MCI", to_key="AD", within=D)
)
}
class_to_filter = class_to_filter_fn(D)
builder = lambda df: df.ix[df["age"].idxmin()]
[docs]@brain_csv_aggregation.automain
def aggregate_protos(
meta_csv_file,
proto_csv_file,
class_to_filter,
builder,
patient_id_key,
_log):
df = pd.read_csv(meta_csv_file)
protos = aggregate(df, class_to_filter,
group_key=patient_id_key,
builder=builder)
protos.to_csv(proto_csv_file, index=False)
_log.info("n_samples = {}".format(protos.shape[0]))
_log.info("n_unique_patients = {}".format(
protos.groupby(patient_id_key).ngroups))