Add a CrossValidation module with all the required methods.
parent
f69cb25266
commit
9ed863b7a1
|
@ -0,0 +1,121 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from sklearn.model_selection import LeaveOneGroupOut, StratifiedKFold
|
||||||
|
|
||||||
|
class CrossValidation():
|
||||||
|
"""This code implements a CrossValidation class for creating cross validation splits.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(self, data=None, cv_method='logo'):
|
||||||
|
"""This method initializes the cv_method argument and optionally prepares the data if supplied.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cv_method (str, optional): String of cross validation method; options are 'logo', 'half_logo' and '5kfold'.
|
||||||
|
Defaults to 'logo'.
|
||||||
|
data (DataFrame, optional): Pandas DataFrame with target, pid columns and other features as columns.
|
||||||
|
Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.initialize_cv_method(cv_method)
|
||||||
|
|
||||||
|
if data is not None:
|
||||||
|
self.prepare_data(data)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_data(self, data):
|
||||||
|
"""Prepares the data ready to be passed to the cross-validation algorithm, depending on the cv_method type.
|
||||||
|
For example, if cv_method is set to 'half_logo' new columns 'pid_index', 'pid_count', 'pid_half'
|
||||||
|
are added and used in the process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (_type_): Pandas DataFrame with target, pid columns and other features as columns.
|
||||||
|
"""
|
||||||
|
self.data = data
|
||||||
|
if self.cv_method == "logo":
|
||||||
|
data_X, data_y, data_groups = data.drop(["target", "pid"], axis=1), data["target"], data["pid"]
|
||||||
|
|
||||||
|
elif self.cv_method == "half_logo":
|
||||||
|
data['pid_index'] = data.groupby('pid').cumcount()
|
||||||
|
data['pid_count'] = data.groupby('pid')['pid'].transform('count')
|
||||||
|
|
||||||
|
data["pid_index"] = (data['pid_index'] / data['pid_count'] + 1).round()
|
||||||
|
data["pid_half"] = data["pid"] + "_" + data["pid_index"].astype(int).astype(str)
|
||||||
|
|
||||||
|
data_X, data_y, data_groups = data.drop(["target", "pid", "pid_index", "pid_half"], axis=1), data["target"], data["pid_half"]
|
||||||
|
|
||||||
|
elif self.cv_method == "5kfold":
|
||||||
|
data_X, data_y, data_groups = data.drop(["target", "pid"], axis=1), data["target"], data["pid"]
|
||||||
|
|
||||||
|
self.X, self.y, self.groups = data_X, data_y, data_groups
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_cv_method(self, cv_method):
|
||||||
|
"""Initializes the given cv_method type. Depending on the type, the appropriate splitting technique is used.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cv_method (str): The type of cross-validation method to use; options are 'logo', 'half_logo' and '5kfold'.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If cv_method is not in the list of available methods, it raises an ValueError.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.cv_method = cv_method
|
||||||
|
if self.cv_method not in ["logo", "half_logo", "5kfold"]:
|
||||||
|
raise ValueError("Invalid cv_method input. Correct values are: 'logo', 'half_logo', '5kfold'")
|
||||||
|
|
||||||
|
if self.cv_method in ["logo", "half_logo"]:
|
||||||
|
self.cv = LeaveOneGroupOut()
|
||||||
|
elif self.cv_method == "5kfold":
|
||||||
|
self.cv = StratifiedKFold(n_splits=5, shuffle=True)
|
||||||
|
|
||||||
|
|
||||||
|
def get_splits(self):
|
||||||
|
"""Returns a generator object containing the cross-validation splits.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: Raises ValueError if no data has been set.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not self.data.empty:
|
||||||
|
return self.cv.split(self.X, self.y, self.groups)
|
||||||
|
else:
|
||||||
|
raise ValueError("No data has been set. Use 'prepare_data(data)' method to set the data.")
|
||||||
|
|
||||||
|
|
||||||
|
def get_data(self):
|
||||||
|
"""data getter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Pandas DataFrame: Returns the data from the class instance.
|
||||||
|
"""
|
||||||
|
return self.data
|
||||||
|
|
||||||
|
|
||||||
|
def get_x_y_groups(self):
|
||||||
|
"""X, y, and groups data getter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Pandas DataFrame: Returns the data from the class instance.
|
||||||
|
"""
|
||||||
|
return self.X, self.y, self.groups
|
||||||
|
|
||||||
|
|
||||||
|
def get_train_test_sets(self, split):
|
||||||
|
"""Gets train and test sets, dependent on the split parameter. This method can be used in a specific splitting context,
|
||||||
|
where by index we can get train and test sets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
split (tuple of indices): It represents one iteration of the split generator (see get_splits method).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple of Pandas DataFrames: This method returns train_X, train_y, test_X, test_y, with correctly indexed rows by split param.
|
||||||
|
"""
|
||||||
|
return self.X.iloc[split[0]], self.y.iloc[split[0]], self.X.iloc[split[1]], self.y.iloc[split[1]]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue