diff --git a/src/features/all_cleaning_overall/straw/main.py b/src/features/all_cleaning_overall/straw/main.py index b0fc74a2..5fba061e 100644 --- a/src/features/all_cleaning_overall/straw/main.py +++ b/src/features/all_cleaning_overall/straw/main.py @@ -96,18 +96,29 @@ def straw_cleaning(sensor_data_files, provider, target): impute_w_sn3 = [col for col in features.columns if "loglocationvariance" in col] features[impute_w_sn2] = features[impute_w_sn2].fillna(-1000000) # Special case of imputation - loglocation - # Impute selected phone features with 0 + impute ESM features with 0 + # Impute location features + impute_locations = [col for col in features \ + if col.startswith('phone_locations_doryab_') and + 'radiusgyration' not in col + ] + + # Impute selected phone, location, and esm features with 0 impute_zero = [col for col in features if \ col.startswith('phone_applications_foreground_rapids_') or + col.startswith('phone_activity_recognition_') or col.startswith('phone_battery_rapids_') or col.startswith('phone_bluetooth_rapids_') or col.startswith('phone_light_rapids_') or col.startswith('phone_calls_rapids_') or col.startswith('phone_messages_rapids_') or col.startswith('phone_screen_rapids_') or - col.startswith('phone_wifi_visible')] - - features[impute_zero+list(esm_cols.columns)] = features[impute_zero+list(esm_cols.columns)].fillna(0) + col.startswith('phone_bluetooth_doryab_') or + col.startswith('phone_wifi_visible') + ] + + features[impute_zero+impute_locations+list(esm_cols.columns)] = features[impute_zero+impute_locations+list(esm_cols.columns)].fillna(0) + + pd.set_option('display.max_rows', None) graph_bf_af(features, "4context_imp") @@ -138,15 +149,14 @@ def straw_cleaning(sensor_data_files, provider, target): if features.empty: return pd.DataFrame(columns=excluded_columns) - - # (7) STANDARDIZATION TODO: exclude nominal features from standardization - + # (7) STANDARDIZATION if provider["STANDARDIZATION"]: + nominal_cols = [col for col in features.columns if "mostcommonactivity" in col or "homelabel" in col] # Excluded nominal features # Expected warning within this code block with warnings.catch_warnings(): warnings.simplefilter("ignore", category=RuntimeWarning) - features.loc[:, ~features.columns.isin(excluded_columns + ["pid"])] = \ - features.loc[:, ~features.columns.isin(excluded_columns)].groupby('pid').transform(lambda x: StandardScaler().fit_transform(x.values[:,np.newaxis]).ravel()) + features.loc[:, ~features.columns.isin(excluded_columns + ["pid"] + nominal_cols)] = \ + features.loc[:, ~features.columns.isin(excluded_columns + nominal_cols)].groupby('pid').transform(lambda x: StandardScaler().fit_transform(x.values[:,np.newaxis]).ravel()) graph_bf_af(features, "8standardization")