Define a confusion matrix scorer.
parent
aca84b214d
commit
bc78a1d498
|
@ -11,6 +11,7 @@ from sklearn import (
|
||||||
svm,
|
svm,
|
||||||
)
|
)
|
||||||
from sklearn.dummy import DummyClassifier, DummyRegressor
|
from sklearn.dummy import DummyClassifier, DummyRegressor
|
||||||
|
from sklearn.metrics import confusion_matrix
|
||||||
from sklearn.model_selection import (
|
from sklearn.model_selection import (
|
||||||
BaseCrossValidator,
|
BaseCrossValidator,
|
||||||
LeaveOneGroupOut,
|
LeaveOneGroupOut,
|
||||||
|
@ -418,6 +419,12 @@ def run_all_regression_models(
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
def confusion_matrix_scorer(clf, X, y):
|
||||||
|
y_pred = clf.predict(X)
|
||||||
|
cm = confusion_matrix(y, y_pred)
|
||||||
|
return {"tn": cm[0, 0], "fp": cm[0, 1], "fn": cm[1, 0], "tp": cm[1, 1]}
|
||||||
|
|
||||||
|
|
||||||
def run_all_classification_models(
|
def run_all_classification_models(
|
||||||
data_x: pd.DataFrame,
|
data_x: pd.DataFrame,
|
||||||
data_y: pd.DataFrame,
|
data_y: pd.DataFrame,
|
||||||
|
|
Loading…
Reference in New Issue