2022-07-20 15:51:22 +02:00
import pandas as pd
import numpy as np
import math , sys
2022-09-01 12:33:36 +02:00
import yaml
2022-08-31 12:18:50 +02:00
from sklearn . impute import KNNImputer
2022-09-12 15:44:17 +02:00
from sklearn . preprocessing import StandardScaler
2022-09-01 12:33:36 +02:00
import matplotlib . pyplot as plt
import seaborn as sns
2022-07-20 15:51:22 +02:00
2022-09-27 11:54:15 +02:00
sys . path . append ( ' /rapids/ ' )
2022-09-26 17:54:00 +02:00
from src . features import empatica_data_yield as edy
2022-09-01 12:33:36 +02:00
def straw_cleaning ( sensor_data_files , provider ) :
2022-07-20 15:51:22 +02:00
features = pd . read_csv ( sensor_data_files [ " sensor_data " ] [ 0 ] )
2022-09-12 15:44:17 +02:00
2022-09-01 12:33:36 +02:00
esm_cols = features . loc [ : , features . columns . str . startswith ( ' phone_esm_straw ' ) ] # Get target (esm) columns
with open ( ' config.yaml ' , ' r ' ) as stream :
config = yaml . load ( stream , Loader = yaml . FullLoader )
2022-07-20 15:51:22 +02:00
2022-09-20 14:57:55 +02:00
excluded_columns = [ ' local_segment ' , ' local_segment_label ' , ' local_segment_start_datetime ' , ' local_segment_end_datetime ' ]
2022-09-12 15:44:17 +02:00
# (1) FILTER_OUT THE ROWS THAT DO NOT HAVE THE TARGET COLUMN AVAILABLE
2022-09-27 16:12:08 +02:00
if config [ ' PARAMS_FOR_ANALYSIS ' ] [ ' TARGET ' ] [ ' COMPUTE ' ] :
target = config [ ' PARAMS_FOR_ANALYSIS ' ] [ ' TARGET ' ] [ ' LABEL ' ] # get target label from config
features = features [ features [ ' phone_esm_straw_ ' + target ] . notna ( ) ] . reset_index ( drop = True )
2022-09-01 12:33:36 +02:00
2022-09-27 16:12:08 +02:00
# (2.1) QUALITY CHECK (DATA YIELD COLUMN) deletes the rows where E4 or phone data is low quality
2022-09-23 15:24:50 +02:00
phone_data_yield_unit = provider [ " PHONE_DATA_YIELD_FEATURE " ] . split ( " _ " ) [ 3 ] . lower ( )
phone_data_yield_column = " phone_data_yield_rapids_ratiovalidyielded " + phone_data_yield_unit
2022-09-26 17:54:00 +02:00
features = edy . calculate_empatica_data_yield ( features )
2022-09-23 15:24:50 +02:00
2022-09-26 17:54:00 +02:00
if not phone_data_yield_column in features . columns and not " empatica_data_yield " in features . columns :
2022-09-23 15:24:50 +02:00
raise KeyError ( f " RAPIDS provider needs to clean the selected event features based on { phone_data_yield_column } column, please set config[PHONE_DATA_YIELD][PROVIDERS][RAPIDS][COMPUTE] to True and include ' ratiovalidyielded { data_yield_unit } ' in [FEATURES]. " )
if provider [ " PHONE_DATA_YIELD_RATIO_THRESHOLD " ] :
2022-09-27 16:12:08 +02:00
features = features [ features [ phone_data_yield_column ] > = provider [ " PHONE_DATA_YIELD_RATIO_THRESHOLD " ] ] . reset_index ( drop = True )
2022-09-26 17:54:00 +02:00
2022-09-23 15:24:50 +02:00
if provider [ " EMPATICA_DATA_YIELD_RATIO_THRESHOLD " ] :
2022-09-27 16:12:08 +02:00
features = features [ features [ " empatica_data_yield " ] > = provider [ " EMPATICA_DATA_YIELD_RATIO_THRESHOLD " ] ] . reset_index ( drop = True )
2022-09-22 15:45:51 +02:00
2022-09-23 15:24:50 +02:00
# ---> imputation ??
2022-09-26 17:54:00 +02:00
# impute_phone_features = provider["IMPUTE_PHONE_SELECTED_EVENT_FEATURES"]
2022-09-22 15:45:51 +02:00
2022-09-26 17:54:00 +02:00
# if True: #impute_phone_features["COMPUTE"]:
# if not 'phone_data_yield_rapids_ratiovalidyieldedminutes' in features.columns:
# raise KeyError("RAPIDS provider needs to impute the selected event features based on phone_data_yield_rapids_ratiovalidyieldedminutes column, please set config[PHONE_DATA_YIELD][PROVIDERS][RAPIDS][COMPUTE] to True and include 'ratiovalidyieldedminutes' in [FEATURES].")
# phone_cols = [col for col in features if \
# col.startswith('phone_applications_foreground_rapids_') or
# col.startswith('phone_battery_rapids_') or
# col.startswith('phone_calls_rapids_') or
# col.startswith('phone_keyboard_rapids_') or
# col.startswith('phone_messages_rapids_') or
# col.startswith('phone_screen_rapids_') or
# col.startswith('phone_wifi_')]
2022-07-20 15:51:22 +02:00
2022-09-26 17:54:00 +02:00
# mask = features['phone_data_yield_rapids_ratiovalidyieldedminutes'] > impute_phone_features['MIN_DATA_YIELDED_MINUTES_TO_IMPUTE']
# features.loc[mask, phone_cols] = impute(features[mask][phone_cols], method=impute_phone_features["TYPE"].lower())
2022-07-20 15:51:22 +02:00
2022-09-26 17:54:00 +02:00
# print(features[features['phone_data_yield_rapids_ratiovalidyieldedminutes'] > impute_phone_features['MIN_DATA_YIELDED_MINUTES_TO_IMPUTE']][phone_cols])
2022-07-20 15:51:22 +02:00
2022-09-27 16:12:08 +02:00
# (2.2) (optional) DOES ROW CONSIST OF ENOUGH NON-NAN VALUES? Possible some of these examples could still pass previous condition but not this one?
2022-09-23 15:24:50 +02:00
min_count = math . ceil ( ( 1 - provider [ " ROWS_NAN_THRESHOLD " ] ) * features . shape [ 1 ] ) # minimal not nan values in row
features . dropna ( axis = 0 , thresh = min_count , inplace = True )
2022-09-01 12:33:36 +02:00
2022-09-27 16:12:08 +02:00
# (3) REMOVE COLS IF THEIR NAN THRESHOLD IS PASSED (should be <= if even all NaN columns must be preserved - this solution now drops columns with all NaN rows)
esm_cols = features . loc [ : , features . columns . str . startswith ( ' phone_esm_straw ' ) ] # Get target (esm) columns
features = features . loc [ : , features . isna ( ) . sum ( ) < provider [ " COLS_NAN_THRESHOLD " ] * features . shape [ 0 ] ]
# Preserve esm cols if deleted (has to come after drop cols operations)
for esm in esm_cols :
if esm not in features :
features [ esm ] = esm_cols [ esm ]
2022-09-28 12:02:47 +02:00
# (4) CONTEXTUAL IMPUTATION
2022-09-23 15:24:50 +02:00
2022-09-28 12:02:47 +02:00
graph_bf_af ( features , " contextual_imputation_before " )
# Impute selected phone features with a high number
impute_w_hn = [ col for col in features . columns if \
" timeoffirstuse " in col or
" timeoflastuse " in col or
" timefirstcall " in col or
" timelastcall " in col or
" timefirstmessages " in col or
" timelastmessages " in col or
" firstuseafter " in col ]
features [ impute_w_hn ] = impute ( features [ impute_w_hn ] , method = " high_number " )
# Impute phone locations with median
impute_locations = [ col for col in features . columns if " phone_locations_ " in col ]
features [ impute_locations ] = impute ( features [ impute_locations ] , method = " median " )
# Impute remaining phone features with 0
impute_rest = [ col for col in features . columns if " phone_ " in col ]
features [ impute_locations ] = impute ( features [ impute_locations ] , method = " zero " )
graph_bf_af ( features , " contextual_imputation_after " )
## (5) STANDARDIZATION
2022-09-27 16:12:08 +02:00
if provider [ " STANDARDIZATION " ] :
features . loc [ : , ~ features . columns . isin ( excluded_columns ) ] = StandardScaler ( ) . fit_transform ( features . loc [ : , ~ features . columns . isin ( excluded_columns ) ] )
2022-09-28 12:02:47 +02:00
# (6) IMPUTATION: IMPUTE DATA WITH KNN METHOD
2022-09-23 15:24:50 +02:00
impute_cols = [ col for col in features . columns if col not in excluded_columns ]
features [ impute_cols ] = impute ( features [ impute_cols ] , method = " knn " )
2022-09-22 15:45:51 +02:00
2022-09-28 12:02:47 +02:00
# (7) REMOVE COLS WHERE VARIANCE IS 0
2022-09-27 16:12:08 +02:00
esm_cols = features . loc [ : , features . columns . str . startswith ( ' phone_esm_straw ' ) ]
2022-09-22 15:45:51 +02:00
2022-07-20 15:51:22 +02:00
if provider [ " COLS_VAR_THRESHOLD " ] :
features . drop ( features . std ( ) [ features . std ( ) == 0 ] . index . values , axis = 1 , inplace = True )
2022-09-27 16:12:08 +02:00
2022-09-28 12:02:47 +02:00
# (8) DROP HIGHLY CORRELATED FEATURES
2022-07-20 15:51:22 +02:00
drop_corr_features = provider [ " DROP_HIGHLY_CORRELATED_FEATURES " ]
2022-09-28 12:02:47 +02:00
if drop_corr_features [ " COMPUTE " ] and features . shape [ 0 ] > 5 : # If small amount of segments (rows) is present, do not execute correlation check
2022-07-22 17:31:30 +02:00
2022-07-20 15:51:22 +02:00
numerical_cols = features . select_dtypes ( include = np . number ) . columns . tolist ( )
2022-07-22 17:31:30 +02:00
# Remove columns where NaN count threshold is passed
valid_features = features [ numerical_cols ] . loc [ : , features [ numerical_cols ] . isna ( ) . sum ( ) < drop_corr_features [ ' MIN_OVERLAP_FOR_CORR_THRESHOLD ' ] * features [ numerical_cols ] . shape [ 0 ] ]
2022-07-20 15:51:22 +02:00
2022-09-27 16:12:08 +02:00
corr_matrix = valid_features . corr ( ) . abs ( )
upper = corr_matrix . where ( np . triu ( np . ones ( corr_matrix . shape ) , k = 1 ) . astype ( np . bool ) )
to_drop = [ column for column in upper . columns if any ( upper [ column ] > drop_corr_features [ " CORR_THRESHOLD " ] ) ]
2022-07-20 15:51:22 +02:00
features . drop ( to_drop , axis = 1 , inplace = True )
2022-09-27 16:12:08 +02:00
# Preserve esm cols if deleted (has to come after drop cols operations)
for esm in esm_cols :
if esm not in features :
features [ esm ] = esm_cols [ esm ]
2022-09-28 12:02:47 +02:00
# (9) VERIFY IF THERE ARE ANY NANS LEFT IN THE DATAFRAME
2022-09-20 10:03:48 +02:00
if features . isna ( ) . any ( ) . any ( ) :
2022-09-12 15:44:17 +02:00
raise ValueError
2022-09-01 12:33:36 +02:00
2022-09-27 16:12:08 +02:00
sys . exit ( )
2022-07-20 15:51:22 +02:00
return features
def impute ( df , method = ' zero ' ) :
2022-07-22 17:31:30 +02:00
2022-09-12 15:44:17 +02:00
def k_nearest ( df ) :
2022-08-31 12:18:50 +02:00
imputer = KNNImputer ( n_neighbors = 3 )
return pd . DataFrame ( imputer . fit_transform ( df ) , columns = df . columns )
2022-07-22 17:31:30 +02:00
2022-09-12 15:44:17 +02:00
return {
2022-07-20 15:51:22 +02:00
' zero ' : df . fillna ( 0 ) ,
2022-09-28 12:02:47 +02:00
' high_number ' : df . fillna ( 1000000 ) ,
2022-07-20 15:51:22 +02:00
' mean ' : df . fillna ( df . mean ( ) ) ,
' median ' : df . fillna ( df . median ( ) ) ,
2022-09-12 15:44:17 +02:00
' knn ' : k_nearest ( df )
2022-07-20 15:51:22 +02:00
} [ method ]
2022-09-28 12:02:47 +02:00
def graph_bf_af ( features , phase_name ) :
sns . set ( rc = { " figure.figsize " : ( 16 , 8 ) } )
print ( features )
sns . heatmap ( features . isna ( ) , cbar = False ) #features.select_dtypes(include=np.number)
plt . savefig ( f ' features_nans_ { phase_name } .png ' , bbox_inches = ' tight ' )
2022-07-20 15:51:22 +02:00