Define a confusion matrix scorer.
parent
aca84b214d
commit
bc78a1d498
|
@ -11,6 +11,7 @@ from sklearn import (
|
|||
svm,
|
||||
)
|
||||
from sklearn.dummy import DummyClassifier, DummyRegressor
|
||||
from sklearn.metrics import confusion_matrix
|
||||
from sklearn.model_selection import (
|
||||
BaseCrossValidator,
|
||||
LeaveOneGroupOut,
|
||||
|
@ -418,6 +419,12 @@ def run_all_regression_models(
|
|||
return scores
|
||||
|
||||
|
||||
def confusion_matrix_scorer(clf, X, y):
|
||||
y_pred = clf.predict(X)
|
||||
cm = confusion_matrix(y, y_pred)
|
||||
return {"tn": cm[0, 0], "fp": cm[0, 1], "fn": cm[1, 0], "tp": cm[1, 1]}
|
||||
|
||||
|
||||
def run_all_classification_models(
|
||||
data_x: pd.DataFrame,
|
||||
data_y: pd.DataFrame,
|
||||
|
|
Loading…
Reference in New Issue