From 8a532fa95a067882b1169e2bfe580d909b6dad92 Mon Sep 17 00:00:00 2001 From: Primoz Date: Thu, 23 Feb 2023 10:41:36 +0100 Subject: [PATCH] Add a ML pipeline script to develop a whole pipeline. --- exploration/ml_pipeline.py | 49 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 exploration/ml_pipeline.py diff --git a/exploration/ml_pipeline.py b/exploration/ml_pipeline.py new file mode 100644 index 0000000..eeaa9b3 --- /dev/null +++ b/exploration/ml_pipeline.py @@ -0,0 +1,49 @@ +# --- +# jupyter: +# jupytext: +# formats: ipynb,py:percent +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.13.0 +# kernelspec: +# display_name: straw2analysis +# language: python +# name: straw2analysis +# --- + +# %% +import sys, os + +import numpy as np +import matplotlib.pyplot as plt +import pandas as pd + +nb_dir = os.path.split(os.getcwd())[0] +if nb_dir not in sys.path: + sys.path.append(nb_dir) + +from machine_learning.cross_validation import CrossValidation +from machine_learning.preprocessing import Preprocessing + +# %% +df = pd.read_csv("../data/stressfulness_event_with_speech/input_appraisal_stressfulness_event_mean.csv") +index_columns = ["local_segment", "local_segment_label", "local_segment_start_datetime", "local_segment_end_datetime"] +df.set_index(index_columns, inplace=True) + +cv = CrossValidation(data=df, cv_method="logo") + +categorical_columns = ["gender", "startlanguage", "mostcommonactivity", "homelabel"] +interval_feature_list, other_feature_list = [], [] + +print(df.columns.tolist()) + +for split in cv.get_splits(): + train_X, train_y, test_X, test_y = cv.get_train_test_sets(split) + pre = Preprocessing(train_X, train_y, test_X, test_y) + pre.one_hot_encode_train_and_test_sets(categorical_columns) + train_X, train_y, test_X, test_y = pre.get_train_test_sets() + break + +# %%