diff --git a/src/features/all_cleaning_individual/straw/main.py b/src/features/all_cleaning_individual/straw/main.py index 0b0259f7..5557a2dc 100644 --- a/src/features/all_cleaning_individual/straw/main.py +++ b/src/features/all_cleaning_individual/straw/main.py @@ -17,12 +17,13 @@ def straw_cleaning(sensor_data_files, provider): with open('config.yaml', 'r') as stream: config = yaml.load(stream, Loader=yaml.FullLoader) + excluded_columns = ['local_segment', 'local_segment_label', 'local_segment_start_datetime', 'local_segment_end_datetime'] + # (1) FILTER_OUT THE ROWS THAT DO NOT HAVE THE TARGET COLUMN AVAILABLE if config['PARAMS_FOR_ANALYSIS']['TARGET']['COMPUTE']: target = config['PARAMS_FOR_ANALYSIS']['TARGET']['LABEL'] # get target label from config - features = features[features['phone_esm_straw_' + target].notna()].reset_index() + features = features[features['phone_esm_straw_' + target].notna()].reset_index(drop=True) - # TODO: reorder the cleaning steps so it makes sense for the analysis # TODO: add conditions that differentiates cleaning steps for standardized and nonstandardized features, for this # the snakemake rules will also have to come with additional parameter (in rules/features.smk) @@ -64,10 +65,10 @@ def straw_cleaning(sensor_data_files, provider): if provider["DATA_YIELD_RATIO_THRESHOLD"]: features = features[features[data_yield_column] >= provider["DATA_YIELD_RATIO_THRESHOLD"]] - # (3) REMOVE COLS IF THEIR NAN THRESHOLD IS PASSED (should be <= if even all NaN columns must be preserved) + # (3) REMOVE COLS IF THEIR NAN THRESHOLD IS PASSED (should be <= if even all NaN columns must be preserved - this solution now drops columns with all NaN rows) features = features.loc[:, features.isna().sum() < provider["COLS_NAN_THRESHOLD"] * features.shape[0]] - # (4) REMOVE COLS WHERE VARIANCE IS 0 TODO: preveri za local_segment stolpce + # (4) REMOVE COLS WHERE VARIANCE IS 0 if provider["COLS_VAR_THRESHOLD"]: features.drop(features.std()[features.std() == 0].index.values, axis=1, inplace=True) @@ -91,31 +92,35 @@ def straw_cleaning(sensor_data_files, provider): features.drop(to_drop, axis=1, inplace=True) - # Remove rows if threshold of NaN values is passed + # (6) Remove rows if threshold of NaN values is passed min_count = math.ceil((1 - provider["ROWS_NAN_THRESHOLD"]) * features.shape[1]) # minimal not nan values in row features.dropna(axis=0, thresh=min_count, inplace=True) + sns.set(rc={"figure.figsize":(16, 8)}) sns.heatmap(features.isna(), cbar=False) plt.savefig(f'features_nans_bf_knn.png', bbox_inches='tight') - ## STANDARDIZATION - should it happen before or after kNN imputation? - # TODO: check if there are additional columns that need to be excluded from the standardization - excluded_columns = ['local_segment', 'local_segment_label', 'local_segment_start_datetime', 'local_segment_end_datetime'] + ## (7) STANDARDIZATION if provider["STANDARDIZATION"]: features.loc[:, ~features.columns.isin(excluded_columns)] = StandardScaler().fit_transform(features.loc[:, ~features.columns.isin(excluded_columns)]) - # KNN IMPUTATION + # (8) KNN IMPUTATION impute_cols = [col for col in features.columns if col not in excluded_columns] features[impute_cols] = impute(features[impute_cols], method="knn") + # (9) STANDARDIZATION AGAIN + + if provider["STANDARDIZATION"]: + features.loc[:, ~features.columns.isin(excluded_columns)] = StandardScaler().fit_transform(features.loc[:, ~features.columns.isin(excluded_columns)]) + sns.set(rc={"figure.figsize":(16, 8)}) sns.heatmap(features.isna(), cbar=False) plt.savefig(f'features_nans_af_knn.png', bbox_inches='tight') - # VERIFY IF THERE ARE ANY NANS LEFT IN THE DATAFRAME + # (9) VERIFY IF THERE ARE ANY NANS LEFT IN THE DATAFRAME if features.isna().any().any(): raise ValueError diff --git a/tests/scripts/missing_vals.py b/tests/scripts/missing_vals.py index acbae0bb..41cf4709 100644 --- a/tests/scripts/missing_vals.py +++ b/tests/scripts/missing_vals.py @@ -3,8 +3,8 @@ import seaborn as sns import matplotlib.pyplot as plt -participant = "p031" -all_sensors = ["eda", "bvp", "ibi", "temp", "acc"] +participant = "p01" +all_sensors = ["eda", "ibi", "temp", "acc"] for sensor in all_sensors: