128 lines
4.7 KiB
Python
128 lines
4.7 KiB
Python
import os
|
|
import sys
|
|
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import pandas as pd
|
|
|
|
from sklearn.model_selection import LeaveOneGroupOut, StratifiedKFold
|
|
|
|
class CrossValidation():
|
|
"""This code implements a CrossValidation class for creating cross validation splits.
|
|
"""
|
|
|
|
|
|
def __init__(self, data=None, cv_method='logo'):
|
|
"""This method initializes the cv_method argument and optionally prepares the data if supplied.
|
|
|
|
Args:
|
|
cv_method (str, optional): String of cross validation method; options are 'logo', 'half_logo' and '5kfold'.
|
|
Defaults to 'logo'.
|
|
data (DataFrame, optional): Pandas DataFrame with target, pid columns and other features as columns.
|
|
Defaults to None.
|
|
"""
|
|
|
|
self.initialize_cv_method(cv_method)
|
|
|
|
if data is not None:
|
|
self.prepare_data(data)
|
|
|
|
|
|
def prepare_data(self, data):
|
|
"""Prepares the data ready to be passed to the cross-validation algorithm, depending on the cv_method type.
|
|
For example, if cv_method is set to 'half_logo' new columns 'pid_index', 'pid_count', 'pid_half'
|
|
are added and used in the process.
|
|
|
|
Args:
|
|
data (_type_): Pandas DataFrame with target, pid columns and other features as columns.
|
|
"""
|
|
self.data = data
|
|
if self.cv_method == "logo":
|
|
data_X, data_y, data_groups = data.drop(["target", "pid"], axis=1), data["target"], data["pid"]
|
|
|
|
elif self.cv_method == "half_logo":
|
|
data['pid_index'] = data.groupby('pid').cumcount()
|
|
data['pid_count'] = data.groupby('pid')['pid'].transform('count')
|
|
|
|
data["pid_index"] = (data['pid_index'] / data['pid_count'] + 1).round()
|
|
data["pid_half"] = data["pid"] + "_" + data["pid_index"].astype(int).astype(str)
|
|
|
|
data_X, data_y, data_groups = data.drop(["target", "pid", "pid_index", "pid_half"], axis=1), data["target"], data["pid_half"]
|
|
|
|
elif self.cv_method == "Stratified5kfold":
|
|
data_X, data_y, data_groups = data.drop(["target", "pid"], axis=1), data["target"], None
|
|
|
|
self.X, self.y, self.groups = data_X, data_y, data_groups
|
|
|
|
|
|
def initialize_cv_method(self, cv_method):
|
|
"""Initializes the given cv_method type. Depending on the type, the appropriate splitting technique is used.
|
|
|
|
Args:
|
|
cv_method (str): The type of cross-validation method to use; options are 'logo', 'half_logo' and '5kfold'.
|
|
|
|
Raises:
|
|
ValueError: If cv_method is not in the list of available methods, it raises an ValueError.
|
|
"""
|
|
|
|
self.cv_method = cv_method
|
|
if self.cv_method not in ["logo", "half_logo", "5kfold"]:
|
|
raise ValueError("Invalid cv_method input. Correct values are: 'logo', 'half_logo', '5kfold'")
|
|
|
|
if self.cv_method in ["logo", "half_logo"]:
|
|
self.cv = LeaveOneGroupOut()
|
|
elif self.cv_method == "Stratified5kfold":
|
|
self.cv = StratifiedKFold(n_splits=5, shuffle=True)
|
|
|
|
|
|
def get_splits(self):
|
|
"""Returns a generator object containing the cross-validation splits.
|
|
|
|
Raises:
|
|
ValueError: Raises ValueError if no data has been set.
|
|
|
|
"""
|
|
if not self.data.empty:
|
|
return self.cv.split(self.X, self.y, self.groups)
|
|
else:
|
|
raise ValueError("No data has been set. Use 'prepare_data(data)' method to set the data.")
|
|
|
|
|
|
def get_data(self):
|
|
"""data getter
|
|
|
|
Returns:
|
|
Pandas DataFrame: Returns the data from the class instance.
|
|
"""
|
|
return self.data
|
|
|
|
|
|
def get_x_y_groups(self):
|
|
"""X, y, and groups data getter
|
|
|
|
Returns:
|
|
Pandas DataFrame: Returns the data from the class instance.
|
|
"""
|
|
return self.X, self.y, self.groups
|
|
|
|
|
|
def get_train_test_sets(self, split):
|
|
"""Gets train and test sets, dependent on the split parameter. This method can be used in a specific splitting context,
|
|
where by index we can get train and test sets.
|
|
|
|
Args:
|
|
split (tuple of indices): It represents one iteration of the split generator (see get_splits method).
|
|
|
|
Returns:
|
|
tuple of Pandas DataFrames: This method returns train_X, train_y, test_X, test_y, with correctly indexed rows by split param.
|
|
"""
|
|
return self.X.iloc[split[0]], self.y.iloc[split[0]], self.X.iloc[split[1]], self.y.iloc[split[1]]
|
|
|
|
def get_groups_sets(self, split):
|
|
|
|
if self.groups is None:
|
|
return None, None
|
|
else:
|
|
return self.groups.iloc[split[0]], self.groups.iloc[split[1]]
|
|
|
|
|