From a2401b5e367e0ffeb663230ef7b60e4e6791cd79 Mon Sep 17 00:00:00 2001 From: junos Date: Fri, 19 May 2023 01:34:34 +0200 Subject: [PATCH] Add multiclass scoring. --- machine_learning/helper.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/machine_learning/helper.py b/machine_learning/helper.py index 280d58c..7f50e9a 100644 --- a/machine_learning/helper.py +++ b/machine_learning/helper.py @@ -424,7 +424,14 @@ def run_all_classification_models( data_groups: pd.DataFrame, cross_validator: BaseCrossValidator, ): - metrics = ["accuracy", "average_precision", "recall", "f1"] + data_y_value_counts = data_y.value_counts() + if len(data_y_value_counts) == 1: + raise (ValueError("There is only one unique value in data_y.")) + if len(data_y_value_counts) == 2: + metrics = ["accuracy", "average_precision", "recall", "f1"] + else: + metrics = ["accuracy", "precision_micro", "recall_micro", "f1_micro"] + test_metrics = ["test_" + metric for metric in metrics] scores = pd.DataFrame(columns=["method", "test_metric", "max", "mean"])