diff --git a/config.yaml b/config.yaml index fb717f5f..61344d29 100644 --- a/config.yaml +++ b/config.yaml @@ -710,7 +710,8 @@ ALL_CLEANING_OVERALL: COMPUTE: True MIN_OVERLAP_FOR_CORR_THRESHOLD: 0.5 CORR_THRESHOLD: 0.95 - STANDARDIZATION: False + STANDARDIZATION: True + TARGET_STANDARDIZATION: False SRC_SCRIPT: src/features/all_cleaning_overall/straw/main.py diff --git a/src/features/all_cleaning_overall/straw/main.py b/src/features/all_cleaning_overall/straw/main.py index 1362358d..197c285d 100644 --- a/src/features/all_cleaning_overall/straw/main.py +++ b/src/features/all_cleaning_overall/straw/main.py @@ -169,8 +169,12 @@ def straw_cleaning(sensor_data_files, provider, target): # Expected warning within this code block with warnings.catch_warnings(): warnings.simplefilter("ignore", category=RuntimeWarning) - features.loc[:, ~features.columns.isin(excluded_columns + ["pid"] + nominal_cols)] = \ - features.loc[:, ~features.columns.isin(excluded_columns + nominal_cols)].groupby('pid').transform(lambda x: StandardScaler().fit_transform(x.values[:,np.newaxis]).ravel()) + if provider["TARGET_STANDARDIZATION"]: + features.loc[:, ~features.columns.isin(excluded_columns + ["pid"] + nominal_cols)] = \ + features.loc[:, ~features.columns.isin(excluded_columns + nominal_cols)].groupby('pid').transform(lambda x: StandardScaler().fit_transform(x.values[:,np.newaxis]).ravel()) + else: + features.loc[:, ~features.columns.isin(excluded_columns + ["pid"] + nominal_cols + ['phone_esm_straw_' + target])] = \ + features.loc[:, ~features.columns.isin(excluded_columns + nominal_cols + ['phone_esm_straw_' + target])].groupby('pid').transform(lambda x: StandardScaler().fit_transform(x.values[:,np.newaxis]).ravel()) graph_bf_af(features, "8standardization")