Temp: remove stratified logo from ml pipeline.

ml_pipeline
Primoz 2023-01-19 09:26:55 +01:00
parent ad2fab133f
commit d263b32564
3 changed files with 81 additions and 16 deletions

View File

@ -46,9 +46,9 @@ import machine_learning.helper
# #
# %% jupyter={"source_hidden": false, "outputs_hidden": false} nteract={"transient": {"deleting": false}} # %% jupyter={"source_hidden": false, "outputs_hidden": false} nteract={"transient": {"deleting": false}}
cv_method_str = '5kfold' # logo, half_logo, 5kfold # Cross-validation method (could be regarded as a hyperparameter) cv_method_str = 'logo' # logo, half_logo, 5kfold # Cross-validation method (could be regarded as a hyperparameter)
n_sl = 3 # Number of largest/smallest accuracies (of particular CV) outputs n_sl = 3 # Number of largest/smallest accuracies (of particular CV) outputs
undersampling = True # (bool) If True this will train and test data on balanced dataset (using undersampling method) undersampling = False # (bool) If True this will train and test data on balanced dataset (using undersampling method)
# %% jupyter={"source_hidden": false, "outputs_hidden": false} # %% jupyter={"source_hidden": false, "outputs_hidden": false}
model_input = pd.read_csv("../data/stressfulness_event_with_target_0_ver2/input_appraisal_stressfulness_event_mean.csv") model_input = pd.read_csv("../data/stressfulness_event_with_target_0_ver2/input_appraisal_stressfulness_event_mean.csv")
@ -72,20 +72,27 @@ model_input['target'].value_counts()
# %% jupyter={"source_hidden": false, "outputs_hidden": false} # %% jupyter={"source_hidden": false, "outputs_hidden": false}
# UnderSampling # UnderSampling
if undersampling: if undersampling:
model_input_new = pd.DataFrame(columns=model_input.columns) no_stress = model_input[model_input['target'] == 0]
for pid in model_input["pid"].unique(): stress = model_input[model_input['target'] == 1]
stress = model_input[(model_input["pid"] == pid) & (model_input['target'] == 1)]
no_stress = model_input[(model_input["pid"] == pid) & (model_input['target'] == 0)]
if (len(stress) == 0):
continue
if (len(no_stress) == 0):
continue
model_input_new = pd.concat([model_input_new, stress], axis=0)
no_stress = no_stress.sample(n=min(len(stress), len(no_stress))) no_stress = no_stress.sample(n=len(stress))
# In case there are more stress samples than no_stress, take all instances of no_stress. model_input = pd.concat([stress,no_stress], axis=0)
model_input_new = pd.concat([model_input_new, no_stress], axis=0)
model_input = model_input_new # model_input_new = pd.DataFrame(columns=model_input.columns)
# for pid in model_input["pid"].unique():
# stress = model_input[(model_input["pid"] == pid) & (model_input['target'] == 1)]
# no_stress = model_input[(model_input["pid"] == pid) & (model_input['target'] == 0)]
# if (len(stress) == 0):
# continue
# if (len(no_stress) == 0):
# continue
# model_input_new = pd.concat([model_input_new, stress], axis=0)
# no_stress = no_stress.sample(n=min(len(stress), len(no_stress)))
# # In case there are more stress samples than no_stress, take all instances of no_stress.
# model_input_new = pd.concat([model_input_new, no_stress], axis=0)
# model_input = model_input_new
# model_input_new = pd.concat([model_input_new, no_stress], axis=0)
# %% jupyter={"source_hidden": false, "outputs_hidden": false} # %% jupyter={"source_hidden": false, "outputs_hidden": false}
@ -170,7 +177,7 @@ final_scores = machine_learning.helper.run_all_classification_models(imputer.fit
# %% # %%
final_scores.index.name = "metric" final_scores.index.name = "metric"
final_scores = final_scores.set_index(["method", final_scores.index]) final_scores = final_scores.set_index(["method", final_scores.index])
final_scores.to_csv("../presentation/event_stressful_detection_5fold.csv") final_scores.to_csv(f"../presentation/event_stressful_detection_{cv_method_str}.csv")
# %% [markdown] # %% [markdown]
# ### Logistic Regression # ### Logistic Regression

View File

@ -0,0 +1,29 @@
method,metric,max,mean
Dummy,test_accuracy,0.8557046979865772,0.8548446932649828
Dummy,test_average_precision,0.1457286432160804,0.14515530673501736
Dummy,test_recall,0.0,0.0
Dummy,test_f1,0.0,0.0
logistic_reg,test_accuracy,0.8640939597315436,0.8504895843872606
logistic_reg,test_average_precision,0.44363425265068757,0.37511495347389834
logistic_reg,test_recall,0.3023255813953488,0.24266238973536486
logistic_reg,test_f1,0.3909774436090226,0.318943511424051
svc,test_accuracy,0.8557046979865772,0.8548446932649828
svc,test_average_precision,0.44514416839823046,0.4068200938341621
svc,test_recall,0.0,0.0
svc,test_f1,0.0,0.0
gaussian_naive_bayes,test_accuracy,0.7684563758389261,0.7479123806954234
gaussian_naive_bayes,test_average_precision,0.2534828030085334,0.23379392278901853
gaussian_naive_bayes,test_recall,0.42528735632183906,0.3924619085805935
gaussian_naive_bayes,test_f1,0.34285714285714286,0.3107236284017699
stochastic_gradient_descent,test_accuracy,0.8576214405360134,0.7773610783222601
stochastic_gradient_descent,test_average_precision,0.3813093757959869,0.3617503752215592
stochastic_gradient_descent,test_recall,0.686046511627907,0.2822507350975675
stochastic_gradient_descent,test_f1,0.3652173913043478,0.21849107443075583
random_forest,test_accuracy,0.9110738255033557,0.9011129472867694
random_forest,test_average_precision,0.6998372262021191,0.6619275281099584
random_forest,test_recall,0.4069767441860465,0.35356856455493185
random_forest,test_f1,0.5691056910569107,0.5078402513053142
xgboost,test_accuracy,0.9128978224455612,0.9007711937764886
xgboost,test_average_precision,0.7366643049075349,0.698622165966308
xgboost,test_recall,0.5287356321839081,0.44346431435445066
xgboost,test_f1,0.638888888888889,0.5633957169928393
1 method metric max mean
2 Dummy test_accuracy 0.8557046979865772 0.8548446932649828
3 Dummy test_average_precision 0.1457286432160804 0.14515530673501736
4 Dummy test_recall 0.0 0.0
5 Dummy test_f1 0.0 0.0
6 logistic_reg test_accuracy 0.8640939597315436 0.8504895843872606
7 logistic_reg test_average_precision 0.44363425265068757 0.37511495347389834
8 logistic_reg test_recall 0.3023255813953488 0.24266238973536486
9 logistic_reg test_f1 0.3909774436090226 0.318943511424051
10 svc test_accuracy 0.8557046979865772 0.8548446932649828
11 svc test_average_precision 0.44514416839823046 0.4068200938341621
12 svc test_recall 0.0 0.0
13 svc test_f1 0.0 0.0
14 gaussian_naive_bayes test_accuracy 0.7684563758389261 0.7479123806954234
15 gaussian_naive_bayes test_average_precision 0.2534828030085334 0.23379392278901853
16 gaussian_naive_bayes test_recall 0.42528735632183906 0.3924619085805935
17 gaussian_naive_bayes test_f1 0.34285714285714286 0.3107236284017699
18 stochastic_gradient_descent test_accuracy 0.8576214405360134 0.7773610783222601
19 stochastic_gradient_descent test_average_precision 0.3813093757959869 0.3617503752215592
20 stochastic_gradient_descent test_recall 0.686046511627907 0.2822507350975675
21 stochastic_gradient_descent test_f1 0.3652173913043478 0.21849107443075583
22 random_forest test_accuracy 0.9110738255033557 0.9011129472867694
23 random_forest test_average_precision 0.6998372262021191 0.6619275281099584
24 random_forest test_recall 0.4069767441860465 0.35356856455493185
25 random_forest test_f1 0.5691056910569107 0.5078402513053142
26 xgboost test_accuracy 0.9128978224455612 0.9007711937764886
27 xgboost test_average_precision 0.7366643049075349 0.698622165966308
28 xgboost test_recall 0.5287356321839081 0.44346431435445066
29 xgboost test_f1 0.638888888888889 0.5633957169928393

View File

@ -0,0 +1,29 @@
method,metric,max,mean
Dummy,test_accuracy,1.0,0.8524114578096439
Dummy,test_average_precision,0.7,0.14758854219035614
Dummy,test_recall,0.0,0.0
Dummy,test_f1,0.0,0.0
logistic_reg,test_accuracy,0.9824561403508771,0.8445351955631311
logistic_reg,test_average_precision,1.0,0.44605167668563583
logistic_reg,test_recall,1.0,0.25353566685532386
logistic_reg,test_f1,0.823529411764706,0.27951926390778625
svc,test_accuracy,1.0,0.8524114578096439
svc,test_average_precision,0.9612401707068228,0.44179454944271934
svc,test_recall,0.0,0.0
svc,test_f1,0.0,0.0
gaussian_naive_bayes,test_accuracy,0.9,0.7491301746887129
gaussian_naive_bayes,test_average_precision,0.9189430193277607,0.2833170327386991
gaussian_naive_bayes,test_recall,1.0,0.3743761174081108
gaussian_naive_bayes,test_f1,0.7000000000000001,0.2698456659235668
stochastic_gradient_descent,test_accuracy,1.0,0.7926428596764739
stochastic_gradient_descent,test_average_precision,1.0,0.4421948838597582
stochastic_gradient_descent,test_recall,1.0,0.30156420704502945
stochastic_gradient_descent,test_f1,0.8148148148148148,0.24088393234361388
random_forest,test_accuracy,1.0,0.8722158105763481
random_forest,test_average_precision,1.0,0.49817066323226833
random_forest,test_recall,1.0,0.18161263127840668
random_forest,test_f1,1.0,0.2508096532365307
xgboost,test_accuracy,1.0,0.8812627400277729
xgboost,test_average_precision,1.0,0.5505695112208401
xgboost,test_recall,1.0,0.2896161238315027
xgboost,test_f1,0.9411764705882353,0.36887408735855665
1 method metric max mean
2 Dummy test_accuracy 1.0 0.8524114578096439
3 Dummy test_average_precision 0.7 0.14758854219035614
4 Dummy test_recall 0.0 0.0
5 Dummy test_f1 0.0 0.0
6 logistic_reg test_accuracy 0.9824561403508771 0.8445351955631311
7 logistic_reg test_average_precision 1.0 0.44605167668563583
8 logistic_reg test_recall 1.0 0.25353566685532386
9 logistic_reg test_f1 0.823529411764706 0.27951926390778625
10 svc test_accuracy 1.0 0.8524114578096439
11 svc test_average_precision 0.9612401707068228 0.44179454944271934
12 svc test_recall 0.0 0.0
13 svc test_f1 0.0 0.0
14 gaussian_naive_bayes test_accuracy 0.9 0.7491301746887129
15 gaussian_naive_bayes test_average_precision 0.9189430193277607 0.2833170327386991
16 gaussian_naive_bayes test_recall 1.0 0.3743761174081108
17 gaussian_naive_bayes test_f1 0.7000000000000001 0.2698456659235668
18 stochastic_gradient_descent test_accuracy 1.0 0.7926428596764739
19 stochastic_gradient_descent test_average_precision 1.0 0.4421948838597582
20 stochastic_gradient_descent test_recall 1.0 0.30156420704502945
21 stochastic_gradient_descent test_f1 0.8148148148148148 0.24088393234361388
22 random_forest test_accuracy 1.0 0.8722158105763481
23 random_forest test_average_precision 1.0 0.49817066323226833
24 random_forest test_recall 1.0 0.18161263127840668
25 random_forest test_f1 1.0 0.2508096532365307
26 xgboost test_accuracy 1.0 0.8812627400277729
27 xgboost test_average_precision 1.0 0.5505695112208401
28 xgboost test_recall 1.0 0.2896161238315027
29 xgboost test_f1 0.9411764705882353 0.36887408735855665