2022-11-21 14:47:19 +01:00
|
|
|
# ---
|
|
|
|
# jupyter:
|
|
|
|
# jupytext:
|
|
|
|
# formats: ipynb,py:percent
|
|
|
|
# text_representation:
|
|
|
|
# extension: .py
|
|
|
|
# format_name: percent
|
|
|
|
# format_version: '1.3'
|
2023-05-10 23:17:44 +02:00
|
|
|
# jupytext_version: 1.14.5
|
2022-11-21 14:47:19 +01:00
|
|
|
# kernelspec:
|
|
|
|
# display_name: straw2analysis
|
|
|
|
# language: python
|
|
|
|
# name: straw2analysis
|
|
|
|
# ---
|
|
|
|
|
2023-05-11 16:51:38 +02:00
|
|
|
# %% jupyter={"outputs_hidden": false, "source_hidden": false}
|
2022-11-21 14:47:19 +01:00
|
|
|
# %matplotlib inline
|
|
|
|
import os
|
|
|
|
import sys
|
|
|
|
|
|
|
|
import pandas as pd
|
2023-05-10 23:17:44 +02:00
|
|
|
from IPython.core.interactiveshell import InteractiveShell
|
2022-11-21 14:47:19 +01:00
|
|
|
|
2023-05-10 23:17:44 +02:00
|
|
|
from machine_learning.helper import (
|
|
|
|
impute_encode_categorical_features,
|
|
|
|
prepare_cross_validator,
|
|
|
|
prepare_sklearn_data_format,
|
|
|
|
run_all_classification_models,
|
|
|
|
)
|
2022-11-21 14:47:19 +01:00
|
|
|
|
|
|
|
InteractiveShell.ast_node_interactivity = "all"
|
|
|
|
|
|
|
|
nb_dir = os.path.split(os.getcwd())[0]
|
|
|
|
if nb_dir not in sys.path:
|
|
|
|
sys.path.append(nb_dir)
|
|
|
|
|
2023-01-04 21:25:12 +01:00
|
|
|
|
2023-05-10 23:17:44 +02:00
|
|
|
# %%
|
|
|
|
CV_METHOD = "logo" # logo, half_logo, 5kfold
|
|
|
|
# Cross-validation method (could be regarded as a hyperparameter)
|
|
|
|
N_SL = 3 # Number of largest/smallest accuracies (of particular CV) outputs
|
|
|
|
UNDERSAMPLING = False
|
|
|
|
# (bool) If True this will train and test data on balanced dataset
|
|
|
|
# (using undersampling method)
|
2022-11-21 14:47:19 +01:00
|
|
|
|
2023-05-11 16:51:38 +02:00
|
|
|
# %% jupyter={"outputs_hidden": false, "source_hidden": false}
|
2023-05-10 23:17:44 +02:00
|
|
|
model_input = pd.read_csv(
|
2023-05-11 16:51:38 +02:00
|
|
|
"E:/STRAWresults/20230415/stress_event/input_appraisal_stressfulness_event_mean.csv"
|
2023-05-10 23:17:44 +02:00
|
|
|
)
|
|
|
|
# model_input =
|
|
|
|
# model_input[model_input.columns.drop(
|
|
|
|
# list(model_input.filter(regex='empatica_temperature'))
|
|
|
|
# )]
|
2023-05-11 16:51:38 +02:00
|
|
|
# model_input = model_input[model_input['local_segment'].str.contains("daily")]
|
2022-11-21 14:47:19 +01:00
|
|
|
|
2023-05-11 16:51:38 +02:00
|
|
|
# %% jupyter={"outputs_hidden": false, "source_hidden": false}
|
2023-05-10 23:17:44 +02:00
|
|
|
model_input["target"].value_counts()
|
2022-11-22 14:31:49 +01:00
|
|
|
|
2023-05-11 16:51:38 +02:00
|
|
|
# %% jupyter={"outputs_hidden": false, "source_hidden": false}
|
2022-12-13 17:01:46 +01:00
|
|
|
# bins = [-10, 0, 10] # bins for z-scored targets
|
2023-05-10 23:17:44 +02:00
|
|
|
bins = [-1, 0, 4] # bins for stressfulness (0-4) target
|
|
|
|
model_input["target"], edges = pd.cut(
|
|
|
|
model_input.target, bins=bins, labels=["low", "high"], retbins=True, right=True
|
|
|
|
) # ['low', 'medium', 'high']
|
|
|
|
model_input["target"].value_counts(), edges
|
2022-11-29 14:06:06 +01:00
|
|
|
# model_input = model_input[model_input['target'] != "medium"]
|
2023-05-10 23:17:44 +02:00
|
|
|
model_input["target"] = (
|
|
|
|
model_input["target"].astype(str).apply(lambda x: 0 if x == "low" else 1)
|
|
|
|
)
|
2022-11-21 14:47:19 +01:00
|
|
|
|
2023-05-10 23:17:44 +02:00
|
|
|
model_input["target"].value_counts()
|
2022-11-21 14:47:19 +01:00
|
|
|
|
2023-05-11 16:51:38 +02:00
|
|
|
# %% jupyter={"outputs_hidden": false, "source_hidden": false}
|
2022-12-13 17:01:46 +01:00
|
|
|
# UnderSampling
|
2023-05-10 23:17:44 +02:00
|
|
|
if UNDERSAMPLING:
|
|
|
|
no_stress = model_input[model_input["target"] == 0]
|
|
|
|
stress = model_input[model_input["target"] == 1]
|
2022-11-22 14:31:49 +01:00
|
|
|
|
2023-05-10 23:17:44 +02:00
|
|
|
no_stress = no_stress.sample(n=len(stress))
|
|
|
|
model_input = pd.concat([stress, no_stress], axis=0)
|
2022-11-21 14:47:19 +01:00
|
|
|
|
|
|
|
|
2023-05-11 16:51:38 +02:00
|
|
|
# %% jupyter={"outputs_hidden": false, "source_hidden": false}
|
2023-05-10 23:17:44 +02:00
|
|
|
model_input_encoded = impute_encode_categorical_features(model_input)
|
2023-01-04 21:25:42 +01:00
|
|
|
# %%
|
2023-05-10 23:17:44 +02:00
|
|
|
data_x, data_y, data_groups = prepare_sklearn_data_format(
|
|
|
|
model_input_encoded, CV_METHOD
|
2022-11-21 14:47:19 +01:00
|
|
|
)
|
2023-05-10 23:17:44 +02:00
|
|
|
cross_validator = prepare_cross_validator(data_x, data_y, data_groups, CV_METHOD)
|
2022-12-15 16:43:13 +01:00
|
|
|
|
2023-05-10 23:17:44 +02:00
|
|
|
# %%
|
|
|
|
data_y.head()
|
2022-11-21 14:47:19 +01:00
|
|
|
|
2023-05-10 23:17:44 +02:00
|
|
|
# %%
|
|
|
|
data_y.tail()
|
|
|
|
# %%
|
|
|
|
data_y.shape
|
|
|
|
# %%
|
|
|
|
scores = run_all_classification_models(data_x, data_y, data_groups, cross_validator)
|
|
|
|
# %%
|
|
|
|
scores.to_csv(
|
2023-05-11 16:51:38 +02:00
|
|
|
"../presentation/appraisal_stressfulness_event_classification_"
|
|
|
|
+ CV_METHOD
|
|
|
|
+ ".csv",
|
2023-05-10 23:17:44 +02:00
|
|
|
index=False,
|
2022-11-21 14:47:19 +01:00
|
|
|
)
|