diff --git a/machine_learning/helper.py b/machine_learning/helper.py index 280d58c..7f50e9a 100644 --- a/machine_learning/helper.py +++ b/machine_learning/helper.py @@ -424,7 +424,14 @@ def run_all_classification_models( data_groups: pd.DataFrame, cross_validator: BaseCrossValidator, ): - metrics = ["accuracy", "average_precision", "recall", "f1"] + data_y_value_counts = data_y.value_counts() + if len(data_y_value_counts) == 1: + raise (ValueError("There is only one unique value in data_y.")) + if len(data_y_value_counts) == 2: + metrics = ["accuracy", "average_precision", "recall", "f1"] + else: + metrics = ["accuracy", "precision_micro", "recall_micro", "f1_micro"] + test_metrics = ["test_" + metric for metric in metrics] scores = pd.DataFrame(columns=["method", "test_metric", "max", "mean"])