diff --git a/machine_learning/helper.py b/machine_learning/helper.py index 7f50e9a..32ecc98 100644 --- a/machine_learning/helper.py +++ b/machine_learning/helper.py @@ -11,6 +11,7 @@ from sklearn import ( svm, ) from sklearn.dummy import DummyClassifier, DummyRegressor +from sklearn.metrics import confusion_matrix from sklearn.model_selection import ( BaseCrossValidator, LeaveOneGroupOut, @@ -418,6 +419,12 @@ def run_all_regression_models( return scores +def confusion_matrix_scorer(clf, X, y): + y_pred = clf.predict(X) + cm = confusion_matrix(y, y_pred) + return {"tn": cm[0, 0], "fp": cm[0, 1], "fn": cm[1, 0], "tp": cm[1, 1]} + + def run_all_classification_models( data_x: pd.DataFrame, data_y: pd.DataFrame,