import os import sys import numpy as np import matplotlib.pyplot as plt import pandas as pd from sklearn.model_selection import LeaveOneGroupOut, StratifiedKFold class CrossValidation(): """This code implements a CrossValidation class for creating cross validation splits. """ def __init__(self, data=None, cv_method='logo'): """This method initializes the cv_method argument and optionally prepares the data if supplied. Args: cv_method (str, optional): String of cross validation method; options are 'logo', 'half_logo' and '5kfold'. Defaults to 'logo'. data (DataFrame, optional): Pandas DataFrame with target, pid columns and other features as columns. Defaults to None. """ self.initialize_cv_method(cv_method) if data is not None: self.prepare_data(data) def prepare_data(self, data): """Prepares the data ready to be passed to the cross-validation algorithm, depending on the cv_method type. For example, if cv_method is set to 'half_logo' new columns 'pid_index', 'pid_count', 'pid_half' are added and used in the process. Args: data (_type_): Pandas DataFrame with target, pid columns and other features as columns. """ self.data = data if self.cv_method == "logo": data_X, data_y, data_groups = data.drop(["target", "pid"], axis=1), data["target"], data["pid"] elif self.cv_method == "half_logo": data['pid_index'] = data.groupby('pid').cumcount() data['pid_count'] = data.groupby('pid')['pid'].transform('count') data["pid_index"] = (data['pid_index'] / data['pid_count'] + 1).round() data["pid_half"] = data["pid"] + "_" + data["pid_index"].astype(int).astype(str) data_X, data_y, data_groups = data.drop(["target", "pid", "pid_index", "pid_half"], axis=1), data["target"], data["pid_half"] elif self.cv_method == "5kfold": data_X, data_y, data_groups = data.drop(["target", "pid"], axis=1), data["target"], data["pid"] self.X, self.y, self.groups = data_X, data_y, data_groups def initialize_cv_method(self, cv_method): """Initializes the given cv_method type. Depending on the type, the appropriate splitting technique is used. Args: cv_method (str): The type of cross-validation method to use; options are 'logo', 'half_logo' and '5kfold'. Raises: ValueError: If cv_method is not in the list of available methods, it raises an ValueError. """ self.cv_method = cv_method if self.cv_method not in ["logo", "half_logo", "5kfold"]: raise ValueError("Invalid cv_method input. Correct values are: 'logo', 'half_logo', '5kfold'") if self.cv_method in ["logo", "half_logo"]: self.cv = LeaveOneGroupOut() elif self.cv_method == "5kfold": self.cv = StratifiedKFold(n_splits=5, shuffle=True) def get_splits(self): """Returns a generator object containing the cross-validation splits. Raises: ValueError: Raises ValueError if no data has been set. """ if not self.data.empty: return self.cv.split(self.X, self.y, self.groups) else: raise ValueError("No data has been set. Use 'prepare_data(data)' method to set the data.") def get_data(self): """data getter Returns: Pandas DataFrame: Returns the data from the class instance. """ return self.data def get_x_y_groups(self): """X, y, and groups data getter Returns: Pandas DataFrame: Returns the data from the class instance. """ return self.X, self.y, self.groups def get_train_test_sets(self, split): """Gets train and test sets, dependent on the split parameter. This method can be used in a specific splitting context, where by index we can get train and test sets. Args: split (tuple of indices): It represents one iteration of the split generator (see get_splits method). Returns: tuple of Pandas DataFrames: This method returns train_X, train_y, test_X, test_y, with correctly indexed rows by split param. """ return self.X.iloc[split[0]], self.y.iloc[split[0]], self.X.iloc[split[1]], self.y.iloc[split[1]]