stress_at_work_analysis/machine_learning/cross_validation.py

121 lines
4.5 KiB
Python

import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import LeaveOneGroupOut, StratifiedKFold
class CrossValidation():
"""This code implements a CrossValidation class for creating cross validation splits.
"""
def __init__(self, data=None, cv_method='logo'):
"""This method initializes the cv_method argument and optionally prepares the data if supplied.
Args:
cv_method (str, optional): String of cross validation method; options are 'logo', 'half_logo' and '5kfold'.
Defaults to 'logo'.
data (DataFrame, optional): Pandas DataFrame with target, pid columns and other features as columns.
Defaults to None.
"""
self.initialize_cv_method(cv_method)
if data is not None:
self.prepare_data(data)
def prepare_data(self, data):
"""Prepares the data ready to be passed to the cross-validation algorithm, depending on the cv_method type.
For example, if cv_method is set to 'half_logo' new columns 'pid_index', 'pid_count', 'pid_half'
are added and used in the process.
Args:
data (_type_): Pandas DataFrame with target, pid columns and other features as columns.
"""
self.data = data
if self.cv_method == "logo":
data_X, data_y, data_groups = data.drop(["target", "pid"], axis=1), data["target"], data["pid"]
elif self.cv_method == "half_logo":
data['pid_index'] = data.groupby('pid').cumcount()
data['pid_count'] = data.groupby('pid')['pid'].transform('count')
data["pid_index"] = (data['pid_index'] / data['pid_count'] + 1).round()
data["pid_half"] = data["pid"] + "_" + data["pid_index"].astype(int).astype(str)
data_X, data_y, data_groups = data.drop(["target", "pid", "pid_index", "pid_half"], axis=1), data["target"], data["pid_half"]
elif self.cv_method == "5kfold":
data_X, data_y, data_groups = data.drop(["target", "pid"], axis=1), data["target"], data["pid"]
self.X, self.y, self.groups = data_X, data_y, data_groups
def initialize_cv_method(self, cv_method):
"""Initializes the given cv_method type. Depending on the type, the appropriate splitting technique is used.
Args:
cv_method (str): The type of cross-validation method to use; options are 'logo', 'half_logo' and '5kfold'.
Raises:
ValueError: If cv_method is not in the list of available methods, it raises an ValueError.
"""
self.cv_method = cv_method
if self.cv_method not in ["logo", "half_logo", "5kfold"]:
raise ValueError("Invalid cv_method input. Correct values are: 'logo', 'half_logo', '5kfold'")
if self.cv_method in ["logo", "half_logo"]:
self.cv = LeaveOneGroupOut()
elif self.cv_method == "5kfold":
self.cv = StratifiedKFold(n_splits=5, shuffle=True)
def get_splits(self):
"""Returns a generator object containing the cross-validation splits.
Raises:
ValueError: Raises ValueError if no data has been set.
"""
if not self.data.empty:
return self.cv.split(self.X, self.y, self.groups)
else:
raise ValueError("No data has been set. Use 'prepare_data(data)' method to set the data.")
def get_data(self):
"""data getter
Returns:
Pandas DataFrame: Returns the data from the class instance.
"""
return self.data
def get_x_y_groups(self):
"""X, y, and groups data getter
Returns:
Pandas DataFrame: Returns the data from the class instance.
"""
return self.X, self.y, self.groups
def get_train_test_sets(self, split):
"""Gets train and test sets, dependent on the split parameter. This method can be used in a specific splitting context,
where by index we can get train and test sets.
Args:
split (tuple of indices): It represents one iteration of the split generator (see get_splits method).
Returns:
tuple of Pandas DataFrames: This method returns train_X, train_y, test_X, test_y, with correctly indexed rows by split param.
"""
return self.X.iloc[split[0]], self.y.iloc[split[0]], self.X.iloc[split[1]], self.y.iloc[split[1]]