Clean features and create input files based on all possible targets.

notes
Primoz 2022-10-06 14:28:12 +00:00
parent 1e38d9bf1e
commit 001d400729
7 changed files with 26 additions and 40 deletions

View File

@ -416,7 +416,8 @@ for provider in config["ALL_CLEANING_INDIVIDUAL"]["PROVIDERS"].keys():
for provider in config["ALL_CLEANING_OVERALL"]["PROVIDERS"].keys(): for provider in config["ALL_CLEANING_OVERALL"]["PROVIDERS"].keys():
if config["ALL_CLEANING_OVERALL"]["PROVIDERS"][provider]["COMPUTE"]: if config["ALL_CLEANING_OVERALL"]["PROVIDERS"][provider]["COMPUTE"]:
if provider == "STRAW": if provider == "STRAW":
files_to_compute.extend(expand("data/processed/features/all_participants/all_sensor_features_cleaned_" + provider.lower() +"_py.csv")) for target in config["PARAMS_FOR_ANALYSIS"]["TARGET"]["ALL_LABELS"]:
files_to_compute.extend(expand("data/processed/features/all_participants/all_sensor_features_cleaned_" + provider.lower() +"_py_(" + target + ").csv"))
else: else:
files_to_compute.extend(expand("data/processed/features/all_participants/all_sensor_features_cleaned_" + provider.lower() +"_R.csv")) files_to_compute.extend(expand("data/processed/features/all_participants/all_sensor_features_cleaned_" + provider.lower() +"_R.csv"))
@ -430,11 +431,8 @@ if config["PARAMS_FOR_ANALYSIS"]["BASELINE"]["COMPUTE"]:
# Targets (labels) # Targets (labels)
if config["PARAMS_FOR_ANALYSIS"]["TARGET"]["COMPUTE"]: if config["PARAMS_FOR_ANALYSIS"]["TARGET"]["COMPUTE"]:
files_to_compute.extend(expand("data/processed/models/individual_model/{pid}/input.csv", pid=config["PIDS"])) files_to_compute.extend(expand("data/processed/models/individual_model/{pid}/input.csv", pid=config["PIDS"]))
files_to_compute.extend(expand("data/processed/models/population_model/input.csv")) for target in config["PARAMS_FOR_ANALYSIS"]["TARGET"]["ALL_LABELS"]:
# files_to_compute.extend(expand("data/processed/models/individual_model/{pid}/z_input.csv", pid=config["PIDS"])) files_to_compute.extend(expand("data/processed/models/population_model/input_" + target + ".csv"))
# files_to_compute.extend(expand("data/processed/models/population_model/z_input.csv"))
#files_to_compute.extend(expand("data/processed/models/individual_model/{pid}/output_{cv_method}/baselines.csv", pid=config["PIDS"], cv_method=config["PARAMS_FOR_ANALYSIS"]["CV_METHODS"]))
rule all: rule all:
input: input:

View File

@ -729,3 +729,4 @@ PARAMS_FOR_ANALYSIS:
TARGET: TARGET:
COMPUTE: True COMPUTE: True
LABEL: PANAS_negative_affect_mean LABEL: PANAS_negative_affect_mean
ALL_LABELS: [PANAS_positive_affect_mean, PANAS_negative_affect_mean, "JCQ_job_demand_mean", "JCQ_job_control_mean", "JCQ_supervisor_support_mean", "JCQ_coworker_support_mean"]

View File

@ -1008,8 +1008,9 @@ rule clean_sensor_features_for_all_participants:
provider = lambda wildcards: config["ALL_CLEANING_OVERALL"]["PROVIDERS"][wildcards.provider_key.upper()], provider = lambda wildcards: config["ALL_CLEANING_OVERALL"]["PROVIDERS"][wildcards.provider_key.upper()],
provider_key = "{provider_key}", provider_key = "{provider_key}",
script_extension = "{script_extension}", script_extension = "{script_extension}",
sensor_key = "all_cleaning_overall" sensor_key = "all_cleaning_overall",
target = "{target}"
output: output:
"data/processed/features/all_participants/all_sensor_features_cleaned_{provider_key}_{script_extension}.csv" "data/processed/features/all_participants/all_sensor_features_cleaned_{provider_key}_{script_extension}_({target}).csv"
script: script:
"../src/features/entry.{params.script_extension}" "../src/features/entry.{params.script_extension}"

View File

@ -40,33 +40,13 @@ rule select_target:
rule merge_features_and_targets_for_population_model: rule merge_features_and_targets_for_population_model:
input: input:
cleaned_sensor_features = "data/processed/features/all_participants/all_sensor_features_cleaned_straw_py.csv", cleaned_sensor_features = "data/processed/features/all_participants/all_sensor_features_cleaned_straw_py_({target}).csv",
demographic_features = expand("data/processed/features/{pid}/baseline_features.csv", pid=config["PIDS"]), demographic_features = expand("data/processed/features/{pid}/baseline_features.csv", pid=config["PIDS"]),
params: params:
target_variable=config["PARAMS_FOR_ANALYSIS"]["TARGET"]["LABEL"] target_variable="{target}"
output: output:
"data/processed/models/population_model/input.csv" "data/processed/models/population_model/input_{target}.csv"
script: script:
"../src/models/merge_features_and_targets_for_population_model.py" "../src/models/merge_features_and_targets_for_population_model.py"
# rule select_target:
# input:
# cleaned_sensor_features = "data/processed/features/{pid}/all_sensor_features_cleaned_straw_py.csv"
# params:
# target_variable = config["PARAMS_FOR_ANALYSIS"]["TARGET"]["LABEL"]
# output:
# "data/processed/models/individual_model/{pid}/input.csv"
# script:
# "../src/models/select_targets.py"
# rule merge_features_and_targets_for_population_model:
# input:
# cleaned_sensor_features = "data/processed/features/all_participants/all_sensor_features_cleaned_straw_py.csv",
# demographic_features = expand("data/processed/features/{pid}/baseline_features.csv", pid=config["PIDS"]),
# params:
# target_variable=config["PARAMS_FOR_ANALYSIS"]["TARGET"]["LABEL"]
# output:
# "data/processed/models/population_model/input.csv"
# script:
# "../src/models/merge_features_and_targets_for_population_model.py"

