stress_at_work_analysis/exploration/ml_pipeline.py

71 lines
2.2 KiB
Python

# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.13.0
# kernelspec:
# display_name: straw2analysis
# language: python
# name: straw2analysis
# ---
# %%
import sys, os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
nb_dir = os.path.split(os.getcwd())[0]
if nb_dir not in sys.path:
sys.path.append(nb_dir)
from machine_learning.cross_validation import CrossValidation
from machine_learning.preprocessing import Preprocessing
from machine_learning.feature_selection import FeatureSelection
# %%
df = pd.read_csv("../data/stressfulness_event_with_speech/input_appraisal_stressfulness_event_mean.csv")
index_columns = ["local_segment", "local_segment_label", "local_segment_start_datetime", "local_segment_end_datetime"]
df.set_index(index_columns, inplace=True)
# Create binary target
bins = [-1, 0, 4] # bins for stressfulness (0-4) target
df['target'], edges = pd.cut(df.target, bins=bins, labels=[0, 1], retbins=True, right=True) #['low', 'medium', 'high']
nan_cols = df.columns[df.isna().any()].tolist()
df[nan_cols] = df[nan_cols].fillna(round(df[nan_cols].median(), 0))
cv = CrossValidation(data=df, cv_method="logo")
categorical_columns = ["gender", "startlanguage", "mostcommonactivity", "homelabel"]
interval_feature_list, other_feature_list = [], []
# %%
for split in cv.get_splits():
train_X, train_y, test_X, test_y = cv.get_train_test_sets(split)
pre = Preprocessing(train_X, train_y, test_X, test_y)
pre.one_hot_encode_train_and_test_sets(categorical_columns)
train_X, train_y, test_X, test_y = pre.get_train_test_sets()
# train_X = train_X[train_X.columns[:30]]
# Feature selection on train set
# Morda se implementira GroupKfold namesto stratifiedKFold? >>
# >> Tako se bo posamezen pid pojavil ali v test ali v train setu
fs = FeatureSelection(train_X, train_y)
selected_features = fs.select_features(n_min=20, n_max=60, n_not_improve=3)
print(selected_features)
print(len(selected_features))
break
# %%