Clean features and create input files based on all possible targets.

notes
Primoz 2022-10-06 14:28:12 +00:00
parent 1e38d9bf1e
commit 001d400729
7 changed files with 26 additions and 40 deletions

View File

@ -416,7 +416,8 @@ for provider in config["ALL_CLEANING_INDIVIDUAL"]["PROVIDERS"].keys():
for provider in config["ALL_CLEANING_OVERALL"]["PROVIDERS"].keys():
if config["ALL_CLEANING_OVERALL"]["PROVIDERS"][provider]["COMPUTE"]:
if provider == "STRAW":
files_to_compute.extend(expand("data/processed/features/all_participants/all_sensor_features_cleaned_" + provider.lower() +"_py.csv"))
for target in config["PARAMS_FOR_ANALYSIS"]["TARGET"]["ALL_LABELS"]:
files_to_compute.extend(expand("data/processed/features/all_participants/all_sensor_features_cleaned_" + provider.lower() +"_py_(" + target + ").csv"))
else:
files_to_compute.extend(expand("data/processed/features/all_participants/all_sensor_features_cleaned_" + provider.lower() +"_R.csv"))
@ -430,11 +431,8 @@ if config["PARAMS_FOR_ANALYSIS"]["BASELINE"]["COMPUTE"]:
# Targets (labels)
if config["PARAMS_FOR_ANALYSIS"]["TARGET"]["COMPUTE"]:
files_to_compute.extend(expand("data/processed/models/individual_model/{pid}/input.csv", pid=config["PIDS"]))
files_to_compute.extend(expand("data/processed/models/population_model/input.csv"))
# files_to_compute.extend(expand("data/processed/models/individual_model/{pid}/z_input.csv", pid=config["PIDS"]))
# files_to_compute.extend(expand("data/processed/models/population_model/z_input.csv"))
#files_to_compute.extend(expand("data/processed/models/individual_model/{pid}/output_{cv_method}/baselines.csv", pid=config["PIDS"], cv_method=config["PARAMS_FOR_ANALYSIS"]["CV_METHODS"]))
for target in config["PARAMS_FOR_ANALYSIS"]["TARGET"]["ALL_LABELS"]:
files_to_compute.extend(expand("data/processed/models/population_model/input_" + target + ".csv"))
rule all:
input:

View File

@ -729,3 +729,4 @@ PARAMS_FOR_ANALYSIS:
TARGET:
COMPUTE: True
LABEL: PANAS_negative_affect_mean
ALL_LABELS: [PANAS_positive_affect_mean, PANAS_negative_affect_mean, "JCQ_job_demand_mean", "JCQ_job_control_mean", "JCQ_supervisor_support_mean", "JCQ_coworker_support_mean"]

View File

@ -1008,8 +1008,9 @@ rule clean_sensor_features_for_all_participants:
provider = lambda wildcards: config["ALL_CLEANING_OVERALL"]["PROVIDERS"][wildcards.provider_key.upper()],
provider_key = "{provider_key}",
script_extension = "{script_extension}",
sensor_key = "all_cleaning_overall"
sensor_key = "all_cleaning_overall",
target = "{target}"
output:
"data/processed/features/all_participants/all_sensor_features_cleaned_{provider_key}_{script_extension}.csv"
"data/processed/features/all_participants/all_sensor_features_cleaned_{provider_key}_{script_extension}_({target}).csv"
script:
"../src/features/entry.{params.script_extension}"

View File

