diff --git a/Snakefile b/Snakefile index 1e9a1fe8..44e9b5bf 100644 --- a/Snakefile +++ b/Snakefile @@ -416,7 +416,8 @@ for provider in config["ALL_CLEANING_INDIVIDUAL"]["PROVIDERS"].keys(): for provider in config["ALL_CLEANING_OVERALL"]["PROVIDERS"].keys(): if config["ALL_CLEANING_OVERALL"]["PROVIDERS"][provider]["COMPUTE"]: if provider == "STRAW": - files_to_compute.extend(expand("data/processed/features/all_participants/all_sensor_features_cleaned_" + provider.lower() +"_py.csv")) + for target in config["PARAMS_FOR_ANALYSIS"]["TARGET"]["ALL_LABELS"]: + files_to_compute.extend(expand("data/processed/features/all_participants/all_sensor_features_cleaned_" + provider.lower() +"_py_(" + target + ").csv")) else: files_to_compute.extend(expand("data/processed/features/all_participants/all_sensor_features_cleaned_" + provider.lower() +"_R.csv")) @@ -430,11 +431,8 @@ if config["PARAMS_FOR_ANALYSIS"]["BASELINE"]["COMPUTE"]: # Targets (labels) if config["PARAMS_FOR_ANALYSIS"]["TARGET"]["COMPUTE"]: files_to_compute.extend(expand("data/processed/models/individual_model/{pid}/input.csv", pid=config["PIDS"])) - files_to_compute.extend(expand("data/processed/models/population_model/input.csv")) - # files_to_compute.extend(expand("data/processed/models/individual_model/{pid}/z_input.csv", pid=config["PIDS"])) - # files_to_compute.extend(expand("data/processed/models/population_model/z_input.csv")) - -#files_to_compute.extend(expand("data/processed/models/individual_model/{pid}/output_{cv_method}/baselines.csv", pid=config["PIDS"], cv_method=config["PARAMS_FOR_ANALYSIS"]["CV_METHODS"])) + for target in config["PARAMS_FOR_ANALYSIS"]["TARGET"]["ALL_LABELS"]: + files_to_compute.extend(expand("data/processed/models/population_model/input_" + target + ".csv")) rule all: input: diff --git a/config.yaml b/config.yaml index e50a5565..1b6fe038 100644 --- a/config.yaml +++ b/config.yaml @@ -729,3 +729,4 @@ PARAMS_FOR_ANALYSIS: TARGET: COMPUTE: True LABEL: PANAS_negative_affect_mean + ALL_LABELS: [PANAS_positive_affect_mean, PANAS_negative_affect_mean, "JCQ_job_demand_mean", "JCQ_job_control_mean", "JCQ_supervisor_support_mean", "JCQ_coworker_support_mean"] diff --git a/rules/features.smk b/rules/features.smk index 6aa2c150..2638a8f3 100644 --- a/rules/features.smk +++ b/rules/features.smk @@ -1008,8 +1008,9 @@ rule clean_sensor_features_for_all_participants: provider = lambda wildcards: config["ALL_CLEANING_OVERALL"]["PROVIDERS"][wildcards.provider_key.upper()], provider_key = "{provider_key}", script_extension = "{script_extension}", - sensor_key = "all_cleaning_overall" + sensor_key = "all_cleaning_overall", + target = "{target}" output: - "data/processed/features/all_participants/all_sensor_features_cleaned_{provider_key}_{script_extension}.csv" + "data/processed/features/all_participants/all_sensor_features_cleaned_{provider_key}_{script_extension}_({target}).csv" script: "../src/features/entry.{params.script_extension}" diff --git a/rules/models.smk b/rules/models.smk index cc9b406d..2875297b 100644 --- a/rules/models.smk +++ b/rules/models.smk @@ -40,33 +40,13 @@ rule select_target: rule merge_features_and_targets_for_population_model: input: - cleaned_sensor_features = "data/processed/features/all_participants/all_sensor_features_cleaned_straw_py.csv", + cleaned_sensor_features = "data/processed/features/all_participants/all_sensor_features_cleaned_straw_py_({target}).csv", demographic_features = expand("data/processed/features/{pid}/baseline_features.csv", pid=config["PIDS"]), params: - target_variable=config["PARAMS_FOR_ANALYSIS"]["TARGET"]["LABEL"] + target_variable="{target}" output: - "data/processed/models/population_model/input.csv" + "data/processed/models/population_model/input_{target}.csv" script: "../src/models/merge_features_and_targets_for_population_model.py" -# rule select_target: -# input: -# cleaned_sensor_features = "data/processed/features/{pid}/all_sensor_features_cleaned_straw_py.csv" -# params: -# target_variable = config["PARAMS_FOR_ANALYSIS"]["TARGET"]["LABEL"] -# output: -# "data/processed/models/individual_model/{pid}/input.csv" -# script: -# "../src/models/select_targets.py" - -# rule merge_features_and_targets_for_population_model: -# input: -# cleaned_sensor_features = "data/processed/features/all_participants/all_sensor_features_cleaned_straw_py.csv", -# demographic_features = expand("data/processed/features/{pid}/baseline_features.csv", pid=config["PIDS"]), -# params: -# target_variable=config["PARAMS_FOR_ANALYSIS"]["TARGET"]["LABEL"] -# output: -# "data/processed/models/population_model/input.csv" -# script: -# "../src/models/merge_features_and_targets_for_population_model.py" diff --git a/src/features/all_cleaning_overall/straw/main.py b/src/features/all_cleaning_overall/straw/main.py index 9324424d..71151608 100644 --- a/src/features/all_cleaning_overall/straw/main.py +++ b/src/features/all_cleaning_overall/straw/main.py @@ -10,7 +10,7 @@ import seaborn as sns sys.path.append('/rapids/') from src.features import empatica_data_yield as edy -def straw_cleaning(sensor_data_files, provider): +def straw_cleaning(sensor_data_files, provider, target): features = pd.read_csv(sensor_data_files["sensor_data"][0]) @@ -25,7 +25,7 @@ def straw_cleaning(sensor_data_files, provider): # (1) FILTER_OUT THE ROWS THAT DO NOT HAVE THE TARGET COLUMN AVAILABLE if config['PARAMS_FOR_ANALYSIS']['TARGET']['COMPUTE']: - target = config['PARAMS_FOR_ANALYSIS']['TARGET']['LABEL'] # get target label from config + # target = config['PARAMS_FOR_ANALYSIS']['TARGET']['LABEL'] # get target label from config features = features[features['phone_esm_straw_' + target].notna()].reset_index(drop=True) graph_bf_af(features, "2target_rows_after") @@ -170,9 +170,9 @@ def straw_cleaning(sensor_data_files, provider): upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(np.bool)) to_drop = [column for column in upper.columns if any(upper[column] > drop_corr_features["CORR_THRESHOLD"])] - sns.heatmap(corr_matrix, cmap="YlGnBu") - plt.savefig(f'correlation_matrix.png', bbox_inches='tight') - plt.close() + # sns.heatmap(corr_matrix, cmap="YlGnBu") + # plt.savefig(f'correlation_matrix.png', bbox_inches='tight') + # plt.close() s = corr_matrix.unstack() so = s.sort_values(ascending=False) @@ -194,7 +194,6 @@ def straw_cleaning(sensor_data_files, provider): if features.isna().any().any(): raise ValueError("There are still some NaNs present in the dataframe. Please check for implementation errors.") - sys.exit() return features def impute(df, method='zero'): diff --git a/src/features/entry.py b/src/features/entry.py index 288ba168..2b995fc7 100644 --- a/src/features/entry.py +++ b/src/features/entry.py @@ -13,7 +13,10 @@ calc_windows = True if (provider.get("WINDOWS", False) and provider["WINDOWS"].g if sensor_key == "all_cleaning_individual" or sensor_key == "all_cleaning_overall": # Data cleaning - sensor_features = run_provider_cleaning_script(provider, provider_key, sensor_key, sensor_data_files) + if "overall" in sensor_key: + sensor_features = run_provider_cleaning_script(provider, provider_key, sensor_key, sensor_data_files, snakemake.params["target"]) + else: + sensor_features = run_provider_cleaning_script(provider, provider_key, sensor_key, sensor_data_files) else: # Extract sensor features del sensor_data_files["time_segments_labels"] diff --git a/src/features/utils/utils.py b/src/features/utils/utils.py index 7303ac86..8a4d2130 100644 --- a/src/features/utils/utils.py +++ b/src/features/utils/utils.py @@ -160,12 +160,16 @@ def fetch_provider_features(provider, provider_key, sensor_key, sensor_data_file return sensor_features -def run_provider_cleaning_script(provider, provider_key, sensor_key, sensor_data_files): +def run_provider_cleaning_script(provider, provider_key, sensor_key, sensor_data_files, target=False): from importlib import import_module, util print("{} Processing {} {}".format(rapids_log_tag, sensor_key, provider_key)) cleaning_module = import_path(provider["SRC_SCRIPT"]) cleaning_function = getattr(cleaning_module, provider_key.lower() + "_cleaning") - sensor_features = cleaning_function(sensor_data_files, provider) + + if target: + sensor_features = cleaning_function(sensor_data_files, provider, target) + else: + sensor_features = cleaning_function(sensor_data_files, provider) return sensor_features \ No newline at end of file