View File

@ -10,7 +10,7 @@ import seaborn as sns
sys.path.append('/rapids/') sys.path.append('/rapids/')
from src.features import empatica_data_yield as edy from src.features import empatica_data_yield as edy
def straw_cleaning(sensor_data_files, provider): def straw_cleaning(sensor_data_files, provider, target):
features = pd.read_csv(sensor_data_files["sensor_data"][0]) features = pd.read_csv(sensor_data_files["sensor_data"][0])
@ -25,7 +25,7 @@ def straw_cleaning(sensor_data_files, provider):
# (1) FILTER_OUT THE ROWS THAT DO NOT HAVE THE TARGET COLUMN AVAILABLE # (1) FILTER_OUT THE ROWS THAT DO NOT HAVE THE TARGET COLUMN AVAILABLE
if config['PARAMS_FOR_ANALYSIS']['TARGET']['COMPUTE']: if config['PARAMS_FOR_ANALYSIS']['TARGET']['COMPUTE']:
target = config['PARAMS_FOR_ANALYSIS']['TARGET']['LABEL'] # get target label from config # target = config['PARAMS_FOR_ANALYSIS']['TARGET']['LABEL'] # get target label from config
features = features[features['phone_esm_straw_' + target].notna()].reset_index(drop=True) features = features[features['phone_esm_straw_' + target].notna()].reset_index(drop=True)
graph_bf_af(features, "2target_rows_after") graph_bf_af(features, "2target_rows_after")
@ -170,9 +170,9 @@ def straw_cleaning(sensor_data_files, provider):
upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(np.bool)) upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(np.bool))
to_drop = [column for column in upper.columns if any(upper[column] > drop_corr_features["CORR_THRESHOLD"])] to_drop = [column for column in upper.columns if any(upper[column] > drop_corr_features["CORR_THRESHOLD"])]
sns.heatmap(corr_matrix, cmap="YlGnBu") # sns.heatmap(corr_matrix, cmap="YlGnBu")
plt.savefig(f'correlation_matrix.png', bbox_inches='tight') # plt.savefig(f'correlation_matrix.png', bbox_inches='tight')
plt.close() # plt.close()
s = corr_matrix.unstack() s = corr_matrix.unstack()
so = s.sort_values(ascending=False) so = s.sort_values(ascending=False)
@ -194,7 +194,6 @@ def straw_cleaning(sensor_data_files, provider):
if features.isna().any().any(): if features.isna().any().any():
raise ValueError("There are still some NaNs present in the dataframe. Please check for implementation errors.") raise ValueError("There are still some NaNs present in the dataframe. Please check for implementation errors.")
sys.exit()
return features return features
def impute(df, method='zero'): def impute(df, method='zero'):

View File

@ -13,7 +13,10 @@ calc_windows = True if (provider.get("WINDOWS", False) and provider["WINDOWS"].g
if sensor_key == "all_cleaning_individual" or sensor_key == "all_cleaning_overall": if sensor_key == "all_cleaning_individual" or sensor_key == "all_cleaning_overall":
# Data cleaning # Data cleaning
sensor_features = run_provider_cleaning_script(provider, provider_key, sensor_key, sensor_data_files) if "overall" in sensor_key:
sensor_features = run_provider_cleaning_script(provider, provider_key, sensor_key, sensor_data_files, snakemake.params["target"])
else:
sensor_features = run_provider_cleaning_script(provider, provider_key, sensor_key, sensor_data_files)
else: else:
# Extract sensor features # Extract sensor features
del sensor_data_files["time_segments_labels"] del sensor_data_files["time_segments_labels"]

View File

@ -160,12 +160,16 @@ def fetch_provider_features(provider, provider_key, sensor_key, sensor_data_file
return sensor_features return sensor_features
def run_provider_cleaning_script(provider, provider_key, sensor_key, sensor_data_files): def run_provider_cleaning_script(provider, provider_key, sensor_key, sensor_data_files, target=False):
from importlib import import_module, util from importlib import import_module, util
print("{} Processing {} {}".format(rapids_log_tag, sensor_key, provider_key)) print("{} Processing {} {}".format(rapids_log_tag, sensor_key, provider_key))
cleaning_module = import_path(provider["SRC_SCRIPT"]) cleaning_module = import_path(provider["SRC_SCRIPT"])
cleaning_function = getattr(cleaning_module, provider_key.lower() + "_cleaning") cleaning_function = getattr(cleaning_module, provider_key.lower() + "_cleaning")
sensor_features = cleaning_function(sensor_data_files, provider)
if target:
sensor_features = cleaning_function(sensor_data_files, provider, target)
else:
sensor_features = cleaning_function(sensor_data_files, provider)
return sensor_features return sensor_features