@ -40,33 +40,13 @@ rule select_target:
rule merge_features_and_targets_for_population_model:
input:
cleaned_sensor_features = "data/processed/features/all_participants/all_sensor_features_cleaned_straw_py.csv",
cleaned_sensor_features = "data/processed/features/all_participants/all_sensor_features_cleaned_straw_py_({target}).csv",
demographic_features = expand("data/processed/features/{pid}/baseline_features.csv", pid=config["PIDS"]),
params:
target_variable=config["PARAMS_FOR_ANALYSIS"]["TARGET"]["LABEL"]
target_variable="{target}"
output:
"data/processed/models/population_model/input.csv"
"data/processed/models/population_model/input_{target}.csv"
script:
"../src/models/merge_features_and_targets_for_population_model.py"
# rule select_target:
# input:
# cleaned_sensor_features = "data/processed/features/{pid}/all_sensor_features_cleaned_straw_py.csv"
# params:
# target_variable = config["PARAMS_FOR_ANALYSIS"]["TARGET"]["LABEL"]
# output:
# "data/processed/models/individual_model/{pid}/input.csv"
# script:
# "../src/models/select_targets.py"
# rule merge_features_and_targets_for_population_model:
# input:
# cleaned_sensor_features = "data/processed/features/all_participants/all_sensor_features_cleaned_straw_py.csv",
# demographic_features = expand("data/processed/features/{pid}/baseline_features.csv", pid=config["PIDS"]),
# params:
# target_variable=config["PARAMS_FOR_ANALYSIS"]["TARGET"]["LABEL"]
# output:
# "data/processed/models/population_model/input.csv"
# script:
# "../src/models/merge_features_and_targets_for_population_model.py"

View File

@ -10,7 +10,7 @@ import seaborn as sns
sys.path.append('/rapids/')
from src.features import empatica_data_yield as edy
def straw_cleaning(sensor_data_files, provider):
def straw_cleaning(sensor_data_files, provider, target):
features = pd.read_csv(sensor_data_files["sensor_data"][0])
@ -25,7 +25,7 @@ def straw_cleaning(sensor_data_files, provider):
# (1) FILTER_OUT THE ROWS THAT DO NOT HAVE THE TARGET COLUMN AVAILABLE
if config['PARAMS_FOR_ANALYSIS']['TARGET']['COMPUTE']:
target = config['PARAMS_FOR_ANALYSIS']['TARGET']['LABEL'] # get target label from config
# target = config['PARAMS_FOR_ANALYSIS']['TARGET']['LABEL'] # get target label from config
features = features[features['phone_esm_straw_' + target].notna()].reset_index(drop=True)
graph_bf_af(features, "2target_rows_after")
@ -170,9 +170,9 @@ def straw_cleaning(sensor_data_files, provider):
upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(np.bool))
to_drop = [column for column in upper.columns if any(upper[column] > drop_corr_features["CORR_THRESHOLD"])]
sns.heatmap(corr_matrix, cmap="YlGnBu")
plt.savefig(f'correlation_matrix.png', bbox_inches='tight')
plt.close()
# sns.heatmap(corr_matrix, cmap="YlGnBu")
# plt.savefig(f'correlation_matrix.png', bbox_inches='tight')
# plt.close()
s = corr_matrix.unstack()
so = s.sort_values(ascending=False)
@ -194,7 +194,6 @@ def straw_cleaning(sensor_data_files, provider):
if features.isna().any().any():
raise ValueError("There are still some NaNs present in the dataframe. Please check for implementation errors.")
sys.exit()
return features
def impute(df, method='zero'):

View File

@ -13,7 +13,10 @@ calc_windows = True if (provider.get("WINDOWS", False) and provider["WINDOWS"].g
if sensor_key == "all_cleaning_individual" or sensor_key == "all_cleaning_overall":
# Data cleaning
sensor_features = run_provider_cleaning_script(provider, provider_key, sensor_key, sensor_data_files)
if "overall" in sensor_key:
sensor_features = run_provider_cleaning_script(provider, provider_key, sensor_key, sensor_data_files, snakemake.params["target"])
else:
sensor_features = run_provider_cleaning_script(provider, provider_key, sensor_key, sensor_data_files)
else:
# Extract sensor features
del sensor_data_files["time_segments_labels"]

View File

@ -160,12 +160,16 @@ def fetch_provider_features(provider, provider_key, sensor_key, sensor_data_file
return sensor_features
def run_provider_cleaning_script(provider, provider_key, sensor_key, sensor_data_files):
def run_provider_cleaning_script(provider, provider_key, sensor_key, sensor_data_files, target=False):
from importlib import import_module, util
print("{} Processing {} {}".format(rapids_log_tag, sensor_key, provider_key))
cleaning_module = import_path(provider["SRC_SCRIPT"])
cleaning_function = getattr(cleaning_module, provider_key.lower() + "_cleaning")
sensor_features = cleaning_function(sensor_data_files, provider)
if target:
sensor_features = cleaning_function(sensor_data_files, provider, target)
else:
sensor_features = cleaning_function(sensor_data_files, provider)
return sensor_features