diff --git a/machine_learning/helper.py b/machine_learning/helper.py index 4742cca..dd57393 100644 --- a/machine_learning/helper.py +++ b/machine_learning/helper.py @@ -66,8 +66,8 @@ def construct_full_path(folder: Path, filename_prefix: str, data_type: str) -> P def insert_row(df, row): return pd.concat([df, pd.DataFrame([row], columns=df.columns)], ignore_index=True) -def run_all_models(input_csv): - # Prepare data +def prepare_model_input(input_csv): + model_input = pd.read_csv(input_csv) index_columns = ["local_segment", "local_segment_label", "local_segment_start_datetime", "local_segment_end_datetime"] @@ -75,9 +75,11 @@ def run_all_models(input_csv): data_x, data_y, data_groups = model_input.drop(["target", "pid"], axis=1), model_input["target"], model_input["pid"] - categorical_feature_colnames = ["gender", "startlanguage"] - additional_categorical_features = [col for col in data_x.columns if "mostcommonactivity" in col or "homelabel" in col] - categorical_feature_colnames += additional_categorical_features + categorical_feature_colnames = ["gender", "startlanguage", "limesurvey_demand_control_ratio_quartile"] + #TODO: check whether limesurvey_demand_control_ratio_quartile NaNs could be replaced meaningfully + #additional_categorical_features = [col for col in data_x.columns if "mostcommonactivity" in col or "homelabel" in col] + #TODO: check if mostcommonactivity is indeed a categorical features after aggregating + #categorical_feature_colnames += additional_categorical_features categorical_features = data_x[categorical_feature_colnames].copy() mode_categorical_features = categorical_features.mode().iloc[0] # fillna with mode @@ -91,6 +93,13 @@ def run_all_models(input_csv): train_x = pd.concat([numerical_features, categorical_features], axis=1) + return train_x, data_y, data_groups + + +def run_all_models(input_csv): + # Prepare data + train_x, data_y, data_groups = prepare_model_input(input_csv) + # Prepare cross validation logo = LeaveOneGroupOut() logo.get_n_splits(