From be0324fd01d70c58a9eefd84ccb23d06a42ab57c Mon Sep 17 00:00:00 2001 From: Primoz Date: Mon, 28 Nov 2022 12:44:25 +0000 Subject: [PATCH] Fix some bugs and set categorical columns as categories dtypes. --- src/features/all_cleaning_overall/straw/main.py | 12 +++++++++++- src/models/helper.py | 4 ++-- src/models/select_targets.py | 2 +- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/features/all_cleaning_overall/straw/main.py b/src/features/all_cleaning_overall/straw/main.py index fb7b9344..1362358d 100644 --- a/src/features/all_cleaning_overall/straw/main.py +++ b/src/features/all_cleaning_overall/straw/main.py @@ -108,7 +108,7 @@ def straw_cleaning(sensor_data_files, provider, target): features[impute_w_sn2] = features[impute_w_sn2].fillna(1) # Special case of imputation - nominal/ordinal value impute_w_sn3 = [col for col in features.columns if "loglocationvariance" in col] - features[impute_w_sn2] = features[impute_w_sn2].fillna(-1000000) # Special case of imputation - loglocation + features[impute_w_sn3] = features[impute_w_sn3].fillna(-1000000) # Special case of imputation - loglocation # Impute location features impute_locations = [col for col in features \ @@ -218,6 +218,16 @@ def straw_cleaning(sensor_data_files, provider, target): graph_bf_af(features, "10correlation_drop") + # Transform categorical columns to category dtype + + cat1 = [col for col in features.columns if "mostcommonactivity" in col] + if cat1: # Transform columns to category dtype (mostcommonactivity) + features[cat1] = features[cat1].astype(int).astype('category') + + cat2 = [col for col in features.columns if "homelabel" in col] + if cat2: # Transform columns to category dtype (homelabel) + features[cat2] = features[cat2].astype(int).astype('category') + # (10) VERIFY IF THERE ARE ANY NANS LEFT IN THE DATAFRAME if features.isna().any().any(): raise ValueError("There are still some NaNs present in the dataframe. Please check for implementation errors.") diff --git a/src/models/helper.py b/src/models/helper.py index 3b90f52d..2e007810 100644 --- a/src/models/helper.py +++ b/src/models/helper.py @@ -9,8 +9,8 @@ def retain_target_column(df_input: pd.DataFrame, target_variable_name: str): esm_names = column_names[esm_names_index] target_variable_index = esm_names.str.contains(target_variable_name) if all(~target_variable_index): - warnings.warn(f"The requested target (, {target_variable_name} ,)cannot be found in the dataset. Please check the names of phone_esm_ columns in z_all_sensor_features_cleaned_straw_py.csv") - return False + warnings.warn(f"The requested target (, {target_variable_name} ,)cannot be found in the dataset. Please check the names of phone_esm_ columns in cleaned python file") + return None sensor_features_plus_target = df_input.drop(esm_names, axis=1) sensor_features_plus_target["target"] = df_input[esm_names[target_variable_index]] diff --git a/src/models/select_targets.py b/src/models/select_targets.py index 6c29aed7..c6abe687 100644 --- a/src/models/select_targets.py +++ b/src/models/select_targets.py @@ -7,7 +7,7 @@ target_variable_name = snakemake.params["target_variable"] model_input = retain_target_column(cleaned_sensor_features, target_variable_name) -if not model_input: +if model_input is None: pd.DataFrame().to_csv(snakemake.output[0]) else: model_input.to_csv(snakemake.output[0], index=False)