From 9ed863b7a1df98f94e9ce6f06b1e93b4c4a4a63c Mon Sep 17 00:00:00 2001 From: Primoz Date: Thu, 23 Feb 2023 10:40:17 +0100 Subject: [PATCH] Add a CrossValidation module with all the required methods. --- machine_learning/cross_validation.py | 121 +++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 machine_learning/cross_validation.py diff --git a/machine_learning/cross_validation.py b/machine_learning/cross_validation.py new file mode 100644 index 0000000..e030a8f --- /dev/null +++ b/machine_learning/cross_validation.py @@ -0,0 +1,121 @@ +import os +import sys + +import numpy as np +import matplotlib.pyplot as plt +import pandas as pd + +from sklearn.model_selection import LeaveOneGroupOut, StratifiedKFold + +class CrossValidation(): + """This code implements a CrossValidation class for creating cross validation splits. + """ + + + def __init__(self, data=None, cv_method='logo'): + """This method initializes the cv_method argument and optionally prepares the data if supplied. + + Args: + cv_method (str, optional): String of cross validation method; options are 'logo', 'half_logo' and '5kfold'. + Defaults to 'logo'. + data (DataFrame, optional): Pandas DataFrame with target, pid columns and other features as columns. + Defaults to None. + """ + + self.initialize_cv_method(cv_method) + + if data is not None: + self.prepare_data(data) + + + def prepare_data(self, data): + """Prepares the data ready to be passed to the cross-validation algorithm, depending on the cv_method type. + For example, if cv_method is set to 'half_logo' new columns 'pid_index', 'pid_count', 'pid_half' + are added and used in the process. + + Args: + data (_type_): Pandas DataFrame with target, pid columns and other features as columns. + """ + self.data = data + if self.cv_method == "logo": + data_X, data_y, data_groups = data.drop(["target", "pid"], axis=1), data["target"], data["pid"] + + elif self.cv_method == "half_logo": + data['pid_index'] = data.groupby('pid').cumcount() + data['pid_count'] = data.groupby('pid')['pid'].transform('count') + + data["pid_index"] = (data['pid_index'] / data['pid_count'] + 1).round() + data["pid_half"] = data["pid"] + "_" + data["pid_index"].astype(int).astype(str) + + data_X, data_y, data_groups = data.drop(["target", "pid", "pid_index", "pid_half"], axis=1), data["target"], data["pid_half"] + + elif self.cv_method == "5kfold": + data_X, data_y, data_groups = data.drop(["target", "pid"], axis=1), data["target"], data["pid"] + + self.X, self.y, self.groups = data_X, data_y, data_groups + + + def initialize_cv_method(self, cv_method): + """Initializes the given cv_method type. Depending on the type, the appropriate splitting technique is used. + + Args: + cv_method (str): The type of cross-validation method to use; options are 'logo', 'half_logo' and '5kfold'. + + Raises: + ValueError: If cv_method is not in the list of available methods, it raises an ValueError. + """ + + self.cv_method = cv_method + if self.cv_method not in ["logo", "half_logo", "5kfold"]: + raise ValueError("Invalid cv_method input. Correct values are: 'logo', 'half_logo', '5kfold'") + + if self.cv_method in ["logo", "half_logo"]: + self.cv = LeaveOneGroupOut() + elif self.cv_method == "5kfold": + self.cv = StratifiedKFold(n_splits=5, shuffle=True) + + + def get_splits(self): + """Returns a generator object containing the cross-validation splits. + + Raises: + ValueError: Raises ValueError if no data has been set. + + """ + if not self.data.empty: + return self.cv.split(self.X, self.y, self.groups) + else: + raise ValueError("No data has been set. Use 'prepare_data(data)' method to set the data.") + + + def get_data(self): + """data getter + + Returns: + Pandas DataFrame: Returns the data from the class instance. + """ + return self.data + + + def get_x_y_groups(self): + """X, y, and groups data getter + + Returns: + Pandas DataFrame: Returns the data from the class instance. + """ + return self.X, self.y, self.groups + + + def get_train_test_sets(self, split): + """Gets train and test sets, dependent on the split parameter. This method can be used in a specific splitting context, + where by index we can get train and test sets. + + Args: + split (tuple of indices): It represents one iteration of the split generator (see get_splits method). + + Returns: + tuple of Pandas DataFrames: This method returns train_X, train_y, test_X, test_y, with correctly indexed rows by split param. + """ + return self.X.iloc[split[0]], self.y.iloc[split[0]], self.X.iloc[split[1]], self.y.iloc[split[1]] + + \ No newline at end of file