diff --git a/src/features/all_cleaning_individual/straw/main.py b/src/features/all_cleaning_individual/straw/main.py index 19ae71f5..5a2bca7d 100644 --- a/src/features/all_cleaning_individual/straw/main.py +++ b/src/features/all_cleaning_individual/straw/main.py @@ -100,14 +100,6 @@ def straw_cleaning(sensor_data_files, provider): sns.set(rc={"figure.figsize":(16, 8)}) sns.heatmap(features.isna(), cbar=False) plt.savefig(f'features_nans_bf_knn.png', bbox_inches='tight') - - # KNN IMPUTATION - features = impute(features, method="knn") - - sns.set(rc={"figure.figsize":(16, 8)}) - sns.heatmap(features.isna(), cbar=False) - plt.savefig(f'features_nans_af_knn.png', bbox_inches='tight') - ## STANDARDIZATION - should it happen before or after kNN imputation? # TODO: check if there are additional columns that need to be excluded from the standardization @@ -116,6 +108,15 @@ def straw_cleaning(sensor_data_files, provider): features.loc[:, ~features.columns.isin(excluded_columns)] = StandardScaler().fit_transform(features.loc[:, ~features.columns.isin(excluded_columns)]) + # KNN IMPUTATION + impute_cols = [col for col in features.columns if col not in ['local_segment', 'local_segment_label', 'local_segment_start_datetime', 'local_segment_end_datetime']] + features[impute_cols] = impute(features[impute_cols], method="knn") + + + sns.set(rc={"figure.figsize":(16, 8)}) + sns.heatmap(features.isna(), cbar=False) + plt.savefig(f'features_nans_af_knn.png', bbox_inches='tight') + # VERIFY IF THERE ARE ANY NANS LEFT IN THE DATAFRAME if features.isna.any().any(): raise ValueError