From d263b325645f14e344e35b1b30133710279c3fd4 Mon Sep 17 00:00:00 2001 From: Primoz Date: Thu, 19 Jan 2023 09:26:55 +0100 Subject: [PATCH] Temp: remove stratified logo from ml pipeline. --- exploration/ml_pipeline_classification.py | 39 +++++++++++-------- .../event_stressful_detection_5fold.csv | 29 ++++++++++++++ .../event_stressful_detection_logo.csv | 29 ++++++++++++++ 3 files changed, 81 insertions(+), 16 deletions(-) create mode 100644 presentation/event_stressful_detection_5fold.csv create mode 100644 presentation/event_stressful_detection_logo.csv diff --git a/exploration/ml_pipeline_classification.py b/exploration/ml_pipeline_classification.py index 3deae61..131c901 100644 --- a/exploration/ml_pipeline_classification.py +++ b/exploration/ml_pipeline_classification.py @@ -46,9 +46,9 @@ import machine_learning.helper # # %% jupyter={"source_hidden": false, "outputs_hidden": false} nteract={"transient": {"deleting": false}} -cv_method_str = '5kfold' # logo, half_logo, 5kfold # Cross-validation method (could be regarded as a hyperparameter) +cv_method_str = 'logo' # logo, half_logo, 5kfold # Cross-validation method (could be regarded as a hyperparameter) n_sl = 3 # Number of largest/smallest accuracies (of particular CV) outputs -undersampling = True # (bool) If True this will train and test data on balanced dataset (using undersampling method) +undersampling = False # (bool) If True this will train and test data on balanced dataset (using undersampling method) # %% jupyter={"source_hidden": false, "outputs_hidden": false} model_input = pd.read_csv("../data/stressfulness_event_with_target_0_ver2/input_appraisal_stressfulness_event_mean.csv") @@ -72,20 +72,27 @@ model_input['target'].value_counts() # %% jupyter={"source_hidden": false, "outputs_hidden": false} # UnderSampling if undersampling: - model_input_new = pd.DataFrame(columns=model_input.columns) - for pid in model_input["pid"].unique(): - stress = model_input[(model_input["pid"] == pid) & (model_input['target'] == 1)] - no_stress = model_input[(model_input["pid"] == pid) & (model_input['target'] == 0)] - if (len(stress) == 0): - continue - if (len(no_stress) == 0): - continue - model_input_new = pd.concat([model_input_new, stress], axis=0) + no_stress = model_input[model_input['target'] == 0] + stress = model_input[model_input['target'] == 1] - no_stress = no_stress.sample(n=min(len(stress), len(no_stress))) - # In case there are more stress samples than no_stress, take all instances of no_stress. - model_input_new = pd.concat([model_input_new, no_stress], axis=0) - model_input = model_input_new + no_stress = no_stress.sample(n=len(stress)) + model_input = pd.concat([stress,no_stress], axis=0) + +# model_input_new = pd.DataFrame(columns=model_input.columns) +# for pid in model_input["pid"].unique(): +# stress = model_input[(model_input["pid"] == pid) & (model_input['target'] == 1)] +# no_stress = model_input[(model_input["pid"] == pid) & (model_input['target'] == 0)] +# if (len(stress) == 0): +# continue +# if (len(no_stress) == 0): +# continue +# model_input_new = pd.concat([model_input_new, stress], axis=0) + +# no_stress = no_stress.sample(n=min(len(stress), len(no_stress))) +# # In case there are more stress samples than no_stress, take all instances of no_stress. +# model_input_new = pd.concat([model_input_new, no_stress], axis=0) +# model_input = model_input_new +# model_input_new = pd.concat([model_input_new, no_stress], axis=0) # %% jupyter={"source_hidden": false, "outputs_hidden": false} @@ -170,7 +177,7 @@ final_scores = machine_learning.helper.run_all_classification_models(imputer.fit # %% final_scores.index.name = "metric" final_scores = final_scores.set_index(["method", final_scores.index]) -final_scores.to_csv("../presentation/event_stressful_detection_5fold.csv") +final_scores.to_csv(f"../presentation/event_stressful_detection_{cv_method_str}.csv") # %% [markdown] # ### Logistic Regression diff --git a/presentation/event_stressful_detection_5fold.csv b/presentation/event_stressful_detection_5fold.csv new file mode 100644 index 0000000..d005efe --- /dev/null +++ b/presentation/event_stressful_detection_5fold.csv @@ -0,0 +1,29 @@ +method,metric,max,mean +Dummy,test_accuracy,0.8557046979865772,0.8548446932649828 +Dummy,test_average_precision,0.1457286432160804,0.14515530673501736 +Dummy,test_recall,0.0,0.0 +Dummy,test_f1,0.0,0.0 +logistic_reg,test_accuracy,0.8640939597315436,0.8504895843872606 +logistic_reg,test_average_precision,0.44363425265068757,0.37511495347389834 +logistic_reg,test_recall,0.3023255813953488,0.24266238973536486 +logistic_reg,test_f1,0.3909774436090226,0.318943511424051 +svc,test_accuracy,0.8557046979865772,0.8548446932649828 +svc,test_average_precision,0.44514416839823046,0.4068200938341621 +svc,test_recall,0.0,0.0 +svc,test_f1,0.0,0.0 +gaussian_naive_bayes,test_accuracy,0.7684563758389261,0.7479123806954234 +gaussian_naive_bayes,test_average_precision,0.2534828030085334,0.23379392278901853 +gaussian_naive_bayes,test_recall,0.42528735632183906,0.3924619085805935 +gaussian_naive_bayes,test_f1,0.34285714285714286,0.3107236284017699 +stochastic_gradient_descent,test_accuracy,0.8576214405360134,0.7773610783222601 +stochastic_gradient_descent,test_average_precision,0.3813093757959869,0.3617503752215592 +stochastic_gradient_descent,test_recall,0.686046511627907,0.2822507350975675 +stochastic_gradient_descent,test_f1,0.3652173913043478,0.21849107443075583 +random_forest,test_accuracy,0.9110738255033557,0.9011129472867694 +random_forest,test_average_precision,0.6998372262021191,0.6619275281099584 +random_forest,test_recall,0.4069767441860465,0.35356856455493185 +random_forest,test_f1,0.5691056910569107,0.5078402513053142 +xgboost,test_accuracy,0.9128978224455612,0.9007711937764886 +xgboost,test_average_precision,0.7366643049075349,0.698622165966308 +xgboost,test_recall,0.5287356321839081,0.44346431435445066 +xgboost,test_f1,0.638888888888889,0.5633957169928393 diff --git a/presentation/event_stressful_detection_logo.csv b/presentation/event_stressful_detection_logo.csv new file mode 100644 index 0000000..6874e7f --- /dev/null +++ b/presentation/event_stressful_detection_logo.csv @@ -0,0 +1,29 @@ +method,metric,max,mean +Dummy,test_accuracy,1.0,0.8524114578096439 +Dummy,test_average_precision,0.7,0.14758854219035614 +Dummy,test_recall,0.0,0.0 +Dummy,test_f1,0.0,0.0 +logistic_reg,test_accuracy,0.9824561403508771,0.8445351955631311 +logistic_reg,test_average_precision,1.0,0.44605167668563583 +logistic_reg,test_recall,1.0,0.25353566685532386 +logistic_reg,test_f1,0.823529411764706,0.27951926390778625 +svc,test_accuracy,1.0,0.8524114578096439 +svc,test_average_precision,0.9612401707068228,0.44179454944271934 +svc,test_recall,0.0,0.0 +svc,test_f1,0.0,0.0 +gaussian_naive_bayes,test_accuracy,0.9,0.7491301746887129 +gaussian_naive_bayes,test_average_precision,0.9189430193277607,0.2833170327386991 +gaussian_naive_bayes,test_recall,1.0,0.3743761174081108 +gaussian_naive_bayes,test_f1,0.7000000000000001,0.2698456659235668 +stochastic_gradient_descent,test_accuracy,1.0,0.7926428596764739 +stochastic_gradient_descent,test_average_precision,1.0,0.4421948838597582 +stochastic_gradient_descent,test_recall,1.0,0.30156420704502945 +stochastic_gradient_descent,test_f1,0.8148148148148148,0.24088393234361388 +random_forest,test_accuracy,1.0,0.8722158105763481 +random_forest,test_average_precision,1.0,0.49817066323226833 +random_forest,test_recall,1.0,0.18161263127840668 +random_forest,test_f1,1.0,0.2508096532365307 +xgboost,test_accuracy,1.0,0.8812627400277729 +xgboost,test_average_precision,1.0,0.5505695112208401 +xgboost,test_recall,1.0,0.2896161238315027 +xgboost,test_f1,0.9411764705882353,0.36887408735855665