diff --git a/machine_learning/helper.py b/machine_learning/helper.py index 4d4e48c..20ac7eb 100644 --- a/machine_learning/helper.py +++ b/machine_learning/helper.py @@ -73,7 +73,7 @@ def insert_row(df, row): return pd.concat([df, pd.DataFrame([row], columns=df.columns)], ignore_index=True) -def prepare_regression_model_input(model_input, cv_method="logo"): +def prepare_sklearn_data_format(model_input, cv_method="logo"): index_columns = [ "local_segment", "local_segment_label", @@ -82,13 +82,7 @@ def prepare_regression_model_input(model_input, cv_method="logo"): ] model_input.set_index(index_columns, inplace=True) - if cv_method == "logo": - data_x, data_y, data_groups = ( - model_input.drop(["target", "pid"], axis=1), - model_input["target"], - model_input["pid"], - ) - else: + if cv_method == "half_logo": model_input["pid_index"] = model_input.groupby("pid").cumcount() model_input["pid_count"] = model_input.groupby("pid")["pid"].transform("count") @@ -104,6 +98,19 @@ def prepare_regression_model_input(model_input, cv_method="logo"): model_input["target"], model_input["pid_half"], ) + else: + data_x, data_y, data_groups = ( + model_input.drop(["target", "pid"], axis=1), + model_input["target"], + model_input["pid"], + ) + return data_x, data_y, data_groups + + +def prepare_regression_model_input(model_input, cv_method="logo"): + data_x, data_y, data_groups = prepare_sklearn_data_format( + model_input, cv_method=cv_method + ) categorical_feature_colnames = [ "gender",