Define a confusion matrix scorer.

master
junos 2023-05-31 16:05:39 +02:00
parent aca84b214d
commit bc78a1d498
1 changed files with 7 additions and 0 deletions

View File

@ -11,6 +11,7 @@ from sklearn import (
svm, svm,
) )
from sklearn.dummy import DummyClassifier, DummyRegressor from sklearn.dummy import DummyClassifier, DummyRegressor
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import ( from sklearn.model_selection import (
BaseCrossValidator, BaseCrossValidator,
LeaveOneGroupOut, LeaveOneGroupOut,
@ -418,6 +419,12 @@ def run_all_regression_models(
return scores return scores
def confusion_matrix_scorer(clf, X, y):
y_pred = clf.predict(X)
cm = confusion_matrix(y, y_pred)
return {"tn": cm[0, 0], "fp": cm[0, 1], "fn": cm[1, 0], "tp": cm[1, 1]}
def run_all_classification_models( def run_all_classification_models(
data_x: pd.DataFrame, data_x: pd.DataFrame,
data_y: pd.DataFrame, data_y: pd.DataFrame,