From bc78a1d498bff2ab53d92f8e92c1c74cc3657b81 Mon Sep 17 00:00:00 2001 From: junos Date: Wed, 31 May 2023 16:05:39 +0200 Subject: [PATCH] Define a confusion matrix scorer. --- machine_learning/helper.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/machine_learning/helper.py b/machine_learning/helper.py index 7f50e9a..32ecc98 100644 --- a/machine_learning/helper.py +++ b/machine_learning/helper.py @@ -11,6 +11,7 @@ from sklearn import ( svm, ) from sklearn.dummy import DummyClassifier, DummyRegressor +from sklearn.metrics import confusion_matrix from sklearn.model_selection import ( BaseCrossValidator, LeaveOneGroupOut, @@ -418,6 +419,12 @@ def run_all_regression_models( return scores +def confusion_matrix_scorer(clf, X, y): + y_pred = clf.predict(X) + cm = confusion_matrix(y, y_pred) + return {"tn": cm[0, 0], "fp": cm[0, 1], "fn": cm[1, 0], "tp": cm[1, 1]} + + def run_all_classification_models( data_x: pd.DataFrame, data_y: pd.DataFrame,