diff --git a/Snakefile b/Snakefile index c504f623..06d323a6 100644 --- a/Snakefile +++ b/Snakefile @@ -14,7 +14,9 @@ rule all: days_before_surgery = config["METRICS_FOR_ANALYSIS"]["DAYS_BEFORE_SURGERY"], days_after_discharge= config["METRICS_FOR_ANALYSIS"]["DAYS_AFTER_DISCHARGE"], days_in_hospital= config["METRICS_FOR_ANALYSIS"]["DAYS_IN_HOSPITAL"]), - + expand("data/processed/{pid}/targets_{summarised}.csv", + pid = config["PIDS"], + summarised = config["METRICS_FOR_ANALYSIS"]["SUMMARISED"]), # Feature extraction expand("data/raw/{pid}/{sensor}_raw.csv", pid=config["PIDS"], sensor=config["SENSORS"]), expand("data/raw/{pid}/{sensor}_raw.csv", pid=config["PIDS"], sensor=config["FITBIT_TABLE"]), diff --git a/rules/mystudy.snakefile b/rules/mystudy.snakefile index 6d226335..466b87ad 100644 --- a/rules/mystudy.snakefile +++ b/rules/mystudy.snakefile @@ -9,3 +9,13 @@ rule days_to_analyse: "data/interim/{pid}/days_to_analyse_{days_before_surgery}_{days_in_hospital}_{days_after_discharge}.csv" script: "../src/models/select_days_to_analyse.py" + +rule get_targets: + input: + participant_info = "data/raw/{pid}/" + config["METRICS_FOR_ANALYSIS"]["GROUNDTRUTH_TABLE"] + "_raw.csv" + params: + summarised = "{summarised}" + output: + "data/processed/{pid}/targets_{summarised}.csv" + script: + "../src/models/get_targets.py" diff --git a/src/models/get_targets.py b/src/models/get_targets.py new file mode 100644 index 00000000..9a1629b7 --- /dev/null +++ b/src/models/get_targets.py @@ -0,0 +1,16 @@ +import pandas as pd + +participant_info = pd.read_csv(snakemake.input["participant_info"]) +summarised = snakemake.params["summarised"] +pid = snakemake.input["participant_info"].split("/")[2] + +targets = pd.DataFrame({"pid": [pid], "target": [None]}) +if summarised == "summarised": + if not participant_info.empty: + cesds = participant_info.loc[0, ["preop_cesd_total", "inpatient_cesd_total", "postop_cesd_total", "3month_cesd_total"]] + # targets: 1 => 50% (ceiling) or more of available CESD scores were 16 or higher; 0 => otherwise + threshold_num = (cesds.count() + 1) // 2 + threshold_cesd = 16 + target = 1 if cesds.apply(lambda x : 1 if x >= threshold_cesd else 0).sum() >= threshold_num else 0 + targets.loc[0, "target"] = target +targets.to_csv(snakemake.output[0], index=False)