Clean features and create input files based on all possible targets.
parent
1e38d9bf1e
commit
001d400729
10
Snakefile
10
Snakefile
|
@ -416,7 +416,8 @@ for provider in config["ALL_CLEANING_INDIVIDUAL"]["PROVIDERS"].keys():
|
|||
for provider in config["ALL_CLEANING_OVERALL"]["PROVIDERS"].keys():
|
||||
if config["ALL_CLEANING_OVERALL"]["PROVIDERS"][provider]["COMPUTE"]:
|
||||
if provider == "STRAW":
|
||||
files_to_compute.extend(expand("data/processed/features/all_participants/all_sensor_features_cleaned_" + provider.lower() +"_py.csv"))
|
||||
for target in config["PARAMS_FOR_ANALYSIS"]["TARGET"]["ALL_LABELS"]:
|
||||
files_to_compute.extend(expand("data/processed/features/all_participants/all_sensor_features_cleaned_" + provider.lower() +"_py_(" + target + ").csv"))
|
||||
else:
|
||||
files_to_compute.extend(expand("data/processed/features/all_participants/all_sensor_features_cleaned_" + provider.lower() +"_R.csv"))
|
||||
|
||||
|
@ -430,11 +431,8 @@ if config["PARAMS_FOR_ANALYSIS"]["BASELINE"]["COMPUTE"]:
|
|||
# Targets (labels)
|
||||
if config["PARAMS_FOR_ANALYSIS"]["TARGET"]["COMPUTE"]:
|
||||
files_to_compute.extend(expand("data/processed/models/individual_model/{pid}/input.csv", pid=config["PIDS"]))
|
||||
files_to_compute.extend(expand("data/processed/models/population_model/input.csv"))
|
||||
# files_to_compute.extend(expand("data/processed/models/individual_model/{pid}/z_input.csv", pid=config["PIDS"]))
|
||||
# files_to_compute.extend(expand("data/processed/models/population_model/z_input.csv"))
|
||||
|
||||
#files_to_compute.extend(expand("data/processed/models/individual_model/{pid}/output_{cv_method}/baselines.csv", pid=config["PIDS"], cv_method=config["PARAMS_FOR_ANALYSIS"]["CV_METHODS"]))
|
||||
for target in config["PARAMS_FOR_ANALYSIS"]["TARGET"]["ALL_LABELS"]:
|
||||
files_to_compute.extend(expand("data/processed/models/population_model/input_" + target + ".csv"))
|
||||
|
||||
rule all:
|
||||
input:
|
||||
|
|
|
@ -729,3 +729,4 @@ PARAMS_FOR_ANALYSIS:
|
|||
TARGET:
|
||||
COMPUTE: True
|
||||
LABEL: PANAS_negative_affect_mean
|
||||
ALL_LABELS: [PANAS_positive_affect_mean, PANAS_negative_affect_mean, "JCQ_job_demand_mean", "JCQ_job_control_mean", "JCQ_supervisor_support_mean", "JCQ_coworker_support_mean"]
|
||||
|
|
|
@ -1008,8 +1008,9 @@ rule clean_sensor_features_for_all_participants:
|
|||
provider = lambda wildcards: config["ALL_CLEANING_OVERALL"]["PROVIDERS"][wildcards.provider_key.upper()],
|
||||
provider_key = "{provider_key}",
|
||||
script_extension = "{script_extension}",
|
||||
sensor_key = "all_cleaning_overall"
|
||||
sensor_key = "all_cleaning_overall",
|
||||
target = "{target}"
|
||||
output:
|
||||
"data/processed/features/all_participants/all_sensor_features_cleaned_{provider_key}_{script_extension}.csv"
|
||||
"data/processed/features/all_participants/all_sensor_features_cleaned_{provider_key}_{script_extension}_({target}).csv"
|
||||
script:
|
||||
"../src/features/entry.{params.script_extension}"
|
||||
|
|
|
@ -40,33 +40,13 @@ rule select_target:
|
|||
|
||||
rule merge_features_and_targets_for_population_model:
|
||||
input:
|
||||
cleaned_sensor_features = "data/processed/features/all_participants/all_sensor_features_cleaned_straw_py.csv",
|
||||
cleaned_sensor_features = "data/processed/features/all_participants/all_sensor_features_cleaned_straw_py_({target}).csv",
|
||||
demographic_features = expand("data/processed/features/{pid}/baseline_features.csv", pid=config["PIDS"]),
|
||||
params:
|
||||
target_variable=config["PARAMS_FOR_ANALYSIS"]["TARGET"]["LABEL"]
|
||||
target_variable="{target}"
|
||||
output:
|
||||
"data/processed/models/population_model/input.csv"
|
||||
"data/processed/models/population_model/input_{target}.csv"
|
||||
script:
|
||||
"../src/models/merge_features_and_targets_for_population_model.py"
|
||||
|
||||
# rule select_target:
|
||||
# input:
|
||||
# cleaned_sensor_features = "data/processed/features/{pid}/all_sensor_features_cleaned_straw_py.csv"
|
||||
# params:
|
||||
# target_variable = config["PARAMS_FOR_ANALYSIS"]["TARGET"]["LABEL"]
|
||||
# output:
|
||||
# "data/processed/models/individual_model/{pid}/input.csv"
|
||||
# script:
|
||||
# "../src/models/select_targets.py"
|
||||
|
||||
# rule merge_features_and_targets_for_population_model:
|
||||
# input:
|
||||
# cleaned_sensor_features = "data/processed/features/all_participants/all_sensor_features_cleaned_straw_py.csv",
|
||||
# demographic_features = expand("data/processed/features/{pid}/baseline_features.csv", pid=config["PIDS"]),
|
||||
# params:
|
||||
# target_variable=config["PARAMS_FOR_ANALYSIS"]["TARGET"]["LABEL"]
|
||||
# output:
|
||||
# "data/processed/models/population_model/input.csv"
|
||||
# script:
|
||||
# "../src/models/merge_features_and_targets_for_population_model.py"
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ import seaborn as sns
|
|||
sys.path.append('/rapids/')
|
||||
from src.features import empatica_data_yield as edy
|
||||
|
||||
def straw_cleaning(sensor_data_files, provider):
|
||||
def straw_cleaning(sensor_data_files, provider, target):
|
||||
|
||||
features = pd.read_csv(sensor_data_files["sensor_data"][0])
|
||||
|
||||
|
@ -25,7 +25,7 @@ def straw_cleaning(sensor_data_files, provider):
|
|||
|
||||
# (1) FILTER_OUT THE ROWS THAT DO NOT HAVE THE TARGET COLUMN AVAILABLE
|
||||
if config['PARAMS_FOR_ANALYSIS']['TARGET']['COMPUTE']:
|
||||
target = config['PARAMS_FOR_ANALYSIS']['TARGET']['LABEL'] # get target label from config
|
||||
# target = config['PARAMS_FOR_ANALYSIS']['TARGET']['LABEL'] # get target label from config
|
||||
features = features[features['phone_esm_straw_' + target].notna()].reset_index(drop=True)
|
||||
|
||||
graph_bf_af(features, "2target_rows_after")
|
||||
|
@ -170,9 +170,9 @@ def straw_cleaning(sensor_data_files, provider):
|
|||
upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(np.bool))
|
||||
to_drop = [column for column in upper.columns if any(upper[column] > drop_corr_features["CORR_THRESHOLD"])]
|
||||
|
||||
sns.heatmap(corr_matrix, cmap="YlGnBu")
|
||||
plt.savefig(f'correlation_matrix.png', bbox_inches='tight')
|
||||
plt.close()
|
||||
# sns.heatmap(corr_matrix, cmap="YlGnBu")
|
||||
# plt.savefig(f'correlation_matrix.png', bbox_inches='tight')
|
||||
# plt.close()
|
||||
|
||||
s = corr_matrix.unstack()
|
||||
so = s.sort_values(ascending=False)
|
||||
|
@ -194,7 +194,6 @@ def straw_cleaning(sensor_data_files, provider):
|
|||
if features.isna().any().any():
|
||||
raise ValueError("There are still some NaNs present in the dataframe. Please check for implementation errors.")
|
||||
|
||||
sys.exit()
|
||||
return features
|
||||
|
||||
def impute(df, method='zero'):
|
||||
|
|
|
@ -13,7 +13,10 @@ calc_windows = True if (provider.get("WINDOWS", False) and provider["WINDOWS"].g
|
|||
|
||||
if sensor_key == "all_cleaning_individual" or sensor_key == "all_cleaning_overall":
|
||||
# Data cleaning
|
||||
sensor_features = run_provider_cleaning_script(provider, provider_key, sensor_key, sensor_data_files)
|
||||
if "overall" in sensor_key:
|
||||
sensor_features = run_provider_cleaning_script(provider, provider_key, sensor_key, sensor_data_files, snakemake.params["target"])
|
||||
else:
|
||||
sensor_features = run_provider_cleaning_script(provider, provider_key, sensor_key, sensor_data_files)
|
||||
else:
|
||||
# Extract sensor features
|
||||
del sensor_data_files["time_segments_labels"]
|
||||
|
|
|
@ -160,12 +160,16 @@ def fetch_provider_features(provider, provider_key, sensor_key, sensor_data_file
|
|||
|
||||
return sensor_features
|
||||
|
||||
def run_provider_cleaning_script(provider, provider_key, sensor_key, sensor_data_files):
|
||||
def run_provider_cleaning_script(provider, provider_key, sensor_key, sensor_data_files, target=False):
|
||||
from importlib import import_module, util
|
||||
print("{} Processing {} {}".format(rapids_log_tag, sensor_key, provider_key))
|
||||
|
||||
cleaning_module = import_path(provider["SRC_SCRIPT"])
|
||||
cleaning_function = getattr(cleaning_module, provider_key.lower() + "_cleaning")
|
||||
sensor_features = cleaning_function(sensor_data_files, provider)
|
||||
|
||||
if target:
|
||||
sensor_features = cleaning_function(sensor_data_files, provider, target)
|
||||
else:
|
||||
sensor_features = cleaning_function(sensor_data_files, provider)
|
||||
|
||||
return sensor_features
|
Loading…
Reference in New Issue