diff --git a/src/models/modeling_utils.py b/src/models/modeling_utils.py index 84a66188..1ba17c50 100644 --- a/src/models/modeling_utils.py +++ b/src/models/modeling_utils.py @@ -49,7 +49,10 @@ def getMetrics(pred_y, pred_y_prob, true_y): metrics = {} # metrics for all categories metrics["accuracy"] = accuracy_score(true_y, pred_y) - metrics["auc"] = roc_auc_score(true_y, pred_y_prob) + try: + metrics["auc"] = roc_auc_score(true_y, pred_y_prob) + except: + metrics["auc"] = None metrics["kappa"] = cohen_kappa_score(true_y, pred_y) # metrics for label 0 metrics["precision0"] = precision_score(true_y, pred_y, average=None, labels=[0,1], zero_division=0)[0]