From b1f356c3f7f043c3f641b08a200bdd213cabc91f Mon Sep 17 00:00:00 2001 From: junos Date: Fri, 8 Apr 2022 15:36:32 +0200 Subject: [PATCH] Extract a function to be used elsewhere. --- src/models/helper.py | 18 ++++++++++++++++++ src/models/select_targets.py | 17 +++-------------- 2 files changed, 21 insertions(+), 14 deletions(-) create mode 100644 src/models/helper.py diff --git a/src/models/helper.py b/src/models/helper.py new file mode 100644 index 00000000..ffae7208 --- /dev/null +++ b/src/models/helper.py @@ -0,0 +1,18 @@ +import pandas as pd + + +def retain_target_column(df_input: pd.DataFrame, target_variable_name: str): + column_names = df_input.columns + esm_names_index = column_names.str.startswith("phone_esm_straw") + # Find all columns coming from phone_esm, since these are not features for our purposes and we will drop them. + esm_names = column_names[esm_names_index] + target_variable_index = esm_names.str.contains(target_variable_name) + if all(~target_variable_index): + raise ValueError("The requested target (", target_variable_name, + ")cannot be found in the dataset.", + "Please check the names of phone_esm_ columns in all_sensor_features_cleaned_rapids.csv") + sensor_features_plus_target = df_input.drop(esm_names, axis=1) + sensor_features_plus_target["target"] = df_input[esm_names[target_variable_index]] + # We will only keep one column related to phone_esm and that will be our target variable. + # Add it back to the very and of the data frame and rename it to target. + return sensor_features_plus_target diff --git a/src/models/select_targets.py b/src/models/select_targets.py index 89f40f03..196cdcd1 100644 --- a/src/models/select_targets.py +++ b/src/models/select_targets.py @@ -1,21 +1,10 @@ import pandas as pd +from helper import retain_target_column cleaned_sensor_features = pd.read_csv(snakemake.input["cleaned_sensor_features"]) +target_variable_name = snakemake.params["target_variable"] -column_names = cleaned_sensor_features.columns -esm_names_index = column_names.str.startswith("phone_esm_straw") -# Find all columns coming from phone_esm, since these are not features for our purposes and we will drop them. -esm_names = column_names[esm_names_index] - -target_variable_name = esm_names.str.contains(snakemake.params["target_variable"]) -if all(~target_variable_name): - raise ValueError("The requested target (", snakemake.params["target_variable"], ")cannot be found in the dataset.", - "Please check the names of phone_esm_ columns in all_sensor_features_cleaned_rapids.csv") - -model_input = cleaned_sensor_features.drop(esm_names, axis=1) -model_input["target"] = cleaned_sensor_features[esm_names[target_variable_name]] -# We will only keep one column related to phone_esm and that will be our target variable. -# Add it back to the very and of the data frame and rename it to target. +model_input = retain_target_column(cleaned_sensor_features, target_variable_name) model_input.to_csv(snakemake.output[0], index=False)